diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index e744c75fc..0857b650c 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -12,6 +12,7 @@ import ( "runtime" "sync" "testing" + "testing/quick" "time" "github.com/btcsuite/btcd/btcec/v2" @@ -19,7 +20,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" - "github.com/go-errors/errors" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" @@ -125,7 +125,7 @@ func createInterceptorFunc(prefix, receiver string, messages []expectedMessage, if messageChanID == chanID { if len(expectToReceive) == 0 { - return false, errors.Errorf("%v received "+ + return false, fmt.Errorf("%v received "+ "unexpected message out of range: %v", receiver, m.MsgType()) } @@ -134,9 +134,13 @@ func createInterceptorFunc(prefix, receiver string, messages []expectedMessage, expectToReceive = expectToReceive[1:] if expectedMessage.message.MsgType() != m.MsgType() { - return false, errors.Errorf("%v received wrong message: \n"+ - "real: %v\nexpected: %v", receiver, m.MsgType(), - expectedMessage.message.MsgType()) + return false, fmt.Errorf( + "%v received wrong message: \n"+ + "real: %v\nexpected: %v", + receiver, + m.MsgType(), + expectedMessage.message.MsgType(), + ) } if debug { @@ -721,11 +725,10 @@ func TestChannelLinkCancelFullCommitment(t *testing.T) { } } -// TestExitNodeTimelockPayloadMismatch tests that when an exit node receives an -// incoming HTLC, if the time lock encoded in the payload of the forwarded HTLC -// doesn't match the expected payment value, then the HTLC will be rejected -// with the appropriate error. -func TestExitNodeTimelockPayloadMismatch(t *testing.T) { +// TestExitNodeHLTCTimelockExceedsPayload tests that when an exit node receives +// an incoming HTLC, if the timelock of the incoming HTLC is greater than or +// equal to the timelock encoded in the payload, then the HTLC will be accepted. +func TestExitNodeHTLCTimelockExceedsPayload(t *testing.T) { t.Parallel() channels, _, err := createClusterChannels( @@ -733,35 +736,75 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) { ) require.NoError(t, err, "unable to create channel") - n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, - channels.bobToCarol, channels.carolToBob, testStartingHeight) - if err := n.start(); err != nil { - t.Fatal(err) - } + n := newThreeHopNetwork( + t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, testStartingHeight, + ) + require.NoError(t, n.start()) t.Cleanup(n.stop) const amount = btcutil.SatoshiPerBitcoin - htlcAmt, htlcExpiry, hops := generateHops(amount, - testStartingHeight, n.firstBobChannelLink) + htlcAmt, htlcExpiry, hops := generateHops( + amount, testStartingHeight, n.firstBobChannelLink, + ) // In order to exercise this case, we'll now _manually_ modify the - // per-hop payload for outgoing time lock to be the incorrect value. + // per-hop payload for outgoing time lock to be a compatible value that + // differs from the specified expiry. // The proper value of the outgoing CLTV should be the policy set by - // the receiving node, instead we set it to be a random value. - hops[0].FwdInfo.OutgoingCTLV = 500 + // the receiving node, instead we set it to be a value less than the + // incoming HTLC timelock. + hops[0].FwdInfo.OutgoingCTLV = htlcExpiry - 1 firstHop := n.firstBobChannelLink.ShortChanID() _, err = makePayment( n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt, htlcExpiry, ).Wait(30 * time.Second) - if err == nil { - t.Fatalf("payment should have failed but didn't") - } + require.NoError(t, err, "payment should have succeeded but didn't") +} - rtErr, ok := err.(ClearTextError) - if !ok { - t.Fatalf("expected a ClearTextError, instead got: %T", err) - } +// TestExitNodeTimelockPayloadExceedsHTLC tests that when an exit node receives +// an incoming HTLC, if the timelock encoded in the payload of the forwarded +// HTLC exceeds the timelock on the incoming HTLC, then the HTLC will be +// rejected with the appropriate error. +func TestExitNodeTimelockPayloadExceedsHTLC(t *testing.T) { + t.Parallel() + + channels, _, err := createClusterChannels( + t, btcutil.SatoshiPerBitcoin*5, btcutil.SatoshiPerBitcoin*5, + ) + require.NoError(t, err, "unable to create channel") + + n := newThreeHopNetwork( + t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, testStartingHeight, + ) + require.NoError(t, n.start()) + t.Cleanup(n.stop) + + const amount = btcutil.SatoshiPerBitcoin + htlcAmt, htlcExpiry, hops := generateHops( + amount, testStartingHeight, n.firstBobChannelLink, + ) + + // In order to exercise this case, we'll now _manually_ modify the + // per-hop payload for outgoing time lock to be the incorrect value. + // The proper value of the outgoing CLTV should be the policy set by + // the receiving node, instead we set it to be a value greater than the + // incoming HTLC timelock. + hops[0].FwdInfo.OutgoingCTLV = htlcExpiry + 1 + firstHop := n.firstBobChannelLink.ShortChanID() + _, err = makePayment( + n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt, + htlcExpiry, + ).Wait(30 * time.Second) + require.NotNil(t, err, "payment should have failed but didn't") + + rtErr := &ForwardingError{} + require.ErrorAs( + t, err, &rtErr, "expected a ClearTextError, instead got: %T", + err, + ) switch rtErr.WireMessage().(type) { case *lnwire.FailFinalIncorrectCltvExpiry: @@ -771,43 +814,95 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) { } } -// TestExitNodeAmountPayloadMismatch tests that when an exit node receives an -// incoming HTLC, if the amount encoded in the onion payload of the forwarded -// HTLC doesn't match the expected payment value, then the HTLC will be -// rejected. -func TestExitNodeAmountPayloadMismatch(t *testing.T) { +// TestExitNodeHTLCUnderpaysPayloadAmount tests that when an exit node receives +// an incoming HTLC, if the amount offered in the HTLC is less than the amount +// encoded in the onion payload then the HTLC will be rejected with the +// appropriate error. +func TestExitNodeHTLCUnderpaysPayloadAmount(t *testing.T) { t.Parallel() - channels, _, err := createClusterChannels( - t, btcutil.SatoshiPerBitcoin*5, btcutil.SatoshiPerBitcoin*5, - ) - require.NoError(t, err, "unable to create channel") + f := func(underpaymentRand uint64) bool { + underpayment := lnwire.MilliSatoshi( + underpaymentRand%(btcutil.SatoshiPerBitcoin-1) + 1, + ) + channels, _, err := createClusterChannels( + t, btcutil.SatoshiPerBitcoin*5, + btcutil.SatoshiPerBitcoin*5, + ) + require.NoError(t, err, "unable to create channel") - n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, - channels.bobToCarol, channels.carolToBob, testStartingHeight) - if err := n.start(); err != nil { - t.Fatal(err) + n := newThreeHopNetwork( + t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, + testStartingHeight, + ) + require.NoError(t, n.start()) + t.Cleanup(n.stop) + + const amount = btcutil.SatoshiPerBitcoin + htlcAmt, htlcExpiry, hops := generateHops( + amount, testStartingHeight, n.firstBobChannelLink, + ) + + // In order to exercise this case, we'll now _manually_ modify + // the per-hop payload for amount to be the incorrect value. + // The acceptable values of the amount to forward should be less + // than the incoming HTLC value. + hops[0].FwdInfo.AmountToForward = amount + underpayment + firstHop := n.firstBobChannelLink.ShortChanID() + _, err = makePayment( + n.aliceServer, n.bobServer, firstHop, hops, amount, + htlcAmt, htlcExpiry, + ).Wait(30 * time.Second) + assertFailureCode(t, err, lnwire.CodeFinalIncorrectHtlcAmount) + + return err != nil } - t.Cleanup(n.stop) + err := quick.Check(f, &quick.Config{MaxCount: 20}) + require.NoError(t, err, "payment should have failed but didn't") +} - const amount = btcutil.SatoshiPerBitcoin - htlcAmt, htlcExpiry, hops := generateHops(amount, testStartingHeight, - n.firstBobChannelLink) +// TestExitNodeHTLCExceedsAmountPayload tests that when an exit node receives an +// incoming HTLC, if the amount encoded in the onion payload of the forwarded +// HTLC is lower than the incoming HTLC value, then the HTLC will be accepted. +func TestExitNodeHTLCExceedsAmountPayload(t *testing.T) { + t.Parallel() - // In order to exercise this case, we'll now _manually_ modify the - // per-hop payload for amount to be the incorrect value. The proper - // value of the amount to forward should be the amount that the - // receiving node expects to receive. - hops[0].FwdInfo.AmountToForward = 1 - firstHop := n.firstBobChannelLink.ShortChanID() - _, err = makePayment( - n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt, - htlcExpiry, - ).Wait(30 * time.Second) - if err == nil { - t.Fatalf("payment should have failed but didn't") + f := func(overpaymentRand uint64) bool { + overpayment := lnwire.MilliSatoshi( + overpaymentRand%(btcutil.SatoshiPerBitcoin-1) + 1, + ) + channels, _, err := createClusterChannels( + t, btcutil.SatoshiPerBitcoin*5, + btcutil.SatoshiPerBitcoin*5, + ) + require.NoError(t, err, "unable to create channel") + + n := newThreeHopNetwork(t, channels.aliceToBob, + channels.bobToAlice, channels.bobToCarol, + channels.carolToBob, testStartingHeight) + require.NoError(t, n.start()) + t.Cleanup(n.stop) + + const amount = btcutil.SatoshiPerBitcoin + htlcAmt, htlcExpiry, hops := generateHops(amount, + testStartingHeight, n.firstBobChannelLink) + + // In order to exercise this case, we'll now _manually_ modify + // the per-hop payload for amount to be the incorrect value. + // The acceptable values of the amount to forward should be + // lower than the incoming HTLC value. + hops[0].FwdInfo.AmountToForward = amount - overpayment + firstHop := n.firstBobChannelLink.ShortChanID() + _, err = makePayment( + n.aliceServer, n.bobServer, firstHop, hops, amount, + htlcAmt, htlcExpiry, + ).Wait(30 * time.Second) + + return err == nil } - assertFailureCode(t, err, lnwire.CodeFinalIncorrectHtlcAmount) + err := quick.Check(f, &quick.Config{MaxCount: 20}) + require.NoError(t, err, "payment should have succeeded but didn't") } // TestLinkForwardTimelockPolicyMismatch tests that if a node is an @@ -3512,25 +3607,37 @@ func TestChannelRetransmission(t *testing.T) { // bandwidth of htlc links hasn't been changed. invoice, err = receiver.registry.LookupInvoice(rhash) if err != nil { - err = errors.Errorf("unable to get invoice: %v", err) + err = fmt.Errorf( + "unable to get invoice: %w", err, + ) continue } if invoice.State != invpkg.ContractSettled { - err = errors.Errorf("alice invoice haven't been settled") + err = fmt.Errorf( + "alice invoice haven't been settled", + ) continue } aliceExpectedBandwidth := aliceBandwidthBefore - htlcAmt if aliceExpectedBandwidth != n.aliceChannelLink.Bandwidth() { - err = errors.Errorf("expected alice to have %v, instead has %v", - aliceExpectedBandwidth, n.aliceChannelLink.Bandwidth()) + err = fmt.Errorf( + "expected alice to have %v,"+ + " instead has %v", + aliceExpectedBandwidth, + n.aliceChannelLink.Bandwidth(), + ) continue } bobExpectedBandwidth := bobBandwidthBefore + htlcAmt if bobExpectedBandwidth != n.firstBobChannelLink.Bandwidth() { - err = errors.Errorf("expected bob to have %v, instead has %v", - bobExpectedBandwidth, n.firstBobChannelLink.Bandwidth()) + err = fmt.Errorf( + "expected bob to have %v,"+ + " instead has %v", + bobExpectedBandwidth, + n.firstBobChannelLink.Bandwidth(), + ) continue } @@ -5517,8 +5624,10 @@ func TestExpectedFee(t *testing.T) { } fee := ExpectedFee(f, test.htlcAmt) if fee != test.expected { - t.Errorf("expected fee to be (%v), instead got (%v)", test.expected, - fee) + t.Errorf( + "expected fee to be (%v), instead got (%v)", + test.expected, fee, + ) } } } diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 866ad540f..e3ee79d32 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "math" "testing" + "testing/quick" "time" "github.com/lightningnetwork/lnd/amp" @@ -925,6 +926,84 @@ func TestMppPayment(t *testing.T) { } } +// TestMppPaymentWithOverpayment tests settling of an invoice with multiple +// partial payments. It covers the case where the mpp overpays what is in the +// invoice. +func TestMppPaymentWithOverpayment(t *testing.T) { + t.Parallel() + defer timeout()() + + f := func(overpayment_rand uint64) bool { + ctx := newTestContext(t, nil) + + // Add the invoice. + _, err := ctx.registry.AddInvoice( + testInvoice, testInvoicePaymentHash, + ) + if err != nil { + t.Fatal(err) + } + + mppPayload := &mockPayload{ + mpp: record.NewMPP(testInvoiceAmt, [32]byte{}), + } + + // We constrain overpayment amount to be [1,1000]. + overpayment := lnwire.MilliSatoshi((overpayment_rand % 999) + 1) + + // Send htlc 1. + hodlChan1 := make(chan interface{}, 1) + resolution, err := ctx.registry.NotifyExitHopHtlc( + testInvoicePaymentHash, testInvoice.Terms.Value/2, + testHtlcExpiry, testCurrentHeight, getCircuitKey(11), + hodlChan1, mppPayload, + ) + if err != nil { + t.Fatal(err) + } + if resolution != nil { + t.Fatal("expected no direct resolution") + } + + // Send htlc 2. + hodlChan2 := make(chan interface{}, 1) + resolution, err = ctx.registry.NotifyExitHopHtlc( + testInvoicePaymentHash, + testInvoice.Terms.Value/2+overpayment, testHtlcExpiry, + testCurrentHeight, getCircuitKey(12), hodlChan2, + mppPayload, + ) + if err != nil { + t.Fatal(err) + } + settleResolution, ok := + resolution.(*invpkg.HtlcSettleResolution) + if !ok { + t.Fatalf("expected settle resolution, got: %T", + resolution) + } + if settleResolution.Outcome != invpkg.ResultSettled { + t.Fatalf("expected result settled, got: %v", + settleResolution.Outcome) + } + + // Check that settled amount is equal to the sum of values of + // the htlcs 1 and 2. + inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) + if err != nil { + t.Fatal(err) + } + if inv.State != invpkg.ContractSettled { + t.Fatal("expected invoice to be settled") + } + + return inv.AmtPaid == testInvoice.Terms.Value+overpayment + } + if err := quick.Check(f, &quick.Config{MaxCount: 50}); err != nil { + t.Fatalf("amount incorrect: %v", err) + } +} + // Tests that invoices are canceled after expiration. func TestInvoiceExpiryWithRegistry(t *testing.T) { t.Parallel()