diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index a4d8f1bb8..f1d7dc479 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -368,7 +368,10 @@ func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire.OpaqueReason { return reason +} +func (o *mockObfuscator) EncryptMalformedError(reason lnwire.OpaqueReason) lnwire.OpaqueReason { + return reason } // mockDeobfuscator mock implementation of the failure deobfuscator which @@ -400,6 +403,8 @@ type mockIteratorDecoder struct { mu sync.RWMutex responses map[[32]byte][]DecodeHopIteratorResponse + + decodeFail bool } func newMockIteratorDecoder() *mockIteratorDecoder { @@ -451,6 +456,10 @@ func (p *mockIteratorDecoder) DecodeHopIterators(id []byte, req.OnionReader, req.RHash, req.IncomingCltv, ) + if p.decodeFail { + failcode = lnwire.CodeTemporaryChannelFailure + } + resp := DecodeHopIteratorResponse{ HopIterator: iterator, FailCode: failcode, diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 59b0e4edc..0aa435b06 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -2033,3 +2033,76 @@ func TestMultiHopPaymentForwardingEvents(t *testing.T) { } } } + +// TestUpdateFailMalformedHTLCErrorConversion tests that we're able to properly +// convert malformed HTLC errors that originate at the direct link, as well as +// during multi-hop HTLC forwarding. +func TestUpdateFailMalformedHTLCErrorConversion(t *testing.T) { + t.Parallel() + + // First, we'll create our traditional three hop network. + channels, cleanUp, _, err := createClusterChannels( + btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5, + ) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + n := newThreeHopNetwork( + t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, testStartingHeight, + ) + if err := n.start(); err != nil { + t.Fatalf("unable to start three hop network: %v", err) + } + + assertPaymentFailure := func(t *testing.T) { + // With the decoder modified, we'll now attempt to send a + // payment from Alice to carol. + finalAmt := lnwire.NewMSatFromSatoshis(100000) + htlcAmt, totalTimelock, hops := generateHops( + finalAmt, testStartingHeight, n.firstBobChannelLink, + n.carolChannelLink, + ) + firstHop := n.firstBobChannelLink.ShortChanID() + _, err = makePayment( + n.aliceServer, n.carolServer, firstHop, hops, finalAmt, + htlcAmt, totalTimelock, + ).Wait(30 * time.Second) + + // The payment should fail as Carol is unable to decode the + // onion blob sent to her. + if err == nil { + t.Fatalf("unable to send payment: %v", err) + } + + fwdingErr := err.(*ForwardingError) + failureMsg := fwdingErr.FailureMessage + if _, ok := failureMsg.(*lnwire.FailTemporaryChannelFailure); !ok { + t.Fatalf("expected temp chan failure instead got: %v", + fwdingErr.FailureMessage) + } + } + + t.Run("multi-hop error conversion", func(t *testing.T) { + // Now that we have our network up, we'll modify the hop + // iterator for the Bob <-> Carol channel to fail to decode in + // order to simulate either a replay attack or an issue + // decoding the onion. + n.carolOnionDecoder.decodeFail = true + + assertPaymentFailure(t) + }) + + t.Run("direct channel error conversion", func(t *testing.T) { + // Similar to the above test case, we'll now make the Alice <-> + // Bob link always fail to decode an onion. This differs from + // the above test case in that there's no encryption on the + // error at all since Alice will directly receive a + // UpdateFailMalformedHTLC message. + n.bobOnionDecoder.decodeFail = true + + assertPaymentFailure(t) + }) +} diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 89762d755..2ebc1e388 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -589,15 +589,18 @@ func generateRoute(hops ...ForwardingInfo) ([lnwire.OnionPacketSize]byte, error) // threeHopNetwork is used for managing the created cluster of 3 hops. type threeHopNetwork struct { - aliceServer *mockServer - aliceChannelLink *channelLink + aliceServer *mockServer + aliceChannelLink *channelLink + aliceOnionDecoder *mockIteratorDecoder bobServer *mockServer firstBobChannelLink *channelLink secondBobChannelLink *channelLink + bobOnionDecoder *mockIteratorDecoder - carolServer *mockServer - carolChannelLink *channelLink + carolServer *mockServer + carolChannelLink *channelLink + carolOnionDecoder *mockIteratorDecoder hopNetwork } @@ -948,15 +951,18 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, } return &threeHopNetwork{ - aliceServer: aliceServer, - aliceChannelLink: aliceChannelLink.(*channelLink), + aliceServer: aliceServer, + aliceChannelLink: aliceChannelLink.(*channelLink), + aliceOnionDecoder: aliceDecoder, bobServer: bobServer, firstBobChannelLink: firstBobChannelLink.(*channelLink), secondBobChannelLink: secondBobChannelLink.(*channelLink), + bobOnionDecoder: bobDecoder, - carolServer: carolServer, - carolChannelLink: carolChannelLink.(*channelLink), + carolServer: carolServer, + carolChannelLink: carolChannelLink.(*channelLink), + carolOnionDecoder: carolDecoder, hopNetwork: *hopNetwork, }