diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 1edd753cd..55bbd4931 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -37,6 +37,7 @@ func TestSettleInvoice(t *testing.T) { require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice. + testInvoice := newInvoice(t, false) addIdx, err := ctx.registry.AddInvoice( testInvoice, testInvoicePaymentHash, ) @@ -220,7 +221,7 @@ func testCancelInvoice(t *testing.T, gc bool) { require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice. - amt := lnwire.MilliSatoshi(100000) + testInvoice := newInvoice(t, false) _, err = ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) if err != nil { t.Fatal(err) @@ -298,8 +299,8 @@ func testCancelInvoice(t *testing.T, gc bool) { // result in a cancel resolution. hodlChan := make(chan interface{}) resolution, err := ctx.registry.NotifyExitHopHtlc( - testInvoicePaymentHash, amt, testHtlcExpiry, testCurrentHeight, - getCircuitKey(0), hodlChan, testPayload, + testInvoicePaymentHash, testInvoiceAmt, testHtlcExpiry, + testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { t.Fatal("expected settlement of a canceled invoice to succeed") @@ -379,7 +380,8 @@ func TestSettleHoldInvoice(t *testing.T) { require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice. - _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash) + invoice := newInvoice(t, true) + _, err = registry.AddInvoice(invoice, testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -536,7 +538,8 @@ func TestCancelHoldInvoice(t *testing.T) { }) // Add the invoice. - _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash) + invoice := newInvoice(t, true) + _, err = registry.AddInvoice(invoice, testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -840,6 +843,7 @@ func TestMppPayment(t *testing.T) { ctx := newTestContext(t, nil) // Add the invoice. + testInvoice := newInvoice(t, false) _, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) if err != nil { t.Fatal(err) @@ -936,9 +940,9 @@ func TestMppPaymentWithOverpayment(t *testing.T) { ctx := newTestContext(t, nil) // Add the invoice. - invoice := *testInvoice + testInvoice := newInvoice(t, false) _, err := ctx.registry.AddInvoice( - &invoice, testInvoicePaymentHash, + testInvoice, testInvoicePaymentHash, ) if err != nil { t.Fatal(err) @@ -954,7 +958,7 @@ func TestMppPaymentWithOverpayment(t *testing.T) { // Send htlc 1. hodlChan1 := make(chan interface{}, 1) resolution, err := ctx.registry.NotifyExitHopHtlc( - testInvoicePaymentHash, invoice.Terms.Value/2, + testInvoicePaymentHash, testInvoice.Terms.Value/2, testHtlcExpiry, testCurrentHeight, getCircuitKey(11), hodlChan1, mppPayload, ) @@ -969,7 +973,7 @@ func TestMppPaymentWithOverpayment(t *testing.T) { hodlChan2 := make(chan interface{}, 1) resolution, err = ctx.registry.NotifyExitHopHtlc( testInvoicePaymentHash, - invoice.Terms.Value/2+overpayment, testHtlcExpiry, + testInvoice.Terms.Value/2+overpayment, testHtlcExpiry, testCurrentHeight, getCircuitKey(12), hodlChan2, mppPayload, ) @@ -997,7 +1001,7 @@ func TestMppPaymentWithOverpayment(t *testing.T) { t.Fatal("expected invoice to be settled") } - return inv.AmtPaid == invoice.Terms.Value+overpayment + return inv.AmtPaid == testInvoice.Terms.Value+overpayment } if err := quick.Check(f, nil); err != nil { t.Fatalf("amount incorrect: %v", err) @@ -1222,11 +1226,11 @@ func testHeightExpiryWithRegistry(t *testing.T, numParts int, settle bool) { // Add a hold invoice, we set a non-nil payment request so that this // invoice is not considered a keysend by the expiry watcher. - invoice := *testInvoice - invoice.HodlInvoice = true - invoice.PaymentRequest = []byte{1, 2, 3} + testInvoice := newInvoice(t, false) + testInvoice.HodlInvoice = true + testInvoice.PaymentRequest = []byte{1, 2, 3} - _, err := ctx.registry.AddInvoice(&invoice, testInvoicePaymentHash) + _, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) require.NoError(t, err) payLoad := testPayload @@ -1236,7 +1240,7 @@ func testHeightExpiryWithRegistry(t *testing.T, numParts int, settle bool) { } } - htlcAmt := invoice.Terms.Value / lnwire.MilliSatoshi(numParts) + htlcAmt := testInvoice.Terms.Value / lnwire.MilliSatoshi(numParts) hodlChan := make(chan interface{}, numParts) for i := 0; i < numParts; i++ { // We bump our expiry height for each htlc so that we can test @@ -1262,7 +1266,9 @@ func testHeightExpiryWithRegistry(t *testing.T, numParts int, settle bool) { // Now that we've added our htlc(s), we tick our test clock to our // invoice expiry time. We don't expect the invoice to be canceled // based on its expiry time now that we have active htlcs. - ctx.clock.SetTime(invoice.CreationDate.Add(invoice.Terms.Expiry + 1)) + ctx.clock.SetTime( + testInvoice.CreationDate.Add(testInvoice.Terms.Expiry + 1), + ) // The expiry watcher loop takes some time to process the new clock // time. We mine the block before our expiry height, our mock will block @@ -1326,10 +1332,9 @@ func TestMultipleSetHeightExpiry(t *testing.T) { ctx := newTestContext(t, nil) // Add a hold invoice. - invoice := *testInvoice - invoice.HodlInvoice = true + testInvoice := newInvoice(t, true) - _, err := ctx.registry.AddInvoice(&invoice, testInvoicePaymentHash) + _, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) require.NoError(t, err) mppPayload := &mockPayload{ @@ -1339,7 +1344,7 @@ func TestMultipleSetHeightExpiry(t *testing.T) { // Send htlc 1. hodlChan1 := make(chan interface{}, 1) resolution, err := ctx.registry.NotifyExitHopHtlc( - testInvoicePaymentHash, invoice.Terms.Value/2, + testInvoicePaymentHash, testInvoice.Terms.Value/2, testHtlcExpiry, testCurrentHeight, getCircuitKey(10), hodlChan1, mppPayload, ) @@ -1369,7 +1374,7 @@ func TestMultipleSetHeightExpiry(t *testing.T) { // Send htlc 2. hodlChan2 := make(chan interface{}, 1) resolution, err = ctx.registry.NotifyExitHopHtlc( - testInvoicePaymentHash, invoice.Terms.Value/2, expiry, + testInvoicePaymentHash, testInvoice.Terms.Value/2, expiry, testCurrentHeight, getCircuitKey(11), hodlChan2, mppPayload, ) require.NoError(t, err) @@ -1378,7 +1383,7 @@ func TestMultipleSetHeightExpiry(t *testing.T) { // Send htlc 3. hodlChan3 := make(chan interface{}, 1) resolution, err = ctx.registry.NotifyExitHopHtlc( - testInvoicePaymentHash, invoice.Terms.Value/2, expiry, + testInvoicePaymentHash, testInvoice.Terms.Value/2, expiry, testCurrentHeight, getCircuitKey(12), hodlChan3, mppPayload, ) require.NoError(t, err) @@ -1462,7 +1467,7 @@ func TestSettleInvoicePaymentAddrRequired(t *testing.T) { // information, so it should be forced to the updateLegacy path then // fail as a required feature bit exists. resolution, err := ctx.registry.NotifyExitHopHtlc( - testInvoicePaymentHash, testInvoice.Terms.Value, + testInvoicePaymentHash, testPayAddrReqInvoice.Terms.Value, uint32(testCurrentHeight)+testInvoiceCltvDelta-1, testCurrentHeight, getCircuitKey(10), hodlChan, testPayload, ) @@ -1555,7 +1560,7 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { t.Fatalf("expected state ContractOpen, but got %v", update.State) } - if update.AmtPaid != testInvoice.Terms.Value { + if update.AmtPaid != testPayAddrOptionalInvoice.Terms.Value { t.Fatal("invoice AmtPaid incorrect") } case <-time.After(testTimeout): diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index 8bbcdc3ad..53f85d1a5 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -133,15 +133,6 @@ var ( var ( testInvoiceAmt = lnwire.MilliSatoshi(100000) - testInvoice = &invpkg.Invoice{ - Terms: invpkg.ContractTerm{ - PaymentPreimage: &testInvoicePreimage, - Value: testInvoiceAmt, - Expiry: time.Hour, - Features: testFeatures, - }, - CreationDate: testInvoiceCreationDate, - } testPayAddrReqInvoice = &invpkg.Invoice{ Terms: invpkg.ContractTerm{ @@ -174,16 +165,6 @@ var ( }, CreationDate: testInvoiceCreationDate, } - - testHodlInvoice = &invpkg.Invoice{ - Terms: invpkg.ContractTerm{ - Value: testInvoiceAmt, - Expiry: time.Hour, - Features: testFeatures, - }, - CreationDate: testInvoiceCreationDate, - HodlInvoice: true, - } ) func newTestChannelDB(t *testing.T, clock clock.Clock) (*channeldb.DB, error) { @@ -275,6 +256,37 @@ func getCircuitKey(htlcID uint64) invpkg.CircuitKey { } } +// newInvoice returns an invoice that can be used for testing, using the +// constant values defined above (deep copied if necessary). +// +// Note that this invoice *does not* have a payment address set. It will +// create a regular invoice with a preimage is hodl is false, and a hodl +// invoice with no preimage otherwise. +func newInvoice(t *testing.T, hodl bool) *invpkg.Invoice { + invoice := &invpkg.Invoice{ + Terms: invpkg.ContractTerm{ + Value: testInvoiceAmt, + Expiry: time.Hour, + Features: testFeatures.Clone(), + }, + CreationDate: testInvoiceCreationDate, + } + + // If creating a hodl invoice, we don't include a preimage. + if hodl { + invoice.HodlInvoice = true + return invoice + } + + preimage, err := lntypes.MakePreimage( + testInvoicePreimage[:], + ) + require.NoError(t, err) + invoice.Terms.PaymentPreimage = &preimage + + return invoice +} + // timeout implements a test level timeout. func timeout() func() { done := make(chan struct{})