htlcswitch: add dedicated block subscription to interceptable switch

Preparation for making the interceptable switch aware of expiring htlcs.
This commit is contained in:
Joost Jager
2022-08-15 15:28:23 +02:00
parent 4a3e90f4d0
commit a6df9567ba
4 changed files with 92 additions and 5 deletions

View File

@@ -6,6 +6,7 @@ import (
"sync" "sync"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
@@ -20,6 +21,8 @@ var (
// ErrUnsupportedFailureCode when processing of an unsupported failure // ErrUnsupportedFailureCode when processing of an unsupported failure
// code is attempted. // code is attempted.
ErrUnsupportedFailureCode = errors.New("unsupported failure code") ErrUnsupportedFailureCode = errors.New("unsupported failure code")
errBlockStreamStopped = errors.New("block epoch stream stopped")
) )
// InterceptableSwitch is an implementation of ForwardingSwitch interface. // InterceptableSwitch is an implementation of ForwardingSwitch interface.
@@ -61,6 +64,18 @@ type InterceptableSwitch struct {
// htlc where we no longer intercept it and instead cancel it back. // htlc where we no longer intercept it and instead cancel it back.
cltvRejectDelta uint32 cltvRejectDelta uint32
// notifier is an instance of a chain notifier that we'll use to signal
// the switch when a new block has arrived.
notifier chainntnfs.ChainNotifier
// blockEpochStream is an active block epoch event stream backed by an
// active ChainNotifier instance. This will be used to retrieve the
// latest height of the chain.
blockEpochStream *chainntnfs.BlockEpochEvent
// currentHeight is the currently best known height.
currentHeight int32
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
} }
@@ -117,6 +132,10 @@ type InterceptableSwitchConfig struct {
// packets get sent to on resume. // packets get sent to on resume.
Switch *Switch Switch *Switch
// Notifier is an instance of a chain notifier that we'll use to signal
// the switch when a new block has arrived.
Notifier chainntnfs.ChainNotifier
// CltvRejectDelta defines the number of blocks before the expiry of the // CltvRejectDelta defines the number of blocks before the expiry of the
// htlc where we no longer intercept it and instead cancel it back. // htlc where we no longer intercept it and instead cancel it back.
CltvRejectDelta uint32 CltvRejectDelta uint32
@@ -137,6 +156,7 @@ func NewInterceptableSwitch(cfg *InterceptableSwitchConfig) *InterceptableSwitch
resolutionChan: make(chan *fwdResolution), resolutionChan: make(chan *fwdResolution),
requireInterceptor: cfg.RequireInterceptor, requireInterceptor: cfg.RequireInterceptor,
cltvRejectDelta: cfg.CltvRejectDelta, cltvRejectDelta: cfg.CltvRejectDelta,
notifier: cfg.Notifier,
quit: make(chan struct{}), quit: make(chan struct{}),
} }
@@ -157,11 +177,20 @@ func (s *InterceptableSwitch) SetInterceptor(
} }
func (s *InterceptableSwitch) Start() error { func (s *InterceptableSwitch) Start() error {
blockEpochStream, err := s.notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return err
}
s.blockEpochStream = blockEpochStream
s.wg.Add(1) s.wg.Add(1)
go func() { go func() {
defer s.wg.Done() defer s.wg.Done()
s.run() err := s.run()
if err != nil {
log.Errorf("InterceptableSwitch stopped: %v", err)
}
}() }()
return nil return nil
@@ -171,10 +200,28 @@ func (s *InterceptableSwitch) Stop() error {
close(s.quit) close(s.quit)
s.wg.Wait() s.wg.Wait()
s.blockEpochStream.Cancel()
return nil return nil
} }
func (s *InterceptableSwitch) run() { func (s *InterceptableSwitch) run() error {
// The block epoch stream will immediately stream the current height.
// Read it out here.
select {
case currentBlock, ok := <-s.blockEpochStream.Epochs:
if !ok {
return errBlockStreamStopped
}
s.currentHeight = currentBlock.Height
case <-s.quit:
return nil
}
log.Debugf("InterceptableSwitch running: height=%v, "+
"requireInterceptor=%v", s.currentHeight, s.requireInterceptor)
for { for {
select { select {
// An interceptor registration or de-registration came in. // An interceptor registration or de-registration came in.
@@ -210,8 +257,15 @@ func (s *InterceptableSwitch) run() {
case res := <-s.resolutionChan: case res := <-s.resolutionChan:
res.errChan <- s.resolve(res.resolution) res.errChan <- s.resolve(res.resolution)
case currentBlock, ok := <-s.blockEpochStream.Epochs:
if !ok {
return errBlockStreamStopped
}
s.currentHeight = currentBlock.Height
case <-s.quit: case <-s.quit:
return return nil
} }
} }
} }
@@ -448,7 +502,7 @@ func (s *InterceptableSwitch) forward(
func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) ( func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) (
bool, error) { bool, error) {
height := s.htlcSwitch.BestHeight() height := uint32(s.currentHeight)
if fwd.packet.incomingTimeout >= height+s.cltvRejectDelta { if fwd.packet.incomingTimeout >= height+s.cltvRejectDelta {
return false, nil return false, nil
} }

View File

@@ -14,10 +14,12 @@ import (
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/ticker"
@@ -3750,6 +3752,8 @@ func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
func assertOutgoingLinkReceiveIntercepted(t *testing.T, func assertOutgoingLinkReceiveIntercepted(t *testing.T,
targetLink *mockChannelLink) { targetLink *mockChannelLink) {
t.Helper()
select { select {
case <-targetLink.packets: case <-targetLink.packets:
case <-time.After(time.Second): case <-time.After(time.Second):
@@ -3845,10 +3849,17 @@ func TestSwitchHoldForward(t *testing.T) {
t: t, t: t,
interceptedChan: make(chan InterceptedPacket), interceptedChan: make(chan InterceptedPacket),
} }
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
switchForwardInterceptor := NewInterceptableSwitch( switchForwardInterceptor := NewInterceptableSwitch(
&InterceptableSwitchConfig{ &InterceptableSwitchConfig{
Switch: s, Switch: s,
CltvRejectDelta: cltvRejectDelta, CltvRejectDelta: cltvRejectDelta,
Notifier: notifier,
}, },
) )
require.NoError(t, switchForwardInterceptor.Start()) require.NoError(t, switchForwardInterceptor.Start())
@@ -4040,11 +4051,17 @@ func TestSwitchHoldForward(t *testing.T) {
require.NoError(t, switchForwardInterceptor.Stop()) require.NoError(t, switchForwardInterceptor.Stop())
// Test always-on interception. // Test always-on interception.
notifier = &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
switchForwardInterceptor = NewInterceptableSwitch( switchForwardInterceptor = NewInterceptableSwitch(
&InterceptableSwitchConfig{ &InterceptableSwitchConfig{
Switch: s, Switch: s,
CltvRejectDelta: cltvRejectDelta, CltvRejectDelta: cltvRejectDelta,
RequireInterceptor: true, RequireInterceptor: true,
Notifier: notifier,
}, },
) )
require.NoError(t, switchForwardInterceptor.Start()) require.NoError(t, switchForwardInterceptor.Start())
@@ -5338,9 +5355,16 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) {
t: t, t: t,
interceptedChan: make(chan InterceptedPacket), interceptedChan: make(chan InterceptedPacket),
} }
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
interceptSwitch := NewInterceptableSwitch( interceptSwitch := NewInterceptableSwitch(
&InterceptableSwitchConfig{ &InterceptableSwitchConfig{
Switch: s, Switch: s,
Notifier: notifier,
}, },
) )
require.NoError(t, interceptSwitch.Start()) require.NoError(t, interceptSwitch.Start())

View File

@@ -359,6 +359,13 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier,
ChainNet: wire.SimNet, ChainNet: wire.SimNet,
} }
interceptableSwitchNotifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
interceptableSwitchNotifier.EpochChan <- &chainntnfs.BlockEpoch{
Height: 1,
}
cfg := &Config{ cfg := &Config{
Addr: cfgAddr, Addr: cfgAddr,
PubKeyBytes: pubKey, PubKeyBytes: pubKey,
@@ -369,6 +376,7 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier,
InterceptSwitch: htlcswitch.NewInterceptableSwitch( InterceptSwitch: htlcswitch.NewInterceptableSwitch(
&htlcswitch.InterceptableSwitchConfig{ &htlcswitch.InterceptableSwitchConfig{
CltvRejectDelta: testCltvRejectDelta, CltvRejectDelta: testCltvRejectDelta,
Notifier: interceptableSwitchNotifier,
}, },
), ),
ChannelDB: dbAlice.ChannelStateDB(), ChannelDB: dbAlice.ChannelStateDB(),

View File

@@ -671,6 +671,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
Switch: s.htlcSwitch, Switch: s.htlcSwitch,
CltvRejectDelta: lncfg.DefaultFinalCltvRejectDelta, CltvRejectDelta: lncfg.DefaultFinalCltvRejectDelta,
RequireInterceptor: s.cfg.RequireInterceptor, RequireInterceptor: s.cfg.RequireInterceptor,
Notifier: s.cc.ChainNotifier,
}, },
) )