diff --git a/htlcswitch/failure_detail.go b/htlcswitch/failure_detail.go index 92b25510e..976015b8c 100644 --- a/htlcswitch/failure_detail.go +++ b/htlcswitch/failure_detail.go @@ -29,6 +29,11 @@ const ( // FailureDetailInsufficientBalance is returned when we cannot route a // htlc due to insufficient outgoing capacity. FailureDetailInsufficientBalance + + // FailureDetailCircularRoute is returned when an attempt is made + // to forward a htlc through our node which arrives and leaves on the + // same channel. + FailureDetailCircularRoute ) // String returns the string representation of a failure detail. @@ -52,6 +57,9 @@ func (fd FailureDetail) String() string { case FailureDetailInsufficientBalance: return "insufficient bandwidth to route htlc" + case FailureDetailCircularRoute: + return "same incoming and outgoing channel" + default: return "unknown failure detail" } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index c08bdc0cf..760aadf00 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -167,6 +167,10 @@ type Config struct { // fails in forwarding packages. AckEventTicker ticker.Ticker + // AllowCircularRoute is true if the user has configured their node to + // allow forwards that arrive and depart our node over the same channel. + AllowCircularRoute bool + // RejectHTLC is a flag that instructs the htlcswitch to reject any // HTLCs that are not from the source hop. RejectHTLC bool @@ -986,6 +990,22 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { return s.handleLocalDispatch(packet) } + // Before we attempt to find a non-strict forwarding path for + // this htlc, check whether the htlc is being routed over the + // same incoming and outgoing channel. If our node does not + // allow forwards of this nature, we fail the htlc early. This + // check is in place to disallow inefficiently routed htlcs from + // locking up our balance. + linkErr := checkCircularForward( + packet.incomingChanID, packet.outgoingChanID, + s.cfg.AllowCircularRoute, htlc.PaymentHash, + ) + if linkErr != nil { + return s.failAddPacket( + packet, linkErr.WireMessage(), linkErr, + ) + } + s.indexMtx.RLock() targetLink, err := s.getLinkByShortID(packet.outgoingChanID) if err != nil { @@ -1170,6 +1190,37 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { } } +// checkCircularForward checks whether a forward is circular (arrives and +// departs on the same link) and returns a link error if the switch is +// configured to disallow this behaviour. +func checkCircularForward(incoming, outgoing lnwire.ShortChannelID, + allowCircular bool, paymentHash lntypes.Hash) *LinkError { + + // If the route is not circular we do not need to perform any further + // checks. + if incoming != outgoing { + return nil + } + + // If the incoming and outgoing link are equal, the htlc is part of a + // circular route which may be used to lock up our liquidity. If the + // switch is configured to allow circular routes, log that we are + // allowing the route then return nil. + if allowCircular { + log.Debugf("allowing circular route over link: %v "+ + "(payment hash: %x)", incoming, paymentHash) + return nil + } + + // If our node disallows circular routes, return a temporary channel + // failure. There is nothing wrong with the policy used by the remote + // node, so we do not include a channel update. + return NewDetailedLinkError( + lnwire.NewTemporaryChannelFailure(nil), + FailureDetailCircularRoute, + ) +} + // failAddPacket encrypts a fail packet back to an add packet's source. // The ciphertext will be derived from the failure message proivded by context. // This method returns the failErr if all other steps complete successfully. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 1497f0098..828ef232c 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "reflect" "testing" "time" @@ -1324,6 +1325,171 @@ type multiHopFwdTest struct { expectedReply lnwire.FailCode } +// TestCircularForwards tests the allowing/disallowing of circular payments +// through the same channel in the case where the switch is configured to allow +// and disallow same channel circular forwards. +func TestCircularForwards(t *testing.T) { + chanID1, aliceChanID := genID() + preimage := [sha256.Size]byte{1} + hash := fastsha256.Sum256(preimage[:]) + + tests := []struct { + name string + allowCircularPayment bool + expectedErr error + }{ + { + name: "circular payment allowed", + allowCircularPayment: true, + expectedErr: nil, + }, + { + name: "circular payment disallowed", + allowCircularPayment: false, + expectedErr: NewDetailedLinkError( + lnwire.NewTemporaryChannelFailure(nil), + FailureDetailCircularRoute, + ), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, + testDefaultDelta, + ) + if err != nil { + t.Fatalf("unable to create alice server: %v", + err) + } + + s, err := initSwitchWithDB(testStartingHeight, nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer func() { _ = s.Stop() }() + + // Set the switch to allow or disallow circular routes + // according to the test's requirements. + s.cfg.AllowCircularRoute = test.allowCircularPayment + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + + // Create a new packet that loops through alice's link + // in a circle. + obfuscator := NewMockObfuscator() + packet := &htlcPacket{ + incomingChanID: aliceChannelLink.ShortChanID(), + outgoingChanID: aliceChannelLink.ShortChanID(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: hash, + Amount: 1, + }, + obfuscator: obfuscator, + } + + // Attempt to forward the packet and check for the expected + // error. + err = s.forward(packet) + if !reflect.DeepEqual(err, test.expectedErr) { + t.Fatalf("expected: %v, got: %v", + test.expectedErr, err) + } + + // Ensure that no circuits were opened. + if s.circuits.NumOpen() > 0 { + t.Fatal("do not expect any open circuits") + } + }) + } +} + +// TestCheckCircularForward tests the error returned by checkCircularForward +// in cases where we allow and disallow same channel circular forwards. +func TestCheckCircularForward(t *testing.T) { + tests := []struct { + name string + + // allowCircular determines whether we should allow circular + // forwards. + allowCircular bool + + // incomingLink is the link that the htlc arrived on. + incomingLink lnwire.ShortChannelID + + // outgoingLink is the link that the htlc forward + // is destined to leave on. + outgoingLink lnwire.ShortChannelID + + // expectedErr is the error we expect to be returned. + expectedErr *LinkError + }{ + { + name: "not circular, allowed in config", + allowCircular: true, + incomingLink: lnwire.NewShortChanIDFromInt(123), + outgoingLink: lnwire.NewShortChanIDFromInt(321), + expectedErr: nil, + }, + { + name: "not circular, not allowed in config", + allowCircular: false, + incomingLink: lnwire.NewShortChanIDFromInt(123), + outgoingLink: lnwire.NewShortChanIDFromInt(321), + expectedErr: nil, + }, + { + name: "circular, allowed in config", + allowCircular: true, + incomingLink: lnwire.NewShortChanIDFromInt(123), + outgoingLink: lnwire.NewShortChanIDFromInt(123), + expectedErr: nil, + }, + { + name: "circular, not allowed in config", + allowCircular: false, + incomingLink: lnwire.NewShortChanIDFromInt(123), + outgoingLink: lnwire.NewShortChanIDFromInt(123), + expectedErr: NewDetailedLinkError( + lnwire.NewTemporaryChannelFailure(nil), + FailureDetailCircularRoute, + ), + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // Check for a circular forward, the hash passed can + // be nil because it is only used for logging. + err := checkCircularForward( + test.incomingLink, test.outgoingLink, + test.allowCircular, lntypes.Hash{}, + ) + if !reflect.DeepEqual(err, test.expectedErr) { + t.Fatalf("expected: %v, got: %v", + test.expectedErr, err) + } + }) + } +} + // TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes // along, then we won't attempt to froward it down al ink that isn't yet able // to forward any HTLC's.