From 13dff2fb3e29899ab381dd8d6c8e6ced980b4ac2 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Fri, 28 Jan 2022 12:56:17 +0100 Subject: [PATCH] htlcswitch: interceptor expiry check --- docs/release-notes/release-notes-0.15.0.md | 3 + htlcswitch/interceptable_switch.go | 64 +++++++++++++++++- htlcswitch/mock.go | 2 +- htlcswitch/switch_test.go | 79 +++++++++++++++++++--- peer/test_utils.go | 6 +- server.go | 3 +- 6 files changed, 143 insertions(+), 14 deletions(-) diff --git a/docs/release-notes/release-notes-0.15.0.md b/docs/release-notes/release-notes-0.15.0.md index 56eb70ef9..ed485349b 100644 --- a/docs/release-notes/release-notes-0.15.0.md +++ b/docs/release-notes/release-notes-0.15.0.md @@ -180,6 +180,9 @@ then watch it on chain. Taproot script spends are also supported through the * [Add new Peers subserver](https://github.com/lightningnetwork/lnd/pull/5587) with a new endpoint for updating the `NodeAnnouncement` data without having to restart the node. +* Add [htlc expiry protection](https://github.com/lightningnetwork/lnd/pull/6212) +to the htlc interceptor API. + ## Documentation * Improved instructions on [how to build lnd for mobile](https://github.com/lightningnetwork/lnd/pull/6085). diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 87ef1e61d..55345bdc3 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -55,6 +55,10 @@ type InterceptableSwitch struct { // holdForwards keeps track of outstanding intercepted forwards. holdForwards map[channeldb.CircuitKey]InterceptedForward + // cltvRejectDelta defines the number of blocks before the expiry of the + // htlc where we no longer intercept it and instead cancel it back. + cltvRejectDelta uint32 + wg sync.WaitGroup quit chan struct{} } @@ -106,7 +110,7 @@ type fwdResolution struct { } // NewInterceptableSwitch returns an instance of InterceptableSwitch. -func NewInterceptableSwitch(s *Switch, +func NewInterceptableSwitch(s *Switch, cltvRejectDelta uint32, requireInterceptor bool) *InterceptableSwitch { return &InterceptableSwitch{ @@ -116,6 +120,7 @@ func NewInterceptableSwitch(s *Switch, holdForwards: make(map[channeldb.CircuitKey]InterceptedForward), resolutionChan: make(chan *fwdResolution), requireInterceptor: requireInterceptor, + cltvRejectDelta: cltvRejectDelta, quit: make(chan struct{}), } @@ -337,6 +342,27 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, htlcSwitch: s.htlcSwitch, } + // Handle forwards that are too close to expiry. + handled, err := s.handleExpired(intercepted) + if err != nil { + log.Errorf("Error handling intercepted htlc "+ + "that expires too soon: circuit=%v, "+ + "incoming_timeout=%v, err=%v", + packet.inKey(), packet.incomingTimeout, err) + + // Return false so that the packet is offered as normal + // to the switch. This isn't ideal because interception + // may be configured as always-on and is skipped now. + // Returning true isn't great either, because the htlc + // will remain stuck and potentially force-close the + // channel. But in the end, we should never get here, so + // the actual return value doesn't matter that much. + return false + } + if handled { + return true + } + if s.interceptor == nil && !isReplay { // There is no interceptor registered, we are in // interceptor-required mode, and this is a new packet @@ -370,6 +396,32 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, } } +// handleExpired checks that the htlc isn't too close to the channel +// force-close broadcast height. If it is, it is cancelled back. +func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) ( + bool, error) { + + height := s.htlcSwitch.BestHeight() + if fwd.packet.incomingTimeout >= height+s.cltvRejectDelta { + return false, nil + } + + log.Debugf("Interception rejected because htlc "+ + "expires too soon: circuit=%v, "+ + "height=%v, incoming_timeout=%v", + fwd.packet.inKey(), height, + fwd.packet.incomingTimeout) + + err := fwd.FailWithCode( + lnwire.CodeExpiryTooSoon, + ) + if err != nil { + return false, err + } + + return true, nil +} + // interceptedForward implements the InterceptedForward interface. // It is passed from the switch to external interceptors that are interested // in holding forwards and resolve them manually. @@ -450,6 +502,16 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error { failureMsg = lnwire.NewTemporaryChannelFailure(update) + case lnwire.CodeExpiryTooSoon: + update, err := f.htlcSwitch.cfg.FetchLastChannelUpdate( + f.packet.incomingChanID, + ) + if err != nil { + return err + } + + failureMsg = lnwire.NewExpiryTooSoon(*update) + default: return ErrUnsupportedFailureCode } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 6af268d56..692a6ff5d 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -185,7 +185,7 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) events: make(map[time.Time]channeldb.ForwardingEvent), }, FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { - return nil, nil + return &lnwire.ChannelUpdate{}, nil }, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 9c10cca44..0b0e1ea67 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" + "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -3167,10 +3168,12 @@ func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket { func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) { if s.circuits.NumPending() != pending { - t.Fatal("wrong amount of half circuits") + t.Fatalf("wrong amount of half circuits, expected %v but "+ + "got %v", pending, s.circuits.NumPending()) } if s.circuits.NumOpen() != opened { - t.Fatal("wrong amount of circuits") + t.Fatalf("wrong amount of circuits, expected %v but got %v", + opened, s.circuits.NumOpen()) } } @@ -3197,6 +3200,16 @@ func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink, return nil } +func assertOutgoingLinkReceiveIntercepted(t *testing.T, + targetLink *mockChannelLink) { + + select { + case <-targetLink.packets: + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } +} + func TestSwitchHoldForward(t *testing.T) { t.Parallel() @@ -3259,14 +3272,17 @@ func TestSwitchHoldForward(t *testing.T) { onionBlob := [1366]byte{4, 5, 6} incomingHtlcID := uint64(0) + const cltvRejectDelta = 13 + createTestPacket := func() *htlcPacket { incomingHtlcID++ return &htlcPacket{ - incomingChanID: aliceChannelLink.ShortChanID(), - incomingHTLCID: incomingHtlcID, - outgoingChanID: bobChannelLink.ShortChanID(), - obfuscator: NewMockObfuscator(), + incomingChanID: aliceChannelLink.ShortChanID(), + incomingHTLCID: incomingHtlcID, + incomingTimeout: testStartingHeight + cltvRejectDelta + 1, + outgoingChanID: bobChannelLink.ShortChanID(), + obfuscator: NewMockObfuscator(), htlc: &lnwire.UpdateAddHTLC{ PaymentHash: rhash, Amount: 1, @@ -3290,12 +3306,55 @@ func TestSwitchHoldForward(t *testing.T) { t: t, interceptedChan: make(chan InterceptedPacket), } - switchForwardInterceptor := NewInterceptableSwitch(s, false) + switchForwardInterceptor := NewInterceptableSwitch( + s, cltvRejectDelta, false, + ) require.NoError(t, switchForwardInterceptor.Start()) switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc) linkQuit := make(chan struct{}) + // Test a forward that expires too soon. + packet := createTestPacket() + packet.incomingTimeout = testStartingHeight + cltvRejectDelta - 1 + + err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet) + if err != nil { + t.Fatalf("can't forward htlc packet: %v", err) + } + assertOutgoingLinkReceive(t, bobChannelLink, false) + assertOutgoingLinkReceiveIntercepted(t, aliceChannelLink) + assertNumCircuits(t, s, 0, 0) + + // Test a forward that expires too soon and can't be failed. + packet = createTestPacket() + packet.incomingTimeout = testStartingHeight + cltvRejectDelta - 1 + + // Simulate an error during the composition of the failure message. + currentCallback := s.cfg.FetchLastChannelUpdate + s.cfg.FetchLastChannelUpdate = func( + lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { + + return nil, errors.New("cannot fetch update") + } + + err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet) + if err != nil { + t.Fatalf("can't forward htlc packet: %v", err) + } + receivedPkt := assertOutgoingLinkReceive(t, bobChannelLink, true) + assertNumCircuits(t, s, 1, 1) + + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, + createSettlePacket(receivedPkt.outgoingHTLCID), + )) + + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) + + s.cfg.FetchLastChannelUpdate = currentCallback + // Test resume a hold forward. assertNumCircuits(t, s, 0, 0) err = switchForwardInterceptor.ForwardPackets( @@ -3310,7 +3369,7 @@ func TestSwitchHoldForward(t *testing.T) { Action: FwdActionResume, Key: forwardInterceptor.getIntercepted().IncomingCircuit, })) - receivedPkt := assertOutgoingLinkReceive(t, bobChannelLink, true) + receivedPkt = assertOutgoingLinkReceive(t, bobChannelLink, true) assertNumCircuits(t, s, 1, 1) // settling the htlc to close the circuit. @@ -3387,7 +3446,7 @@ func TestSwitchHoldForward(t *testing.T) { })) assertOutgoingLinkReceive(t, bobChannelLink, false) - packet := assertOutgoingLinkReceive(t, aliceChannelLink, true) + packet = assertOutgoingLinkReceive(t, aliceChannelLink, true) require.Equal(t, reason, packet.htlc.(*lnwire.UpdateFailHTLC).Reason) @@ -3443,7 +3502,7 @@ func TestSwitchHoldForward(t *testing.T) { require.NoError(t, switchForwardInterceptor.Stop()) // Test always-on interception. - switchForwardInterceptor = NewInterceptableSwitch(s, true) + switchForwardInterceptor = NewInterceptableSwitch(s, cltvRejectDelta, true) require.NoError(t, switchForwardInterceptor.Start()) // Forward a fresh packet. It is expected to be failed immediately, diff --git a/peer/test_utils.go b/peer/test_utils.go index 40f329563..1a9e6f61e 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -38,6 +38,10 @@ const ( // timeout is a timeout value to use for tests which need to wait for // a return value on a channel. timeout = time.Second * 5 + + // testCltvRejectDelta is the minimum delta between expiry and current + // height below which htlcs are rejected. + testCltvRejectDelta = 13 ) var ( @@ -368,7 +372,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, ChanActiveTimeout: chanActiveTimeout, InterceptSwitch: htlcswitch.NewInterceptableSwitch( - nil, false, + nil, testCltvRejectDelta, false, ), ChannelDB: dbAlice.ChannelStateDB(), diff --git a/server.go b/server.go index c4110bcc1..1254c62d9 100644 --- a/server.go +++ b/server.go @@ -655,7 +655,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } s.interceptableSwitch = htlcswitch.NewInterceptableSwitch( - s.htlcSwitch, s.cfg.RequireInterceptor, + s.htlcSwitch, lncfg.DefaultFinalCltvRejectDelta, + s.cfg.RequireInterceptor, ) chanStatusMgrCfg := &netann.ChanStatusConfig{