diff --git a/channeldb/error.go b/channeldb/error.go index 97e06a14f..f3b89bbab 100644 --- a/channeldb/error.go +++ b/channeldb/error.go @@ -31,30 +31,6 @@ var ( // created. ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas") - // ErrInvoiceNotFound is returned when a targeted invoice can't be - // found. - ErrInvoiceNotFound = fmt.Errorf("unable to locate invoice") - - // ErrNoInvoicesCreated is returned when we don't have invoices in - // our database to return. - ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices") - - // ErrDuplicateInvoice is returned when an invoice with the target - // payment hash already exists. - ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists") - - // ErrDuplicatePayAddr is returned when an invoice with the target - // payment addr already exists. - ErrDuplicatePayAddr = fmt.Errorf("invoice with payemnt addr already exists") - - // ErrInvRefEquivocation is returned when an InvoiceRef targets - // multiple, distinct invoices. - ErrInvRefEquivocation = errors.New("inv ref matches multiple invoices") - - // ErrNoPaymentsCreated is returned when bucket of payments hasn't been - // created. - ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments") - // ErrNodeNotFound is returned when node bucket exists, but node with // specific identity can't be found. ErrNodeNotFound = fmt.Errorf("link node with target identity not found") diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index b869f5838..38674b6be 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/feature" + invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -31,7 +32,7 @@ var ( testNow = time.Unix(1, 0) ) -func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { +func randInvoice(value lnwire.MilliSatoshi) (*invpkg.Invoice, error) { var ( pre lntypes.Preimage payAddr [32]byte @@ -43,24 +44,24 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { return nil, err } - i := &Invoice{ + i := &invpkg.Invoice{ CreationDate: testNow, - Terms: ContractTerm{ + Terms: invpkg.ContractTerm{ Expiry: 4000, PaymentPreimage: &pre, PaymentAddr: payAddr, Value: value, Features: emptyFeatures, }, - Htlcs: map[models.CircuitKey]*InvoiceHTLC{}, - AMPState: map[SetID]InvoiceStateAMP{}, + Htlcs: map[models.CircuitKey]*invpkg.InvoiceHTLC{}, + AMPState: map[invpkg.SetID]invpkg.InvoiceStateAMP{}, } i.Memo = []byte("memo") // Create a random byte slice of MaxPaymentRequestSize bytes to be used // as a dummy paymentrequest, and determine if it should be set based // on one of the random bytes. - var r [MaxPaymentRequestSize]byte + var r [invpkg.MaxPaymentRequestSize]byte if _, err := rand.Read(r[:]); err != nil { return nil, err } @@ -74,15 +75,15 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { } // settleTestInvoice settles a test invoice. -func settleTestInvoice(invoice *Invoice, settleIndex uint64) { +func settleTestInvoice(invoice *invpkg.Invoice, settleIndex uint64) { invoice.SettleDate = testNow invoice.AmtPaid = invoice.Terms.Value - invoice.State = ContractSettled - invoice.Htlcs[models.CircuitKey{}] = &InvoiceHTLC{ + invoice.State = invpkg.ContractSettled + invoice.Htlcs[models.CircuitKey{}] = &invpkg.InvoiceHTLC{ Amt: invoice.Terms.Value, AcceptTime: testNow, ResolveTime: testNow, - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), } invoice.SettleIndex = settleIndex @@ -91,23 +92,23 @@ func settleTestInvoice(invoice *Invoice, settleIndex uint64) { // Tests that pending invoices are those which are either in ContractOpen or // in ContractAccepted state. func TestInvoiceIsPending(t *testing.T) { - contractStates := []ContractState{ - ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, + contractStates := []invpkg.ContractState{ + invpkg.ContractOpen, invpkg.ContractSettled, + invpkg.ContractCanceled, invpkg.ContractAccepted, } for _, state := range contractStates { - invoice := Invoice{ + invoice := invpkg.Invoice{ State: state, } - // We expect that an invoice is pending if it's either in ContractOpen - // or ContractAccepted state. - pending := (state == ContractOpen || state == ContractAccepted) + // We expect that an invoice is pending if it's either in + // ContractOpen or ContractAccepted state. + open := invpkg.ContractOpen + accepted := invpkg.ContractAccepted + pending := (state == open || state == accepted) - if invoice.IsPending() != pending { - t.Fatalf("expected pending: %v, got: %v, invoice: %v", - pending, invoice.IsPending(), invoice) - } + require.Equal(t, pending, invoice.IsPending()) } } @@ -164,16 +165,16 @@ func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { var ( payHash lntypes.Hash payAddr *[32]byte - ref InvoiceRef + ref invpkg.InvoiceRef ) switch { case test.queryPayHash && test.queryPayAddr: payHash = invPayHash payAddr = &fakeInvoice.Terms.PaymentAddr - ref = InvoiceRefByHashAndAddr(payHash, *payAddr) + ref = invpkg.InvoiceRefByHashAndAddr(payHash, *payAddr) case test.queryPayHash: payHash = invPayHash - ref = InvoiceRefByHash(payHash) + ref = invpkg.InvoiceRefByHash(payHash) } // Add the invoice to the database, this should succeed as there aren't @@ -188,9 +189,7 @@ func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { // identical to the one created above. dbInvoice, err := db.LookupInvoice(ref) if !test.queryPayAddr && !test.queryPayHash { - if err != ErrInvoiceNotFound { - t.Fatalf("invoice should not exist: %v", err) - } + require.ErrorIs(t, err, invpkg.ErrInvoiceNotFound) return } @@ -215,7 +214,7 @@ func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { require.NoError(t, err, "unable to settle invoice") dbInvoice2, err := db.LookupInvoice(ref) require.NoError(t, err, "unable to fetch invoice") - if dbInvoice2.State != ContractSettled { + if dbInvoice2.State != invpkg.ContractSettled { t.Fatalf("invoice should now be settled but isn't") } if dbInvoice2.SettleDate.IsZero() { @@ -235,24 +234,20 @@ func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { // Attempt to insert generated above again, this should fail as // duplicates are rejected by the processing logic. - if _, err := db.AddInvoice(fakeInvoice, payHash); err != ErrDuplicateInvoice { - t.Fatalf("invoice insertion should fail due to duplication, "+ - "instead %v", err) - } + _, err = db.AddInvoice(fakeInvoice, payHash) + require.ErrorIs(t, err, invpkg.ErrDuplicateInvoice) // Attempt to look up a non-existent invoice, this should also fail but // with a "not found" error. var fakeHash [32]byte - fakeRef := InvoiceRefByHash(fakeHash) + fakeRef := invpkg.InvoiceRefByHash(fakeHash) _, err = db.LookupInvoice(fakeRef) - if err != ErrInvoiceNotFound { - t.Fatalf("lookup should have failed, instead %v", err) - } + require.ErrorIs(t, err, invpkg.ErrInvoiceNotFound) // Add 10 random invoices. const numInvoices = 10 amt := lnwire.NewMSatFromSatoshis(1000) - invoices := make([]*Invoice, numInvoices+1) + invoices := make([]*invpkg.Invoice, numInvoices+1) invoices[0] = &dbInvoice2 for i := 1; i < len(invoices); i++ { invoice, err := randInvoice(amt) @@ -269,7 +264,7 @@ func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { } // Perform a scan to collect all the active invoices. - query := InvoiceQuery{ + query := invpkg.InvoiceQuery{ IndexOffset: 0, NumMaxInvoices: math.MaxUint64, PendingOnly: false, @@ -312,7 +307,7 @@ func TestAddDuplicatePayAddr(t *testing.T) { // Second insert should fail with duplicate payment addr. inv2Hash := invoice2.Terms.PaymentPreimage.Hash() _, err = db.AddInvoice(invoice2, inv2Hash) - require.Error(t, err, ErrDuplicatePayAddr) + require.Error(t, err, invpkg.ErrDuplicatePayAddr) } // TestAddDuplicateKeysendPayAddr asserts that we permit duplicate payment @@ -325,11 +320,11 @@ func TestAddDuplicateKeysendPayAddr(t *testing.T) { // Create two invoices with the same _blank_ payment addr. invoice1, err := randInvoice(1000) require.NoError(t, err) - invoice1.Terms.PaymentAddr = BlankPayAddr + invoice1.Terms.PaymentAddr = invpkg.BlankPayAddr invoice2, err := randInvoice(20000) require.NoError(t, err) - invoice2.Terms.PaymentAddr = BlankPayAddr + invoice2.Terms.PaymentAddr = invpkg.BlankPayAddr // Inserting both should succeed without a duplicate payment address // failure. @@ -345,12 +340,12 @@ func TestAddDuplicateKeysendPayAddr(t *testing.T) { // the lookup will fail if the hash and addr point to different // invoices, so if both succeed we can be assured they aren't included // in the payment address index. - ref1 := InvoiceRefByHashAndAddr(inv1Hash, BlankPayAddr) + ref1 := invpkg.InvoiceRefByHashAndAddr(inv1Hash, invpkg.BlankPayAddr) dbInv1, err := db.LookupInvoice(ref1) require.NoError(t, err) require.Equal(t, invoice1, &dbInv1) - ref2 := InvoiceRefByHashAndAddr(inv2Hash, BlankPayAddr) + ref2 := invpkg.InvoiceRefByHashAndAddr(inv2Hash, invpkg.BlankPayAddr) dbInv2, err := db.LookupInvoice(ref2) require.NoError(t, err) require.Equal(t, invoice2, &dbInv2) @@ -380,9 +375,9 @@ func TestFailInvoiceLookupMPPPayAddrOnly(t *testing.T) { // lookup should fail since we require the payment hash to match for // legacy/MPP invoices, as this guarantees that the preimage is valid // for the given HTLC. - ref := InvoiceRefByHashAndAddr(payHash, payAddr) + ref := invpkg.InvoiceRefByHashAndAddr(payHash, payAddr) _, err = db.LookupInvoice(ref) - require.Equal(t, ErrInvoiceNotFound, err) + require.Equal(t, invpkg.ErrInvoiceNotFound, err) } // TestInvRefEquivocation asserts that retrieving or updating an invoice using @@ -409,17 +404,19 @@ func TestInvRefEquivocation(t *testing.T) { // Now, query using invoice 1's payment address, but invoice 2's payment // hash. We expect an error since the invref points to multiple // invoices. - ref := InvoiceRefByHashAndAddr(inv2Hash, invoice1.Terms.PaymentAddr) + ref := invpkg.InvoiceRefByHashAndAddr( + inv2Hash, invoice1.Terms.PaymentAddr, + ) _, err = db.LookupInvoice(ref) - require.Error(t, err, ErrInvRefEquivocation) + require.Error(t, err, invpkg.ErrInvRefEquivocation) // The same error should be returned when updating an equivocating // reference. - nop := func(_ *Invoice) (*InvoiceUpdateDesc, error) { + nop := func(_ *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, error) { return nil, nil } _, err = db.UpdateInvoice(ref, nil, nop) - require.Error(t, err, ErrInvRefEquivocation) + require.Error(t, err, invpkg.ErrInvRefEquivocation) } // TestInvoiceCancelSingleHtlc tests that a single htlc can be canceled on the @@ -433,9 +430,9 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) { preimage := lntypes.Preimage{1} paymentHash := preimage.Hash() - testInvoice := &Invoice{ - Htlcs: map[models.CircuitKey]*InvoiceHTLC{}, - Terms: ContractTerm{ + testInvoice := &invpkg.Invoice{ + Htlcs: map[models.CircuitKey]*invpkg.InvoiceHTLC{}, + Terms: invpkg.ContractTerm{ Value: lnwire.NewMSatFromSatoshis(10000), Features: emptyFeatures, PaymentPreimage: &preimage, @@ -451,42 +448,50 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) { ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4, } - htlc := HtlcAcceptDesc{ + htlc := invpkg.HtlcAcceptDesc{ Amt: 500, CustomRecords: make(record.CustomSet), } - ref := InvoiceRefByHash(paymentHash) - invoice, err := db.UpdateInvoice(ref, nil, - func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - return &InvoiceUpdateDesc{ - AddHtlcs: map[models.CircuitKey]*HtlcAcceptDesc{ - key: &htlc, - }, - }, nil - }) + callback := func( + invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, error) { + + htlcs := map[models.CircuitKey]*invpkg.HtlcAcceptDesc{ + key: &htlc, + } + + return &invpkg.InvoiceUpdateDesc{ + AddHtlcs: htlcs, + }, nil + } + + ref := invpkg.InvoiceRefByHash(paymentHash) + invoice, err := db.UpdateInvoice(ref, nil, callback) require.NoError(t, err, "unable to add invoice htlc") if len(invoice.Htlcs) != 1 { t.Fatalf("expected the htlc to be added") } - if invoice.Htlcs[key].State != HtlcStateAccepted { + if invoice.Htlcs[key].State != invpkg.HtlcStateAccepted { t.Fatalf("expected htlc in state accepted") } + callback = func( + invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, error) { + + return &invpkg.InvoiceUpdateDesc{ + CancelHtlcs: map[models.CircuitKey]struct{}{ + key: {}, + }, + }, nil + } + // Cancel the htlc again. - invoice, err = db.UpdateInvoice(ref, nil, - func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - return &InvoiceUpdateDesc{ - CancelHtlcs: map[models.CircuitKey]struct{}{ - key: {}, - }, - }, nil - }) + invoice, err = db.UpdateInvoice(ref, nil, callback) require.NoError(t, err, "unable to cancel htlc") if len(invoice.Htlcs) != 1 { t.Fatalf("expected the htlc to be present") } - if invoice.Htlcs[key].State != HtlcStateCanceled { + if invoice.Htlcs[key].State != invpkg.HtlcStateCanceled { t.Fatalf("expected htlc in state canceled") } } @@ -518,21 +523,26 @@ func TestInvoiceCancelSingleHtlcAMP(t *testing.T) { setID1 := &[32]byte{1} setID2 := &[32]byte{2} - ref := InvoiceRefByHashAndAddr(payHash, invoice.Terms.PaymentAddr) + ref := invpkg.InvoiceRefByHashAndAddr( + payHash, invoice.Terms.PaymentAddr, + ) // The first set ID with a single HTLC added. _, err = db.UpdateInvoice( - ref, (*SetID)(setID1), updateAcceptAMPHtlc(0, amt, setID1, true), + ref, (*invpkg.SetID)(setID1), + updateAcceptAMPHtlc(0, amt, setID1, true), ) require.Nil(t, err) // The second set ID with two HTLCs added. _, err = db.UpdateInvoice( - ref, (*SetID)(setID2), updateAcceptAMPHtlc(1, amt, setID2, true), + ref, (*invpkg.SetID)(setID2), + updateAcceptAMPHtlc(1, amt, setID2, true), ) require.Nil(t, err) dbInvoice, err := db.UpdateInvoice( - ref, (*SetID)(setID2), updateAcceptAMPHtlc(2, amt, setID2, true), + ref, (*invpkg.SetID)(setID2), + updateAcceptAMPHtlc(2, amt, setID2, true), ) require.Nil(t, err) @@ -540,18 +550,21 @@ func TestInvoiceCancelSingleHtlcAMP(t *testing.T) { // paid. require.Equal(t, dbInvoice.AmtPaid, amt*3) + callback := func( + invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, error) { + + return &invpkg.InvoiceUpdateDesc{ + CancelHtlcs: map[models.CircuitKey]struct{}{ + {HtlcID: 0}: {}, + }, + SetID: (*invpkg.SetID)(setID1), + }, nil + } + // Now we'll cancel a single invoice, and assert that the amount paid // is decremented, and the state for that HTLC set reflects that is // been cancelled. - _, err = db.UpdateInvoice(ref, (*SetID)(setID1), - func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - return &InvoiceUpdateDesc{ - CancelHtlcs: map[models.CircuitKey]struct{}{ - {HtlcID: 0}: {}, - }, - SetID: (*SetID)(setID1), - }, nil - }) + _, err = db.UpdateInvoice(ref, (*invpkg.SetID)(setID1), callback) require.NoError(t, err, "unable to cancel htlc") freshInvoice, err := db.LookupInvoice(ref) @@ -563,22 +576,27 @@ func TestInvoiceCancelSingleHtlcAMP(t *testing.T) { // The HTLC and AMP state should also show that only one HTLC set is // left. - invoice.State = ContractOpen + invoice.State = invpkg.ContractOpen invoice.AmtPaid = 2 * amt invoice.SettleDate = dbInvoice.SettleDate - invoice.Htlcs = map[models.CircuitKey]*InvoiceHTLC{ - {HtlcID: 0}: makeAMPInvoiceHTLC(amt, *setID1, payHash, &preimage), - {HtlcID: 1}: makeAMPInvoiceHTLC(amt, *setID2, payHash, &preimage), - {HtlcID: 2}: makeAMPInvoiceHTLC(amt, *setID2, payHash, &preimage), + + htlc0 := models.CircuitKey{HtlcID: 0} + htlc1 := models.CircuitKey{HtlcID: 1} + htlc2 := models.CircuitKey{HtlcID: 2} + + invoice.Htlcs = map[models.CircuitKey]*invpkg.InvoiceHTLC{ + htlc0: makeAMPInvoiceHTLC(amt, *setID1, payHash, &preimage), + htlc1: makeAMPInvoiceHTLC(amt, *setID2, payHash, &preimage), + htlc2: makeAMPInvoiceHTLC(amt, *setID2, payHash, &preimage), } - invoice.AMPState[*setID1] = InvoiceStateAMP{ - State: HtlcStateCanceled, + invoice.AMPState[*setID1] = invpkg.InvoiceStateAMP{ + State: invpkg.HtlcStateCanceled, InvoiceKeys: map[models.CircuitKey]struct{}{ {HtlcID: 0}: {}, }, } - invoice.AMPState[*setID2] = InvoiceStateAMP{ - State: HtlcStateAccepted, + invoice.AMPState[*setID2] = invpkg.InvoiceStateAMP{ + State: invpkg.HtlcStateAccepted, AmtPaid: amt * 2, InvoiceKeys: map[models.CircuitKey]struct{}{ {HtlcID: 1}: {}, @@ -586,21 +604,22 @@ func TestInvoiceCancelSingleHtlcAMP(t *testing.T) { }, } - htlc0 := models.CircuitKey{HtlcID: 0} - invoice.Htlcs[htlc0].State = HtlcStateCanceled + invoice.Htlcs[htlc0].State = invpkg.HtlcStateCanceled invoice.Htlcs[htlc0].ResolveTime = time.Unix(1, 0) require.Equal(t, invoice, dbInvoice) // Next, we'll cancel the _other_ HTLCs active, but we'll do them one // by one. - _, err = db.UpdateInvoice(ref, (*SetID)(setID2), - func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - return &InvoiceUpdateDesc{ + _, err = db.UpdateInvoice(ref, (*invpkg.SetID)(setID2), + func(invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, + error) { + + return &invpkg.InvoiceUpdateDesc{ CancelHtlcs: map[models.CircuitKey]struct{}{ {HtlcID: 1}: {}, }, - SetID: (*SetID)(setID2), + SetID: (*invpkg.SetID)(setID2), }, nil }) require.NoError(t, err, "unable to cancel htlc") @@ -609,29 +628,31 @@ func TestInvoiceCancelSingleHtlcAMP(t *testing.T) { require.Nil(t, err) dbInvoice = &freshInvoice - htlc1 := models.CircuitKey{HtlcID: 1} - invoice.Htlcs[htlc1].State = HtlcStateCanceled + invoice.Htlcs[htlc1].State = invpkg.HtlcStateCanceled invoice.Htlcs[htlc1].ResolveTime = time.Unix(1, 0) invoice.AmtPaid = amt ampState := invoice.AMPState[*setID2] - ampState.State = HtlcStateCanceled + ampState.State = invpkg.HtlcStateCanceled ampState.AmtPaid = amt invoice.AMPState[*setID2] = ampState require.Equal(t, invoice, dbInvoice) + callback = func( + invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, error) { + + return &invpkg.InvoiceUpdateDesc{ + CancelHtlcs: map[models.CircuitKey]struct{}{ + {HtlcID: 2}: {}, + }, + SetID: (*invpkg.SetID)(setID2), + }, nil + } + // Now we'll cancel the final HTLC, which should cause all the active // HTLCs to transition to the cancelled state. - _, err = db.UpdateInvoice(ref, (*SetID)(setID2), - func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - return &InvoiceUpdateDesc{ - CancelHtlcs: map[models.CircuitKey]struct{}{ - {HtlcID: 2}: {}, - }, - SetID: (*SetID)(setID2), - }, nil - }) + _, err = db.UpdateInvoice(ref, (*invpkg.SetID)(setID2), callback) require.NoError(t, err, "unable to cancel htlc") freshInvoice, err = db.LookupInvoice(ref) @@ -642,8 +663,7 @@ func TestInvoiceCancelSingleHtlcAMP(t *testing.T) { ampState.AmtPaid = 0 invoice.AMPState[*setID2] = ampState - htlc2 := models.CircuitKey{HtlcID: 2} - invoice.Htlcs[htlc2].State = HtlcStateCanceled + invoice.Htlcs[htlc2].State = invpkg.HtlcStateCanceled invoice.Htlcs[htlc2].ResolveTime = time.Unix(1, 0) invoice.AmtPaid = 0 @@ -666,7 +686,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { // into the database. const numInvoices = 20 amt := lnwire.NewMSatFromSatoshis(1000) - invoices := make([]Invoice, numInvoices) + invoices := make([]invpkg.Invoice, numInvoices) for i := 0; i < len(invoices); i++ { invoice, err := randInvoice(amt) if err != nil { @@ -688,7 +708,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { addQueries := []struct { sinceAddIndex uint64 - resp []Invoice + resp []invpkg.Invoice }{ // If we specify a value of zero, we shouldn't get any invoices // back. @@ -736,7 +756,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { _, err = db.InvoicesSettledSince(0) require.NoError(t, err) - var settledInvoices []Invoice + var settledInvoices []invpkg.Invoice var settleIndex uint64 = 1 // We'll now only settle the latter half of each of those invoices. for i := 10; i < len(invoices); i++ { @@ -744,7 +764,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { paymentHash := invoice.Terms.PaymentPreimage.Hash() - ref := InvoiceRefByHash(paymentHash) + ref := invpkg.InvoiceRefByHash(paymentHash) _, err := db.UpdateInvoice( ref, nil, getUpdateInvoice(invoice.Terms.Value), ) @@ -764,7 +784,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { settleQueries := []struct { sinceSettleIndex uint64 - resp []Invoice + resp []invpkg.Invoice }{ // If we specify a value of zero, we shouldn't get any settled // invoices back. @@ -835,17 +855,22 @@ func TestSettleIndexAmpPayments(t *testing.T) { setID2 := &[32]byte{2} setID3 := &[32]byte{3} - ref := InvoiceRefByHashAndAddr(payHash, testInvoice.Terms.PaymentAddr) + ref := invpkg.InvoiceRefByHashAndAddr( + payHash, testInvoice.Terms.PaymentAddr, + ) _, err = db.UpdateInvoice( - ref, (*SetID)(setID1), updateAcceptAMPHtlc(1, amt, setID1, true), + ref, (*invpkg.SetID)(setID1), + updateAcceptAMPHtlc(1, amt, setID1, true), ) require.Nil(t, err) _, err = db.UpdateInvoice( - ref, (*SetID)(setID2), updateAcceptAMPHtlc(2, amt, setID2, true), + ref, (*invpkg.SetID)(setID2), + updateAcceptAMPHtlc(2, amt, setID2, true), ) require.Nil(t, err) _, err = db.UpdateInvoice( - ref, (*SetID)(setID3), updateAcceptAMPHtlc(3, amt, setID3, true), + ref, (*invpkg.SetID)(setID3), + updateAcceptAMPHtlc(3, amt, setID3, true), ) require.Nil(t, err) @@ -855,7 +880,9 @@ func TestSettleIndexAmpPayments(t *testing.T) { // // First, we'll query for the invoice with just the payment addr, but // specify no HTLcs are to be included. - refNoHtlcs := InvoiceRefByAddrBlankHtlc(testInvoice.Terms.PaymentAddr) + refNoHtlcs := invpkg.InvoiceRefByAddrBlankHtlc( + testInvoice.Terms.PaymentAddr, + ) invoiceNoHTLCs, err := db.LookupInvoice(refNoHtlcs) require.Nil(t, err) @@ -864,7 +891,7 @@ func TestSettleIndexAmpPayments(t *testing.T) { // We'll now look up the HTLCs based on the individual setIDs added // above. for i, setID := range []*[32]byte{setID1, setID2, setID3} { - refFiltered := InvoiceRefBySetIDFiltered(*setID) + refFiltered := invpkg.InvoiceRefBySetIDFiltered(*setID) invoiceFiltered, err := db.LookupInvoice(refFiltered) require.Nil(t, err) @@ -877,27 +904,27 @@ func TestSettleIndexAmpPayments(t *testing.T) { require.Equal(t, *setID, htlc.AMP.Record.SetID()) // The HTLC should show that it's in the accepted state. - require.Equal(t, htlc.State, HtlcStateAccepted) + require.Equal(t, htlc.State, invpkg.HtlcStateAccepted) } // Now that we know the invoices are in the proper state, we'll settle // them on by one in distinct updates. _, err = db.UpdateInvoice( - ref, (*SetID)(setID1), + ref, (*invpkg.SetID)(setID1), getUpdateInvoiceAMPSettle( setID1, preimage, models.CircuitKey{HtlcID: 1}, ), ) require.Nil(t, err) _, err = db.UpdateInvoice( - ref, (*SetID)(setID2), + ref, (*invpkg.SetID)(setID2), getUpdateInvoiceAMPSettle( setID2, preimage, models.CircuitKey{HtlcID: 2}, ), ) require.Nil(t, err) _, err = db.UpdateInvoice( - ref, (*SetID)(setID3), + ref, (*invpkg.SetID)(setID3), getUpdateInvoiceAMPSettle( setID3, preimage, models.CircuitKey{HtlcID: 3}, ), @@ -914,9 +941,13 @@ func TestSettleIndexAmpPayments(t *testing.T) { // To get around the settle index quirk, we'll fetch the very first // invoice in the HTLC filtered mode and append it to the set of // invoices. - firstInvoice, err := db.LookupInvoice(InvoiceRefBySetIDFiltered(*setID1)) + firstInvoice, err := db.LookupInvoice( + invpkg.InvoiceRefBySetIDFiltered(*setID1), + ) require.Nil(t, err) - settledInvoices = append([]Invoice{firstInvoice}, settledInvoices...) + settledInvoices = append( + []invpkg.Invoice{firstInvoice}, settledInvoices..., + ) // There should be 3 invoices settled, as we created 3 "sub-invoices" // above. @@ -935,7 +966,7 @@ func TestSettleIndexAmpPayments(t *testing.T) { subInvoiceState, ok := settledInvoice.AMPState[*invSetID] require.True(t, ok) - require.Equal(t, subInvoiceState.State, HtlcStateSettled) + require.Equal(t, subInvoiceState.State, invpkg.HtlcStateSettled) require.Equal(t, int(subInvoiceState.SettleIndex), i+1) invoiceKey := models.CircuitKey{HtlcID: uint64(i + 1)} @@ -945,14 +976,14 @@ func TestSettleIndexAmpPayments(t *testing.T) { // If we attempt to look up the invoice by the payment addr, with all // the HTLCs, the main invoice should have 3 HTLCs present. - refWithHtlcs := InvoiceRefByAddr(testInvoice.Terms.PaymentAddr) + refWithHtlcs := invpkg.InvoiceRefByAddr(testInvoice.Terms.PaymentAddr) invoiceWithHTLCs, err := db.LookupInvoice(refWithHtlcs) require.Nil(t, err) require.Equal(t, numInvoices, len(invoiceWithHTLCs.Htlcs)) // Finally, delete the invoice. If we query again, then nothing should // be found. - err = db.DeleteInvoice([]InvoiceDeleteRef{ + err = db.DeleteInvoice([]invpkg.InvoiceDeleteRef{ { PayHash: payHash, PayAddr: &testInvoice.Terms.PaymentAddr, @@ -970,7 +1001,7 @@ func TestScanInvoices(t *testing.T) { db, err := MakeTestDB(t) require.NoError(t, err, "unable to make test db") - var invoices map[lntypes.Hash]*Invoice + var invoices map[lntypes.Hash]*invpkg.Invoice callCount := 0 resetCount := 0 @@ -978,12 +1009,14 @@ func TestScanInvoices(t *testing.T) { // upon calling ScanInvoices and when the underlying transaction is // retried. reset := func() { - invoices = make(map[lntypes.Hash]*Invoice) + invoices = make(map[lntypes.Hash]*invpkg.Invoice) callCount = 0 resetCount++ } - scanFunc := func(paymentHash lntypes.Hash, invoice *Invoice) error { + scanFunc := func(paymentHash lntypes.Hash, + invoice *invpkg.Invoice) error { + invoices[paymentHash] = invoice callCount++ @@ -997,7 +1030,7 @@ func TestScanInvoices(t *testing.T) { require.Equal(t, 1, resetCount) numInvoices := 5 - testInvoices := make(map[lntypes.Hash]*Invoice) + testInvoices := make(map[lntypes.Hash]*invpkg.Invoice) // Now populate the DB and check if we can get all invoices with their // payment hashes as expected. @@ -1040,22 +1073,22 @@ func TestDuplicateSettleInvoice(t *testing.T) { } // With the invoice in the DB, we'll now attempt to settle the invoice. - ref := InvoiceRefByHash(payHash) + ref := invpkg.InvoiceRefByHash(payHash) dbInvoice, err := db.UpdateInvoice(ref, nil, getUpdateInvoice(amt)) require.NoError(t, err, "unable to settle invoice") // We'll update what we expect the settle invoice to be so that our // comparison below has the correct assumption. invoice.SettleIndex = 1 - invoice.State = ContractSettled + invoice.State = invpkg.ContractSettled invoice.AmtPaid = amt invoice.SettleDate = dbInvoice.SettleDate - invoice.Htlcs = map[models.CircuitKey]*InvoiceHTLC{ + invoice.Htlcs = map[models.CircuitKey]*invpkg.InvoiceHTLC{ {}: { Amt: amt, AcceptTime: time.Unix(1, 0), ResolveTime: time.Unix(1, 0), - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), }, } @@ -1066,16 +1099,16 @@ func TestDuplicateSettleInvoice(t *testing.T) { // If we try to settle the invoice again, then we should get the very // same invoice back, but with an error this time. dbInvoice, err = db.UpdateInvoice(ref, nil, getUpdateInvoice(amt)) - if err != ErrInvoiceAlreadySettled { - t.Fatalf("expected ErrInvoiceAlreadySettled") - } + require.ErrorIs(t, err, invpkg.ErrInvoiceAlreadySettled) if dbInvoice == nil { t.Fatalf("invoice from db is nil after settle!") } invoice.SettleDate = dbInvoice.SettleDate - require.Equal(t, invoice, dbInvoice, "wrong invoice after second settle") + require.Equal( + t, invoice, dbInvoice, "wrong invoice after second settle", + ) } // TestQueryInvoices ensures that we can properly query the invoice database for @@ -1091,8 +1124,8 @@ func TestQueryInvoices(t *testing.T) { // as the amount of the invoice itself. const numInvoices = 50 var settleIndex uint64 = 1 - var invoices []Invoice - var pendingInvoices []Invoice + var invoices []invpkg.Invoice + var pendingInvoices []invpkg.Invoice for i := 1; i <= numInvoices; i++ { amt := lnwire.MilliSatoshi(i) @@ -1112,8 +1145,10 @@ func TestQueryInvoices(t *testing.T) { // We'll only settle half of all invoices created. if i%2 == 0 { - ref := InvoiceRefByHash(paymentHash) - _, err := db.UpdateInvoice(ref, nil, getUpdateInvoice(amt)) + ref := invpkg.InvoiceRefByHash(paymentHash) + _, err := db.UpdateInvoice( + ref, nil, getUpdateInvoice(amt), + ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } @@ -1131,19 +1166,19 @@ func TestQueryInvoices(t *testing.T) { // The test will consist of several queries along with their respective // expected response. Each query response should match its expected one. testCases := []struct { - query InvoiceQuery - expected []Invoice + query invpkg.InvoiceQuery + expected []invpkg.Invoice }{ // Fetch all invoices with a single query. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ NumMaxInvoices: numInvoices, }, expected: invoices, }, // Fetch all invoices with a single query, reversed. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ Reversed: true, NumMaxInvoices: numInvoices, }, @@ -1151,7 +1186,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch the first 25 invoices. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ NumMaxInvoices: numInvoices / 2, }, expected: invoices[:numInvoices/2], @@ -1159,7 +1194,7 @@ func TestQueryInvoices(t *testing.T) { // Fetch the first 10 invoices, but this time iterating // backwards. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 11, Reversed: true, NumMaxInvoices: numInvoices, @@ -1168,7 +1203,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch the last 40 invoices. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 10, NumMaxInvoices: numInvoices, }, @@ -1176,7 +1211,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch all but the first invoice. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 1, NumMaxInvoices: numInvoices, }, @@ -1185,7 +1220,7 @@ func TestQueryInvoices(t *testing.T) { // Fetch one invoice, reversed, with index offset 3. This // should give us the second invoice in the array. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 3, Reversed: true, NumMaxInvoices: 1, @@ -1194,7 +1229,7 @@ func TestQueryInvoices(t *testing.T) { }, // Same as above, at index 2. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 2, Reversed: true, NumMaxInvoices: 1, @@ -1205,7 +1240,7 @@ func TestQueryInvoices(t *testing.T) { // the very first, there won't be any left in a reverse search, // so we expect no invoices to be returned. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 1, Reversed: true, NumMaxInvoices: 1, @@ -1215,7 +1250,7 @@ func TestQueryInvoices(t *testing.T) { // Same as above, but don't restrict the number of invoices to // 1. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 1, Reversed: true, NumMaxInvoices: numInvoices, @@ -1225,7 +1260,7 @@ func TestQueryInvoices(t *testing.T) { // Fetch one invoice, reversed, with no offset set. We expect // the last invoice in the response. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ Reversed: true, NumMaxInvoices: 1, }, @@ -1234,7 +1269,7 @@ func TestQueryInvoices(t *testing.T) { // Fetch one invoice, reversed, the offset set at numInvoices+1. // We expect this to return the last invoice. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: numInvoices + 1, Reversed: true, NumMaxInvoices: 1, @@ -1243,7 +1278,7 @@ func TestQueryInvoices(t *testing.T) { }, // Same as above, at offset numInvoices. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: numInvoices, Reversed: true, NumMaxInvoices: 1, @@ -1253,14 +1288,14 @@ func TestQueryInvoices(t *testing.T) { // Fetch one invoice, at no offset (same as offset 0). We // expect the first invoice only in the response. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ NumMaxInvoices: 1, }, expected: invoices[:1], }, // Same as above, at offset 1. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 1, NumMaxInvoices: 1, }, @@ -1268,7 +1303,7 @@ func TestQueryInvoices(t *testing.T) { }, // Same as above, at offset 2. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 2, NumMaxInvoices: 1, }, @@ -1277,7 +1312,7 @@ func TestQueryInvoices(t *testing.T) { // Same as above, at offset numInvoices-1. Expect the last // invoice to be returned. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: numInvoices - 1, NumMaxInvoices: 1, }, @@ -1286,7 +1321,7 @@ func TestQueryInvoices(t *testing.T) { // Same as above, at offset numInvoices. No invoices should be // returned, as there are no invoices after this offset. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: numInvoices, NumMaxInvoices: 1, }, @@ -1294,7 +1329,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch all pending invoices with a single query. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ PendingOnly: true, NumMaxInvoices: numInvoices, }, @@ -1302,7 +1337,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch the first 12 pending invoices. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ PendingOnly: true, NumMaxInvoices: numInvoices / 4, }, @@ -1311,7 +1346,7 @@ func TestQueryInvoices(t *testing.T) { // Fetch the first 5 pending invoices, but this time iterating // backwards. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 10, PendingOnly: true, Reversed: true, @@ -1325,7 +1360,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch the last 15 invoices. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 20, PendingOnly: true, NumMaxInvoices: numInvoices, @@ -1339,7 +1374,7 @@ func TestQueryInvoices(t *testing.T) { // that is beyond our last offset. We expect all invoices to be // returned. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: numInvoices * 2, PendingOnly: false, Reversed: true, @@ -1349,7 +1384,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch invoices <= 25 by creation date. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ NumMaxInvoices: numInvoices, CreationDateEnd: time.Unix(25, 0), }, @@ -1357,7 +1392,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch invoices >= 26 creation date. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ NumMaxInvoices: numInvoices, CreationDateStart: time.Unix(26, 0), }, @@ -1365,7 +1400,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch pending invoices <= 25 by creation date. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ PendingOnly: true, NumMaxInvoices: numInvoices, CreationDateEnd: time.Unix(25, 0), @@ -1374,7 +1409,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch pending invoices >= 26 creation date. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ PendingOnly: true, NumMaxInvoices: numInvoices, CreationDateStart: time.Unix(26, 0), @@ -1383,7 +1418,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch pending invoices with offset and end creation date. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 20, NumMaxInvoices: numInvoices, CreationDateEnd: time.Unix(30, 0), @@ -1395,7 +1430,7 @@ func TestQueryInvoices(t *testing.T) { // Fetch pending invoices with offset and start creation date // in reversed order. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ IndexOffset: 21, Reversed: true, NumMaxInvoices: numInvoices, @@ -1407,7 +1442,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch invoices with start and end creation date. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ NumMaxInvoices: numInvoices, CreationDateStart: time.Unix(11, 0), CreationDateEnd: time.Unix(20, 0), @@ -1416,7 +1451,7 @@ func TestQueryInvoices(t *testing.T) { }, // Fetch pending invoices with start and end creation date. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ PendingOnly: true, NumMaxInvoices: numInvoices, CreationDateStart: time.Unix(11, 0), @@ -1427,7 +1462,7 @@ func TestQueryInvoices(t *testing.T) { // Fetch invoices with start and end creation date in reverse // order. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ Reversed: true, NumMaxInvoices: numInvoices, CreationDateStart: time.Unix(11, 0), @@ -1438,7 +1473,7 @@ func TestQueryInvoices(t *testing.T) { // Fetch pending invoices with start and end creation date in // reverse order. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ PendingOnly: true, Reversed: true, NumMaxInvoices: numInvoices, @@ -1450,7 +1485,7 @@ func TestQueryInvoices(t *testing.T) { // Fetch invoices with a start date greater than end date // should result in an empty slice. { - query: InvoiceQuery{ + query: invpkg.InvoiceQuery{ NumMaxInvoices: numInvoices, CreationDateStart: time.Unix(20, 0), CreationDateEnd: time.Unix(11, 0), @@ -1478,25 +1513,27 @@ func TestQueryInvoices(t *testing.T) { // getUpdateInvoice returns an invoice update callback that, when called, // settles the invoice with the given amount. -func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { - return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - if invoice.State == ContractSettled { - return nil, ErrInvoiceAlreadySettled +func getUpdateInvoice(amt lnwire.MilliSatoshi) invpkg.InvoiceUpdateCallback { + return func(invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, + error) { + + if invoice.State == invpkg.ContractSettled { + return nil, invpkg.ErrInvoiceAlreadySettled } noRecords := make(record.CustomSet) - - update := &InvoiceUpdateDesc{ - State: &InvoiceStateUpdateDesc{ + htlcs := map[models.CircuitKey]*invpkg.HtlcAcceptDesc{ + {}: { + Amt: amt, + CustomRecords: noRecords, + }, + } + update := &invpkg.InvoiceUpdateDesc{ + State: &invpkg.InvoiceStateUpdateDesc{ Preimage: invoice.Terms.PaymentPreimage, - NewState: ContractSettled, - }, - AddHtlcs: map[models.CircuitKey]*HtlcAcceptDesc{ - {}: { - Amt: amt, - CustomRecords: noRecords, - }, + NewState: invpkg.ContractSettled, }, + AddHtlcs: htlcs, } return update, nil @@ -1514,9 +1551,9 @@ func TestCustomRecords(t *testing.T) { preimage := lntypes.Preimage{1} paymentHash := preimage.Hash() - testInvoice := &Invoice{ - Htlcs: map[models.CircuitKey]*InvoiceHTLC{}, - Terms: ContractTerm{ + testInvoice := &invpkg.Invoice{ + Htlcs: map[models.CircuitKey]*invpkg.InvoiceHTLC{}, + Terms: invpkg.ContractTerm{ Value: lnwire.NewMSatFromSatoshis(10000), Features: emptyFeatures, PaymentPreimage: &preimage, @@ -1538,19 +1575,21 @@ func TestCustomRecords(t *testing.T) { 100001: []byte{1, 2}, } - ref := InvoiceRefByHash(paymentHash) - _, err = db.UpdateInvoice(ref, nil, - func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - htlcs := map[models.CircuitKey]*HtlcAcceptDesc{ - key: { - Amt: 500, - CustomRecords: records, - }, - } + ref := invpkg.InvoiceRefByHash(paymentHash) + callback := func(invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, + error) { - return &InvoiceUpdateDesc{AddHtlcs: htlcs}, nil - }, - ) + htlcs := map[models.CircuitKey]*invpkg.HtlcAcceptDesc{ + key: { + Amt: 500, + CustomRecords: records, + }, + } + + return &invpkg.InvoiceUpdateDesc{AddHtlcs: htlcs}, nil + } + + _, err = db.UpdateInvoice(ref, nil, callback) require.NoError(t, err, "unable to add invoice htlc") // Retrieve the invoice from that database and verify that the custom @@ -1601,31 +1640,34 @@ func testInvoiceHtlcAMPFields(t *testing.T, isAMP bool) { } records := make(map[uint64][]byte) - var ampData *InvoiceHtlcAMPData + var ampData *invpkg.InvoiceHtlcAMPData if isAMP { amp := record.NewAMP([32]byte{1}, [32]byte{2}, 3) preimage := &lntypes.Preimage{4} - ampData = &InvoiceHtlcAMPData{ + ampData = &invpkg.InvoiceHtlcAMPData{ Record: *amp, Hash: preimage.Hash(), Preimage: preimage, } } - ref := InvoiceRefByHash(payHash) - _, err = db.UpdateInvoice(ref, nil, - func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - htlcs := map[models.CircuitKey]*HtlcAcceptDesc{ - key: { - Amt: 500, - AMP: ampData, - CustomRecords: records, - }, - } - return &InvoiceUpdateDesc{AddHtlcs: htlcs}, nil - }, - ) + callback := func(invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, + error) { + + htlcs := map[models.CircuitKey]*invpkg.HtlcAcceptDesc{ + key: { + Amt: 500, + AMP: ampData, + CustomRecords: records, + }, + } + + return &invpkg.InvoiceUpdateDesc{AddHtlcs: htlcs}, nil + } + + ref := invpkg.InvoiceRefByHash(payHash) + _, err = db.UpdateInvoice(ref, nil, callback) require.Nil(t, err) // Retrieve the invoice from that database and verify that the AMP @@ -1646,28 +1688,28 @@ func TestInvoiceRef(t *testing.T) { // An InvoiceRef by hash should return the provided hash and a nil // payment addr. - refByHash := InvoiceRefByHash(payHash) + refByHash := invpkg.InvoiceRefByHash(payHash) require.Equal(t, &payHash, refByHash.PayHash()) require.Equal(t, (*[32]byte)(nil), refByHash.PayAddr()) require.Equal(t, (*[32]byte)(nil), refByHash.SetID()) // An InvoiceRef by hash and addr should return the payment hash and // payment addr passed to the constructor. - refByHashAndAddr := InvoiceRefByHashAndAddr(payHash, payAddr) + refByHashAndAddr := invpkg.InvoiceRefByHashAndAddr(payHash, payAddr) require.Equal(t, &payHash, refByHashAndAddr.PayHash()) require.Equal(t, &payAddr, refByHashAndAddr.PayAddr()) require.Equal(t, (*[32]byte)(nil), refByHashAndAddr.SetID()) // An InvoiceRef by set id should return an empty pay hash, a nil pay // addr, and a reference to the given set id. - refBySetID := InvoiceRefBySetID(setID) + refBySetID := invpkg.InvoiceRefBySetID(setID) require.Equal(t, (*lntypes.Hash)(nil), refBySetID.PayHash()) require.Equal(t, (*[32]byte)(nil), refBySetID.PayAddr()) require.Equal(t, &setID, refBySetID.SetID()) // An InvoiceRef by pay addr should only return a pay addr, but nil for // pay hash and set id. - refByAddr := InvoiceRefByAddr(payAddr) + refByAddr := invpkg.InvoiceRefByAddr(payAddr) require.Equal(t, (*lntypes.Hash)(nil), refByAddr.PayHash()) require.Equal(t, &payAddr, refByAddr.PayAddr()) require.Equal(t, (*[32]byte)(nil), refByAddr.SetID()) @@ -1678,8 +1720,8 @@ func TestInvoiceRef(t *testing.T) { // not comingle, and also that HTLCs with disjoint set ids appear in different // sets. func TestHTLCSet(t *testing.T) { - inv := &Invoice{ - Htlcs: make(map[models.CircuitKey]*InvoiceHTLC), + inv := &invpkg.Invoice{ + Htlcs: make(map[models.CircuitKey]*invpkg.InvoiceHTLC), } // Construct two distinct set id's, in this test we'll also track the @@ -1689,14 +1731,24 @@ func TestHTLCSet(t *testing.T) { // Create the expected htlc sets for each group, these will be updated // as the invoice is modified. - expSetNil := make(map[models.CircuitKey]*InvoiceHTLC) - expSet1 := make(map[models.CircuitKey]*InvoiceHTLC) - expSet2 := make(map[models.CircuitKey]*InvoiceHTLC) + + expSetNil := make(map[models.CircuitKey]*invpkg.InvoiceHTLC) + expSet1 := make(map[models.CircuitKey]*invpkg.InvoiceHTLC) + expSet2 := make(map[models.CircuitKey]*invpkg.InvoiceHTLC) checkHTLCSets := func() { - require.Equal(t, expSetNil, inv.HTLCSet(nil, HtlcStateAccepted)) - require.Equal(t, expSet1, inv.HTLCSet(setID1, HtlcStateAccepted)) - require.Equal(t, expSet2, inv.HTLCSet(setID2, HtlcStateAccepted)) + require.Equal( + t, expSetNil, + inv.HTLCSet(nil, invpkg.HtlcStateAccepted), + ) + require.Equal( + t, expSet1, + inv.HTLCSet(setID1, invpkg.HtlcStateAccepted), + ) + require.Equal( + t, expSet2, + inv.HTLCSet(setID2, invpkg.HtlcStateAccepted), + ) } // All HTLC sets should be empty initially. @@ -1709,33 +1761,35 @@ func TestHTLCSet(t *testing.T) { // - only accepted htlcs are returned as part of the set. htlcs := []struct { setID *[32]byte - state HtlcState + state invpkg.HtlcState }{ - {nil, HtlcStateAccepted}, - {nil, HtlcStateAccepted}, - {setID1, HtlcStateAccepted}, - {setID1, HtlcStateAccepted}, - {setID2, HtlcStateAccepted}, - {setID2, HtlcStateAccepted}, - {nil, HtlcStateCanceled}, - {setID1, HtlcStateCanceled}, - {setID2, HtlcStateCanceled}, - {nil, HtlcStateSettled}, - {setID1, HtlcStateSettled}, - {setID2, HtlcStateSettled}, + {nil, invpkg.HtlcStateAccepted}, + {nil, invpkg.HtlcStateAccepted}, + {setID1, invpkg.HtlcStateAccepted}, + {setID1, invpkg.HtlcStateAccepted}, + {setID2, invpkg.HtlcStateAccepted}, + {setID2, invpkg.HtlcStateAccepted}, + {nil, invpkg.HtlcStateCanceled}, + {setID1, invpkg.HtlcStateCanceled}, + {setID2, invpkg.HtlcStateCanceled}, + {nil, invpkg.HtlcStateSettled}, + {setID1, invpkg.HtlcStateSettled}, + {setID2, invpkg.HtlcStateSettled}, } for i, h := range htlcs { - var ampData *InvoiceHtlcAMPData + var ampData *invpkg.InvoiceHtlcAMPData if h.setID != nil { - ampData = &InvoiceHtlcAMPData{ - Record: *record.NewAMP([32]byte{0}, *h.setID, 0), + ampData = &invpkg.InvoiceHtlcAMPData{ + Record: *record.NewAMP( + [32]byte{0}, *h.setID, 0, + ), } } // Add the HTLC to the invoice's set of HTLCs. key := models.CircuitKey{HtlcID: uint64(i)} - htlc := &InvoiceHTLC{ + htlc := &invpkg.InvoiceHTLC{ AMP: ampData, State: h.state, } @@ -1743,7 +1797,7 @@ func TestHTLCSet(t *testing.T) { // Update our expected htlc set if the htlc is accepted, // otherwise it shouldn't be reflected. - if h.state == HtlcStateAccepted { + if h.state == invpkg.HtlcStateAccepted { switch h.setID { case nil: expSetNil[key] = htlc @@ -1770,11 +1824,11 @@ func TestAddInvoiceWithHTLCs(t *testing.T) { require.Nil(t, err) key := models.CircuitKey{HtlcID: 1} - testInvoice.Htlcs[key] = &InvoiceHTLC{} + testInvoice.Htlcs[key] = &invpkg.InvoiceHTLC{} payHash := testInvoice.Terms.PaymentPreimage.Hash() _, err = db.AddInvoice(testInvoice, payHash) - require.Equal(t, ErrInvoiceHasHtlcs, err) + require.Equal(t, invpkg.ErrInvoiceHasHtlcs, err) } // TestSetIDIndex asserts that the set id index properly adds new invoices as we @@ -1803,26 +1857,30 @@ func TestSetIDIndex(t *testing.T) { // Update the invoice with an accepted HTLC that also accepts the // invoice. - ref := InvoiceRefByHashAndAddr(payHash, invoice.Terms.PaymentAddr) + ref := invpkg.InvoiceRefByHashAndAddr( + payHash, invoice.Terms.PaymentAddr, + ) dbInvoice, err := db.UpdateInvoice( - ref, (*SetID)(setID), updateAcceptAMPHtlc(0, amt, setID, true), + ref, (*invpkg.SetID)(setID), + updateAcceptAMPHtlc(0, amt, setID, true), ) require.Nil(t, err) // We'll update what we expect the accepted invoice to be so that our // comparison below has the correct assumption. - invoice.State = ContractOpen + invoice.State = invpkg.ContractOpen invoice.AmtPaid = amt invoice.SettleDate = dbInvoice.SettleDate - invoice.Htlcs = map[models.CircuitKey]*InvoiceHTLC{ - {HtlcID: 0}: makeAMPInvoiceHTLC(amt, *setID, payHash, &preimage), + htlc0 := models.CircuitKey{HtlcID: 0} + invoice.Htlcs = map[models.CircuitKey]*invpkg.InvoiceHTLC{ + htlc0: makeAMPInvoiceHTLC(amt, *setID, payHash, &preimage), } - invoice.AMPState = map[SetID]InvoiceStateAMP{} - invoice.AMPState[*setID] = InvoiceStateAMP{ - State: HtlcStateAccepted, + invoice.AMPState = map[invpkg.SetID]invpkg.InvoiceStateAMP{} + invoice.AMPState[*setID] = invpkg.InvoiceStateAMP{ + State: invpkg.HtlcStateAccepted, AmtPaid: amt, InvoiceKeys: map[models.CircuitKey]struct{}{ - {HtlcID: 0}: {}, + htlc0: {}, }, } @@ -1830,7 +1888,7 @@ func TestSetIDIndex(t *testing.T) { require.Equal(t, invoice, dbInvoice) // Now lookup the invoice by set id and see that we get the same one. - refBySetID := InvoiceRefBySetID(*setID) + refBySetID := invpkg.InvoiceRefBySetID(*setID) dbInvoiceBySetID, err := db.LookupInvoice(refBySetID) require.Nil(t, err) require.Equal(t, invoice, &dbInvoiceBySetID) @@ -1848,48 +1906,55 @@ func TestSetIDIndex(t *testing.T) { _, err = db.AddInvoice(invoice2, payHash2) require.Nil(t, err) - ref2 := InvoiceRefByHashAndAddr(payHash2, invoice2.Terms.PaymentAddr) - _, err = db.UpdateInvoice( - ref2, (*SetID)(setID), updateAcceptAMPHtlc(0, amt, setID, true), + ref2 := invpkg.InvoiceRefByHashAndAddr( + payHash2, invoice2.Terms.PaymentAddr, ) - require.Equal(t, ErrDuplicateSetID{setID: *setID}, err) + _, err = db.UpdateInvoice( + ref2, (*invpkg.SetID)(setID), + updateAcceptAMPHtlc(0, amt, setID, true), + ) + require.Equal(t, invpkg.ErrDuplicateSetID{SetID: *setID}, err) // Now, begin constructing a second htlc set under a different set id. // This set will contain two distinct HTLCs. setID2 := &[32]byte{2} _, err = db.UpdateInvoice( - ref, (*SetID)(setID2), updateAcceptAMPHtlc(1, amt, setID2, false), + ref, (*invpkg.SetID)(setID2), + updateAcceptAMPHtlc(1, amt, setID2, false), ) require.Nil(t, err) dbInvoice, err = db.UpdateInvoice( - ref, (*SetID)(setID2), updateAcceptAMPHtlc(2, amt, setID2, false), + ref, (*invpkg.SetID)(setID2), + updateAcceptAMPHtlc(2, amt, setID2, false), ) require.Nil(t, err) // We'll update what we expect the settle invoice to be so that our // comparison below has the correct assumption. - invoice.State = ContractOpen + invoice.State = invpkg.ContractOpen invoice.AmtPaid += 2 * amt invoice.SettleDate = dbInvoice.SettleDate - invoice.Htlcs = map[models.CircuitKey]*InvoiceHTLC{ - {HtlcID: 0}: makeAMPInvoiceHTLC(amt, *setID, payHash, &preimage), - {HtlcID: 1}: makeAMPInvoiceHTLC(amt, *setID2, payHash, nil), - {HtlcID: 2}: makeAMPInvoiceHTLC(amt, *setID2, payHash, nil), + htlc1 := models.CircuitKey{HtlcID: 1} + htlc2 := models.CircuitKey{HtlcID: 2} + invoice.Htlcs = map[models.CircuitKey]*invpkg.InvoiceHTLC{ + htlc0: makeAMPInvoiceHTLC(amt, *setID, payHash, &preimage), + htlc1: makeAMPInvoiceHTLC(amt, *setID2, payHash, nil), + htlc2: makeAMPInvoiceHTLC(amt, *setID2, payHash, nil), } - invoice.AMPState[*setID] = InvoiceStateAMP{ - State: HtlcStateAccepted, + invoice.AMPState[*setID] = invpkg.InvoiceStateAMP{ + State: invpkg.HtlcStateAccepted, AmtPaid: amt, InvoiceKeys: map[models.CircuitKey]struct{}{ - {HtlcID: 0}: {}, + htlc0: {}, }, } - invoice.AMPState[*setID2] = InvoiceStateAMP{ - State: HtlcStateAccepted, + invoice.AMPState[*setID2] = invpkg.InvoiceStateAMP{ + State: invpkg.HtlcStateAccepted, AmtPaid: amt * 2, InvoiceKeys: map[models.CircuitKey]struct{}{ - {HtlcID: 1}: {}, - {HtlcID: 2}: {}, + htlc1: {}, + htlc2: {}, }, } @@ -1904,7 +1969,7 @@ func TestSetIDIndex(t *testing.T) { // Now lookup the invoice by second set id and see that we get the same // index, including the htlcs under the first set id. - refBySetID = InvoiceRefBySetID(*setID2) + refBySetID = invpkg.InvoiceRefBySetID(*setID2) dbInvoiceBySetID, err = db.LookupInvoice(refBySetID) require.Nil(t, err) require.Equal(t, invoice, &dbInvoiceBySetID) @@ -1918,13 +1983,13 @@ func TestSetIDIndex(t *testing.T) { models.CircuitKey{HtlcID: 99}, ), ) - require.Equal(t, ErrEmptyHTLCSet, err) + require.Equal(t, invpkg.ErrEmptyHTLCSet, err) // Now settle the first htlc set. The existing HTLCs should remain in // the accepted state and shouldn't be canceled, since we permit an // invoice to be settled multiple times. _, err = db.UpdateInvoice( - ref, (*SetID)(setID), + ref, (*invpkg.SetID)(setID), getUpdateInvoiceAMPSettle( setID, preimage, models.CircuitKey{HtlcID: 0}, ), @@ -1935,21 +2000,20 @@ func TestSetIDIndex(t *testing.T) { require.Nil(t, err) dbInvoice = &freshInvoice - invoice.State = ContractOpen + invoice.State = invpkg.ContractOpen // The amount paid should reflect that we have 3 present HTLCs, each // with an amount of the original invoice. invoice.AmtPaid = amt * 3 ampState := invoice.AMPState[*setID] - ampState.State = HtlcStateSettled + ampState.State = invpkg.HtlcStateSettled ampState.SettleDate = testNow ampState.SettleIndex = 1 invoice.AMPState[*setID] = ampState - htlc0 := models.CircuitKey{HtlcID: 0} - invoice.Htlcs[htlc0].State = HtlcStateSettled + invoice.Htlcs[htlc0].State = invpkg.HtlcStateSettled invoice.Htlcs[htlc0].ResolveTime = time.Unix(1, 0) require.Equal(t, invoice, dbInvoice) @@ -1957,19 +2021,19 @@ func TestSetIDIndex(t *testing.T) { // If we try to settle the same set ID again, then we should get an // error, as it's already been settled. _, err = db.UpdateInvoice( - ref, (*SetID)(setID), + ref, (*invpkg.SetID)(setID), getUpdateInvoiceAMPSettle( setID, preimage, models.CircuitKey{HtlcID: 0}, ), ) - require.Equal(t, ErrEmptyHTLCSet, err) + require.Equal(t, invpkg.ErrEmptyHTLCSet, err) // Next, let's attempt to settle the other active set ID for this // invoice. This will allow us to exercise the case where we go to // settle an invoice with a new setID after one has already been fully // settled. _, err = db.UpdateInvoice( - ref, (*SetID)(setID2), + ref, (*invpkg.SetID)(setID2), getUpdateInvoiceAMPSettle( setID2, preimage, models.CircuitKey{HtlcID: 1}, models.CircuitKey{HtlcID: 2}, @@ -1983,40 +2047,38 @@ func TestSetIDIndex(t *testing.T) { // Now the rest of the HTLCs should show as fully settled. ampState = invoice.AMPState[*setID2] - ampState.State = HtlcStateSettled + ampState.State = invpkg.HtlcStateSettled ampState.SettleDate = testNow ampState.SettleIndex = 2 invoice.AMPState[*setID2] = ampState - htlc1 := models.CircuitKey{HtlcID: 1} - htlc2 := models.CircuitKey{HtlcID: 2} - invoice.Htlcs[htlc1].State = HtlcStateSettled + invoice.Htlcs[htlc1].State = invpkg.HtlcStateSettled invoice.Htlcs[htlc1].ResolveTime = time.Unix(1, 0) invoice.Htlcs[htlc1].AMP.Preimage = &preimage - invoice.Htlcs[htlc2].State = HtlcStateSettled + invoice.Htlcs[htlc2].State = invpkg.HtlcStateSettled invoice.Htlcs[htlc2].ResolveTime = time.Unix(1, 0) invoice.Htlcs[htlc2].AMP.Preimage = &preimage require.Equal(t, invoice, dbInvoice) // Lastly, querying for an unknown set id should fail. - refUnknownSetID := InvoiceRefBySetID([32]byte{}) + refUnknownSetID := invpkg.InvoiceRefBySetID([32]byte{}) _, err = db.LookupInvoice(refUnknownSetID) - require.Equal(t, ErrInvoiceNotFound, err) + require.Equal(t, invpkg.ErrInvoiceNotFound, err) } func makeAMPInvoiceHTLC(amt lnwire.MilliSatoshi, setID [32]byte, - hash lntypes.Hash, preimage *lntypes.Preimage) *InvoiceHTLC { + hash lntypes.Hash, preimage *lntypes.Preimage) *invpkg.InvoiceHTLC { - return &InvoiceHTLC{ + return &invpkg.InvoiceHTLC{ Amt: amt, AcceptTime: testNow, ResolveTime: time.Time{}, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *record.NewAMP([32]byte{}, setID, 0), Hash: hash, Preimage: preimage, @@ -2027,54 +2089,61 @@ func makeAMPInvoiceHTLC(amt lnwire.MilliSatoshi, setID [32]byte, // updateAcceptAMPHtlc returns an invoice update callback that, when called, // settles the invoice with the given amount. func updateAcceptAMPHtlc(id uint64, amt lnwire.MilliSatoshi, - setID *[32]byte, accept bool) InvoiceUpdateCallback { + setID *[32]byte, accept bool) invpkg.InvoiceUpdateCallback { - return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - if invoice.State == ContractSettled { - return nil, ErrInvoiceAlreadySettled + return func(invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, + error) { + + if invoice.State == invpkg.ContractSettled { + return nil, invpkg.ErrInvoiceAlreadySettled } noRecords := make(record.CustomSet) var ( - state *InvoiceStateUpdateDesc + state *invpkg.InvoiceStateUpdateDesc preimage *lntypes.Preimage ) if accept { - state = &InvoiceStateUpdateDesc{ - NewState: ContractAccepted, + state = &invpkg.InvoiceStateUpdateDesc{ + NewState: invpkg.ContractAccepted, SetID: setID, } pre := *invoice.Terms.PaymentPreimage preimage = &pre } - ampData := &InvoiceHtlcAMPData{ + ampData := &invpkg.InvoiceHtlcAMPData{ Record: *record.NewAMP([32]byte{}, *setID, 0), Hash: invoice.Terms.PaymentPreimage.Hash(), Preimage: preimage, } - update := &InvoiceUpdateDesc{ - State: state, - AddHtlcs: map[models.CircuitKey]*HtlcAcceptDesc{ - {HtlcID: id}: { - Amt: amt, - CustomRecords: noRecords, - AMP: ampData, - }, + + htlcs := map[models.CircuitKey]*invpkg.HtlcAcceptDesc{ + {HtlcID: id}: { + Amt: amt, + CustomRecords: noRecords, + AMP: ampData, }, } + update := &invpkg.InvoiceUpdateDesc{ + State: state, + AddHtlcs: htlcs, + } + return update, nil } } func getUpdateInvoiceAMPSettle(setID *[32]byte, preimage [32]byte, - circuitKeys ...models.CircuitKey) InvoiceUpdateCallback { + circuitKeys ...models.CircuitKey) invpkg.InvoiceUpdateCallback { - return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - if invoice.State == ContractSettled { - return nil, ErrInvoiceAlreadySettled + return func(invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, + error) { + + if invoice.State == invpkg.ContractSettled { + return nil, invpkg.ErrInvoiceAlreadySettled } preImageSet := make(map[models.CircuitKey]lntypes.Preimage) @@ -2082,10 +2151,10 @@ func getUpdateInvoiceAMPSettle(setID *[32]byte, preimage [32]byte, preImageSet[key] = preimage } - update := &InvoiceUpdateDesc{ - State: &InvoiceStateUpdateDesc{ + update := &invpkg.InvoiceUpdateDesc{ + State: &invpkg.InvoiceStateUpdateDesc{ Preimage: nil, - NewState: ContractSettled, + NewState: invpkg.ContractSettled, SetID: setID, HTLCPreimages: preImageSet, }, @@ -2118,12 +2187,12 @@ func TestUnexpectedInvoicePreimage(t *testing.T) { // in order to settle an MPP invoice, the InvoiceRef must present a // payment hash against which to validate the preimage. _, err = db.UpdateInvoice( - InvoiceRefByAddr(invoice.Terms.PaymentAddr), nil, + invpkg.InvoiceRefByAddr(invoice.Terms.PaymentAddr), nil, getUpdateInvoice(invoice.Terms.Value), ) // Assert that we get ErrUnexpectedInvoicePreimage. - require.Error(t, ErrUnexpectedInvoicePreimage, err) + require.Error(t, invpkg.ErrUnexpectedInvoicePreimage, err) } type updateHTLCPreimageTestCase struct { @@ -2146,7 +2215,7 @@ func TestUpdateHTLCPreimages(t *testing.T) { { name: "diff preimage on settle", settleSamePreimage: false, - expError: ErrHTLCPreimageAlreadyExists, + expError: invpkg.ErrHTLCPreimageAlreadyExists, }, } @@ -2181,9 +2250,10 @@ func testUpdateHTLCPreimages(t *testing.T, test updateHTLCPreimageTestCase) { // Update the invoice with an accepted HTLC that also accepts the // invoice. - ref := InvoiceRefByAddr(invoice.Terms.PaymentAddr) + ref := invpkg.InvoiceRefByAddr(invoice.Terms.PaymentAddr) dbInvoice, err := db.UpdateInvoice( - ref, (*SetID)(setID), updateAcceptAMPHtlc(0, amt, setID, true), + ref, (*invpkg.SetID)(setID), + updateAcceptAMPHtlc(0, amt, setID, true), ) require.Nil(t, err) @@ -2198,11 +2268,13 @@ func testUpdateHTLCPreimages(t *testing.T, test updateHTLCPreimageTestCase) { htlcPreimages[key] = pre } - updateInvoice := func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - update := &InvoiceUpdateDesc{ - State: &InvoiceStateUpdateDesc{ + updateInvoice := func( + invoice *invpkg.Invoice) (*invpkg.InvoiceUpdateDesc, error) { + + update := &invpkg.InvoiceUpdateDesc{ + State: &invpkg.InvoiceStateUpdateDesc{ Preimage: nil, - NewState: ContractSettled, + NewState: invpkg.ContractSettled, HTLCPreimages: htlcPreimages, SetID: setID, }, @@ -2212,16 +2284,16 @@ func testUpdateHTLCPreimages(t *testing.T, test updateHTLCPreimageTestCase) { } // Now settle the HTLC set and assert the resulting error. - _, err = db.UpdateInvoice(ref, (*SetID)(setID), updateInvoice) + _, err = db.UpdateInvoice(ref, (*invpkg.SetID)(setID), updateInvoice) require.Equal(t, test.expError, err) } type updateHTLCTest struct { name string - input InvoiceHTLC - invState ContractState + input invpkg.InvoiceHTLC + invState invpkg.ContractState setID *[32]byte - output InvoiceHTLC + output invpkg.InvoiceHTLC expErr error } @@ -2242,27 +2314,27 @@ func TestUpdateHTLC(t *testing.T) { tests := []updateHTLCTest{ { name: "MPP accept", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), AMP: nil, }, - invState: ContractAccepted, + invState: invpkg.ContractAccepted, setID: nil, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), AMP: nil, }, @@ -2270,27 +2342,27 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "MPP settle", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), AMP: nil, }, - invState: ContractSettled, + invState: invpkg.ContractSettled, setID: nil, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testNow, Expiry: 40, - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), AMP: nil, }, @@ -2298,27 +2370,27 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "MPP cancel", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), AMP: nil, }, - invState: ContractCanceled, + invState: invpkg.ContractCanceled, setID: nil, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testNow, Expiry: 40, - State: HtlcStateCanceled, + State: invpkg.HtlcStateCanceled, CustomRecords: make(record.CustomSet), AMP: nil, }, @@ -2326,105 +2398,105 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "AMP accept missing preimage", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: nil, }, }, - invState: ContractAccepted, + invState: invpkg.ContractAccepted, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: nil, }, }, - expErr: ErrHTLCPreimageMissing, + expErr: invpkg.ErrHTLCPreimageMissing, }, { name: "AMP accept invalid preimage", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &fakePreimage, }, }, - invState: ContractAccepted, + invState: invpkg.ContractAccepted, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &fakePreimage, }, }, - expErr: ErrHTLCPreimageMismatch, + expErr: invpkg.ErrHTLCPreimageMismatch, }, { name: "AMP accept valid preimage", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractAccepted, + invState: invpkg.ContractAccepted, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, @@ -2434,33 +2506,33 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "AMP accept valid preimage different htlc set", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractAccepted, + invState: invpkg.ContractAccepted, setID: &diffSetID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, @@ -2470,105 +2542,105 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "AMP settle missing preimage", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: nil, }, }, - invState: ContractSettled, + invState: invpkg.ContractSettled, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: nil, }, }, - expErr: ErrHTLCPreimageMissing, + expErr: invpkg.ErrHTLCPreimageMissing, }, { name: "AMP settle invalid preimage", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &fakePreimage, }, }, - invState: ContractSettled, + invState: invpkg.ContractSettled, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &fakePreimage, }, }, - expErr: ErrHTLCPreimageMismatch, + expErr: invpkg.ErrHTLCPreimageMismatch, }, { name: "AMP settle valid preimage", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractSettled, + invState: invpkg.ContractSettled, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testNow, Expiry: 40, - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, @@ -2582,33 +2654,33 @@ func TestUpdateHTLC(t *testing.T) { // to a given pay_addr. In this case, the HTLC should // remain in the accepted state. name: "AMP settle valid preimage different htlc set", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractSettled, + invState: invpkg.ContractSettled, setID: &diffSetID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, @@ -2618,105 +2690,105 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "accept invoice htlc already settled", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractAccepted, + invState: invpkg.ContractAccepted, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - expErr: ErrHTLCAlreadySettled, + expErr: invpkg.ErrHTLCAlreadySettled, }, { name: "cancel invoice htlc already settled", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractCanceled, + invState: invpkg.ContractCanceled, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - expErr: ErrHTLCAlreadySettled, + expErr: invpkg.ErrHTLCAlreadySettled, }, { name: "settle invoice htlc already settled", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractSettled, + invState: invpkg.ContractSettled, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateSettled, + State: invpkg.HtlcStateSettled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, @@ -2726,33 +2798,33 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "cancel invoice", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: time.Time{}, Expiry: 40, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractCanceled, + invState: invpkg.ContractCanceled, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testNow, Expiry: 40, - State: HtlcStateCanceled, + State: invpkg.HtlcStateCanceled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, @@ -2762,33 +2834,33 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "accept invoice htlc already canceled", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateCanceled, + State: invpkg.HtlcStateCanceled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractAccepted, + invState: invpkg.ContractAccepted, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateCanceled, + State: invpkg.HtlcStateCanceled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, @@ -2798,33 +2870,33 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "cancel invoice htlc already canceled", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateCanceled, + State: invpkg.HtlcStateCanceled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractCanceled, + invState: invpkg.ContractCanceled, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateCanceled, + State: invpkg.HtlcStateCanceled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, @@ -2834,33 +2906,33 @@ func TestUpdateHTLC(t *testing.T) { }, { name: "settle invoice htlc already canceled", - input: InvoiceHTLC{ + input: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateCanceled, + State: invpkg.HtlcStateCanceled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, }, }, - invState: ContractSettled, + invState: invpkg.ContractSettled, setID: &setID, - output: InvoiceHTLC{ + output: invpkg.InvoiceHTLC{ Amt: 5000, MppTotalAmt: 5000, AcceptHeight: 100, AcceptTime: testNow, ResolveTime: testAlreadyNow, Expiry: 40, - State: HtlcStateCanceled, + State: invpkg.HtlcStateCanceled, CustomRecords: make(record.CustomSet), - AMP: &InvoiceHtlcAMPData{ + AMP: &invpkg.InvoiceHtlcAMPData{ Record: *ampRecord, Hash: hash, Preimage: &preimage, @@ -2895,7 +2967,7 @@ func TestDeleteInvoices(t *testing.T) { // Add some invoices to the test db. numInvoices := 3 - invoicesToDelete := make([]InvoiceDeleteRef, numInvoices) + invoicesToDelete := make([]invpkg.InvoiceDeleteRef, numInvoices) for i := 0; i < numInvoices; i++ { invoice, err := randInvoice(lnwire.MilliSatoshi(i + 1)) @@ -2908,14 +2980,14 @@ func TestDeleteInvoices(t *testing.T) { // Settle the second invoice. if i == 1 { invoice, err = db.UpdateInvoice( - InvoiceRefByHash(paymentHash), nil, + invpkg.InvoiceRefByHash(paymentHash), nil, getUpdateInvoice(invoice.Terms.Value), ) require.NoError(t, err, "unable to settle invoice") } // store the delete ref for later. - invoicesToDelete[i] = InvoiceDeleteRef{ + invoicesToDelete[i] = invpkg.InvoiceDeleteRef{ PayHash: paymentHash, PayAddr: &invoice.Terms.PaymentAddr, AddIndex: addIndex, @@ -2927,7 +2999,7 @@ func TestDeleteInvoices(t *testing.T) { // to the passed count. assertInvoiceCount := func(count int) { // Query to collect all invoices. - query := InvoiceQuery{ + query := invpkg.InvoiceQuery{ IndexOffset: 0, NumMaxInvoices: math.MaxUint64, } @@ -3015,9 +3087,9 @@ func TestEncodeDecodeAmpInvoiceState(t *testing.T) { // Make a sample invoice state map that we'll encode then decode to // assert equality of. - ampState := AMPInvoiceState{ - setID1: InvoiceStateAMP{ - State: HtlcStateSettled, + ampState := invpkg.AMPInvoiceState{ + setID1: invpkg.InvoiceStateAMP{ + State: invpkg.HtlcStateSettled, SettleDate: testNow, SettleIndex: 1, InvoiceKeys: map[models.CircuitKey]struct{}{ @@ -3026,8 +3098,8 @@ func TestEncodeDecodeAmpInvoiceState(t *testing.T) { }, AmtPaid: 5, }, - setID2: InvoiceStateAMP{ - State: HtlcStateCanceled, + setID2: invpkg.InvoiceStateAMP{ + State: invpkg.HtlcStateCanceled, SettleDate: testNow, SettleIndex: 2, InvoiceKeys: map[models.CircuitKey]struct{}{ @@ -3035,8 +3107,8 @@ func TestEncodeDecodeAmpInvoiceState(t *testing.T) { }, AmtPaid: 6, }, - setID3: InvoiceStateAMP{ - State: HtlcStateAccepted, + setID3: invpkg.InvoiceStateAMP{ + State: invpkg.HtlcStateAccepted, SettleDate: testNow, SettleIndex: 3, InvoiceKeys: map[models.CircuitKey]struct{}{ @@ -3052,8 +3124,9 @@ func TestEncodeDecodeAmpInvoiceState(t *testing.T) { // amp state we created above. tlvStream, err := tlv.NewStream( tlv.MakeDynamicRecord( - invoiceAmpStateType, &State, ampState.recordSize, - ampStateEncoder, ampStateDecoder, + invoiceAmpStateType, &State, + ampRecordSize(&State), ampStateEncoder, + ampStateDecoder, ), ) require.Nil(t, err) @@ -3065,7 +3138,7 @@ func TestEncodeDecodeAmpInvoiceState(t *testing.T) { // Now create a new blank ampState map, which we'll use to decode the // bytes into. - ampState2 := make(AMPInvoiceState) + ampState2 := make(invpkg.AMPInvoiceState) // Decode from the raw stream into this blank mpa. tlvStream, err = tlv.NewStream( diff --git a/channeldb/invoices.go b/channeldb/invoices.go index b7174914b..c05776eb5 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -6,12 +6,11 @@ import ( "errors" "fmt" "io" - "strings" "time" "github.com/lightningnetwork/lnd/channeldb/models" - "github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/htlcswitch/hop" + invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -20,16 +19,6 @@ import ( ) var ( - // unknownPreimage is an all-zeroes preimage that indicates that the - // preimage for this invoice is not yet known. - unknownPreimage lntypes.Preimage - - // BlankPayAddr is a sentinel payment address for legacy invoices. - // Invoices with this payment address are special-cased in the insertion - // logic to prevent being indexed in the payment address index, - // otherwise they would cause collisions after the first insertion. - BlankPayAddr [32]byte - // invoiceBucket is the name of the bucket within the database that // stores all data related to invoices no matter their final state. // Within the invoice bucket, each invoice is keyed by its invoice ID @@ -90,91 +79,9 @@ var ( // // settleIndexNo => invoiceKey settleIndexBucket = []byte("invoice-settle-index") - - // ErrInvoiceAlreadySettled is returned when the invoice is already - // settled. - ErrInvoiceAlreadySettled = errors.New("invoice already settled") - - // ErrInvoiceAlreadyCanceled is returned when the invoice is already - // canceled. - ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled") - - // ErrInvoiceAlreadyAccepted is returned when the invoice is already - // accepted. - ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted") - - // ErrInvoiceStillOpen is returned when the invoice is still open. - ErrInvoiceStillOpen = errors.New("invoice still open") - - // ErrInvoiceCannotOpen is returned when an attempt is made to move an - // invoice to the open state. - ErrInvoiceCannotOpen = errors.New("cannot move invoice to open") - - // ErrInvoiceCannotAccept is returned when an attempt is made to accept - // an invoice while the invoice is not in the open state. - ErrInvoiceCannotAccept = errors.New("cannot accept invoice") - - // ErrInvoicePreimageMismatch is returned when the preimage doesn't - // match the invoice hash. - ErrInvoicePreimageMismatch = errors.New("preimage does not match") - - // ErrHTLCPreimageMissing is returned when trying to accept/settle an - // AMP HTLC but the HTLC-level preimage has not been set. - ErrHTLCPreimageMissing = errors.New("AMP htlc missing preimage") - - // ErrHTLCPreimageMismatch is returned when trying to accept/settle an - // AMP HTLC but the HTLC-level preimage does not satisfying the - // HTLC-level payment hash. - ErrHTLCPreimageMismatch = errors.New("htlc preimage mismatch") - - // ErrHTLCAlreadySettled is returned when trying to settle an invoice - // but HTLC already exists in the settled state. - ErrHTLCAlreadySettled = errors.New("htlc already settled") - - // ErrInvoiceHasHtlcs is returned when attempting to insert an invoice - // that already has HTLCs. - ErrInvoiceHasHtlcs = errors.New("cannot add invoice with htlcs") - - // ErrEmptyHTLCSet is returned when attempting to accept or settle and - // HTLC set that has no HTLCs. - ErrEmptyHTLCSet = errors.New("cannot settle/accept empty HTLC set") - - // ErrUnexpectedInvoicePreimage is returned when an invoice-level - // preimage is provided when trying to settle an invoice that shouldn't - // have one, e.g. an AMP invoice. - ErrUnexpectedInvoicePreimage = errors.New( - "unexpected invoice preimage provided on settle", - ) - - // ErrHTLCPreimageAlreadyExists is returned when trying to set an - // htlc-level preimage but one is already known. - ErrHTLCPreimageAlreadyExists = errors.New( - "htlc-level preimage already exists", - ) ) -// ErrDuplicateSetID is an error returned when attempting to adding an AMP HTLC -// to an invoice, but another invoice is already indexed by the same set id. -type ErrDuplicateSetID struct { - setID [32]byte -} - -// Error returns a human-readable description of ErrDuplicateSetID. -func (e ErrDuplicateSetID) Error() string { - return fmt.Sprintf("invoice with set_id=%x already exists", e.setID) -} - const ( - // MaxMemoSize is maximum size of the memo field within invoices stored - // in the database. - MaxMemoSize = 1024 - - // MaxPaymentRequestSize is the max size of a payment request for - // this invoice. - // TODO(halseth): determine the max length payment request when field - // lengths are final. - MaxPaymentRequestSize = 4096 - // A set of tlv type definitions used to serialize invoice htlcs to the // database. // @@ -228,715 +135,15 @@ const ( ampStateAmtPaidType tlv.Type = 5 ) -// RefModifier is a modification on top of a base invoice ref. It allows the -// caller to opt to skip out on HTLCs for a given payAddr, or only return the -// set of specified HTLCs for a given setID. -type RefModifier uint8 - -const ( - // DefaultModifier is the base modifier that doesn't change any behavior. - DefaultModifier RefModifier = iota - - // HtlcSetOnlyModifier can only be used with a setID based invoice ref, and - // specifies that only the set of HTLCs related to that setID are to be - // returned. - HtlcSetOnlyModifier - - // HtlcSetOnlyModifier can only be used with a payAddr based invoice ref, - // and specifies that the returned invoice shouldn't include any HTLCs at - // all. - HtlcSetBlankModifier -) - -// InvoiceRef is a composite identifier for invoices. Invoices can be referenced -// by various combinations of payment hash and payment addr, in certain contexts -// only some of these are known. An InvoiceRef and its constructors thus -// encapsulate the valid combinations of query parameters that can be supplied -// to LookupInvoice and UpdateInvoice. -type InvoiceRef struct { - // payHash is the payment hash of the target invoice. All invoices are - // currently indexed by payment hash. This value will be used as a - // fallback when no payment address is known. - payHash *lntypes.Hash - - // payAddr is the payment addr of the target invoice. Newer invoices - // (0.11 and up) are indexed by payment address in addition to payment - // hash, but pre 0.8 invoices do not have one at all. When this value is - // known it will be used as the primary identifier, falling back to - // payHash if no value is known. - payAddr *[32]byte - - // setID is the optional set id for an AMP payment. This can be used to - // lookup or update the invoice knowing only this value. Queries by set - // id are only used to facilitate user-facing requests, e.g. lookup, - // settle or cancel an AMP invoice. The regular update flow from the - // invoice registry will always query for the invoice by - // payHash+payAddr. - setID *[32]byte - - // refModifier allows an invoice ref to include or exclude specific - // HTLC sets based on the payAddr or setId. - refModifier RefModifier -} - -// InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by -// its payment hash. -func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef { - return InvoiceRef{ - payHash: &payHash, - } -} - -// InvoiceRefByHashAndAddr creates an InvoiceRef that first queries for an -// invoice by the provided payment address, falling back to the payment hash if -// the payment address is unknown. -func InvoiceRefByHashAndAddr(payHash lntypes.Hash, - payAddr [32]byte) InvoiceRef { - - return InvoiceRef{ - payHash: &payHash, - payAddr: &payAddr, - } -} - -// InvoiceRefByAddr creates an InvoiceRef that queries the payment addr index -// for an invoice with the provided payment address. -func InvoiceRefByAddr(addr [32]byte) InvoiceRef { - return InvoiceRef{ - payAddr: &addr, - } -} - -// InvoiceRefByAddrBlankHtlc creates an InvoiceRef that queries the payment addr index -// for an invoice with the provided payment address, but excludes any of the -// core HTLC information. -func InvoiceRefByAddrBlankHtlc(addr [32]byte) InvoiceRef { - return InvoiceRef{ - payAddr: &addr, - refModifier: HtlcSetBlankModifier, - } -} - -// InvoiceRefBySetID creates an InvoiceRef that queries the set id index for an -// invoice with the provided setID. If the invoice is not found, the query will -// not fallback to payHash or payAddr. -func InvoiceRefBySetID(setID [32]byte) InvoiceRef { - return InvoiceRef{ - setID: &setID, - } -} - -// InvoiceRefBySetIDFiltered is similar to the InvoiceRefBySetID identifier, -// but it specifies that the returned set of HTLCs should be filtered to only -// include HTLCs that are part of that set. -func InvoiceRefBySetIDFiltered(setID [32]byte) InvoiceRef { - return InvoiceRef{ - setID: &setID, - refModifier: HtlcSetOnlyModifier, - } -} - -// PayHash returns the optional payment hash of the target invoice. -// -// NOTE: This value may be nil. -func (r InvoiceRef) PayHash() *lntypes.Hash { - if r.payHash != nil { - hash := *r.payHash - return &hash - } - return nil -} - -// PayAddr returns the optional payment address of the target invoice. -// -// NOTE: This value may be nil. -func (r InvoiceRef) PayAddr() *[32]byte { - if r.payAddr != nil { - addr := *r.payAddr - return &addr - } - return nil -} - -// SetID returns the optional set id of the target invoice. -// -// NOTE: This value may be nil. -func (r InvoiceRef) SetID() *[32]byte { - if r.setID != nil { - id := *r.setID - return &id - } - return nil -} - -// Modifier defines the set of available modifications to the base invoice ref -// look up that are available. -func (r InvoiceRef) Modifier() RefModifier { - return r.refModifier -} - -// String returns a human-readable representation of an InvoiceRef. -func (r InvoiceRef) String() string { - var ids []string - if r.payHash != nil { - ids = append(ids, fmt.Sprintf("pay_hash=%v", *r.payHash)) - } - if r.payAddr != nil { - ids = append(ids, fmt.Sprintf("pay_addr=%x", *r.payAddr)) - } - if r.setID != nil { - ids = append(ids, fmt.Sprintf("set_id=%x", *r.setID)) - } - return fmt.Sprintf("(%s)", strings.Join(ids, ", ")) -} - -// ContractState describes the state the invoice is in. -type ContractState uint8 - -const ( - // ContractOpen means the invoice has only been created. - ContractOpen ContractState = 0 - - // ContractSettled means the htlc is settled and the invoice has been paid. - ContractSettled ContractState = 1 - - // ContractCanceled means the invoice has been canceled. - ContractCanceled ContractState = 2 - - // ContractAccepted means the HTLC has been accepted but not settled yet. - ContractAccepted ContractState = 3 -) - -// String returns a human readable identifier for the ContractState type. -func (c ContractState) String() string { - switch c { - case ContractOpen: - return "Open" - case ContractSettled: - return "Settled" - case ContractCanceled: - return "Canceled" - case ContractAccepted: - return "Accepted" - } - - return "Unknown" -} - -// IsFinal returns a boolean indicating whether an invoice state is final. -func (c ContractState) IsFinal() bool { - return c == ContractSettled || c == ContractCanceled -} - -// ContractTerm is a companion struct to the Invoice struct. This struct houses -// the necessary conditions required before the invoice can be considered fully -// settled by the payee. -type ContractTerm struct { - // FinalCltvDelta is the minimum required number of blocks before htlc - // expiry when the invoice is accepted. - FinalCltvDelta int32 - - // Expiry defines how long after creation this invoice should expire. - Expiry time.Duration - - // PaymentPreimage is the preimage which is to be revealed in the - // occasion that an HTLC paying to the hash of this preimage is - // extended. Set to nil if the preimage isn't known yet. - PaymentPreimage *lntypes.Preimage - - // Value is the expected amount of milli-satoshis to be paid to an HTLC - // which can be satisfied by the above preimage. - Value lnwire.MilliSatoshi - - // PaymentAddr is a randomly generated value include in the MPP record - // by the sender to prevent probing of the receiver. - PaymentAddr [32]byte - - // Features is the feature vectors advertised on the payment request. - Features *lnwire.FeatureVector -} - -// String returns a human-readable description of the prominent contract terms. -func (c ContractTerm) String() string { - return fmt.Sprintf("amt=%v, expiry=%v, final_cltv_delta=%v", c.Value, - c.Expiry, c.FinalCltvDelta) -} - -// SetID is the extra unique tuple item for AMP invoices. In addition to -// setting a payment address, each repeated payment to an AMP invoice will also -// contain a set ID as well. -type SetID [32]byte - -// InvoiceStateAMP is a struct that associates the current state of an AMP -// invoice identified by its set ID along with the set of invoices identified -// by the circuit key. This allows callers to easily look up the latest state -// of an AMP "sub-invoice" and also look up the invoice HLTCs themselves in the -// greater HTLC map index. -type InvoiceStateAMP struct { - // State is the state of this sub-AMP invoice. - State HtlcState - - // SettleIndex indicates the location in the settle index that - // references this instance of InvoiceStateAMP, but only if - // this value is set (non-zero), and State is HtlcStateSettled. - SettleIndex uint64 - - // SettleDate is the date that the setID was settled. - SettleDate time.Time - - // InvoiceKeys is the set of circuit keys that can be used to locate - // the invoices for a given set ID. - InvoiceKeys map[models.CircuitKey]struct{} - - // AmtPaid is the total amount that was paid in the AMP sub-invoice. - // Fetching the full HTLC/invoice state allows one to extract the - // custom records as well as the break down of the payment splits used - // when paying. - AmtPaid lnwire.MilliSatoshi -} - -// copy makes a deep copy of the underlying InvoiceStateAMP. -func (i *InvoiceStateAMP) copy() (InvoiceStateAMP, error) { - result := *i - - // Make a copy of the InvoiceKeys map. - result.InvoiceKeys = make(map[models.CircuitKey]struct{}) - for k := range i.InvoiceKeys { - result.InvoiceKeys[k] = struct{}{} - } - - // As a safety measure, copy SettleDate. time.Time is concurrency safe - // except when using any of the (un)marshalling methods. - settleDateBytes, err := i.SettleDate.MarshalBinary() - if err != nil { - return InvoiceStateAMP{}, err - } - - err = result.SettleDate.UnmarshalBinary(settleDateBytes) - if err != nil { - return InvoiceStateAMP{}, err - } - - return result, nil -} - -// AMPInvoiceState represents a type that stores metadata related to the set of -// settled AMP "sub-invoices". -type AMPInvoiceState map[SetID]InvoiceStateAMP - -// recordSize returns the amount of bytes this TLV record will occupy when -// encoded. -func (a *AMPInvoiceState) recordSize() uint64 { - var ( - b bytes.Buffer - buf [8]byte - ) - - // We know that encoding works since the tests pass in the build this file - // is checked into, so we'll simplify things and simply encode it ourselves - // then report the total amount of bytes used. - if err := ampStateEncoder(&b, a, &buf); err != nil { - // This should never error out, but we log it just in case it - // does. - log.Errorf("encoding the amp invoice state failed: %v", err) - } - - return uint64(len(b.Bytes())) -} - -// Invoice is a payment invoice generated by a payee in order to request -// payment for some good or service. The inclusion of invoices within Lightning -// creates a payment work flow for merchants very similar to that of the -// existing financial system within PayPal, etc. Invoices are added to the -// database when a payment is requested, then can be settled manually once the -// payment is received at the upper layer. For record keeping purposes, -// invoices are never deleted from the database, instead a bit is toggled -// denoting the invoice has been fully settled. Within the database, all -// invoices must have a unique payment hash which is generated by taking the -// sha256 of the payment preimage. -type Invoice struct { - // Memo is an optional memo to be stored along side an invoice. The - // memo may contain further details pertaining to the invoice itself, - // or any other message which fits within the size constraints. - Memo []byte - - // PaymentRequest is the encoded payment request for this invoice. For - // spontaneous (keysend) payments, this field will be empty. - PaymentRequest []byte - - // CreationDate is the exact time the invoice was created. - CreationDate time.Time - - // SettleDate is the exact time the invoice was settled. - SettleDate time.Time - - // Terms are the contractual payment terms of the invoice. Once all the - // terms have been satisfied by the payer, then the invoice can be - // considered fully fulfilled. - // - // TODO(roasbeef): later allow for multiple terms to fulfill the final - // invoice: payment fragmentation, etc. - Terms ContractTerm - - // AddIndex is an auto-incrementing integer that acts as a - // monotonically increasing sequence number for all invoices created. - // Clients can then use this field as a "checkpoint" of sorts when - // implementing a streaming RPC to notify consumers of instances where - // an invoice has been added before they re-connected. - // - // NOTE: This index starts at 1. - AddIndex uint64 - - // SettleIndex is an auto-incrementing integer that acts as a - // monotonically increasing sequence number for all settled invoices. - // Clients can then use this field as a "checkpoint" of sorts when - // implementing a streaming RPC to notify consumers of instances where - // an invoice has been settled before they re-connected. - // - // NOTE: This index starts at 1. - SettleIndex uint64 - - // State describes the state the invoice is in. This is the global - // state of the invoice which may remain open even when a series of - // sub-invoices for this invoice has been settled. - State ContractState - - // AmtPaid is the final amount that we ultimately accepted for pay for - // this invoice. We specify this value independently as it's possible - // that the invoice originally didn't specify an amount, or the sender - // overpaid. - AmtPaid lnwire.MilliSatoshi - - // Htlcs records all htlcs that paid to this invoice. Some of these - // htlcs may have been marked as canceled. - Htlcs map[models.CircuitKey]*InvoiceHTLC - - // AMPState describes the state of any related sub-invoices AMP to this - // greater invoice. A sub-invoice is defined by a set of HTLCs with the - // same set ID that attempt to make one time or recurring payments to - // this greater invoice. It's possible for a sub-invoice to be canceled - // or settled, but the greater invoice still open. - AMPState AMPInvoiceState - - // HodlInvoice indicates whether the invoice should be held in the - // Accepted state or be settled right away. - HodlInvoice bool -} - -// HTLCSet returns the set of HTLCs belonging to setID and in the provided -// state. Passing a nil setID will return all HTLCs in the provided state in the -// case of legacy or MPP, and no HTLCs in the case of AMP. Otherwise, the -// returned set will be filtered by the populated setID which is used to -// retrieve AMP HTLC sets. -func (i *Invoice) HTLCSet(setID *[32]byte, - state HtlcState) map[models.CircuitKey]*InvoiceHTLC { - - htlcSet := make(map[models.CircuitKey]*InvoiceHTLC) - for key, htlc := range i.Htlcs { - // Only add HTLCs that are in the requested HtlcState. - if htlc.State != state { - continue - } - - if !htlc.IsInHTLCSet(setID) { - continue - } - - htlcSet[key] = htlc - } - - return htlcSet -} - -// HTLCSetCompliment returns the set of all HTLCs not belonging to setID that -// are in the target state. Passing a nil setID will return no invoices, since -// all MPP HTLCs are part of the same HTLC set. -func (i *Invoice) HTLCSetCompliment(setID *[32]byte, - state HtlcState) map[models.CircuitKey]*InvoiceHTLC { - - htlcSet := make(map[models.CircuitKey]*InvoiceHTLC) - for key, htlc := range i.Htlcs { - // Only add HTLCs that are in the requested HtlcState. - if htlc.State != state { - continue - } - - // We are constructing the compliment, so filter anything that - // matches this set id. - if htlc.IsInHTLCSet(setID) { - continue - } - - htlcSet[key] = htlc - } - - return htlcSet -} - -// HtlcState defines the states an htlc paying to an invoice can be in. -type HtlcState uint8 - -const ( - // HtlcStateAccepted indicates the htlc is locked-in, but not resolved. - HtlcStateAccepted HtlcState = iota - - // HtlcStateCanceled indicates the htlc is canceled back to the - // sender. - HtlcStateCanceled - - // HtlcStateSettled indicates the htlc is settled. - HtlcStateSettled -) - -// InvoiceHTLC contains details about an htlc paying to this invoice. -type InvoiceHTLC struct { - // Amt is the amount that is carried by this htlc. - Amt lnwire.MilliSatoshi - - // MppTotalAmt is a field for mpp that indicates the expected total - // amount. - MppTotalAmt lnwire.MilliSatoshi - - // AcceptHeight is the block height at which the invoice registry - // decided to accept this htlc as a payment to the invoice. At this - // height, the invoice cltv delay must have been met. - AcceptHeight uint32 - - // AcceptTime is the wall clock time at which the invoice registry - // decided to accept the htlc. - AcceptTime time.Time - - // ResolveTime is the wall clock time at which the invoice registry - // decided to settle the htlc. - ResolveTime time.Time - - // Expiry is the expiry height of this htlc. - Expiry uint32 - - // State indicates the state the invoice htlc is currently in. A - // canceled htlc isn't just removed from the invoice htlcs map, because - // we need AcceptHeight to properly cancel the htlc back. - State HtlcState - - // CustomRecords contains the custom key/value pairs that accompanied - // the htlc. - CustomRecords record.CustomSet - - // AMP encapsulates additional data relevant to AMP HTLCs. This includes - // the AMP onion record, in addition to the HTLC's payment hash and - // preimage since these are unique to each AMP HTLC, and not the invoice - // as a whole. - // - // NOTE: This value will only be set for AMP HTLCs. - AMP *InvoiceHtlcAMPData -} - -// Copy makes a deep copy of the target InvoiceHTLC. -func (h *InvoiceHTLC) Copy() *InvoiceHTLC { - result := *h - - // Make a copy of the CustomSet map. - result.CustomRecords = make(record.CustomSet) - for k, v := range h.CustomRecords { - result.CustomRecords[k] = v - } - - result.AMP = h.AMP.Copy() - - return &result -} - -// IsInHTLCSet returns true if this HTLC is part an HTLC set. If nil is passed, -// this method returns true if this is an MPP HTLC. Otherwise, it only returns -// true if the AMP HTLC's set id matches the populated setID. -func (h *InvoiceHTLC) IsInHTLCSet(setID *[32]byte) bool { - wantAMPSet := setID != nil - isAMPHtlc := h.AMP != nil - - // Non-AMP HTLCs cannot be part of AMP HTLC sets, and vice versa. - if wantAMPSet != isAMPHtlc { - return false - } - - // Skip AMP HTLCs that have differing set ids. - if isAMPHtlc && *setID != h.AMP.Record.SetID() { - return false - } - - return true -} - -// InvoiceHtlcAMPData is a struct hodling the additional metadata stored for -// each received AMP HTLC. This includes the AMP onion record, in addition to -// the HTLC's payment hash and preimage. -type InvoiceHtlcAMPData struct { - // AMP is a copy of the AMP record presented in the onion payload - // containing the information necessary to correlate and settle a - // spontaneous HTLC set. Newly accepted legacy keysend payments will - // also have this field set as we automatically promote them into an AMP - // payment for internal processing. - Record record.AMP - - // Hash is an HTLC-level payment hash that is stored only for AMP - // payments. This is done because an AMP HTLC will carry a different - // payment hash from the invoice it might be satisfying, so we track the - // payment hashes individually to able to compute whether or not the - // reconstructed preimage correctly matches the HTLC's hash. - Hash lntypes.Hash - - // Preimage is an HTLC-level preimage that satisfies the AMP HTLC's - // Hash. The preimage will be be derived either from secret share - // reconstruction of the shares in the AMP payload. - // - // NOTE: Preimage will only be present once the HTLC is in - // HtlcStateSettled. - Preimage *lntypes.Preimage -} - -// Copy returns a deep copy of the InvoiceHtlcAMPData. -func (d *InvoiceHtlcAMPData) Copy() *InvoiceHtlcAMPData { - if d == nil { - return nil - } - - var preimage *lntypes.Preimage - if d.Preimage != nil { - pimg := *d.Preimage - preimage = &pimg - } - - return &InvoiceHtlcAMPData{ - Record: d.Record, - Hash: d.Hash, - Preimage: preimage, - } -} - -// HtlcAcceptDesc describes the details of a newly accepted htlc. -type HtlcAcceptDesc struct { - // AcceptHeight is the block height at which this htlc was accepted. - AcceptHeight int32 - - // Amt is the amount that is carried by this htlc. - Amt lnwire.MilliSatoshi - - // MppTotalAmt is a field for mpp that indicates the expected total - // amount. - MppTotalAmt lnwire.MilliSatoshi - - // Expiry is the expiry height of this htlc. - Expiry uint32 - - // CustomRecords contains the custom key/value pairs that accompanied - // the htlc. - CustomRecords record.CustomSet - - // AMP encapsulates additional data relevant to AMP HTLCs. This includes - // the AMP onion record, in addition to the HTLC's payment hash and - // preimage since these are unique to each AMP HTLC, and not the invoice - // as a whole. - // - // NOTE: This value will only be set for AMP HTLCs. - AMP *InvoiceHtlcAMPData -} - -// InvoiceUpdateDesc describes the changes that should be applied to the -// invoice. -type InvoiceUpdateDesc struct { - // State is the new state that this invoice should progress to. If nil, - // the state is left unchanged. - State *InvoiceStateUpdateDesc - - // CancelHtlcs describes the htlcs that need to be canceled. - CancelHtlcs map[models.CircuitKey]struct{} - - // AddHtlcs describes the newly accepted htlcs that need to be added to - // the invoice. - AddHtlcs map[models.CircuitKey]*HtlcAcceptDesc - - // SetID is an optional set ID for AMP invoices that allows operations - // to be more efficient by ensuring we don't need to read out the - // entire HTLC set each timee an HTLC is to be cancelled. - SetID *SetID -} - -// InvoiceStateUpdateDesc describes an invoice-level state transition. -type InvoiceStateUpdateDesc struct { - // NewState is the new state that this invoice should progress to. - NewState ContractState - - // Preimage must be set to the preimage when NewState is settled. - Preimage *lntypes.Preimage - - // HTLCPreimages set the HTLC-level preimages stored for AMP HTLCs. - // These are only learned when settling the invoice as a whole. Must be - // set when settling an invoice with non-nil SetID. - HTLCPreimages map[models.CircuitKey]lntypes.Preimage - - // SetID identifies a specific set of HTLCs destined for the same - // invoice as part of a larger AMP payment. This value will be nil for - // legacy or MPP payments. - SetID *[32]byte -} - -// InvoiceUpdateCallback is a callback used in the db transaction to update the -// invoice. -type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) - -func validateInvoice(i *Invoice, paymentHash lntypes.Hash) error { - // Avoid conflicts with all-zeroes magic value in the database. - if paymentHash == unknownPreimage.Hash() { - return fmt.Errorf("cannot use hash of all-zeroes preimage") - } - - if len(i.Memo) > MaxMemoSize { - return fmt.Errorf("max length a memo is %v, and invoice "+ - "of length %v was provided", MaxMemoSize, len(i.Memo)) - } - if len(i.PaymentRequest) > MaxPaymentRequestSize { - return fmt.Errorf("max length of payment request is %v, length "+ - "provided was %v", MaxPaymentRequestSize, - len(i.PaymentRequest)) - } - if i.Terms.Features == nil { - return errors.New("invoice must have a feature vector") - } - - err := feature.ValidateDeps(i.Terms.Features) - if err != nil { - return err - } - - // AMP invoices and hodl invoices are allowed to have no preimage - // specified. - isAMP := i.Terms.Features.HasFeature( - lnwire.AMPOptional, - ) - if i.Terms.PaymentPreimage == nil && !(i.HodlInvoice || isAMP) { - return errors.New("non-hodl invoices must have a preimage") - } - - if len(i.Htlcs) > 0 { - return ErrInvoiceHasHtlcs - } - - return nil -} - -// IsPending returns true if the invoice is in ContractOpen state. -func (i *Invoice) IsPending() bool { - return i.State == ContractOpen || i.State == ContractAccepted -} - // AddInvoice inserts the targeted invoice into the database. If the invoice has // *any* payment hashes which already exists within the database, then the // insertion will be aborted and rejected due to the strict policy banning any // duplicate payment hashes. A side effect of this function is that it sets // AddIndex on newInvoice. -func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( +func (d *DB) AddInvoice(newInvoice *invpkg.Invoice, paymentHash lntypes.Hash) ( uint64, error) { - if err := validateInvoice(newInvoice, paymentHash); err != nil { + if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil { return 0, err } @@ -963,7 +170,7 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( // Ensure that an invoice an identical payment hash doesn't // already exist within the index. if invoiceIndex.Get(paymentHash[:]) != nil { - return ErrDuplicateInvoice + return invpkg.ErrDuplicateInvoice } // Check that we aren't inserting an invoice with a duplicate @@ -972,9 +179,10 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( // assign one. This is safe since later we also will avoid // indexing them and avoid collisions. payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket) - if newInvoice.Terms.PaymentAddr != BlankPayAddr { - if payAddrIndex.Get(newInvoice.Terms.PaymentAddr[:]) != nil { - return ErrDuplicatePayAddr + if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr { + paymentAddr := newInvoice.Terms.PaymentAddr[:] + if payAddrIndex.Get(paymentAddr) != nil { + return invpkg.ErrDuplicatePayAddr } } @@ -1021,8 +229,10 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( // // NOTE: The index starts from 1, as a result. We enforce that specifying a // value below the starting index value is a noop. -func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { - var newInvoices []Invoice +func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]invpkg.Invoice, + error) { + + var newInvoices []invpkg.Invoice // If an index of zero was specified, then in order to maintain // backwards compat, we won't send out any new invoices. @@ -1083,16 +293,16 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { // full invoice is returned. Before setting the incoming HTLC, the values // SHOULD be checked to ensure the payer meets the agreed upon contractual // terms of the payment. -func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { - var invoice Invoice +func (d *DB) LookupInvoice(ref invpkg.InvoiceRef) (invpkg.Invoice, error) { + var invoice invpkg.Invoice err := kvdb.View(d, func(tx kvdb.RTx) error { invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { - return ErrNoInvoicesCreated + return invpkg.ErrNoInvoicesCreated } invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket) if invoiceIndex == nil { - return ErrNoInvoicesCreated + return invpkg.ErrNoInvoicesCreated } payAddrIndex := tx.ReadBucket(payAddrIndexBucket) setIDIndex := tx.ReadBucket(setIDIndexBucket) @@ -1106,20 +316,24 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { return err } - var setID *SetID + var setID *invpkg.SetID switch { // If this is a payment address ref, and the blank modified was // specified, then we'll use the zero set ID to indicate that // we won't want any HTLCs returned. - case ref.PayAddr() != nil && ref.Modifier() == HtlcSetBlankModifier: - var zeroSetID SetID + case ref.PayAddr() != nil && + ref.Modifier() == invpkg.HtlcSetBlankModifier: + + var zeroSetID invpkg.SetID setID = &zeroSetID // If this is a set ID ref, and the htlc set only modified was // specified, then we'll pass through the specified setID so // only that will be returned. - case ref.SetID() != nil && ref.Modifier() == HtlcSetOnlyModifier: - setID = (*SetID)(ref.SetID()) + case ref.SetID() != nil && + ref.Modifier() == invpkg.HtlcSetOnlyModifier: + + setID = (*invpkg.SetID)(ref.SetID()) } // An invoice was found, retrieve the remainder of the invoice @@ -1144,7 +358,7 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { // back to the payment hash if nothing is found for the payment address. An // error is returned if the invoice is not found. func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket, - ref InvoiceRef) ([]byte, error) { + ref invpkg.InvoiceRef) ([]byte, error) { // If the set id is present, we only consult the set id index for this // invoice. This type of query is only used to facilitate user-facing @@ -1153,7 +367,7 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket, if setID != nil { invoiceNumBySetID := setIDIndex.Get(setID[:]) if invoiceNumBySetID == nil { - return nil, ErrInvoiceNotFound + return nil, invpkg.ErrInvoiceNotFound } return invoiceNumBySetID, nil @@ -1174,7 +388,7 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket, // Only allow lookups for payment address if it is not a // blank payment address, which is a special-cased value // for legacy keysend invoices. - if *payAddr != BlankPayAddr { + if *payAddr != invpkg.BlankPayAddr { return payAddrIndex.Get(payAddr[:]) } } @@ -1188,7 +402,7 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket, // invoice, ensure they reference the _same_ invoice. case invoiceNumByAddr != nil && invoiceNumByHash != nil: if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) { - return nil, ErrInvRefEquivocation + return nil, invpkg.ErrInvRefEquivocation } return invoiceNumByAddr, nil @@ -1212,7 +426,7 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket, // Otherwise we don't know of the target invoice. default: - return nil, ErrInvoiceNotFound + return nil, invpkg.ErrInvoiceNotFound } } @@ -1220,13 +434,11 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket, // for each invoice with its respective payment hash. Additionally a reset() // closure is passed which is used to reset/initialize partial results and also // to signal if the kvdb.View transaction has been retried. -func (d *DB) ScanInvoices( - scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error { - +func (d *DB) ScanInvoices(scanFunc invpkg.InvScanFunc, reset func()) error { return kvdb.View(d, func(tx kvdb.RTx) error { invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { - return ErrNoInvoicesCreated + return invpkg.ErrNoInvoicesCreated } invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket) @@ -1263,65 +475,13 @@ func (d *DB) ScanInvoices( }, reset) } -// InvoiceQuery represents a query to the invoice database. The query allows a -// caller to retrieve all invoices starting from a particular add index and -// limit the number of results returned. -type InvoiceQuery struct { - // IndexOffset is the offset within the add indices to start at. This - // can be used to start the response at a particular invoice. - IndexOffset uint64 - - // NumMaxInvoices is the maximum number of invoices that should be - // starting from the add index. - NumMaxInvoices uint64 - - // PendingOnly, if set, returns unsettled invoices starting from the - // add index. - PendingOnly bool - - // Reversed, if set, indicates that the invoices returned should start - // from the IndexOffset and go backwards. - Reversed bool - - // CreationDateStart, if set, filters out all invoices with a creation - // date greater than or euqal to it. - CreationDateStart time.Time - - // CreationDateEnd, if set, filters out all invoices with a creation - // date less than or euqal to it. - CreationDateEnd time.Time -} - -// InvoiceSlice is the response to a invoice query. It includes the original -// query, the set of invoices that match the query, and an integer which -// represents the offset index of the last item in the set of returned invoices. -// This integer allows callers to resume their query using this offset in the -// event that the query's response exceeds the maximum number of returnable -// invoices. -type InvoiceSlice struct { - InvoiceQuery - - // Invoices is the set of invoices that matched the query above. - Invoices []Invoice - - // FirstIndexOffset is the index of the first element in the set of - // returned Invoices above. Callers can use this to resume their query - // in the event that the slice has too many events to fit into a single - // response. - FirstIndexOffset uint64 - - // LastIndexOffset is the index of the last element in the set of - // returned Invoices above. Callers can use this to resume their query - // in the event that the slice has too many events to fit into a single - // response. - LastIndexOffset uint64 -} - // QueryInvoices allows a caller to query the invoice database for invoices // within the specified add index range. -func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { +func (d *DB) QueryInvoices(q invpkg.InvoiceQuery) (invpkg.InvoiceSlice, + error) { + var ( - resp InvoiceSlice + resp invpkg.InvoiceSlice startDateSet = !q.CreationDateStart.IsZero() endDateSet = !q.CreationDateEnd.IsZero() ) @@ -1331,14 +491,14 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { // within the database yet, so we can simply exit. invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { - return ErrNoInvoicesCreated + return invpkg.ErrNoInvoicesCreated } // Get the add index bucket which we will use to iterate through // our indexed invoices. invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket) if invoiceAddIndex == nil { - return ErrNoInvoicesCreated + return invpkg.ErrNoInvoicesCreated } // Create a paginator which reads from our add index bucket with @@ -1400,19 +560,19 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { if q.Reversed { numInvoices := len(resp.Invoices) for i := 0; i < numInvoices/2; i++ { - opposite := numInvoices - i - 1 - resp.Invoices[i], resp.Invoices[opposite] = - resp.Invoices[opposite], resp.Invoices[i] + reverse := numInvoices - i - 1 + resp.Invoices[i], resp.Invoices[reverse] = + resp.Invoices[reverse], resp.Invoices[i] } } return nil }, func() { - resp = InvoiceSlice{ + resp = invpkg.InvoiceSlice{ InvoiceQuery: q, } }) - if err != nil && err != ErrNoInvoicesCreated { + if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) { return resp, err } @@ -1420,7 +580,8 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { // so that the caller can resume from this point later on. if len(resp.Invoices) > 0 { resp.FirstIndexOffset = resp.Invoices[0].AddIndex - resp.LastIndexOffset = resp.Invoices[len(resp.Invoices)-1].AddIndex + lastIdx := len(resp.Invoices) - 1 + resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex } return resp, nil @@ -1433,10 +594,10 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { // The update is performed inside the same database transaction that fetches the // invoice and is therefore atomic. The fields to update are controlled by the // supplied callback. -func (d *DB) UpdateInvoice(ref InvoiceRef, setIDHint *SetID, - callback InvoiceUpdateCallback) (*Invoice, error) { +func (d *DB) UpdateInvoice(ref invpkg.InvoiceRef, setIDHint *invpkg.SetID, + callback invpkg.InvoiceUpdateCallback) (*invpkg.Invoice, error) { - var updatedInvoice *Invoice + var updatedInvoice *invpkg.Invoice err := kvdb.Update(d, func(tx kvdb.RwTx) error { invoices, err := tx.CreateTopLevelBucket(invoiceBucket) if err != nil { @@ -1487,8 +648,10 @@ func (d *DB) UpdateInvoice(ref InvoiceRef, setIDHint *SetID, // // NOTE: The index starts from 1, as a result. We enforce that specifying a // value below the starting index value is a noop. -func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { - var settledInvoices []Invoice +func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]invpkg.Invoice, + error) { + + var settledInvoices []invpkg.Invoice // If an index of zero was specified, then in order to maintain // backwards compat, we won't send out any new invoices. @@ -1527,18 +690,20 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { // and the setID (may not be there). var ( invoiceKey [4]byte - setID *SetID + setID *invpkg.SetID ) valueLen := copy(invoiceKey[:], indexValue) if len(indexValue) == invoiceSetIDKeyLen { - setID = new(SetID) + setID = new(invpkg.SetID) copy(setID[:], indexValue[valueLen:]) } // For each key found, we'll look up the actual // invoice, then accumulate it into our return value. - invoice, err := fetchInvoice(invoiceKey[:], invoices, setID) + invoice, err := fetchInvoice( + invoiceKey[:], invoices, setID, + ) if err != nil { return err } @@ -1558,7 +723,7 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { } func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket, - i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) ( + i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) ( uint64, error) { // Create the invoice key which is just the big-endian representation @@ -1587,7 +752,7 @@ func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket, // has a non-zero payment address. The all-zero payment address is still // in use by legacy keysend, so we special-case here to avoid // collisions. - if i.Terms.PaymentAddr != BlankPayAddr { + if i.Terms.PaymentAddr != invpkg.BlankPayAddr { err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:]) if err != nil { return 0, err @@ -1626,12 +791,34 @@ func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket, return nextAddSeqNo, nil } +// recordSize returns the amount of bytes this TLV record will occupy when +// encoded. +func ampRecordSize(a *invpkg.AMPInvoiceState) func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + + // We know that encoding works since the tests pass in the build this + // file is checked into, so we'll simplify things and simply encode it + // ourselves then report the total amount of bytes used. + if err := ampStateEncoder(&b, a, &buf); err != nil { + // This should never error out, but we log it just in case it + // does. + log.Errorf("encoding the amp invoice state failed: %v", err) + } + + return func() uint64 { + return uint64(len(b.Bytes())) + } +} + // serializeInvoice serializes an invoice to a writer. // // Note: this function is in use for a migration. Before making changes that // would modify the on disk format, make a copy of the original code and store // it with the migration. -func serializeInvoice(w io.Writer, i *Invoice) error { +func serializeInvoice(w io.Writer, i *invpkg.Invoice) error { creationDateBytes, err := i.CreationDate.MarshalBinary() if err != nil { return err @@ -1649,10 +836,10 @@ func serializeInvoice(w io.Writer, i *Invoice) error { } featureBytes := fb.Bytes() - preimage := [32]byte(unknownPreimage) + preimage := [32]byte(invpkg.UnknownPreimage) if i.Terms.PaymentPreimage != nil { preimage = *i.Terms.PaymentPreimage - if preimage == unknownPreimage { + if preimage == invpkg.UnknownPreimage { return errors.New("cannot use all-zeroes preimage") } } @@ -1696,7 +883,7 @@ func serializeInvoice(w io.Writer, i *Invoice) error { // Invoice AMP state. tlv.MakeDynamicRecord( invoiceAmpStateType, &i.AMPState, - i.AMPState.recordSize, + ampRecordSize(&i.AMPState), ampStateEncoder, ampStateDecoder, ), ) @@ -1733,7 +920,7 @@ func serializeInvoice(w io.Writer, i *Invoice) error { // serializeHtlcs serializes a map containing circuit keys and invoice htlcs to // a writer. func serializeHtlcs(w io.Writer, - htlcs map[models.CircuitKey]*InvoiceHTLC) error { + htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error { for key, htlc := range htlcs { // Encode the htlc in a tlv stream. @@ -1837,9 +1024,10 @@ func getNanoTime(ns uint64) time.Time { // fetchFilteredAmpInvoices retrieves only a select set of AMP invoices // identified by the setID value. func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte, - setIDs ...*SetID) (map[models.CircuitKey]*InvoiceHTLC, error) { + setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC, + error) { - htlcs := make(map[models.CircuitKey]*InvoiceHTLC) + htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC) for _, setID := range setIDs { invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:]) @@ -1848,7 +1036,7 @@ func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte, // A set ID was passed in, but we don't have this // stored yet, meaning that the setID is being added // for the first time. - return htlcs, ErrInvoiceNotFound + return htlcs, invpkg.ErrInvoiceNotFound } htlcSetReader := bytes.NewReader(htlcSetBytes) @@ -1905,7 +1093,8 @@ func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte, // given invoice. If a list of set IDs are specified, then only HTLCs // associated with that setID will be retrieved. func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte, - setIDs ...*SetID) (map[models.CircuitKey]*InvoiceHTLC, error) { + setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC, + error) { // If a set of setIDs was specified, then we can skip the cursor and // just read out exactly what we need. @@ -1917,7 +1106,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte, // Otherwise, iterate over all the htlc sets that are prefixed beside // this invoice in the main invoice bucket. - htlcs := make(map[models.CircuitKey]*InvoiceHTLC) + htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC) err := forEachAMPInvoice(invoiceBucket, invoiceNum, func(key, htlcSet []byte) error { htlcSetReader := bytes.NewReader(htlcSet) @@ -1931,7 +1120,9 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte, } return nil - }) + }, + ) + if err != nil { return nil, err } @@ -1942,17 +1133,19 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte, // fetchInvoice attempts to read out the relevant state for the invoice as // specified by the invoice number. If the setID fields are set, then only the // HTLC information pertaining to those set IDs is returned. -func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, setIDs ...*SetID) (Invoice, error) { +func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, + setIDs ...*invpkg.SetID) (invpkg.Invoice, error) { + invoiceBytes := invoices.Get(invoiceNum) if invoiceBytes == nil { - return Invoice{}, ErrInvoiceNotFound + return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound } invoiceReader := bytes.NewReader(invoiceBytes) invoice, err := deserializeInvoice(invoiceReader) if err != nil { - return Invoice{}, err + return invpkg.Invoice{}, err } // If this is an AMP invoice, then we'll also attempt to read out the @@ -1974,7 +1167,7 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, setIDs ...*SetID) (I // If the "zero" setID was specified, then this means that no HTLC data // should be returned alongside of it. case invoiceIsAMP && len(setIDs) != 0 && setIDs[0] != nil && - *setIDs[0] == BlankPayAddr: + *setIDs[0] == invpkg.BlankPayAddr: return invoice, nil } @@ -1993,12 +1186,12 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, setIDs ...*SetID) (I // an AMP invoice. This methods only decode the relevant state vs the entire // invoice. func fetchInvoiceStateAMP(invoiceNum []byte, - invoices kvdb.RBucket) (AMPInvoiceState, error) { + invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) { // Fetch the raw invoice bytes. invoiceBytes := invoices.Get(invoiceNum) if invoiceBytes == nil { - return nil, ErrInvoiceNotFound + return nil, invpkg.ErrInvoiceNotFound } r := bytes.NewReader(invoiceBytes) @@ -2011,7 +1204,7 @@ func fetchInvoiceStateAMP(invoiceNum []byte, // Next, we'll make a new TLV stream that only attempts to decode the // bytes we actually need. - ampState := make(AMPInvoiceState) + ampState := make(invpkg.AMPInvoiceState) tlvStream, err := tlv.NewStream( // Invoice AMP state. tlv.MakeDynamicRecord( @@ -2031,7 +1224,7 @@ func fetchInvoiceStateAMP(invoiceNum []byte, return ampState, nil } -func deserializeInvoice(r io.Reader) (Invoice, error) { +func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) { var ( preimageBytes [32]byte value uint64 @@ -2046,8 +1239,8 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { featureBytes []byte ) - var i Invoice - i.AMPState = make(AMPInvoiceState) + var i invpkg.Invoice + i.AMPState = make(invpkg.AMPInvoiceState) tlvStream, err := tlv.NewStream( // Memo and payreq. tlv.MakePrimitiveRecord(memoType, &i.Memo), @@ -2095,7 +1288,7 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { } preimage := lntypes.Preimage(preimageBytes) - if preimage != unknownPreimage { + if preimage != invpkg.UnknownPreimage { i.Terms.PaymentPreimage = &preimage } @@ -2103,7 +1296,7 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { i.Terms.FinalCltvDelta = int32(cltvDelta) i.Terms.Expiry = time.Duration(expiry) i.AmtPaid = lnwire.MilliSatoshi(amtPaid) - i.State = ContractState(state) + i.State = invpkg.ContractState(state) if hodlInvoice != 0 { i.HodlInvoice = true @@ -2173,8 +1366,8 @@ func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte, return err } - // Now that we know how many keys to expect, iterate reading each - // one until we're done. + // Now that we know how many keys to expect, iterate reading + // each one until we're done. for i := uint64(0); i < numKeys; i++ { var ( key models.CircuitKey @@ -2187,7 +1380,8 @@ func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte, key.ChanID = lnwire.NewShortChanIDFromInt(scid) - if err := tlv.DUint64(r, &key.HtlcID, buf, 8); err != nil { + err := tlv.DUint64(r, &key.HtlcID, buf, 8) //nolint:gomnd,lll + if err != nil { return err } @@ -2202,7 +1396,7 @@ func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte, // ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record. func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error { - if v, ok := val.(*AMPInvoiceState); ok { + if v, ok := val.(*invpkg.AMPInvoiceState); ok { // We'll encode the AMP state as a series of KV pairs on the // wire with a length prefix. numRecords := uint64(len(*v)) @@ -2220,7 +1414,8 @@ func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error { ampState := ampState htlcState := uint8(ampState.State) - settleDateBytes, err := ampState.SettleDate.MarshalBinary() + settleDate := ampState.SettleDate + settleDateBytes, err := settleDate.MarshalBinary() if err != nil { return err } @@ -2236,21 +1431,28 @@ func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error { ampStateHtlcStateType, &htlcState, ), tlv.MakePrimitiveRecord( - ampStateSettleIndexType, &State.SettleIndex, + ampStateSettleIndexType, + &State.SettleIndex, ), tlv.MakePrimitiveRecord( - ampStateSettleDateType, &settleDateBytes, + ampStateSettleDateType, + &settleDateBytes, ), tlv.MakeDynamicRecord( ampStateCircuitKeysType, &State.InvoiceKeys, func() uint64 { - // The record takes 8 bytes to encode the - // set of circuits, 8 bytes for the scid - // for the key, and 8 bytes for the HTLC + // The record takes 8 bytes to + // encode the set of circuits, + // 8 bytes for the scid for the + // key, and 8 bytes for the HTLC // index. - numKeys := uint64(len(ampState.InvoiceKeys)) - return tlv.VarIntSize(numKeys) + (numKeys * 16) + keys := ampState.InvoiceKeys + numKeys := uint64(len(keys)) + size := tlv.VarIntSize(numKeys) + dataSize := (numKeys * 16) //nolint:gomnd,lll + + return size + dataSize }, encodeCircuitKeys, decodeCircuitKeys, ), @@ -2262,7 +1464,8 @@ func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error { return err } - if err := tlvStream.Encode(&StateTlvBytes); err != nil { + err = tlvStream.Encode(&StateTlvBytes) + if err != nil { return err } @@ -2273,7 +1476,8 @@ func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error { return err } - if _, err := w.Write(ampStateTlvBytes.Bytes()); err != nil { + _, err = w.Write(ampStateTlvBytes.Bytes()) + if err != nil { return err } } @@ -2285,8 +1489,10 @@ func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error { } // ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record. -func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { - if v, ok := val.(*AMPInvoiceState); ok { +func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*invpkg.AMPInvoiceState); ok { // First, we'll decode the varint that encodes how many set IDs // are encoded within the greater map. numRecords, err := tlv.ReadVarInt(r, buf) @@ -2297,8 +1503,8 @@ func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error // Now that we know how many records we'll need to read, we can // iterate and read them all out in series. for i := uint64(0); i < numRecords; i++ { - // Read out the varint that encodes the size of this inner - // TLV record + // Read out the varint that encodes the size of this + // inner TLV record. stateRecordSize, err := tlv.ReadVarInt(r, buf) if err != nil { return err @@ -2333,7 +1539,8 @@ func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error ampStateSettleIndexType, &settleIndex, ), tlv.MakePrimitiveRecord( - ampStateSettleDateType, &settleDateBytes, + ampStateSettleDateType, + &settleDateBytes, ), tlv.MakeDynamicRecord( ampStateCircuitKeysType, @@ -2348,7 +1555,8 @@ func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error return err } - if err := tlvStream.Decode(&innerTlvReader); err != nil { + err = tlvStream.Decode(&innerTlvReader) + if err != nil { return err } @@ -2358,8 +1566,8 @@ func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error return err } - (*v)[setID] = InvoiceStateAMP{ - State: HtlcState(htlcState), + (*v)[setID] = invpkg.InvoiceStateAMP{ + State: invpkg.HtlcState(htlcState), SettleIndex: settleIndex, SettleDate: settleDate, InvoiceKeys: invoiceKeys, @@ -2377,9 +1585,10 @@ func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error // deserializeHtlcs reads a list of invoice htlcs from a reader and returns it // as a map. -func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*InvoiceHTLC, error) { - htlcs := make(map[models.CircuitKey]*InvoiceHTLC) +func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC, + error) { + htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC) for { // Read the length of the tlv stream for this htlc. var streamLen int64 @@ -2397,7 +1606,7 @@ func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*InvoiceHTLC, error) { // Decode the contents into the htlc fields. var ( - htlc InvoiceHTLC + htlc invpkg.InvoiceHTLC key models.CircuitKey chanID uint64 state uint8 @@ -2454,11 +1663,11 @@ func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*InvoiceHTLC, error) { key.ChanID = lnwire.NewShortChanIDFromInt(chanID) htlc.AcceptTime = getNanoTime(acceptTime) htlc.ResolveTime = getNanoTime(resolveTime) - htlc.State = HtlcState(state) + htlc.State = invpkg.HtlcState(state) htlc.Amt = lnwire.MilliSatoshi(amt) htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt) if amp != nil && hash != nil { - htlc.AMP = &InvoiceHtlcAMPData{ + htlc.AMP = &invpkg.InvoiceHtlcAMPData{ Record: *amp, Hash: *hash, Preimage: preimage, @@ -2475,56 +1684,6 @@ func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*InvoiceHTLC, error) { return htlcs, nil } -// copySlice allocates a new slice and copies the source into it. -func copySlice(src []byte) []byte { - dest := make([]byte, len(src)) - copy(dest, src) - return dest -} - -// copyInvoice makes a deep copy of the supplied invoice. -func copyInvoice(src *Invoice) (*Invoice, error) { - dest := Invoice{ - Memo: copySlice(src.Memo), - PaymentRequest: copySlice(src.PaymentRequest), - CreationDate: src.CreationDate, - SettleDate: src.SettleDate, - Terms: src.Terms, - AddIndex: src.AddIndex, - SettleIndex: src.SettleIndex, - State: src.State, - AmtPaid: src.AmtPaid, - Htlcs: make( - map[models.CircuitKey]*InvoiceHTLC, len(src.Htlcs), - ), - AMPState: make(map[SetID]InvoiceStateAMP), - HodlInvoice: src.HodlInvoice, - } - - dest.Terms.Features = src.Terms.Features.Clone() - - if src.Terms.PaymentPreimage != nil { - preimage := *src.Terms.PaymentPreimage - dest.Terms.PaymentPreimage = &preimage - } - - for k, v := range src.Htlcs { - dest.Htlcs[k] = v.Copy() - } - - // Lastly, copy the amp invoice state. - for k, v := range src.AMPState { - ampInvState, err := v.copy() - if err != nil { - return nil, err - } - - dest.AMPState[k] = ampInvState - } - - return &dest, nil -} - // invoiceSetIDKeyLen is the length of the key that's used to store the // individual HTLCs prefixed by their ID along side the main invoice within the // invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the @@ -2550,11 +1709,11 @@ func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte { // potentially massive HTLC set, and also allows us to quickly find the HLTCs // associated with a particular HTLC set. func updateAMPInvoices(invoiceBucket kvdb.RwBucket, invoiceNum []byte, - htlcsToUpdate map[SetID]map[models.CircuitKey]*InvoiceHTLC) error { + htlcsToUpdate map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC) error { //nolint:lll for setID, htlcSet := range htlcsToUpdate { - // First write out the set of HTLCs including all the relevant TLV - // values. + // First write out the set of HTLCs including all the relevant + // TLV values. var b bytes.Buffer if err := serializeHtlcs(&b, htlcSet); err != nil { return err @@ -2576,16 +1735,17 @@ func updateAMPInvoices(invoiceBucket kvdb.RwBucket, invoiceNum []byte, // updateHtlcsAmp takes an invoice, and a new HTLC to be added (along with its // set ID), and update sthe internal AMP state of an invoice, and also tallies // the set of HTLCs to be updated on disk. -func updateHtlcsAmp(invoice *Invoice, - updateMap map[SetID]map[models.CircuitKey]*InvoiceHTLC, - htlc *InvoiceHTLC, setID SetID, circuitKey models.CircuitKey) { +func updateHtlcsAmp(invoice *invpkg.Invoice, + updateMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC, + htlc *invpkg.InvoiceHTLC, setID invpkg.SetID, + circuitKey models.CircuitKey) { ampState, ok := invoice.AMPState[setID] if !ok { // If an entry for this set ID doesn't already exist, then // we'll need to create it. - ampState = InvoiceStateAMP{ - State: HtlcStateAccepted, + ampState = invpkg.InvoiceStateAMP{ + State: invpkg.HtlcStateAccepted, InvoiceKeys: make(map[models.CircuitKey]struct{}), } } @@ -2604,7 +1764,7 @@ func updateHtlcsAmp(invoice *Invoice, // also pull in the existing HTLCs are part of this set, so we // can write them all to disk together (same value) updateMap[setID] = invoice.HTLCSet( - (*[32]byte)(&setID), HtlcStateAccepted, + (*[32]byte)(&setID), invpkg.HtlcStateAccepted, ) } updateMap[setID][circuitKey] = htlc @@ -2614,15 +1774,15 @@ func updateHtlcsAmp(invoice *Invoice, // HTLC set. We'll need to update the meta data in the main invoice, and also // apply the new update to the update MAP, since all the HTLCs for a given HTLC // set need to be written in-line with each other. -func cancelHtlcsAmp(invoice *Invoice, - updateMap map[SetID]map[models.CircuitKey]*InvoiceHTLC, - htlc *InvoiceHTLC, circuitKey models.CircuitKey) { +func cancelHtlcsAmp(invoice *invpkg.Invoice, + updateMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC, + htlc *invpkg.InvoiceHTLC, circuitKey models.CircuitKey) { setID := htlc.AMP.Record.SetID() // First, we'll update the state of the entire HTLC set to cancelled. ampState := invoice.AMPState[setID] - ampState.State = HtlcStateCanceled + ampState.State = invpkg.HtlcStateCanceled ampState.InvoiceKeys[circuitKey] = struct{}{} ampState.AmtPaid -= htlc.Amt @@ -2635,9 +1795,13 @@ func cancelHtlcsAmp(invoice *Invoice, // Only HTLCs in the accepted state, can be cancelled, but we // also want to merge that with HTLCs that may be canceled as // well since it can be cancelled one by one. - updateMap[setID] = invoice.HTLCSet(&setID, HtlcStateAccepted) + updateMap[setID] = invoice.HTLCSet( + &setID, invpkg.HtlcStateAccepted, + ) - cancelledHtlcs := invoice.HTLCSet(&setID, HtlcStateCanceled) + cancelledHtlcs := invoice.HTLCSet( + &setID, invpkg.HtlcStateCanceled, + ) for htlcKey, htlc := range cancelledHtlcs { updateMap[setID][htlcKey] = htlc } @@ -2657,10 +1821,10 @@ func cancelHtlcsAmp(invoice *Invoice, // settleHtlcsAmp processes a new settle operation on an HTLC set for an AMP // invoice. We'll update some meta data in the main invoice, and also signal // that this HTLC set needs to be re-written back to disk. -func settleHtlcsAmp(invoice *Invoice, - settledSetIDs map[SetID]struct{}, - updateMap map[SetID]map[models.CircuitKey]*InvoiceHTLC, - htlc *InvoiceHTLC, circuitKey models.CircuitKey) { +func settleHtlcsAmp(invoice *invpkg.Invoice, + settledSetIDs map[invpkg.SetID]struct{}, + updateMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC, + htlc *invpkg.InvoiceHTLC, circuitKey models.CircuitKey) { // First, add the set ID to the set that was settled in this invoice // update. We'll use this later to update the settle index. @@ -2670,7 +1834,7 @@ func settleHtlcsAmp(invoice *Invoice, // Next update the main AMP meta-data to indicate that this HTLC set // has been fully settled. ampState := invoice.AMPState[setID] - ampState.State = HtlcStateSettled + ampState.State = invpkg.HtlcStateSettled ampState.InvoiceKeys[circuitKey] = struct{}{} @@ -2678,22 +1842,23 @@ func settleHtlcsAmp(invoice *Invoice, // Finally, we'll add this to the set of HTLCs that need to be updated. if _, ok := updateMap[setID]; !ok { - updateMap[setID] = make(map[models.CircuitKey]*InvoiceHTLC) + mapEntry := make(map[models.CircuitKey]*invpkg.InvoiceHTLC) + updateMap[setID] = mapEntry } updateMap[setID][circuitKey] = htlc } // updateInvoice fetches the invoice, obtains the update descriptor from the // callback and applies the updates in a single db transaction. -func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, +func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices, //nolint:lll,funlen settleIndex, setIDIndex kvdb.RwBucket, invoiceNum []byte, - callback InvoiceUpdateCallback) (*Invoice, error) { + callback invpkg.InvoiceUpdateCallback) (*invpkg.Invoice, error) { // If the set ID is non-nil, then we'll use that to filter out the // HTLCs for AMP invoice so we don't need to read them all out to // satisfy the invoice callback below. If it's nil, then we pass in the // zero set ID which means no HTLCs will be read out. - var invSetID SetID + var invSetID invpkg.SetID if refSetID != nil { invSetID = *refSetID } @@ -2704,7 +1869,7 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, // Create deep copy to prevent any accidental modification in the // callback. - invoiceCopy, err := copyInvoice(&invoice) + invoiceCopy, err := invpkg.CopyInvoice(&invoice) if err != nil { return nil, err } @@ -2745,7 +1910,7 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, ) // Process add actions from update descriptor. - htlcsAmpUpdate := make(map[SetID]map[models.CircuitKey]*InvoiceHTLC) + htlcsAmpUpdate := make(map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC) //nolint:lll for key, htlcUpdate := range update.AddHtlcs { if _, exists := invoice.Htlcs[key]; exists { return nil, fmt.Errorf("duplicate add of htlc %v", key) @@ -2770,17 +1935,18 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, return nil, err } } else if !bytes.Equal(setIDInvNum, invoiceNum) { - return nil, ErrDuplicateSetID{setID: setID} + err = invpkg.ErrDuplicateSetID{SetID: setID} + return nil, err } } - htlc := &InvoiceHTLC{ + htlc := &invpkg.InvoiceHTLC{ Amt: htlcUpdate.Amt, MppTotalAmt: htlcUpdate.MppTotalAmt, Expiry: htlcUpdate.Expiry, AcceptHeight: uint32(htlcUpdate.AcceptHeight), AcceptTime: now, - State: HtlcStateAccepted, + State: invpkg.HtlcStateAccepted, CustomRecords: htlcUpdate.CustomRecords, AMP: htlcUpdate.AMP.Copy(), } @@ -2857,14 +2023,15 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, // update the state of each _htlc set_ instead. However, we'll // allow the invoice to transition to the cancelled state // regardless. - if !invoiceIsAMP || *newState == ContractCanceled { + if !invoiceIsAMP || *newState == invpkg.ContractCanceled { invoice.State = *newState } // If this is a non-AMP invoice, then the state can eventually // go to ContractSettled, so we pass in nil value as part of // setSettleMetaFields. - if !invoiceIsAMP && update.State.NewState == ContractSettled { + isSettled := update.State.NewState == invpkg.ContractSettled + if !invoiceIsAMP && isSettled { err := setSettleMetaFields( settleIndex, invoiceNum, &invoice, now, nil, ) @@ -2885,7 +2052,7 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, // finalize the process by updating the state transitions for // individual HTLCs var ( - settledSetIDs = make(map[SetID]struct{}) + settledSetIDs = make(map[invpkg.SetID]struct{}) amtPaid lnwire.MilliSatoshi ) for key, htlc := range invoice.Htlcs { @@ -2902,7 +2069,7 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, // preimage. Ignore the case where the preimage is // identical. case ok && *htlc.AMP.Preimage != preimage: - return nil, ErrHTLCPreimageAlreadyExists + return nil, invpkg.ErrHTLCPreimageAlreadyExists } } @@ -2917,7 +2084,7 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, // state. htlcContextState := invoice.State if settleEligibleAMP { - htlcContextState = ContractSettled + htlcContextState = invpkg.ContractSettled } htlcSettled, err := updateHtlc( now, htlc, htlcContextState, setID, @@ -2931,17 +2098,22 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, // meta data state. if htlcSettled && invoiceIsAMP { settleHtlcsAmp( - &invoice, settledSetIDs, htlcsAmpUpdate, htlc, key, + &invoice, settledSetIDs, htlcsAmpUpdate, htlc, + key, ) } - invoiceStateReady := (htlc.State == HtlcStateAccepted || - htlc.State == HtlcStateSettled) + accepted := htlc.State == invpkg.HtlcStateAccepted + settled := htlc.State == invpkg.HtlcStateSettled + invoiceStateReady := accepted || settled + if !invoiceIsAMP { // Update the running amount paid to this invoice. We // don't include accepted htlcs when the invoice is // still open. - if invoice.State != ContractOpen && invoiceStateReady { + if invoice.State != invpkg.ContractOpen && + invoiceStateReady { + amtPaid += htlc.Amt } } else { @@ -2956,7 +2128,9 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, // Update the running amount paid to this invoice. AMP // invoices never go to the settled state, so if it's // open, then we tally the HTLC. - if invoice.State == ContractOpen && invoiceStateReady { + if invoice.State == invpkg.ContractOpen && + invoiceStateReady { + amtPaid += htlc.Amt } } @@ -3009,12 +2183,12 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices, // updateInvoiceState validates and processes an invoice state update. The new // state to transition to is returned, so the caller is able to select exactly // how the invoice state is updated. -func updateInvoiceState(invoice *Invoice, hash *lntypes.Hash, - update InvoiceStateUpdateDesc) (*ContractState, error) { +func updateInvoiceState(invoice *invpkg.Invoice, hash *lntypes.Hash, + update invpkg.InvoiceStateUpdateDesc) (*invpkg.ContractState, error) { // Returning to open is never allowed from any state. - if update.NewState == ContractOpen { - return nil, ErrInvoiceCannotOpen + if update.NewState == invpkg.ContractOpen { + return nil, invpkg.ErrInvoiceCannotOpen } switch invoice.State { @@ -3022,9 +2196,9 @@ func updateInvoiceState(invoice *Invoice, hash *lntypes.Hash, // canceled. Forbid transitioning back into this state. Otherwise this // state is identical to ContractOpen, so we fallthrough to apply the // same checks that we apply to open invoices. - case ContractAccepted: - if update.NewState == ContractAccepted { - return nil, ErrInvoiceCannotAccept + case invpkg.ContractAccepted: + if update.NewState == invpkg.ContractAccepted { + return nil, invpkg.ErrInvoiceCannotAccept } fallthrough @@ -3032,15 +2206,16 @@ func updateInvoiceState(invoice *Invoice, hash *lntypes.Hash, // If a contract is open, permit a state transition to accepted, settled // or canceled. The only restriction is on transitioning to settled // where we ensure the preimage is valid. - case ContractOpen: - if update.NewState == ContractCanceled { + case invpkg.ContractOpen: + if update.NewState == invpkg.ContractCanceled { return &update.NewState, nil } // Sanity check that the user isn't trying to settle or accept a // non-existent HTLC set. - if len(invoice.HTLCSet(update.SetID, HtlcStateAccepted)) == 0 { - return nil, ErrEmptyHTLCSet + set := invoice.HTLCSet(update.SetID, invpkg.HtlcStateAccepted) + if len(set) == 0 { + return nil, invpkg.ErrEmptyHTLCSet } // For AMP invoices, there are no invoice-level preimage checks. @@ -3058,23 +2233,23 @@ func updateInvoiceState(invoice *Invoice, hash *lntypes.Hash, // If an invoice-level preimage was supplied, but the InvoiceRef // doesn't specify a hash (e.g. AMP invoices) we fail. case update.Preimage != nil && hash == nil: - return nil, ErrUnexpectedInvoicePreimage + return nil, invpkg.ErrUnexpectedInvoicePreimage // Validate the supplied preimage for non-AMP invoices. case update.Preimage != nil: if update.Preimage.Hash() != *hash { - return nil, ErrInvoicePreimageMismatch + return nil, invpkg.ErrInvoicePreimageMismatch } invoice.Terms.PaymentPreimage = update.Preimage // Permit non-AMP invoices to be accepted without knowing the // preimage. When trying to settle we'll have to pass through // the above check in order to not hit the one below. - case update.NewState == ContractAccepted: + case update.NewState == invpkg.ContractAccepted: // Fail if we still don't have a preimage when transitioning to // settle the non-AMP invoice. - case update.NewState == ContractSettled && + case update.NewState == invpkg.ContractSettled && invoice.Terms.PaymentPreimage == nil: return nil, errors.New("unknown preimage") @@ -3083,35 +2258,36 @@ func updateInvoiceState(invoice *Invoice, hash *lntypes.Hash, return &update.NewState, nil // Once settled, we are in a terminal state. - case ContractSettled: - return nil, ErrInvoiceAlreadySettled + case invpkg.ContractSettled: + return nil, invpkg.ErrInvoiceAlreadySettled // Once canceled, we are in a terminal state. - case ContractCanceled: - return nil, ErrInvoiceAlreadyCanceled + case invpkg.ContractCanceled: + return nil, invpkg.ErrInvoiceAlreadyCanceled default: return nil, errors.New("unknown state transition") } } -// cancelSingleHtlc validates cancellation of a single htlc and update its state. -func cancelSingleHtlc(resolveTime time.Time, htlc *InvoiceHTLC, - invState ContractState) error { +// cancelSingleHtlc validates cancellation of a single htlc and update its +// state. +func cancelSingleHtlc(resolveTime time.Time, htlc *invpkg.InvoiceHTLC, + invState invpkg.ContractState) error { // It is only possible to cancel individual htlcs on an open invoice. - if invState != ContractOpen { + if invState != invpkg.ContractOpen { return fmt.Errorf("htlc canceled on invoice in "+ "state %v", invState) } // It is only possible if the htlc is still pending. - if htlc.State != HtlcStateAccepted { + if htlc.State != invpkg.HtlcStateAccepted { return fmt.Errorf("htlc canceled in state %v", htlc.State) } - htlc.State = HtlcStateCanceled + htlc.State = invpkg.HtlcStateCanceled htlc.ResolveTime = resolveTime return nil @@ -3119,11 +2295,11 @@ func cancelSingleHtlc(resolveTime time.Time, htlc *InvoiceHTLC, // updateHtlc aligns the state of an htlc with the given invoice state. A // boolean is returned if the HTLC was settled. -func updateHtlc(resolveTime time.Time, htlc *InvoiceHTLC, - invState ContractState, setID *[32]byte) (bool, error) { +func updateHtlc(resolveTime time.Time, htlc *invpkg.InvoiceHTLC, + invState invpkg.ContractState, setID *[32]byte) (bool, error) { trySettle := func(persist bool) (bool, error) { - if htlc.State != HtlcStateAccepted { + if htlc.State != invpkg.HtlcStateAccepted { return false, nil } @@ -3131,7 +2307,7 @@ func updateHtlc(resolveTime time.Time, htlc *InvoiceHTLC, // there're other HTLCs with distinct setIDs, then we'll leave // them, as they may eventually be settled as we permit // multiple settles to a single pay_addr for AMP. - var htlcState HtlcState + var htlcState invpkg.HtlcState if htlc.IsInHTLCSet(setID) { // Non-AMP HTLCs can be settled immediately since we // already know the preimage is valid due to checks at @@ -3149,29 +2325,29 @@ func updateHtlc(resolveTime time.Time, htlc *InvoiceHTLC, // // Fail if an accepted AMP HTLC has no preimage. case htlc.AMP.Preimage == nil: - return false, ErrHTLCPreimageMissing + return false, invpkg.ErrHTLCPreimageMissing // Fail if the accepted AMP HTLC has an invalid // preimage. case !htlc.AMP.Preimage.Matches(htlc.AMP.Hash): - return false, ErrHTLCPreimageMismatch + return false, invpkg.ErrHTLCPreimageMismatch } - htlcState = HtlcStateSettled + htlcState = invpkg.HtlcStateSettled } // Only persist the changes if the invoice is moving to the // settled state, and we're actually updating the state to // settled. - if persist && htlcState == HtlcStateSettled { + if persist && htlcState == invpkg.HtlcStateSettled { htlc.State = htlcState htlc.ResolveTime = resolveTime } - return persist && htlcState == HtlcStateSettled, nil + return persist && htlcState == invpkg.HtlcStateSettled, nil } - if invState == ContractSettled { + if invState == invpkg.ContractSettled { // Check that we can settle the HTLCs. For legacy and MPP HTLCs // this will be a NOP, but for AMP HTLCs this asserts that we // have a valid hash/preimage pair. Passing true permits the @@ -3181,20 +2357,20 @@ func updateHtlc(resolveTime time.Time, htlc *InvoiceHTLC, // We should never find a settled HTLC on an invoice that isn't in // ContractSettled. - if htlc.State == HtlcStateSettled { - return false, ErrHTLCAlreadySettled + if htlc.State == invpkg.HtlcStateSettled { + return false, invpkg.ErrHTLCAlreadySettled } switch invState { - case ContractCanceled: - if htlc.State == HtlcStateAccepted { - htlc.State = HtlcStateCanceled + case invpkg.ContractCanceled: + if htlc.State == invpkg.HtlcStateAccepted { + htlc.State = invpkg.HtlcStateCanceled htlc.ResolveTime = resolveTime } return false, nil // TODO(roasbeef): never fully passed thru now? - case ContractAccepted: + case invpkg.ContractAccepted: // Check that we can settle the HTLCs. For legacy and MPP HTLCs // this will be a NOP, but for AMP HTLCs this asserts that we // have a valid hash/preimage pair. Passing false prevents the @@ -3202,7 +2378,7 @@ func updateHtlc(resolveTime time.Time, htlc *InvoiceHTLC, // in HtlcStateAccepted. return trySettle(false) - case ContractOpen: + case invpkg.ContractOpen: return false, nil default: @@ -3215,7 +2391,7 @@ func updateHtlc(resolveTime time.Time, htlc *InvoiceHTLC, // the invoice number as well, in order to allow us to detect repeated payments // to the same AMP invoices "across time". func setSettleMetaFields(settleIndex kvdb.RwBucket, invoiceNum []byte, - invoice *Invoice, now time.Time, setID *SetID) error { + invoice *invpkg.Invoice, now time.Time, setID *invpkg.SetID) error { // Now that we know the invoice hasn't already been settled, we'll // update the settle index so we can place this settle event in the @@ -3238,7 +2414,8 @@ func setSettleMetaFields(settleIndex kvdb.RwBucket, invoiceNum []byte, var seqNoBytes [8]byte byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) - if err := settleIndex.Put(seqNoBytes[:], indexKey[:valueLen]); err != nil { + err = settleIndex.Put(seqNoBytes[:], indexKey[:valueLen]) + if err != nil { return err } @@ -3270,10 +2447,13 @@ func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error { // cursor simply to collect the set of keys we need to delete, _then_ // delete them in another pass. var keysToDel [][]byte - err := forEachAMPInvoice(invoiceBucket, invoiceNum, func(cursorKey, v []byte) error { - keysToDel = append(keysToDel, cursorKey) - return nil - }) + err := forEachAMPInvoice( + invoiceBucket, invoiceNum, + func(cursorKey, v []byte) error { + keysToDel = append(keysToDel, cursorKey) + return nil + }, + ) if err != nil { return err } @@ -3290,7 +2470,9 @@ func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error { // delAMPSettleIndex removes all the entries in the settle index associated // with a given AMP invoice. -func delAMPSettleIndex(invoiceNum []byte, invoices, settleIndex kvdb.RwBucket) error { +func delAMPSettleIndex(invoiceNum []byte, invoices, + settleIndex kvdb.RwBucket) error { + // First, we need to grab the AMP invoice state to see if there's // anything that we even need to delete. ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices) @@ -3298,7 +2480,8 @@ func delAMPSettleIndex(invoiceNum []byte, invoices, settleIndex kvdb.RwBucket) e return err } - // If there's no AMP state at all (non-AMP invoice), then we can return early. + // If there's no AMP state at all (non-AMP invoice), then we can return + // early. if len(ampState) == 0 { return nil } @@ -3319,46 +2502,28 @@ func delAMPSettleIndex(invoiceNum []byte, invoices, settleIndex kvdb.RwBucket) e return nil } -// InvoiceDeleteRef holds a reference to an invoice to be deleted. -type InvoiceDeleteRef struct { - // PayHash is the payment hash of the target invoice. All invoices are - // currently indexed by payment hash. - PayHash lntypes.Hash - - // PayAddr is the payment addr of the target invoice. Newer invoices - // (0.11 and up) are indexed by payment address in addition to payment - // hash, but pre 0.8 invoices do not have one at all. - PayAddr *[32]byte - - // AddIndex is the add index of the invoice. - AddIndex uint64 - - // SettleIndex is the settle index of the invoice. - SettleIndex uint64 -} - // DeleteInvoice attempts to delete the passed invoices from the database in // one transaction. The passed delete references hold all keys required to // delete the invoices without also needing to deserialze them. -func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error { +func (d *DB) DeleteInvoice(invoicesToDelete []invpkg.InvoiceDeleteRef) error { err := kvdb.Update(d, func(tx kvdb.RwTx) error { invoices := tx.ReadWriteBucket(invoiceBucket) if invoices == nil { - return ErrNoInvoicesCreated + return invpkg.ErrNoInvoicesCreated } invoiceIndex := invoices.NestedReadWriteBucket( invoiceIndexBucket, ) if invoiceIndex == nil { - return ErrNoInvoicesCreated + return invpkg.ErrNoInvoicesCreated } invoiceAddIndex := invoices.NestedReadWriteBucket( addIndexBucket, ) if invoiceAddIndex == nil { - return ErrNoInvoicesCreated + return invpkg.ErrNoInvoicesCreated } // settleIndex can be nil, as the bucket is created lazily @@ -3369,10 +2534,11 @@ func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error { for _, ref := range invoicesToDelete { // Fetch the invoice key for using it to check for - // consistency and also to delete from the invoice index. + // consistency and also to delete from the invoice + // index. invoiceKey := invoiceIndex.Get(ref.PayHash[:]) if invoiceKey == nil { - return ErrInvoiceNotFound + return invpkg.ErrInvoiceNotFound } err := invoiceIndex.Delete(ref.PayHash[:]) @@ -3388,12 +2554,12 @@ func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error { // payment address index. key := payAddrIndex.Get(ref.PayAddr[:]) if bytes.Equal(key, invoiceKey) { - // Delete from the payment address index. - // Note that since the payment address - // index has been introduced with an - // empty migration it may be possible - // that the index doesn't have an entry - // for this invoice. + // Delete from the payment address + // index. Note that since the payment + // address index has been introduced + // with an empty migration it may be + // possible that the index doesn't have + // an entry for this invoice. // ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127 if err := payAddrIndex.Delete( ref.PayAddr[:], diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index b69b068f0..57d6675a3 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -17,6 +17,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lntypes" @@ -1748,7 +1749,7 @@ func (c *ChannelArbitrator) isPreimageAvailable(hash lntypes.Hash) (bool, invoice, err := c.cfg.Registry.LookupInvoice(hash) switch err { case nil: - case channeldb.ErrInvoiceNotFound, channeldb.ErrNoInvoicesCreated: + case invoices.ErrInvoiceNotFound, invoices.ErrNoInvoicesCreated: return false, nil default: return false, err diff --git a/contractcourt/interfaces.go b/contractcourt/interfaces.go index d8c603228..49e1f913a 100644 --- a/contractcourt/interfaces.go +++ b/contractcourt/interfaces.go @@ -19,7 +19,7 @@ import ( type Registry interface { // LookupInvoice attempts to look up an invoice according to its 32 // byte payment hash. - LookupInvoice(lntypes.Hash) (channeldb.Invoice, error) + LookupInvoice(lntypes.Hash) (invoices.Invoice, error) // NotifyExitHopHtlc attempts to mark an invoice as settled. If the // invoice is a debug invoice, then this method is a noop as debug diff --git a/contractcourt/mock_registry_test.go b/contractcourt/mock_registry_test.go index e3fb45202..19105a77e 100644 --- a/contractcourt/mock_registry_test.go +++ b/contractcourt/mock_registry_test.go @@ -1,7 +1,6 @@ package contractcourt import ( - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" @@ -40,8 +39,8 @@ func (r *mockRegistry) NotifyExitHopHtlc(payHash lntypes.Hash, func (r *mockRegistry) HodlUnsubscribeAll(subscriber chan<- interface{}) {} -func (r *mockRegistry) LookupInvoice(lntypes.Hash) (channeldb.Invoice, +func (r *mockRegistry) LookupInvoice(lntypes.Hash) (invoices.Invoice, error) { - return channeldb.Invoice{}, channeldb.ErrInvoiceNotFound + return invoices.Invoice{}, invoices.ErrInvoiceNotFound } diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 85c219d77..d24f48193 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -18,7 +18,7 @@ import ( type InvoiceDatabase interface { // LookupInvoice attempts to look up an invoice according to its 32 // byte payment hash. - LookupInvoice(lntypes.Hash) (channeldb.Invoice, error) + LookupInvoice(lntypes.Hash) (invoices.Invoice, error) // NotifyExitHopHtlc attempts to mark an invoice as settled. If the // invoice is a debug invoice, then this method is a noop as debug diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index bf3916651..d0aac8551 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -27,6 +27,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" + invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lntest/wait" @@ -479,7 +480,7 @@ func TestChannelLinkSingleHopPayment(t *testing.T) { // links was changed. invoice, err := receiver.registry.LookupInvoice(rhash) require.NoError(t, err, "unable to get invoice") - if invoice.State != channeldb.ContractSettled { + if invoice.State != invpkg.ContractSettled { t.Fatal("alice invoice wasn't settled") } @@ -597,7 +598,7 @@ func testChannelLinkMultiHopPayment(t *testing.T, // links were changed. invoice, err := receiver.registry.LookupInvoice(rhash) require.NoError(t, err, "unable to get invoice") - if invoice.State != channeldb.ContractSettled { + if invoice.State != invpkg.ContractSettled { t.Fatal("carol invoice haven't been settled") } @@ -1080,7 +1081,7 @@ func TestUpdateForwardingPolicy(t *testing.T) { // succeeded. invoice, err := n.carolServer.registry.LookupInvoice(payResp) require.NoError(t, err, "unable to get invoice") - if invoice.State != channeldb.ContractSettled { + if invoice.State != invpkg.ContractSettled { t.Fatal("carol invoice haven't been settled") } @@ -1234,7 +1235,7 @@ func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) { // links hasn't been changed. invoice, err := receiver.registry.LookupInvoice(rhash) require.NoError(t, err, "unable to get invoice") - if invoice.State == channeldb.ContractSettled { + if invoice.State == invpkg.ContractSettled { t.Fatal("carol invoice have been settled") } @@ -1412,7 +1413,7 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { // links hasn't been changed. invoice, err := receiver.registry.LookupInvoice(rhash) require.NoError(t, err, "unable to get invoice") - if invoice.State == channeldb.ContractSettled { + if invoice.State == invpkg.ContractSettled { t.Fatal("carol invoice have been settled") } @@ -1520,7 +1521,7 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { // links hasn't been changed. invoice, err := receiver.registry.LookupInvoice(rhash) require.NoError(t, err, "unable to get invoice") - if invoice.State == channeldb.ContractSettled { + if invoice.State == invpkg.ContractSettled { t.Fatal("carol invoice have been settled") } @@ -3498,7 +3499,7 @@ func TestChannelRetransmission(t *testing.T) { // TODO(andrew.shvv) Will be removed if we move the notification center // to the channel link itself. - var invoice channeldb.Invoice + var invoice invpkg.Invoice for i := 0; i < 20; i++ { select { case <-time.After(time.Millisecond * 200): @@ -3513,7 +3514,7 @@ func TestChannelRetransmission(t *testing.T) { err = errors.Errorf("unable to get invoice: %v", err) continue } - if invoice.State != channeldb.ContractSettled { + if invoice.State != invpkg.ContractSettled { err = errors.Errorf("alice invoice haven't been settled") continue } @@ -4059,7 +4060,7 @@ func TestChannelLinkAcceptOverpay(t *testing.T) { // accepted the payment and marked it as settled. invoice, err := receiver.registry.LookupInvoice(rhash) require.NoError(t, err, "unable to get invoice") - if invoice.State != channeldb.ContractSettled { + if invoice.State != invpkg.ContractSettled { t.Fatal("carol invoice haven't been settled") } @@ -4383,7 +4384,7 @@ func generateHtlc(t *testing.T, coreLink *channelLink, // generateHtlcAndInvoice generates an invoice and a single hop htlc to send to // the receiver. func generateHtlcAndInvoice(t *testing.T, - id uint64) (*lnwire.UpdateAddHTLC, *channeldb.Invoice) { + id uint64) (*lnwire.UpdateAddHTLC, *invpkg.Invoice) { t.Helper() @@ -4778,7 +4779,7 @@ func testChannelLinkBatchPreimageWrite(t *testing.T, disconnect bool) { // We will send 10 HTLCs in total, from Bob to Alice. numHtlcs := 10 var htlcs []*lnwire.UpdateAddHTLC - var invoices []*channeldb.Invoice + var invoices []*invpkg.Invoice for i := 0; i < numHtlcs; i++ { htlc, invoice := generateHtlcAndInvoice(t, uint64(i)) htlcs = append(htlcs, htlc) diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index f8eedd857..866e7d36f 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -980,7 +980,7 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { } func (i *mockInvoiceRegistry) LookupInvoice(rHash lntypes.Hash) ( - channeldb.Invoice, error) { + invoices.Invoice, error) { return i.registry.LookupInvoice(rHash) } @@ -1014,7 +1014,7 @@ func (i *mockInvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { return i.registry.CancelInvoice(payHash) } -func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice, +func (i *mockInvoiceRegistry) AddInvoice(invoice invoices.Invoice, paymentHash lntypes.Hash) error { _, err := i.registry.AddInvoice(&invoice, paymentHash) diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index b9ba325ab..6e8defe3f 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -26,6 +26,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" @@ -510,7 +511,7 @@ func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) { func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, blob [lnwire.OnionPacketSize]byte, preimage *lntypes.Preimage, rhash, payAddr [32]byte) ( - *channeldb.Invoice, *lnwire.UpdateAddHTLC, uint64, error) { + *invoices.Invoice, *lnwire.UpdateAddHTLC, uint64, error) { // Create the db invoice. Normally the payment requests needs to be set, // because it is decoded in InvoiceRegistry to obtain the cltv expiry. @@ -519,9 +520,9 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, // don't need to bother here with creating and signing a payment // request. - invoice := &channeldb.Invoice{ + invoice := &invoices.Invoice{ CreationDate: time.Now(), - Terms: channeldb.ContractTerm{ + Terms: invoices.ContractTerm{ FinalCltvDelta: testInvoiceCltvExpiry, Value: invoiceAmt, PaymentPreimage: preimage, @@ -552,7 +553,7 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, // generatePayment generates the htlc add request by given path blob and // invoice which should be added by destination peer. func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, - blob [lnwire.OnionPacketSize]byte) (*channeldb.Invoice, + blob [lnwire.OnionPacketSize]byte) (*invoices.Invoice, *lnwire.UpdateAddHTLC, uint64, error) { var preimage lntypes.Preimage @@ -753,7 +754,7 @@ func makePayment(sendingPeer, receivingPeer lnpeer.Peer, func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, firstHop lnwire.ShortChannelID, hops []*hop.Payload, invoiceAmt, htlcAmt lnwire.MilliSatoshi, - timelock uint32) (*channeldb.Invoice, func() error, error) { + timelock uint32) (*invoices.Invoice, func() error, error) { sender := sendingPeer.(*mockServer) receiver := receivingPeer.(*mockServer) diff --git a/invoices/errors.go b/invoices/errors.go new file mode 100644 index 000000000..4cc592357 --- /dev/null +++ b/invoices/errors.go @@ -0,0 +1,108 @@ +package invoices + +import ( + "errors" + "fmt" +) + +var ( + // ErrInvoiceAlreadySettled is returned when the invoice is already + // settled. + ErrInvoiceAlreadySettled = errors.New("invoice already settled") + + // ErrInvoiceAlreadyCanceled is returned when the invoice is already + // canceled. + ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled") + + // ErrInvoiceAlreadyAccepted is returned when the invoice is already + // accepted. + ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted") + + // ErrInvoiceStillOpen is returned when the invoice is still open. + ErrInvoiceStillOpen = errors.New("invoice still open") + + // ErrInvoiceCannotOpen is returned when an attempt is made to move an + // invoice to the open state. + ErrInvoiceCannotOpen = errors.New("cannot move invoice to open") + + // ErrInvoiceCannotAccept is returned when an attempt is made to accept + // an invoice while the invoice is not in the open state. + ErrInvoiceCannotAccept = errors.New("cannot accept invoice") + + // ErrInvoicePreimageMismatch is returned when the preimage doesn't + // match the invoice hash. + ErrInvoicePreimageMismatch = errors.New("preimage does not match") + + // ErrHTLCPreimageMissing is returned when trying to accept/settle an + // AMP HTLC but the HTLC-level preimage has not been set. + ErrHTLCPreimageMissing = errors.New("AMP htlc missing preimage") + + // ErrHTLCPreimageMismatch is returned when trying to accept/settle an + // AMP HTLC but the HTLC-level preimage does not satisfying the + // HTLC-level payment hash. + ErrHTLCPreimageMismatch = errors.New("htlc preimage mismatch") + + // ErrHTLCAlreadySettled is returned when trying to settle an invoice + // but HTLC already exists in the settled state. + ErrHTLCAlreadySettled = errors.New("htlc already settled") + + // ErrInvoiceHasHtlcs is returned when attempting to insert an invoice + // that already has HTLCs. + ErrInvoiceHasHtlcs = errors.New("cannot add invoice with htlcs") + + // ErrEmptyHTLCSet is returned when attempting to accept or settle and + // HTLC set that has no HTLCs. + ErrEmptyHTLCSet = errors.New("cannot settle/accept empty HTLC set") + + // ErrUnexpectedInvoicePreimage is returned when an invoice-level + // preimage is provided when trying to settle an invoice that shouldn't + // have one, e.g. an AMP invoice. + ErrUnexpectedInvoicePreimage = errors.New( + "unexpected invoice preimage provided on settle", + ) + + // ErrHTLCPreimageAlreadyExists is returned when trying to set an + // htlc-level preimage but one is already known. + ErrHTLCPreimageAlreadyExists = errors.New( + "htlc-level preimage already exists", + ) + + // ErrInvoiceNotFound is returned when a targeted invoice can't be + // found. + ErrInvoiceNotFound = errors.New("unable to locate invoice") + + // ErrNoInvoicesCreated is returned when we don't have invoices in + // our database to return. + ErrNoInvoicesCreated = errors.New("there are no existing invoices") + + // ErrDuplicateInvoice is returned when an invoice with the target + // payment hash already exists. + ErrDuplicateInvoice = errors.New( + "invoice with payment hash already exists", + ) + + // ErrDuplicatePayAddr is returned when an invoice with the target + // payment addr already exists. + ErrDuplicatePayAddr = errors.New( + "invoice with payemnt addr already exists", + ) + + // ErrInvRefEquivocation is returned when an InvoiceRef targets + // multiple, distinct invoices. + ErrInvRefEquivocation = errors.New("inv ref matches multiple invoices") + + // ErrNoPaymentsCreated is returned when bucket of payments hasn't been + // created. + ErrNoPaymentsCreated = errors.New("there are no existing payments") +) + +// ErrDuplicateSetID is an error returned when attempting to adding an AMP HTLC +// to an invoice, but another invoice is already indexed by the same set id. +type ErrDuplicateSetID struct { + SetID [32]byte +} + +// Error returns a human-readable description of ErrDuplicateSetID. +func (e ErrDuplicateSetID) Error() string { + return fmt.Sprintf("invoice with set_id=%x already exists", e.SetID) +} diff --git a/invoices/interface.go b/invoices/interface.go index 804817dd8..f717afd4e 100644 --- a/invoices/interface.go +++ b/invoices/interface.go @@ -1,9 +1,91 @@ package invoices import ( + "time" + + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/record" ) +// InvScanFunc is a helper type used to specify the type used in the +// ScanInvoices methods (part of the InvoiceDB interface). +type InvScanFunc func(lntypes.Hash, *Invoice) error + +// InvoiceDB is the database that stores the information about invoices. +type InvoiceDB interface { + // AddInvoice inserts the targeted invoice into the database. + // If the invoice has *any* payment hashes which already exists within + // the database, then the insertion will be aborted and rejected due to + // the strict policy banning any duplicate payment hashes. + // + // NOTE: A side effect of this function is that it sets AddIndex on + // newInvoice. + AddInvoice(invoice *Invoice, paymentHash lntypes.Hash) (uint64, error) + + // InvoicesAddedSince can be used by callers to seek into the event + // time series of all the invoices added in the database. The specified + // sinceAddIndex should be the highest add index that the caller knows + // of. This method will return all invoices with an add index greater + // than the specified sinceAddIndex. + // + // NOTE: The index starts from 1, as a result. We enforce that + // specifying a value below the starting index value is a noop. + InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) + + // LookupInvoice attempts to look up an invoice according to its 32 byte + // payment hash. If an invoice which can settle the HTLC identified by + // the passed payment hash isn't found, then an error is returned. + // Otherwise, the full invoice is returned. + // Before setting the incoming HTLC, the values SHOULD be checked to + // ensure the payer meets the agreed upon contractual terms of the + // payment. + LookupInvoice(ref InvoiceRef) (Invoice, error) + + // ScanInvoices scans through all invoices and calls the passed scanFunc + // for each invoice with its respective payment hash. Additionally a + // reset() closure is passed which is used to reset/initialize partial + // results and also to signal if the kvdb.View transaction has been + // retried. + // + // TODO(positiveblue): abstract this functionality so it makes sense for + // other backends like sql. + ScanInvoices(scanFunc InvScanFunc, reset func()) error + + // QueryInvoices allows a caller to query the invoice database for + // invoices within the specified add index range. + QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) + + // UpdateInvoice attempts to update an invoice corresponding to the + // passed payment hash. If an invoice matching the passed payment hash + // doesn't exist within the database, then the action will fail with a + // "not found" error. + // + // The update is performed inside the same database transaction that + // fetches the invoice and is therefore atomic. The fields to update + // are controlled by the supplied callback. + // + // TODO(positiveblue): abstract this functionality so it makes sense for + // other backends like sql. + UpdateInvoice(ref InvoiceRef, setIDHint *SetID, + callback InvoiceUpdateCallback) (*Invoice, error) + + // InvoicesSettledSince can be used by callers to catch up any settled + // invoices they missed within the settled invoice time series. We'll + // return all known settled invoice that have a settle index higher than + // the passed sinceSettleIndex. + // + // NOTE: The index starts from 1, as a result. We enforce that + // specifying a value below the starting index value is a noop. + InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) + + // DeleteInvoice attempts to delete the passed invoices from the + // database in one transaction. The passed delete references hold all + // keys required to delete the invoices without also needing to + // deserialze them. + DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error +} + // Payload abstracts access to any additional fields provided in the final hop's // TLV onion payload. type Payload interface { @@ -23,3 +105,61 @@ type Payload interface { // payment to the payee. Metadata() []byte } + +// InvoiceQuery represents a query to the invoice database. The query allows a +// caller to retrieve all invoices starting from a particular add index and +// limit the number of results returned. +type InvoiceQuery struct { + // IndexOffset is the offset within the add indices to start at. This + // can be used to start the response at a particular invoice. + IndexOffset uint64 + + // NumMaxInvoices is the maximum number of invoices that should be + // starting from the add index. + NumMaxInvoices uint64 + + // PendingOnly, if set, returns unsettled invoices starting from the + // add index. + PendingOnly bool + + // Reversed, if set, indicates that the invoices returned should start + // from the IndexOffset and go backwards. + Reversed bool + + // CreationDateStart, if set, filters out all invoices with a creation + // date greater than or euqal to it. + CreationDateStart time.Time + + // CreationDateEnd, if set, filters out all invoices with a creation + // date less than or euqal to it. + CreationDateEnd time.Time +} + +// InvoiceSlice is the response to a invoice query. It includes the original +// query, the set of invoices that match the query, and an integer which +// represents the offset index of the last item in the set of returned invoices. +// This integer allows callers to resume their query using this offset in the +// event that the query's response exceeds the maximum number of returnable +// invoices. +type InvoiceSlice struct { + InvoiceQuery + + // Invoices is the set of invoices that matched the query above. + Invoices []Invoice + + // FirstIndexOffset is the index of the first element in the set of + // returned Invoices above. Callers can use this to resume their query + // in the event that the slice has too many events to fit into a single + // response. + FirstIndexOffset uint64 + + // LastIndexOffset is the index of the last element in the set of + // returned Invoices above. Callers can use this to resume their query + // in the event that the slice has too many events to fit into a single + // response. + LastIndexOffset uint64 +} + +// CircuitKey is a tuple of channel ID and HTLC ID, used to uniquely identify +// HTLCs in a circuit. +type CircuitKey = models.CircuitKey diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index 5038e5e84..f16a7a18d 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -7,7 +7,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/queue" @@ -178,18 +177,18 @@ func (ew *InvoiceExpiryWatcher) Stop() { // makeInvoiceExpiry checks if the passed invoice may be canceled and calculates // the expiry time and creates a slimmer invoiceExpiry implementation. func makeInvoiceExpiry(paymentHash lntypes.Hash, - invoice *channeldb.Invoice) invoiceExpiry { + invoice *Invoice) invoiceExpiry { switch invoice.State { // If we have an open invoice with no htlcs, we want to expire the // invoice based on timestamp - case channeldb.ContractOpen: + case ContractOpen: return makeTimestampExpiry(paymentHash, invoice) // If an invoice has active htlcs, we want to expire it based on block // height. We only do this for hodl invoices, since regular invoices // should resolve themselves automatically. - case channeldb.ContractAccepted: + case ContractAccepted: if !invoice.HodlInvoice { log.Debugf("Invoice in accepted state not added to "+ "expiry watcher: %v", paymentHash) @@ -201,7 +200,7 @@ func makeInvoiceExpiry(paymentHash lntypes.Hash, for _, htlc := range invoice.Htlcs { // We only care about accepted htlcs, since they will // trigger force-closes. - if htlc.State != channeldb.HtlcStateAccepted { + if htlc.State != HtlcStateAccepted { continue } @@ -222,9 +221,9 @@ func makeInvoiceExpiry(paymentHash lntypes.Hash, // makeTimestampExpiry creates a timestamp-based expiry entry. func makeTimestampExpiry(paymentHash lntypes.Hash, - invoice *channeldb.Invoice) *invoiceExpiryTs { + invoice *Invoice) *invoiceExpiryTs { - if invoice.State != channeldb.ContractOpen { + if invoice.State != ContractOpen { return nil } @@ -349,11 +348,11 @@ func (ew *InvoiceExpiryWatcher) expireInvoice(hash lntypes.Hash, force bool) { switch err { case nil: - case channeldb.ErrInvoiceAlreadyCanceled: + case ErrInvoiceAlreadyCanceled: - case channeldb.ErrInvoiceAlreadySettled: + case ErrInvoiceAlreadySettled: - case channeldb.ErrInvoiceNotFound: + case ErrInvoiceNotFound: // It's possible that the user has manually canceled the invoice // which will then be deleted by the garbage collector resulting // in an ErrInvoiceNotFound error. diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 9f32ecd48..5242ed004 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -5,8 +5,6 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" @@ -22,29 +20,6 @@ type invoiceExpiryWatcherTest struct { canceledInvoices []lntypes.Hash } -type mockChainNotifier struct { - chainntnfs.ChainNotifier - - blockChan chan *chainntnfs.BlockEpoch -} - -func newMockNotifier() *mockChainNotifier { - return &mockChainNotifier{ - blockChan: make(chan *chainntnfs.BlockEpoch), - } -} - -// RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's -// block channel to deliver blocks to the client. -func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( - *chainntnfs.BlockEpochEvent, error) { - - return &chainntnfs.BlockEpochEvent{ - Epochs: m.blockChan, - Cancel: func() {}, - }, nil -} - // newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture // and sets up the test environment. func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time, @@ -213,7 +188,7 @@ func TestExpiredHodlInv(t *testing.T) { expiry := time.Hour test := setupHodlExpiry( - t, creationDate, expiry, 0, channeldb.ContractOpen, nil, + t, creationDate, expiry, 0, ContractOpen, nil, ) test.assertCanceled(t, test.hash) @@ -231,7 +206,7 @@ func TestAcceptedHodlNotExpired(t *testing.T) { expiry := time.Hour test := setupHodlExpiry( - t, creationDate, expiry, 0, channeldb.ContractAccepted, nil, + t, creationDate, expiry, 0, ContractAccepted, nil, ) defer test.watcher.Stop() @@ -255,15 +230,15 @@ func TestAcceptedHodlNotExpired(t *testing.T) { func TestHeightAlreadyExpired(t *testing.T) { t.Parallel() - expiredHtlc := []*channeldb.InvoiceHTLC{ + expiredHtlc := []*InvoiceHTLC{ { - State: channeldb.HtlcStateAccepted, + State: HtlcStateAccepted, Expiry: uint32(testCurrentHeight), }, } test := setupHodlExpiry( - t, testTime, time.Hour, 0, channeldb.ContractAccepted, + t, testTime, time.Hour, 0, ContractAccepted, expiredHtlc, ) defer test.watcher.Stop() @@ -284,7 +259,7 @@ func TestExpiryHeightArrives(t *testing.T) { // Start out with a hodl invoice that is open, and has no htlcs. test := setupHodlExpiry( - t, creationDate, expiry, delta, channeldb.ContractOpen, nil, + t, creationDate, expiry, delta, ContractOpen, nil, ) defer test.watcher.Stop() @@ -293,7 +268,7 @@ func TestExpiryHeightArrives(t *testing.T) { // Add htlcs to our invoice and progress its state to accepted. test.watcher.AddInvoices(expiry1) - test.setState(channeldb.ContractAccepted) + test.setState(ContractAccepted) // Progress time so that our expiry has elapsed. We no longer expect // this invoice to be canceled because it has been accepted. diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index e96bc53e4..a80a85c8d 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -7,8 +7,6 @@ import ( "sync/atomic" "time" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -17,13 +15,15 @@ import ( ) var ( - // ErrInvoiceExpiryTooSoon is returned when an invoice is attempted to be - // accepted or settled with not enough blocks remaining. + // ErrInvoiceExpiryTooSoon is returned when an invoice is attempted to + // be accepted or settled with not enough blocks remaining. ErrInvoiceExpiryTooSoon = errors.New("invoice expiry too soon") - // ErrInvoiceAmountTooLow is returned when an invoice is attempted to be - // accepted or settled with an amount that is too low. - ErrInvoiceAmountTooLow = errors.New("paid amount less than invoice amount") + // ErrInvoiceAmountTooLow is returned when an invoice is attempted to + // be accepted or settled with an amount that is too low. + ErrInvoiceAmountTooLow = errors.New( + "paid amount less than invoice amount", + ) // ErrShuttingDown is returned when an operation failed because the // invoice registry is shutting down. @@ -79,10 +79,10 @@ type RegistryConfig struct { // mpp htlcs for which the complete set didn't arrive in time. type htlcReleaseEvent struct { // invoiceRef identifiers the invoice this htlc belongs to. - invoiceRef channeldb.InvoiceRef + invoiceRef InvoiceRef // key is the circuit key of the htlc to release. - key models.CircuitKey + key CircuitKey // releaseTime is the time at which to release the htlc. releaseTime time.Time @@ -104,7 +104,7 @@ type InvoiceRegistry struct { nextClientID uint32 // must be used atomically - cdb *channeldb.DB + idb InvoiceDB // cfg contains the registry's configuration parameters. cfg *RegistryConfig @@ -130,13 +130,14 @@ type InvoiceRegistry struct { // necessary to avoid deadlocks in the registry when processing invoice // events. hodlSubscriptionsMux sync.RWMutex - // subscriptions is a map from a circuit key to a list of subscribers. - // It is used for efficient notification of links. - hodlSubscriptions map[models.CircuitKey]map[chan<- interface{}]struct{} + + // hodlSubscriptions is a map from a circuit key to a list of + // subscribers. It is used for efficient notification of links. + hodlSubscriptions map[CircuitKey]map[chan<- interface{}]struct{} // reverseSubscriptions tracks circuit keys subscribed to per // subscriber. This is used to unsubscribe from all hashes efficiently. - hodlReverseSubscriptions map[chan<- interface{}]map[models.CircuitKey]struct{} + hodlReverseSubscriptions map[chan<- interface{}]map[CircuitKey]struct{} // htlcAutoReleaseChan contains the new htlcs that need to be // auto-released. @@ -152,19 +153,21 @@ type InvoiceRegistry struct { // wraps the persistent on-disk invoice storage with an additional in-memory // layer. The in-memory layer is in place such that debug invoices can be added // which are volatile yet available system wide within the daemon. -func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, +func NewRegistry(idb InvoiceDB, expiryWatcher *InvoiceExpiryWatcher, cfg *RegistryConfig) *InvoiceRegistry { + notificationClients := make(map[uint32]*InvoiceSubscription) + singleNotificationClients := make(map[uint32]*SingleInvoiceSubscription) return &InvoiceRegistry{ - cdb: cdb, - notificationClients: make(map[uint32]*InvoiceSubscription), - singleNotificationClients: make(map[uint32]*SingleInvoiceSubscription), + idb: idb, + notificationClients: notificationClients, + singleNotificationClients: singleNotificationClients, invoiceEvents: make(chan *invoiceEvent, 100), hodlSubscriptions: make( - map[models.CircuitKey]map[chan<- interface{}]struct{}, + map[CircuitKey]map[chan<- interface{}]struct{}, ), hodlReverseSubscriptions: make( - map[chan<- interface{}]map[models.CircuitKey]struct{}, + map[chan<- interface{}]map[CircuitKey]struct{}, ), cfg: cfg, htlcAutoReleaseChan: make(chan *htlcReleaseEvent), @@ -179,7 +182,7 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, func (i *InvoiceRegistry) scanInvoicesOnStart() error { var ( pending []invoiceExpiry - removable []channeldb.InvoiceDeleteRef + removable []InvoiceDeleteRef ) reset := func() { @@ -189,11 +192,11 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { // using the etcd driver, where all transactions are allowed // to retry for serializability). pending = nil - removable = make([]channeldb.InvoiceDeleteRef, 0) + removable = make([]InvoiceDeleteRef, 0) } scanFunc := func(paymentHash lntypes.Hash, - invoice *channeldb.Invoice) error { + invoice *Invoice) error { if invoice.IsPending() { expiryRef := makeInvoiceExpiry(paymentHash, invoice) @@ -201,19 +204,19 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { pending = append(pending, expiryRef) } } else if i.cfg.GcCanceledInvoicesOnStartup && - invoice.State == channeldb.ContractCanceled { + invoice.State == ContractCanceled { // Consider invoice for removal if it is already // canceled. Invoices that are expired but not yet // canceled, will be queued up for cancellation after // startup and will be deleted afterwards. - ref := channeldb.InvoiceDeleteRef{ + ref := InvoiceDeleteRef{ PayHash: paymentHash, AddIndex: invoice.AddIndex, SettleIndex: invoice.SettleIndex, } - if invoice.Terms.PaymentAddr != channeldb.BlankPayAddr { + if invoice.Terms.PaymentAddr != BlankPayAddr { ref.PayAddr = &invoice.Terms.PaymentAddr } @@ -222,7 +225,7 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { return nil } - err := i.cdb.ScanInvoices(scanFunc, reset) + err := i.idb.ScanInvoices(scanFunc, reset) if err != nil { return err } @@ -234,7 +237,7 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { if len(removable) > 0 { log.Infof("Attempting to delete %v canceled invoices", len(removable)) - if err := i.cdb.DeleteInvoice(removable); err != nil { + if err := i.idb.DeleteInvoice(removable); err != nil { log.Warnf("Deleting canceled invoices failed: %v", err) } else { log.Infof("Deleted %v canceled invoices", @@ -287,7 +290,7 @@ func (i *InvoiceRegistry) Stop() error { // instance where invoices are settled. type invoiceEvent struct { hash lntypes.Hash - invoice *channeldb.Invoice + invoice *Invoice setID *[32]byte } @@ -323,8 +326,8 @@ func (i *InvoiceRegistry) invoiceEventLoop() { // For backwards compatibility, do not notify all // invoice subscribers of cancel and accept events. state := event.invoice.State - if state != channeldb.ContractCanceled && - state != channeldb.ContractAccepted { + if state != ContractCanceled && + state != ContractAccepted { i.dispatchToClients(event) } @@ -401,7 +404,7 @@ func (i *InvoiceRegistry) dispatchToClients(event *invoiceEvent) { switch { // If we've already sent this settle event to // the client, then we can skip this. - case state == channeldb.ContractSettled && + case state == ContractSettled && client.settleIndex >= invoice.SettleIndex: continue @@ -409,14 +412,14 @@ func (i *InvoiceRegistry) dispatchToClients(event *invoiceEvent) { // the client then we can skip this one, but only if this isn't // an AMP invoice. AMP invoices always remain in the settle // state as a base invoice. - case event.setID == nil && state == channeldb.ContractOpen && + case event.setID == nil && state == ContractOpen && client.addIndex >= invoice.AddIndex: continue // These two states should never happen, but we // log them just in case so we can detect this // instance. - case state == channeldb.ContractOpen && + case state == ContractOpen && client.addIndex+1 != invoice.AddIndex: log.Warnf("client=%v for invoice "+ "notifications missed an update, "+ @@ -424,7 +427,7 @@ func (i *InvoiceRegistry) dispatchToClients(event *invoiceEvent) { clientID, client.addIndex, invoice.AddIndex) - case state == channeldb.ContractSettled && + case state == ContractSettled && client.settleIndex+1 != invoice.SettleIndex: log.Warnf("client=%v for invoice "+ "notifications missed an update, "+ @@ -456,10 +459,10 @@ func (i *InvoiceRegistry) dispatchToClients(event *invoiceEvent) { // event is added while we're catching up a new client. invState := event.invoice.State switch { - case invState == channeldb.ContractSettled: + case invState == ContractSettled: client.settleIndex = invoice.SettleIndex - case invState == channeldb.ContractOpen && event.setID == nil: + case invState == ContractOpen && event.setID == nil: client.addIndex = invoice.AddIndex // If this is an AMP invoice, then we'll need to use the set ID @@ -467,7 +470,7 @@ func (i *InvoiceRegistry) dispatchToClients(event *invoiceEvent) { // invoices never go to the open state, but if a setID is // passed, then we know it was just settled and will track the // highest settle index so far. - case invState == channeldb.ContractOpen && event.setID != nil: + case invState == ContractOpen && event.setID != nil: setID := *event.setID client.settleIndex = invoice.AMPState[setID].SettleIndex @@ -480,13 +483,15 @@ func (i *InvoiceRegistry) dispatchToClients(event *invoiceEvent) { // deliverBacklogEvents will attempts to query the invoice database for any // notifications that the client has missed since it reconnected last. -func (i *InvoiceRegistry) deliverBacklogEvents(client *InvoiceSubscription) error { - addEvents, err := i.cdb.InvoicesAddedSince(client.addIndex) +func (i *InvoiceRegistry) deliverBacklogEvents( + client *InvoiceSubscription) error { + + addEvents, err := i.idb.InvoicesAddedSince(client.addIndex) if err != nil { return err } - settleEvents, err := i.cdb.InvoicesSettledSince(client.settleIndex) + settleEvents, err := i.idb.InvoicesSettledSince(client.settleIndex) if err != nil { return err } @@ -533,13 +538,13 @@ func (i *InvoiceRegistry) deliverBacklogEvents(client *InvoiceSubscription) erro func (i *InvoiceRegistry) deliverSingleBacklogEvents( client *SingleInvoiceSubscription) error { - invoice, err := i.cdb.LookupInvoice(client.invoiceRef) + invoice, err := i.idb.LookupInvoice(client.invoiceRef) // It is possible that the invoice does not exist yet, but the client is // already watching it in anticipation. - if err == channeldb.ErrInvoiceNotFound || - err == channeldb.ErrNoInvoicesCreated { - + isNotFound := errors.Is(err, ErrInvoiceNotFound) + isNotCreated := errors.Is(err, ErrNoInvoicesCreated) + if isNotFound || isNotCreated { return nil } if err != nil { @@ -573,15 +578,15 @@ func (i *InvoiceRegistry) deliverSingleBacklogEvents( // addIndex of the newly created invoice which monotonically increases for each // new invoice added. A side effect of this function is that it also sets // AddIndex on the invoice argument. -func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, +func (i *InvoiceRegistry) AddInvoice(invoice *Invoice, paymentHash lntypes.Hash) (uint64, error) { i.Lock() - ref := channeldb.InvoiceRefByHash(paymentHash) + ref := InvoiceRefByHash(paymentHash) log.Debugf("Invoice%v: added with terms %v", ref, invoice.Terms) - addIndex, err := i.cdb.AddInvoice(invoice, paymentHash) + addIndex, err := i.idb.AddInvoice(invoice, paymentHash) if err != nil { i.Unlock() return 0, err @@ -607,27 +612,23 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, // then we're able to pull the funds pending within an HTLC. // // TODO(roasbeef): ignore if settled? -func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice, - error) { - +func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (Invoice, error) { // We'll check the database to see if there's an existing matching // invoice. - ref := channeldb.InvoiceRefByHash(rHash) - return i.cdb.LookupInvoice(ref) + ref := InvoiceRefByHash(rHash) + return i.idb.LookupInvoice(ref) } // LookupInvoiceByRef looks up an invoice by the given reference, if found // then we're able to pull the funds pending within an HTLC. -func (i *InvoiceRegistry) LookupInvoiceByRef( - ref channeldb.InvoiceRef) (channeldb.Invoice, error) { - - return i.cdb.LookupInvoice(ref) +func (i *InvoiceRegistry) LookupInvoiceByRef(ref InvoiceRef) (Invoice, error) { + return i.idb.LookupInvoice(ref) } // startHtlcTimer starts a new timer via the invoice registry main loop that // cancels a single htlc on an invoice when the htlc hold duration has passed. -func (i *InvoiceRegistry) startHtlcTimer(invoiceRef channeldb.InvoiceRef, - key models.CircuitKey, acceptTime time.Time) error { +func (i *InvoiceRegistry) startHtlcTimer(invoiceRef InvoiceRef, + key CircuitKey, acceptTime time.Time) error { releaseTime := acceptTime.Add(i.cfg.HtlcHoldDuration) event := &htlcReleaseEvent{ @@ -648,14 +649,12 @@ func (i *InvoiceRegistry) startHtlcTimer(invoiceRef channeldb.InvoiceRef, // cancelSingleHtlc cancels a single accepted htlc on an invoice. It takes // a resolution result which will be used to notify subscribed links and // resolvers of the details of the htlc cancellation. -func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef, - key models.CircuitKey, result FailResolutionResult) error { - - updateInvoice := func(invoice *channeldb.Invoice) ( - *channeldb.InvoiceUpdateDesc, error) { +func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef InvoiceRef, + key CircuitKey, result FailResolutionResult) error { + updateInvoice := func(invoice *Invoice) (*InvoiceUpdateDesc, error) { // Only allow individual htlc cancellation on open invoices. - if invoice.State != channeldb.ContractOpen { + if invoice.State != ContractOpen { log.Debugf("cancelSingleHtlc: invoice %v no longer "+ "open", invoiceRef) @@ -664,8 +663,8 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef, // Lookup the current status of the htlc in the database. var ( - htlcState channeldb.HtlcState - setID *channeldb.SetID + htlcState HtlcState + setID *SetID ) htlc, ok := invoice.Htlcs[key] if !ok { @@ -695,7 +694,7 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef, // Cancellation is only possible if the htlc wasn't already // resolved. - if htlcState != channeldb.HtlcStateAccepted { + if htlcState != HtlcStateAccepted { log.Debugf("cancelSingleHtlc: htlc %v on invoice %v "+ "is already resolved", key, invoiceRef) @@ -707,11 +706,11 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef, // Return an update descriptor that cancels htlc and keeps // invoice open. - canceledHtlcs := map[models.CircuitKey]struct{}{ + canceledHtlcs := map[CircuitKey]struct{}{ key: {}, } - return &channeldb.InvoiceUpdateDesc{ + return &InvoiceUpdateDesc{ CancelHtlcs: canceledHtlcs, SetID: setID, }, nil @@ -720,11 +719,11 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef, // Try to mark the specified htlc as canceled in the invoice database. // Intercept the update descriptor to set the local updated variable. If // no invoice update is performed, we can return early. - setID := (*channeldb.SetID)(invoiceRef.SetID()) + setID := (*SetID)(invoiceRef.SetID()) var updated bool - invoice, err := i.cdb.UpdateInvoice(invoiceRef, setID, - func(invoice *channeldb.Invoice) ( - *channeldb.InvoiceUpdateDesc, error) { + invoice, err := i.idb.UpdateInvoice(invoiceRef, setID, + func(invoice *Invoice) ( + *InvoiceUpdateDesc, error) { updateDesc, err := updateInvoice(invoice) if err != nil { @@ -748,7 +747,7 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef, if !ok { return fmt.Errorf("htlc %v not found", key) } - if htlc.State == channeldb.HtlcStateCanceled { + if htlc.State == HtlcStateCanceled { resolution := NewFailResolution( key, int32(htlc.AcceptHeight), result, ) @@ -808,12 +807,12 @@ func (i *InvoiceRegistry) processKeySend(ctx invoiceUpdateCtx) error { // to not be indexed. In the future, once AMP is merged, this should be // replaced by generating a random payment address on the behalf of the // sender. - payAddr := channeldb.BlankPayAddr + payAddr := BlankPayAddr // Create placeholder invoice. - invoice := &channeldb.Invoice{ + invoice := &Invoice{ CreationDate: i.cfg.Clock.Now(), - Terms: channeldb.ContractTerm{ + Terms: ContractTerm{ FinalCltvDelta: finalCltvDelta, Value: amt, PaymentPreimage: &preimage, @@ -830,7 +829,7 @@ func (i *InvoiceRegistry) processKeySend(ctx invoiceUpdateCtx) error { // Insert invoice into database. Ignore duplicates, because this // may be a replay. _, err = i.AddInvoice(invoice, ctx.hash) - if err != nil && err != channeldb.ErrDuplicateInvoice { + if err != nil && !errors.Is(err, ErrDuplicateInvoice) { return err } @@ -873,9 +872,9 @@ func (i *InvoiceRegistry) processAMP(ctx invoiceUpdateCtx) error { payAddr := ctx.mpp.PaymentAddr() // Create placeholder invoice. - invoice := &channeldb.Invoice{ + invoice := &Invoice{ CreationDate: i.cfg.Clock.Now(), - Terms: channeldb.ContractTerm{ + Terms: ContractTerm{ FinalCltvDelta: finalCltvDelta, Value: amt, PaymentPreimage: nil, @@ -888,10 +887,10 @@ func (i *InvoiceRegistry) processAMP(ctx invoiceUpdateCtx) error { // payment addrs, this may be a replay or a different HTLC for the AMP // invoice. _, err := i.AddInvoice(invoice, ctx.hash) + isDuplicatedInvoice := errors.Is(err, ErrDuplicateInvoice) + isDuplicatedPayAddr := errors.Is(err, ErrDuplicatePayAddr) switch { - case err == channeldb.ErrDuplicateInvoice: - return nil - case err == channeldb.ErrDuplicatePayAddr: + case isDuplicatedInvoice || isDuplicatedPayAddr: return nil default: return err @@ -915,7 +914,7 @@ func (i *InvoiceRegistry) processAMP(ctx invoiceUpdateCtx) error { // held htlc. func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, amtPaid lnwire.MilliSatoshi, expiry uint32, currentHeight int32, - circuitKey models.CircuitKey, hodlChan chan<- interface{}, + circuitKey CircuitKey, hodlChan chan<- interface{}, payload Payload) (HtlcResolution, error) { // Create the update context containing the relevant details of the @@ -982,9 +981,9 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, // main event loop. case *htlcAcceptResolution: if r.autoRelease { - var invRef channeldb.InvoiceRef + var invRef InvoiceRef if ctx.amp != nil { - invRef = channeldb.InvoiceRefBySetID(*ctx.setID()) + invRef = InvoiceRefBySetID(*ctx.setID()) } else { invRef = ctx.invoiceRef() } @@ -1025,29 +1024,29 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( resolution HtlcResolution updateSubscribers bool ) - invoice, err := i.cdb.UpdateInvoice( - ctx.invoiceRef(), - (*channeldb.SetID)(ctx.setID()), - func(inv *channeldb.Invoice) ( - *channeldb.InvoiceUpdateDesc, error) { - updateDesc, res, err := updateInvoice(ctx, inv) - if err != nil { - return nil, err - } + callback := func(inv *Invoice) (*InvoiceUpdateDesc, error) { + updateDesc, res, err := updateInvoice(ctx, inv) + if err != nil { + return nil, err + } - // Only send an update if the invoice state was changed. - updateSubscribers = updateDesc != nil && - updateDesc.State != nil + // Only send an update if the invoice state was changed. + updateSubscribers = updateDesc != nil && + updateDesc.State != nil - // Assign resolution to outer scope variable. - resolution = res + // Assign resolution to outer scope variable. + resolution = res - return updateDesc, nil - }, - ) + return updateDesc, nil + } - if _, ok := err.(channeldb.ErrDuplicateSetID); ok { + invoiceRef := ctx.invoiceRef() + setID := (*SetID)(ctx.setID()) + invoice, err := i.idb.UpdateInvoice(invoiceRef, setID, callback) + + var duplicateSetIDErr ErrDuplicateSetID + if errors.As(err, &duplicateSetIDErr) { return NewFailResolution( ctx.circuitKey, ctx.currentHeight, ResultInvoiceNotFound, @@ -1055,7 +1054,7 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( } switch err { - case channeldb.ErrInvoiceNotFound: + case ErrInvoiceNotFound: // If the invoice was not found, return a failure resolution // with an invoice not found result. return NewFailResolution( @@ -1063,7 +1062,7 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( ResultInvoiceNotFound, ), nil, nil - case channeldb.ErrInvRefEquivocation: + case ErrInvRefEquivocation: return NewFailResolution( ctx.circuitKey, ctx.currentHeight, ResultInvoiceNotFound, @@ -1105,7 +1104,7 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( // Also cancel any HTLCs in the HTLC set that are also in the // canceled state with the same failure result. setID := ctx.setID() - canceledHtlcSet := invoice.HTLCSet(setID, channeldb.HtlcStateCanceled) + canceledHtlcSet := invoice.HTLCSet(setID, HtlcStateCanceled) for key, htlc := range canceledHtlcSet { htlcFailResolution := NewFailResolution( key, int32(htlc.AcceptHeight), res.Outcome, @@ -1125,7 +1124,7 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( // marked as settled, we should follow now and settle the htlc // with our peer. setID := ctx.setID() - settledHtlcSet := invoice.HTLCSet(setID, channeldb.HtlcStateSettled) + settledHtlcSet := invoice.HTLCSet(setID, HtlcStateSettled) for key, htlc := range settledHtlcSet { preimage := res.Preimage if htlc.AMP != nil && htlc.AMP.Preimage != nil { @@ -1155,7 +1154,7 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( // // TODO(roasbeef): can remove now?? canceledHtlcSet := invoice.HTLCSetCompliment( - setID, channeldb.HtlcStateCanceled, + setID, HtlcStateCanceled, ) for key, htlc := range canceledHtlcSet { htlcFailResolution := NewFailResolution( @@ -1189,7 +1188,7 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( // Auto-release the htlc if the invoice is still open. It can // only happen for mpp payments that there are htlcs in state // Accepted while the invoice is Open. - if invoice.State == channeldb.ContractOpen { + if invoice.State == ContractOpen { res.acceptTime = invoiceHtlc.AcceptTime res.autoRelease = true } @@ -1231,29 +1230,31 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error { i.Lock() defer i.Unlock() - updateInvoice := func(invoice *channeldb.Invoice) ( - *channeldb.InvoiceUpdateDesc, error) { + updateInvoice := func(invoice *Invoice) ( + *InvoiceUpdateDesc, error) { switch invoice.State { - case channeldb.ContractOpen: - return nil, channeldb.ErrInvoiceStillOpen - case channeldb.ContractCanceled: - return nil, channeldb.ErrInvoiceAlreadyCanceled - case channeldb.ContractSettled: - return nil, channeldb.ErrInvoiceAlreadySettled + case ContractOpen: + return nil, ErrInvoiceStillOpen + + case ContractCanceled: + return nil, ErrInvoiceAlreadyCanceled + + case ContractSettled: + return nil, ErrInvoiceAlreadySettled } - return &channeldb.InvoiceUpdateDesc{ - State: &channeldb.InvoiceStateUpdateDesc{ - NewState: channeldb.ContractSettled, + return &InvoiceUpdateDesc{ + State: &InvoiceStateUpdateDesc{ + NewState: ContractSettled, Preimage: &preimage, }, }, nil } hash := preimage.Hash() - invoiceRef := channeldb.InvoiceRefByHash(hash) - invoice, err := i.cdb.UpdateInvoice(invoiceRef, nil, updateInvoice) + invoiceRef := InvoiceRefByHash(hash) + invoice, err := i.idb.UpdateInvoice(invoiceRef, nil, updateInvoice) if err != nil { log.Errorf("SettleHodlInvoice with preimage %v: %v", preimage, err) @@ -1271,7 +1272,7 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error { // that were already settled before, will be notified again. This isn't // necessary but doesn't hurt either. for key, htlc := range invoice.Htlcs { - if htlc.State != channeldb.HtlcStateSettled { + if htlc.State != HtlcStateSettled { continue } @@ -1296,8 +1297,8 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { // cancel already accepted invoices, taking our force cancel boolean into // account. This is pulled out into its own function so that tests that mock // cancelInvoiceImpl can reuse this logic. -func shouldCancel(state channeldb.ContractState, cancelAccepted bool) bool { - if state != channeldb.ContractAccepted { +func shouldCancel(state ContractState, cancelAccepted bool) bool { + if state != ContractAccepted { return true } @@ -1316,12 +1317,10 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, i.Lock() defer i.Unlock() - ref := channeldb.InvoiceRefByHash(payHash) + ref := InvoiceRefByHash(payHash) log.Debugf("Invoice%v: canceling invoice", ref) - updateInvoice := func(invoice *channeldb.Invoice) ( - *channeldb.InvoiceUpdateDesc, error) { - + updateInvoice := func(invoice *Invoice) (*InvoiceUpdateDesc, error) { if !shouldCancel(invoice.State, cancelAccepted) { return nil, nil } @@ -1329,19 +1328,19 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, // Move invoice to the canceled state. Rely on validation in // channeldb to return an error if the invoice is already // settled or canceled. - return &channeldb.InvoiceUpdateDesc{ - State: &channeldb.InvoiceStateUpdateDesc{ - NewState: channeldb.ContractCanceled, + return &InvoiceUpdateDesc{ + State: &InvoiceStateUpdateDesc{ + NewState: ContractCanceled, }, }, nil } - invoiceRef := channeldb.InvoiceRefByHash(payHash) - invoice, err := i.cdb.UpdateInvoice(invoiceRef, nil, updateInvoice) + invoiceRef := InvoiceRefByHash(payHash) + invoice, err := i.idb.UpdateInvoice(invoiceRef, nil, updateInvoice) // Implement idempotency by returning success if the invoice was already // canceled. - if err == channeldb.ErrInvoiceAlreadyCanceled { + if errors.Is(err, ErrInvoiceAlreadyCanceled) { log.Debugf("Invoice%v: already canceled", ref) return nil } @@ -1350,7 +1349,7 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, } // Return without cancellation if the invoice state is ContractAccepted. - if invoice.State == channeldb.ContractAccepted { + if invoice.State == ContractAccepted { log.Debugf("Invoice%v: remains accepted as cancel wasn't"+ "explicitly requested.", ref) return nil @@ -1364,7 +1363,7 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, // before, will be notified again. This isn't necessary but doesn't hurt // either. for key, htlc := range invoice.Htlcs { - if htlc.State != channeldb.HtlcStateCanceled { + if htlc.State != HtlcStateCanceled { continue } @@ -1381,24 +1380,22 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, if i.cfg.GcCanceledInvoicesOnTheFly { // Assemble the delete reference and attempt to delete through // the invocice from the DB. - deleteRef := channeldb.InvoiceDeleteRef{ + deleteRef := InvoiceDeleteRef{ PayHash: payHash, AddIndex: invoice.AddIndex, SettleIndex: invoice.SettleIndex, } - if invoice.Terms.PaymentAddr != channeldb.BlankPayAddr { + if invoice.Terms.PaymentAddr != BlankPayAddr { deleteRef.PayAddr = &invoice.Terms.PaymentAddr } - err = i.cdb.DeleteInvoice( - []channeldb.InvoiceDeleteRef{deleteRef}, - ) + err = i.idb.DeleteInvoice([]InvoiceDeleteRef{deleteRef}) // If by any chance deletion failed, then log it instead of - // returning the error, as the invoice itsels has already been + // returning the error, as the invoice itself has already been // canceled. if err != nil { - log.Warnf("Invoice%v could not be deleted: %v", - ref, err) + log.Warnf("Invoice %v could not be deleted: %v", ref, + err) } } @@ -1408,7 +1405,7 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, // notifyClients notifies all currently registered invoice notification clients // of a newly added/settled invoice. func (i *InvoiceRegistry) notifyClients(hash lntypes.Hash, - invoice *channeldb.Invoice, setID *[32]byte) { + invoice *Invoice, setID *[32]byte) { event := &invoiceEvent{ invoice: invoice, @@ -1451,12 +1448,12 @@ type InvoiceSubscription struct { // NewInvoices is a channel that we'll use to send all newly created // invoices with an invoice index greater than the specified // StartingInvoiceIndex field. - NewInvoices chan *channeldb.Invoice + NewInvoices chan *Invoice // SettledInvoices is a channel that we'll use to send all settled // invoices with an invoices index greater than the specified // StartingInvoiceIndex field. - SettledInvoices chan *channeldb.Invoice + SettledInvoices chan *Invoice // addIndex is the highest add index the caller knows of. We'll use // this information to send out an event backlog to the notifications @@ -1477,11 +1474,19 @@ type InvoiceSubscription struct { type SingleInvoiceSubscription struct { invoiceSubscriptionKit - invoiceRef channeldb.InvoiceRef + invoiceRef InvoiceRef // Updates is a channel that we'll use to send all invoice events for // the invoice that is subscribed to. - Updates chan *channeldb.Invoice + Updates chan *Invoice +} + +// PayHash returns the optional payment hash of the target invoice. +// +// TODO(positiveblue): This method is only supposed to be used in tests. It will +// be deleted as soon as invoiceregistery_test is in the same module. +func (s *SingleInvoiceSubscription) PayHash() *lntypes.Hash { + return s.invoiceRef.PayHash() } // Cancel unregisters the InvoiceSubscription, freeing any previously allocated @@ -1498,6 +1503,7 @@ func (i *invoiceSubscriptionKit) Cancel() { func (i *invoiceSubscriptionKit) notify(event *invoiceEvent) error { select { case i.ntfnQueue.ChanIn() <- event: + case <-i.cancelChan: // This can only be triggered by delivery of non-backlog // events. @@ -1518,8 +1524,8 @@ func (i *InvoiceRegistry) SubscribeNotifications( addIndex, settleIndex uint64) (*InvoiceSubscription, error) { client := &InvoiceSubscription{ - NewInvoices: make(chan *channeldb.Invoice), - SettledInvoices: make(chan *channeldb.Invoice), + NewInvoices: make(chan *Invoice), + SettledInvoices: make(chan *Invoice), addIndex: addIndex, settleIndex: settleIndex, invoiceSubscriptionKit: invoiceSubscriptionKit{ @@ -1556,24 +1562,25 @@ func (i *InvoiceRegistry) SubscribeNotifications( case ntfn := <-client.ntfnQueue.ChanOut(): invoiceEvent := ntfn.(*invoiceEvent) - var targetChan chan *channeldb.Invoice + var targetChan chan *Invoice state := invoiceEvent.invoice.State switch { // AMP invoices never move to settled, but will // be sent with a set ID if an HTLC set is // being settled. - case state == channeldb.ContractOpen && + case state == ContractOpen && invoiceEvent.setID != nil: fallthrough - case state == channeldb.ContractSettled: + + case state == ContractSettled: targetChan = client.SettledInvoices - case state == channeldb.ContractOpen: + case state == ContractOpen: targetChan = client.NewInvoices default: - log.Errorf("unknown invoice "+ - "state: %v", state) + log.Errorf("unknown invoice state: %v", + state) continue } @@ -1619,14 +1626,14 @@ func (i *InvoiceRegistry) SubscribeSingleInvoice( hash lntypes.Hash) (*SingleInvoiceSubscription, error) { client := &SingleInvoiceSubscription{ - Updates: make(chan *channeldb.Invoice), + Updates: make(chan *Invoice), invoiceSubscriptionKit: invoiceSubscriptionKit{ quit: i.quit, ntfnQueue: queue.NewConcurrentQueue(20), cancelChan: make(chan struct{}), backlogDelivered: make(chan struct{}), }, - invoiceRef: channeldb.InvoiceRefByHash(hash), + invoiceRef: InvoiceRefByHash(hash), } client.ntfnQueue.Start() @@ -1720,7 +1727,7 @@ func (i *InvoiceRegistry) notifyHodlSubscribers(htlcResolution HtlcResolution) { // hodlSubscribe adds a new invoice subscription. func (i *InvoiceRegistry) hodlSubscribe(subscriber chan<- interface{}, - circuitKey models.CircuitKey) { + circuitKey CircuitKey) { i.hodlSubscriptionsMux.Lock() defer i.hodlSubscriptionsMux.Unlock() @@ -1736,7 +1743,7 @@ func (i *InvoiceRegistry) hodlSubscribe(subscriber chan<- interface{}, reverseSubscriptions, ok := i.hodlReverseSubscriptions[subscriber] if !ok { - reverseSubscriptions = make(map[models.CircuitKey]struct{}) + reverseSubscriptions = make(map[CircuitKey]struct{}) i.hodlReverseSubscriptions[subscriber] = reverseSubscriptions } reverseSubscriptions[circuitKey] = struct{}{} @@ -1757,7 +1764,7 @@ func (i *InvoiceRegistry) HodlUnsubscribeAll(subscriber chan<- interface{}) { // copySingleClients copies i.SingleInvoiceSubscription inside a lock. This is // useful when we need to iterate the map to send notifications. -func (i *InvoiceRegistry) copySingleClients() map[uint32]*SingleInvoiceSubscription { +func (i *InvoiceRegistry) copySingleClients() map[uint32]*SingleInvoiceSubscription { //nolint:lll i.notificationClientMux.RLock() defer i.notificationClientMux.RUnlock() diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index c41576033..45bb5712c 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -1,4 +1,4 @@ -package invoices +package invoices_test import ( "crypto/rand" @@ -8,8 +8,8 @@ import ( "github.com/lightningnetwork/lnd/amp" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" + invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -18,23 +18,27 @@ import ( // TestSettleInvoice tests settling of an invoice and related notifications. func TestSettleInvoice(t *testing.T) { - ctx := newTestContext(t) + ctx := newTestContext(t, nil) allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) require.Nil(t, err) defer allSubscriptions.Cancel() // Subscribe to the not yet existing invoice. - subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash) + subscription, err := ctx.registry.SubscribeSingleInvoice( + testInvoicePaymentHash, + ) if err != nil { t.Fatal(err) } defer subscription.Cancel() - require.Equal(t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash) + require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice. - addIdx, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) + addIdx, err := ctx.registry.AddInvoice( + testInvoice, testInvoicePaymentHash, + ) if err != nil { t.Fatal(err) } @@ -47,7 +51,7 @@ func TestSettleInvoice(t *testing.T) { // We expect the open state to be sent to the single invoice subscriber. select { case update := <-subscription.Updates: - if update.State != channeldb.ContractOpen { + if update.State != invpkg.ContractOpen { t.Fatalf("expected state ContractOpen, but got %v", update.State) } @@ -58,7 +62,7 @@ func TestSettleInvoice(t *testing.T) { // We expect a new invoice notification to be sent out. select { case newInvoice := <-allSubscriptions.NewInvoices: - if newInvoice.State != channeldb.ContractOpen { + if newInvoice.State != invpkg.ContractOpen { t.Fatalf("expected state ContractOpen, but got %v", newInvoice.State) } @@ -79,7 +83,7 @@ func TestSettleInvoice(t *testing.T) { } require.NotNil(t, resolution) failResolution := checkFailResolution( - t, resolution, ResultExpiryTooSoon, + t, resolution, invpkg.ResultExpiryTooSoon, ) require.Equal(t, testCurrentHeight, failResolution.AcceptHeight) @@ -97,13 +101,13 @@ func TestSettleInvoice(t *testing.T) { settleResolution := checkSettleResolution( t, resolution, testInvoicePreimage, ) - require.Equal(t, ResultSettled, settleResolution.Outcome) + require.Equal(t, invpkg.ResultSettled, settleResolution.Outcome) // We expect the settled state to be sent to the single invoice // subscriber. select { case update := <-subscription.Updates: - if update.State != channeldb.ContractSettled { + if update.State != invpkg.ContractSettled { t.Fatalf("expected state ContractOpen, but got %v", update.State) } @@ -117,7 +121,7 @@ func TestSettleInvoice(t *testing.T) { // We expect a settled notification to be sent out. select { case settledInvoice := <-allSubscriptions.SettledInvoices: - if settledInvoice.State != channeldb.ContractSettled { + if settledInvoice.State != invpkg.ContractSettled { t.Fatalf("expected state ContractOpen, but got %v", settledInvoice.State) } @@ -128,39 +132,41 @@ func TestSettleInvoice(t *testing.T) { // Try to settle again with the same htlc id. We need this idempotent // behaviour after a restart. resolution, err = ctx.registry.NotifyExitHopHtlc( - testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight, - getCircuitKey(0), hodlChan, testPayload, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, + testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) require.NoError(t, err, "unexpected NotifyExitHopHtlc error") require.NotNil(t, resolution) settleResolution = checkSettleResolution( t, resolution, testInvoicePreimage, ) - require.Equal(t, ResultReplayToSettled, settleResolution.Outcome) + require.Equal(t, invpkg.ResultReplayToSettled, settleResolution.Outcome) // Try to settle again with a new higher-valued htlc. This payment // should also be accepted, to prevent any change in behaviour for a // paid invoice that may open up a probe vector. resolution, err = ctx.registry.NotifyExitHopHtlc( - testInvoicePaymentHash, amtPaid+600, testHtlcExpiry, testCurrentHeight, - getCircuitKey(1), hodlChan, testPayload, + testInvoicePaymentHash, amtPaid+600, testHtlcExpiry, + testCurrentHeight, getCircuitKey(1), hodlChan, testPayload, ) require.NoError(t, err, "unexpected NotifyExitHopHtlc error") require.NotNil(t, resolution) settleResolution = checkSettleResolution( t, resolution, testInvoicePreimage, ) - require.Equal(t, ResultDuplicateToSettled, settleResolution.Outcome) + require.Equal( + t, invpkg.ResultDuplicateToSettled, settleResolution.Outcome, + ) // Try to settle again with a lower amount. This should fail just as it // would have failed if it were the first payment. resolution, err = ctx.registry.NotifyExitHopHtlc( - testInvoicePaymentHash, amtPaid-600, testHtlcExpiry, testCurrentHeight, - getCircuitKey(2), hodlChan, testPayload, + testInvoicePaymentHash, amtPaid-600, testHtlcExpiry, + testCurrentHeight, getCircuitKey(2), hodlChan, testPayload, ) require.NoError(t, err, "unexpected NotifyExitHopHtlc error") require.NotNil(t, resolution) - checkFailResolution(t, resolution, ResultAmountTooLow) + checkFailResolution(t, resolution, invpkg.ResultAmountTooLow) // Check that settled amount is equal to the sum of values of the htlcs // 0 and 1. @@ -175,9 +181,7 @@ func TestSettleInvoice(t *testing.T) { // Try to cancel. err = ctx.registry.CancelInvoice(testInvoicePaymentHash) - if err != channeldb.ErrInvoiceAlreadySettled { - t.Fatal("expected cancellation of a settled invoice to fail") - } + require.ErrorIs(t, err, invpkg.ErrInvoiceAlreadySettled) // As this is a direct settle, we expect nothing on the hodl chan. select { @@ -188,11 +192,12 @@ func TestSettleInvoice(t *testing.T) { } func testCancelInvoice(t *testing.T, gc bool) { - ctx := newTestContext(t) + cfg := defaultRegistryConfig() // If set to true, then also delete the invoice from the DB after // cancellation. - ctx.registry.cfg.GcCanceledInvoicesOnTheFly = gc + cfg.GcCanceledInvoicesOnTheFly = gc + ctx := newTestContext(t, &cfg) allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) require.Nil(t, err) @@ -200,18 +205,18 @@ func testCancelInvoice(t *testing.T, gc bool) { // Try to cancel the not yet existing invoice. This should fail. err = ctx.registry.CancelInvoice(testInvoicePaymentHash) - if err != channeldb.ErrInvoiceNotFound { - t.Fatalf("expected ErrInvoiceNotFound, but got %v", err) - } + require.ErrorIs(t, err, invpkg.ErrInvoiceNotFound) // Subscribe to the not yet existing invoice. - subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash) + subscription, err := ctx.registry.SubscribeSingleInvoice( + testInvoicePaymentHash, + ) if err != nil { t.Fatal(err) } defer subscription.Cancel() - require.Equal(t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash) + require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice. amt := lnwire.MilliSatoshi(100000) @@ -223,7 +228,7 @@ func testCancelInvoice(t *testing.T, gc bool) { // We expect the open state to be sent to the single invoice subscriber. select { case update := <-subscription.Updates: - if update.State != channeldb.ContractOpen { + if update.State != invpkg.ContractOpen { t.Fatalf( "expected state ContractOpen, but got %v", update.State, @@ -236,7 +241,7 @@ func testCancelInvoice(t *testing.T, gc bool) { // We expect a new invoice notification to be sent out. select { case newInvoice := <-allSubscriptions.NewInvoices: - if newInvoice.State != channeldb.ContractOpen { + if newInvoice.State != invpkg.ContractOpen { t.Fatalf( "expected state ContractOpen, but got %v", newInvoice.State, @@ -256,7 +261,7 @@ func testCancelInvoice(t *testing.T, gc bool) { // subscriber. select { case update := <-subscription.Updates: - if update.State != channeldb.ContractCanceled { + if update.State != invpkg.ContractCanceled { t.Fatalf( "expected state ContractCanceled, but got %v", update.State, @@ -268,8 +273,8 @@ func testCancelInvoice(t *testing.T, gc bool) { if gc { // Check that the invoice has been deleted from the db. - _, err = ctx.cdb.LookupInvoice( - channeldb.InvoiceRefByHash(testInvoicePaymentHash), + _, err = ctx.idb.LookupInvoice( + invpkg.InvoiceRefByHash(testInvoicePaymentHash), ) require.Error(t, err) } @@ -283,7 +288,7 @@ func testCancelInvoice(t *testing.T, gc bool) { err = ctx.registry.CancelInvoice(testInvoicePaymentHash) if gc { - require.Error(t, err, channeldb.ErrInvoiceNotFound) + require.Error(t, err, invpkg.ErrInvoiceNotFound) } else { require.NoError(t, err) } @@ -303,14 +308,14 @@ func testCancelInvoice(t *testing.T, gc bool) { // If the invoice has been deleted (or not present) then we expect the // outcome to be ResultInvoiceNotFound instead of when the invoice is // in our database in which case we expect ResultInvoiceAlreadyCanceled. - var failResolution *HtlcFailResolution + var failResolution *invpkg.HtlcFailResolution if gc { failResolution = checkFailResolution( - t, resolution, ResultInvoiceNotFound, + t, resolution, invpkg.ResultInvoiceNotFound, ) } else { failResolution = checkFailResolution( - t, resolution, ResultInvoiceAlreadyCanceled, + t, resolution, invpkg.ResultInvoiceAlreadyCanceled, ) } @@ -335,21 +340,21 @@ func TestCancelInvoice(t *testing.T) { func TestSettleHoldInvoice(t *testing.T) { defer timeout()() - cdb, err := newTestChannelDB(t, clock.NewTestClock(time.Time{})) + idb, err := newTestChannelDB(t, clock.NewTestClock(time.Time{})) if err != nil { t.Fatal(err) } // Instantiate and start the invoice ctx.registry. - cfg := RegistryConfig{ + cfg := invpkg.RegistryConfig{ FinalCltvRejectDelta: testFinalCltvRejectDelta, Clock: clock.NewTestClock(testTime), } - expiryWatcher := NewInvoiceExpiryWatcher( + expiryWatcher := invpkg.NewInvoiceExpiryWatcher( cfg.Clock, 0, uint32(testCurrentHeight), nil, newMockNotifier(), ) - registry := NewRegistry(cdb, expiryWatcher, &cfg) + registry := invpkg.NewRegistry(idb, expiryWatcher, &cfg) err = registry.Start() if err != nil { @@ -362,13 +367,15 @@ func TestSettleHoldInvoice(t *testing.T) { defer allSubscriptions.Cancel() // Subscribe to the not yet existing invoice. - subscription, err := registry.SubscribeSingleInvoice(testInvoicePaymentHash) + subscription, err := registry.SubscribeSingleInvoice( + testInvoicePaymentHash, + ) if err != nil { t.Fatal(err) } defer subscription.Cancel() - require.Equal(t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash) + require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice. _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash) @@ -378,14 +385,14 @@ func TestSettleHoldInvoice(t *testing.T) { // We expect the open state to be sent to the single invoice subscriber. update := <-subscription.Updates - if update.State != channeldb.ContractOpen { + if update.State != invpkg.ContractOpen { t.Fatalf("expected state ContractOpen, but got %v", update.State) } // We expect a new invoice notification to be sent out. newInvoice := <-allSubscriptions.NewInvoices - if newInvoice.State != channeldb.ContractOpen { + if newInvoice.State != invpkg.ContractOpen { t.Fatalf("expected state ContractOpen, but got %v", newInvoice.State) } @@ -398,8 +405,8 @@ func TestSettleHoldInvoice(t *testing.T) { // NotifyExitHopHtlc without a preimage present in the invoice registry // should be possible. resolution, err := registry.NotifyExitHopHtlc( - testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight, - getCircuitKey(0), hodlChan, testPayload, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, + testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { t.Fatalf("expected settle to succeed but got %v", err) @@ -410,8 +417,8 @@ func TestSettleHoldInvoice(t *testing.T) { // Test idempotency. resolution, err = registry.NotifyExitHopHtlc( - testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight, - getCircuitKey(0), hodlChan, testPayload, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, + testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { t.Fatalf("expected settle to succeed but got %v", err) @@ -423,8 +430,8 @@ func TestSettleHoldInvoice(t *testing.T) { // Test replay at a higher height. We expect the same result because it // is a replay. resolution, err = registry.NotifyExitHopHtlc( - testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+10, - getCircuitKey(0), hodlChan, testPayload, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, + testCurrentHeight+10, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { t.Fatalf("expected settle to succeed but got %v", err) @@ -443,13 +450,13 @@ func TestSettleHoldInvoice(t *testing.T) { t.Fatalf("expected settle to succeed but got %v", err) } require.NotNil(t, resolution) - checkFailResolution(t, resolution, ResultExpiryTooSoon) + checkFailResolution(t, resolution, invpkg.ResultExpiryTooSoon) // We expect the accepted state to be sent to the single invoice // subscriber. For all invoice subscribers, we don't expect an update. // Those only get notified on settle. update = <-subscription.Updates - if update.State != channeldb.ContractAccepted { + if update.State != invpkg.ContractAccepted { t.Fatalf("expected state ContractAccepted, but got %v", update.State) } @@ -463,18 +470,18 @@ func TestSettleHoldInvoice(t *testing.T) { t.Fatal("expected set preimage to succeed") } - htlcResolution := (<-hodlChan).(HtlcResolution) + htlcResolution, _ := (<-hodlChan).(invpkg.HtlcResolution) require.NotNil(t, htlcResolution) settleResolution := checkSettleResolution( t, htlcResolution, testInvoicePreimage, ) require.Equal(t, testCurrentHeight, settleResolution.AcceptHeight) - require.Equal(t, ResultSettled, settleResolution.Outcome) + require.Equal(t, invpkg.ResultSettled, settleResolution.Outcome) // We expect a settled notification to be sent out for both all and // single invoice subscribers. settledInvoice := <-allSubscriptions.SettledInvoices - if settledInvoice.State != channeldb.ContractSettled { + if settledInvoice.State != invpkg.ContractSettled { t.Fatalf("expected state ContractSettled, but got %v", settledInvoice.State) } @@ -484,16 +491,14 @@ func TestSettleHoldInvoice(t *testing.T) { } update = <-subscription.Updates - if update.State != channeldb.ContractSettled { + if update.State != invpkg.ContractSettled { t.Fatalf("expected state ContractSettled, but got %v", update.State) } // Idempotency. err = registry.SettleHodlInvoice(testInvoicePreimage) - if err != channeldb.ErrInvoiceAlreadySettled { - t.Fatalf("expected ErrInvoiceAlreadySettled but got %v", err) - } + require.ErrorIs(t, err, invpkg.ErrInvoiceAlreadySettled) // Try to cancel. err = registry.CancelInvoice(testInvoicePaymentHash) @@ -507,20 +512,19 @@ func TestSettleHoldInvoice(t *testing.T) { func TestCancelHoldInvoice(t *testing.T) { defer timeout()() - cdb, err := newTestChannelDB(t, clock.NewTestClock(time.Time{})) - if err != nil { - t.Fatal(err) - } + testClock := clock.NewTestClock(testTime) + idb, err := newTestChannelDB(t, testClock) + require.NoError(t, err) // Instantiate and start the invoice ctx.registry. - cfg := RegistryConfig{ + cfg := invpkg.RegistryConfig{ FinalCltvRejectDelta: testFinalCltvRejectDelta, - Clock: clock.NewTestClock(testTime), + Clock: testClock, } - expiryWatcher := NewInvoiceExpiryWatcher( + expiryWatcher := invpkg.NewInvoiceExpiryWatcher( cfg.Clock, 0, uint32(testCurrentHeight), nil, newMockNotifier(), ) - registry := NewRegistry(cdb, expiryWatcher, &cfg) + registry := invpkg.NewRegistry(idb, expiryWatcher, &cfg) err = registry.Start() if err != nil { @@ -542,8 +546,8 @@ func TestCancelHoldInvoice(t *testing.T) { // NotifyExitHopHtlc without a preimage present in the invoice registry // should be possible. resolution, err := registry.NotifyExitHopHtlc( - testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight, - getCircuitKey(0), hodlChan, testPayload, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, + testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { t.Fatalf("expected settle to succeed but got %v", err) @@ -558,23 +562,23 @@ func TestCancelHoldInvoice(t *testing.T) { t.Fatal("cancel invoice failed") } - htlcResolution := (<-hodlChan).(HtlcResolution) + htlcResolution, _ := (<-hodlChan).(invpkg.HtlcResolution) require.NotNil(t, htlcResolution) - checkFailResolution(t, htlcResolution, ResultCanceled) + checkFailResolution(t, htlcResolution, invpkg.ResultCanceled) // Offering the same htlc again at a higher height should still result // in a rejection. The accept height is expected to be the original // accept height. resolution, err = registry.NotifyExitHopHtlc( - testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+1, - getCircuitKey(0), hodlChan, testPayload, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, + testCurrentHeight+1, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { t.Fatalf("expected settle to succeed but got %v", err) } require.NotNil(t, resolution) failResolution := checkFailResolution( - t, resolution, ResultReplayToCanceled, + t, resolution, invpkg.ResultReplayToCanceled, ) require.Equal(t, testCurrentHeight, failResolution.AcceptHeight) } @@ -585,7 +589,7 @@ func TestCancelHoldInvoice(t *testing.T) { // if we are the exit hop, but in htlcIncomingContestResolver it is called with // forwarded htlc hashes as well. func TestUnknownInvoice(t *testing.T) { - ctx := newTestContext(t) + ctx := newTestContext(t, nil) // Notify arrival of a new htlc paying to this invoice. This should // succeed. @@ -599,7 +603,7 @@ func TestUnknownInvoice(t *testing.T) { t.Fatal("unexpected error") } require.NotNil(t, resolution) - checkFailResolution(t, resolution, ResultInvoiceNotFound) + checkFailResolution(t, resolution, invpkg.ResultInvoiceNotFound) } // TestKeySend tests receiving a spontaneous payment with and without keysend @@ -618,9 +622,9 @@ func TestKeySend(t *testing.T) { func testKeySend(t *testing.T, keySendEnabled bool) { defer timeout()() - ctx := newTestContext(t) - - ctx.registry.cfg.AcceptKeySend = keySendEnabled + cfg := defaultRegistryConfig() + cfg.AcceptKeySend = keySendEnabled + ctx := newTestContext(t, &cfg) allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) require.Nil(t, err) @@ -653,9 +657,9 @@ func testKeySend(t *testing.T, keySendEnabled bool) { require.NotNil(t, resolution) if !keySendEnabled { - checkFailResolution(t, resolution, ResultInvoiceNotFound) + checkFailResolution(t, resolution, invpkg.ResultInvoiceNotFound) } else { - checkFailResolution(t, resolution, ResultKeySendError) + checkFailResolution(t, resolution, invpkg.ResultKeySendError) } // Try to settle invoice with a valid keysend htlc. @@ -675,18 +679,18 @@ func testKeySend(t *testing.T, keySendEnabled bool) { // Expect a cancel resolution if keysend is disabled. if !keySendEnabled { - checkFailResolution(t, resolution, ResultInvoiceNotFound) + checkFailResolution(t, resolution, invpkg.ResultInvoiceNotFound) return } checkSubscription := func() { // We expect a new invoice notification to be sent out. newInvoice := <-allSubscriptions.NewInvoices - require.Equal(t, newInvoice.State, channeldb.ContractOpen) + require.Equal(t, newInvoice.State, invpkg.ContractOpen) // We expect a settled notification to be sent out. settledInvoice := <-allSubscriptions.SettledInvoices - require.Equal(t, settledInvoice.State, channeldb.ContractSettled) + require.Equal(t, settledInvoice.State, invpkg.ContractSettled) } checkSettleResolution(t, resolution, preimage) @@ -744,10 +748,10 @@ func testHoldKeysend(t *testing.T, timeoutKeysend bool) { const holdDuration = time.Minute - ctx := newTestContext(t) - - ctx.registry.cfg.AcceptKeySend = true - ctx.registry.cfg.KeysendHoldTime = holdDuration + cfg := defaultRegistryConfig() + cfg.AcceptKeySend = true + cfg.KeysendHoldTime = holdDuration + ctx := newTestContext(t, &cfg) allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) require.Nil(t, err) @@ -782,7 +786,7 @@ func testHoldKeysend(t *testing.T, timeoutKeysend bool) { // We expect a new invoice notification to be sent out. newInvoice := <-allSubscriptions.NewInvoices - if newInvoice.State != channeldb.ContractOpen { + if newInvoice.State != invpkg.ContractOpen { t.Fatalf("expected state ContractOpen, but got %v", newInvoice.State) } @@ -803,13 +807,13 @@ func testHoldKeysend(t *testing.T, timeoutKeysend bool) { // Expect the keysend payment to be failed. res := <-hodlChan - failResolution, ok := res.(*HtlcFailResolution) + failResolution, ok := res.(*invpkg.HtlcFailResolution) require.Truef( t, ok, "expected fail resolution, got: %T", resolution, ) require.Equal( - t, ResultCanceled, failResolution.Outcome, + t, invpkg.ResultCanceled, failResolution.Outcome, "expected keysend payment to be failed", ) @@ -823,7 +827,7 @@ func testHoldKeysend(t *testing.T, timeoutKeysend bool) { // We expect a settled notification to be sent out. settledInvoice := <-allSubscriptions.SettledInvoices - require.Equal(t, settledInvoice.State, channeldb.ContractSettled) + require.Equal(t, settledInvoice.State, invpkg.ContractSettled) } // TestMppPayment tests settling of an invoice with multiple partial payments. @@ -832,7 +836,7 @@ func testHoldKeysend(t *testing.T, timeoutKeysend bool) { func TestMppPayment(t *testing.T) { defer timeout()() - ctx := newTestContext(t) + ctx := newTestContext(t, nil) // Add the invoice. _, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) @@ -861,13 +865,13 @@ func TestMppPayment(t *testing.T) { // Simulate mpp timeout releasing htlc 1. ctx.clock.SetTime(testTime.Add(30 * time.Second)) - htlcResolution := (<-hodlChan1).(HtlcResolution) - failResolution, ok := htlcResolution.(*HtlcFailResolution) + htlcResolution, _ := (<-hodlChan1).(invpkg.HtlcResolution) + failResolution, ok := htlcResolution.(*invpkg.HtlcFailResolution) if !ok { t.Fatalf("expected fail resolution, got: %T", resolution) } - if failResolution.Outcome != ResultMppTimeout { + if failResolution.Outcome != invpkg.ResultMppTimeout { t.Fatalf("expected mpp timeout, got: %v", failResolution.Outcome) } @@ -896,12 +900,12 @@ func TestMppPayment(t *testing.T) { if err != nil { t.Fatal(err) } - settleResolution, ok := resolution.(*HtlcSettleResolution) + settleResolution, ok := resolution.(*invpkg.HtlcSettleResolution) if !ok { t.Fatalf("expected settle resolution, got: %T", htlcResolution) } - if settleResolution.Outcome != ResultSettled { + if settleResolution.Outcome != invpkg.ResultSettled { t.Fatalf("expected result settled, got: %v", settleResolution.Outcome) } @@ -912,7 +916,7 @@ func TestMppPayment(t *testing.T) { if err != nil { t.Fatal(err) } - if inv.State != channeldb.ContractSettled { + if inv.State != invpkg.ContractSettled { t.Fatal("expected invoice to be settled") } if inv.AmtPaid != testInvoice.Terms.Value { @@ -925,22 +929,19 @@ func TestMppPayment(t *testing.T) { func TestInvoiceExpiryWithRegistry(t *testing.T) { t.Parallel() - cdb, err := newTestChannelDB(t, clock.NewTestClock(time.Time{})) - if err != nil { - t.Fatal(err) - } - testClock := clock.NewTestClock(testTime) + idb, err := newTestChannelDB(t, testClock) + require.NoError(t, err) - cfg := RegistryConfig{ + cfg := invpkg.RegistryConfig{ FinalCltvRejectDelta: testFinalCltvRejectDelta, Clock: testClock, } - expiryWatcher := NewInvoiceExpiryWatcher( + expiryWatcher := invpkg.NewInvoiceExpiryWatcher( cfg.Clock, 0, uint32(testCurrentHeight), nil, newMockNotifier(), ) - registry := NewRegistry(cdb, expiryWatcher, &cfg) + registry := invpkg.NewRegistry(idb, expiryWatcher, &cfg) // First prefill the Channel DB with some pre-existing invoices, // half of them still pending, half of them expired. @@ -951,21 +952,22 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { ) var expectedCancellations []lntypes.Hash - - for paymentHash, expiredInvoice := range existingInvoices.expiredInvoices { - if _, err := cdb.AddInvoice(expiredInvoice, paymentHash); err != nil { - t.Fatalf("cannot add invoice to channel db: %v", err) - } - expectedCancellations = append(expectedCancellations, paymentHash) + expiredInvoices := existingInvoices.expiredInvoices + for paymentHash, expiredInvoice := range expiredInvoices { + _, err := idb.AddInvoice(expiredInvoice, paymentHash) + require.NoError(t, err) + expectedCancellations = append( + expectedCancellations, paymentHash, + ) } - for paymentHash, pendingInvoice := range existingInvoices.pendingInvoices { - if _, err := cdb.AddInvoice(pendingInvoice, paymentHash); err != nil { - t.Fatalf("cannot add invoice to channel db: %v", err) - } + pendingInvoices := existingInvoices.pendingInvoices + for paymentHash, pendingInvoice := range pendingInvoices { + _, err := idb.AddInvoice(pendingInvoice, paymentHash) + require.NoError(t, err) } - if err = registry.Start(); err != nil { + if err := registry.Start(); err != nil { t.Fatalf("cannot start registry: %v", err) } @@ -978,7 +980,9 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { var invoicesThatWillCancel []lntypes.Hash for paymentHash, pendingInvoice := range newInvoices.pendingInvoices { _, err := registry.AddInvoice(pendingInvoice, paymentHash) - invoicesThatWillCancel = append(invoicesThatWillCancel, paymentHash) + invoicesThatWillCancel = append( + invoicesThatWillCancel, paymentHash, + ) if err != nil { t.Fatal(err) } @@ -987,12 +991,14 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { // Check that they are really not canceled until before the clock is // advanced. for i := range invoicesThatWillCancel { - invoice, err := registry.LookupInvoice(invoicesThatWillCancel[i]) + invoice, err := registry.LookupInvoice( + invoicesThatWillCancel[i], + ) if err != nil { t.Fatalf("cannot find invoice: %v", err) } - if invoice.State == channeldb.ContractCanceled { + if invoice.State == invpkg.ContractCanceled { t.Fatalf("expected pending invoice, got canceled") } } @@ -1002,7 +1008,7 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { // Give some time to the watcher to cancel everything. time.Sleep(500 * time.Millisecond) - if err = registry.Stop(); err != nil { + if err := registry.Stop(); err != nil { t.Fatalf("failed to stop invoice registry: %v", err) } @@ -1011,17 +1017,15 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { expectedCancellations, invoicesThatWillCancel..., ) - // Retrospectively check that all invoices that were expected to be canceled - // are indeed canceled. + // Retrospectively check that all invoices that were expected to be + // canceled are indeed canceled. for i := range expectedCancellations { invoice, err := registry.LookupInvoice(expectedCancellations[i]) if err != nil { t.Fatalf("cannot find invoice: %v", err) } - if invoice.State != channeldb.ContractCanceled { - t.Fatalf("expected canceled invoice, got: %v", invoice.State) - } + require.Equal(t, invpkg.ContractCanceled, invoice.State) } } @@ -1031,19 +1035,19 @@ func TestOldInvoiceRemovalOnStart(t *testing.T) { t.Parallel() testClock := clock.NewTestClock(testTime) - cdb, err := newTestChannelDB(t, testClock) + idb, err := newTestChannelDB(t, testClock) require.NoError(t, err) - cfg := RegistryConfig{ + cfg := invpkg.RegistryConfig{ FinalCltvRejectDelta: testFinalCltvRejectDelta, Clock: testClock, GcCanceledInvoicesOnStartup: true, } - expiryWatcher := NewInvoiceExpiryWatcher( + expiryWatcher := invpkg.NewInvoiceExpiryWatcher( cfg.Clock, 0, uint32(testCurrentHeight), nil, newMockNotifier(), ) - registry := NewRegistry(cdb, expiryWatcher, &cfg) + registry := invpkg.NewRegistry(idb, expiryWatcher, &cfg) // First prefill the Channel DB with some pre-existing expired invoices. const numExpired = 5 @@ -1057,31 +1061,31 @@ func TestOldInvoiceRemovalOnStart(t *testing.T) { // Mark half of the invoices as settled, the other half as // canceled. if i%2 == 0 { - invoice.State = channeldb.ContractSettled + invoice.State = invpkg.ContractSettled } else { - invoice.State = channeldb.ContractCanceled + invoice.State = invpkg.ContractCanceled } - _, err := cdb.AddInvoice(invoice, paymentHash) + _, err := idb.AddInvoice(invoice, paymentHash) require.NoError(t, err) i++ } // Collect all settled invoices for our expectation set. - var expected []channeldb.Invoice + var expected []invpkg.Invoice // Perform a scan query to collect all invoices. - query := channeldb.InvoiceQuery{ + query := invpkg.InvoiceQuery{ IndexOffset: 0, NumMaxInvoices: math.MaxUint64, } - response, err := cdb.QueryInvoices(query) + response, err := idb.QueryInvoices(query) require.NoError(t, err) // Save all settled invoices for our expectation set. for _, invoice := range response.Invoices { - if invoice.State == channeldb.ContractSettled { + if invoice.State == invpkg.ContractSettled { expected = append(expected, invoice) } } @@ -1092,7 +1096,7 @@ func TestOldInvoiceRemovalOnStart(t *testing.T) { require.NoError(t, err, "cannot start the registry") // Perform a scan query to collect all invoices. - response, err = cdb.QueryInvoices(query) + response, err = idb.QueryInvoices(query) require.NoError(t, err) // Check that we really only kept the settled invoices after the @@ -1126,7 +1130,7 @@ func testHeightExpiryWithRegistry(t *testing.T, numParts int, settle bool) { t.Parallel() defer timeout()() - ctx := newTestContext(t) + ctx := newTestContext(t, nil) require.Greater(t, numParts, 0, "test requires at least one part") @@ -1166,7 +1170,7 @@ func testHeightExpiryWithRegistry(t *testing.T, numParts int, settle bool) { inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) require.NoError(t, err) - return inv.State == channeldb.ContractAccepted + return inv.State == invpkg.ContractAccepted }, time.Second, time.Millisecond*100) // Now that we've added our htlc(s), we tick our test clock to our @@ -1189,12 +1193,13 @@ func testHeightExpiryWithRegistry(t *testing.T, numParts int, settle bool) { require.NoError(t, err) for i := 0; i < numParts; i++ { - htlcResolution := (<-hodlChan).(HtlcResolution) - require.NotNil(t, htlcResolution) + resolution, _ := (<-hodlChan).(invpkg.HtlcResolution) + require.NotNil(t, resolution) settleResolution := checkSettleResolution( - t, htlcResolution, testInvoicePreimage, + t, resolution, testInvoicePreimage, ) - require.Equal(t, ResultSettled, settleResolution.Outcome) + outcome := settleResolution.Outcome + require.Equal(t, invpkg.ResultSettled, outcome) } } @@ -1205,14 +1210,14 @@ func testHeightExpiryWithRegistry(t *testing.T, numParts int, settle bool) { // If we did not settle the invoice before its expiry, we now expect // a cancellation. - expectedState := channeldb.ContractSettled + expectedState := invpkg.ContractSettled if !settle { - expectedState = channeldb.ContractCanceled + expectedState = invpkg.ContractCanceled - htlcResolution := (<-hodlChan).(HtlcResolution) + htlcResolution, _ := (<-hodlChan).(invpkg.HtlcResolution) require.NotNil(t, htlcResolution) checkFailResolution( - t, htlcResolution, ResultCanceled, + t, htlcResolution, invpkg.ResultCanceled, ) } @@ -1232,7 +1237,7 @@ func TestMultipleSetHeightExpiry(t *testing.T) { t.Parallel() defer timeout()() - ctx := newTestContext(t) + ctx := newTestContext(t, nil) // Add a hold invoice. invoice := *testInvoice @@ -1258,10 +1263,10 @@ func TestMultipleSetHeightExpiry(t *testing.T) { // Simulate mpp timeout releasing htlc 1. ctx.clock.SetTime(testTime.Add(30 * time.Second)) - htlcResolution := (<-hodlChan1).(HtlcResolution) - failResolution, ok := htlcResolution.(*HtlcFailResolution) + htlcResolution, _ := (<-hodlChan1).(invpkg.HtlcResolution) + failResolution, ok := htlcResolution.(*invpkg.HtlcFailResolution) require.True(t, ok, "expected fail resolution, got: %T", resolution) - require.Equal(t, ResultMppTimeout, failResolution.Outcome, + require.Equal(t, invpkg.ResultMppTimeout, failResolution.Outcome, "expected MPP Timeout, got: %v", failResolution.Outcome) // Notify the expiry height for our first htlc. We don't expect the @@ -1297,7 +1302,7 @@ func TestMultipleSetHeightExpiry(t *testing.T) { // been paid with a complete set. inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) require.NoError(t, err) - require.Equal(t, channeldb.ContractAccepted, inv.State, "expected "+ + require.Equal(t, invpkg.ContractAccepted, inv.State, "expected "+ "hold invoice accepted") // Now we will notify the expiry height for the new set of htlcs. We @@ -1310,7 +1315,7 @@ func TestMultipleSetHeightExpiry(t *testing.T) { inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) require.NoError(t, err) - return inv.State == channeldb.ContractCanceled + return inv.State == invpkg.ContractCanceled }, testTimeout, time.Millisecond*100, "invoice not canceled") } @@ -1320,7 +1325,7 @@ func TestMultipleSetHeightExpiry(t *testing.T) { func TestSettleInvoicePaymentAddrRequired(t *testing.T) { t.Parallel() - ctx := newTestContext(t) + ctx := newTestContext(t, nil) allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) require.Nil(t, err) @@ -1333,9 +1338,7 @@ func TestSettleInvoicePaymentAddrRequired(t *testing.T) { require.NoError(t, err) defer subscription.Cancel() - require.Equal( - t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash, - ) + require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice, which requires the MPP payload to always be // included due to its set of feature bits. @@ -1348,7 +1351,7 @@ func TestSettleInvoicePaymentAddrRequired(t *testing.T) { // We expect the open state to be sent to the single invoice subscriber. select { case update := <-subscription.Updates: - if update.State != channeldb.ContractOpen { + if update.State != invpkg.ContractOpen { t.Fatalf("expected state ContractOpen, but got %v", update.State) } @@ -1359,7 +1362,7 @@ func TestSettleInvoicePaymentAddrRequired(t *testing.T) { // We expect a new invoice notification to be sent out. select { case newInvoice := <-allSubscriptions.NewInvoices: - if newInvoice.State != channeldb.ContractOpen { + if newInvoice.State != invpkg.ContractOpen { t.Fatalf("expected state ContractOpen, but got %v", newInvoice.State) } @@ -1379,13 +1382,13 @@ func TestSettleInvoicePaymentAddrRequired(t *testing.T) { ) require.NoError(t, err) - failResolution, ok := resolution.(*HtlcFailResolution) + failResolution, ok := resolution.(*invpkg.HtlcFailResolution) if !ok { t.Fatalf("expected fail resolution, got: %T", resolution) } require.Equal(t, failResolution.AcceptHeight, testCurrentHeight) - require.Equal(t, failResolution.Outcome, ResultAddressMismatch) + require.Equal(t, failResolution.Outcome, invpkg.ResultAddressMismatch) } // TestSettleInvoicePaymentAddrRequiredOptionalGrace tests that if an invoice @@ -1395,7 +1398,7 @@ func TestSettleInvoicePaymentAddrRequired(t *testing.T) { func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { t.Parallel() - ctx := newTestContext(t) + ctx := newTestContext(t, nil) allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) require.Nil(t, err) @@ -1408,9 +1411,7 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { require.NoError(t, err) defer subscription.Cancel() - require.Equal( - t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash, - ) + require.Equal(t, subscription.PayHash(), &testInvoicePaymentHash) // Add the invoice, which requires the MPP payload to always be // included due to its set of feature bits. @@ -1424,7 +1425,7 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { // subscriber. select { case update := <-subscription.Updates: - if update.State != channeldb.ContractOpen { + if update.State != invpkg.ContractOpen { t.Fatalf("expected state ContractOpen, but got %v", update.State) } @@ -1435,7 +1436,7 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { // We expect a new invoice notification to be sent out. select { case newInvoice := <-allSubscriptions.NewInvoices: - if newInvoice.State != channeldb.ContractOpen { + if newInvoice.State != invpkg.ContractOpen { t.Fatalf("expected state ContractOpen, but got %v", newInvoice.State) } @@ -1453,18 +1454,18 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { ) require.NoError(t, err) - settleResolution, ok := resolution.(*HtlcSettleResolution) + settleResolution, ok := resolution.(*invpkg.HtlcSettleResolution) if !ok { t.Fatalf("expected settle resolution, got: %T", resolution) } - require.Equal(t, settleResolution.Outcome, ResultSettled) + require.Equal(t, settleResolution.Outcome, invpkg.ResultSettled) // We expect the settled state to be sent to the single invoice // subscriber. select { case update := <-subscription.Updates: - if update.State != channeldb.ContractSettled { + if update.State != invpkg.ContractSettled { t.Fatalf("expected state ContractOpen, but got %v", update.State) } @@ -1478,7 +1479,7 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { // We expect a settled notification to be sent out. select { case settledInvoice := <-allSubscriptions.SettledInvoices: - if settledInvoice.State != channeldb.ContractSettled { + if settledInvoice.State != invpkg.ContractSettled { t.Fatalf("expected state ContractOpen, but got %v", settledInvoice.State) } @@ -1492,9 +1493,9 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { func TestAMPWithoutMPPPayload(t *testing.T) { defer timeout()() - ctx := newTestContext(t) - - ctx.registry.cfg.AcceptAMP = true + cfg := defaultRegistryConfig() + cfg.AcceptAMP = true + ctx := newTestContext(t, &cfg) const ( shardAmt = lnwire.MilliSatoshi(10) @@ -1516,7 +1517,7 @@ func TestAMPWithoutMPPPayload(t *testing.T) { // We should receive the ResultAmpError failure. require.NotNil(t, resolution) - checkFailResolution(t, resolution, ResultAmpError) + checkFailResolution(t, resolution, invpkg.ResultAmpError) } // TestSpontaneousAmpPayment tests receiving a spontaneous AMP payment with both @@ -1577,9 +1578,9 @@ func testSpontaneousAmpPayment( defer timeout()() - ctx := newTestContext(t) - - ctx.registry.cfg.AcceptAMP = ampEnabled + cfg := defaultRegistryConfig() + cfg.AcceptAMP = ampEnabled + ctx := newTestContext(t, &cfg) allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) require.Nil(t, err) @@ -1608,7 +1609,7 @@ func testSpontaneousAmpPayment( checkOpenSubscription := func() { t.Helper() newInvoice := <-allSubscriptions.NewInvoices - require.Equal(t, newInvoice.State, channeldb.ContractOpen) + require.Equal(t, newInvoice.State, invpkg.ContractOpen) } // Asserts that a settled invoice is published on the SettledInvoices @@ -1621,7 +1622,7 @@ func testSpontaneousAmpPayment( // changes, but the AMP state should show that the setID has // been settled. htlcState := settledInvoice.AMPState[setID].State - require.Equal(t, htlcState, channeldb.HtlcStateSettled) + require.Equal(t, htlcState, invpkg.HtlcStateSettled) } // Asserts that no invoice is published on the SettledInvoices channel @@ -1681,7 +1682,9 @@ func testSpontaneousAmpPayment( // UpdateInvoice. if !ampEnabled { require.NotNil(t, resolution) - checkFailResolution(t, resolution, ResultInvoiceNotFound) + checkFailResolution( + t, resolution, invpkg.ResultInvoiceNotFound, + ) continue } @@ -1697,9 +1700,14 @@ func testSpontaneousAmpPayment( // test case. require.NotNil(t, resolution) if failReconstruction { - checkFailResolution(t, resolution, ResultAmpReconstruction) + checkFailResolution( + t, resolution, + invpkg.ResultAmpReconstruction, + ) } else { - checkSettleResolution(t, resolution, child.Preimage) + checkSettleResolution( + t, resolution, child.Preimage, + ) } } @@ -1730,11 +1738,13 @@ func testSpontaneousAmpPayment( // For the non-final hodl chans, assert that they receive the expected // failure or preimage. for preimage, hodlChan := range hodlChans { - resolution, ok := (<-hodlChan).(HtlcResolution) + resolution, ok := (<-hodlChan).(invpkg.HtlcResolution) require.True(t, ok) require.NotNil(t, resolution) if failReconstruction { - checkFailResolution(t, resolution, ResultAmpReconstruction) + checkFailResolution( + t, resolution, invpkg.ResultAmpReconstruction, + ) } else { checkSettleResolution(t, resolution, preimage) } diff --git a/invoices/invoices.go b/invoices/invoices.go new file mode 100644 index 000000000..d239897fe --- /dev/null +++ b/invoices/invoices.go @@ -0,0 +1,795 @@ +package invoices + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/lightningnetwork/lnd/feature" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" +) + +const ( + // MaxMemoSize is maximum size of the memo field within invoices stored + // in the database. + MaxMemoSize = 1024 + + // MaxPaymentRequestSize is the max size of a payment request for + // this invoice. + // TODO(halseth): determine the max length payment request when field + // lengths are final. + MaxPaymentRequestSize = 4096 +) + +var ( + // unknownPreimage is an all-zeroes preimage that indicates that the + // preimage for this invoice is not yet known. + UnknownPreimage lntypes.Preimage + + // BlankPayAddr is a sentinel payment address for legacy invoices. + // Invoices with this payment address are special-cased in the insertion + // logic to prevent being indexed in the payment address index, + // otherwise they would cause collisions after the first insertion. + BlankPayAddr [32]byte +) + +// RefModifier is a modification on top of a base invoice ref. It allows the +// caller to opt to skip out on HTLCs for a given payAddr, or only return the +// set of specified HTLCs for a given setID. +type RefModifier uint8 + +const ( + // DefaultModifier is the base modifier that doesn't change any + // behavior. + DefaultModifier RefModifier = iota + + // HtlcSetOnlyModifier can only be used with a setID based invoice ref, + // and specifies that only the set of HTLCs related to that setID are + // to be returned. + HtlcSetOnlyModifier + + // HtlcSetOnlyModifier can only be used with a payAddr based invoice + // ref, and specifies that the returned invoice shouldn't include any + // HTLCs at all. + HtlcSetBlankModifier +) + +// InvoiceRef is a composite identifier for invoices. Invoices can be referenced +// by various combinations of payment hash and payment addr, in certain contexts +// only some of these are known. An InvoiceRef and its constructors thus +// encapsulate the valid combinations of query parameters that can be supplied +// to LookupInvoice and UpdateInvoice. +type InvoiceRef struct { + // payHash is the payment hash of the target invoice. All invoices are + // currently indexed by payment hash. This value will be used as a + // fallback when no payment address is known. + payHash *lntypes.Hash + + // payAddr is the payment addr of the target invoice. Newer invoices + // (0.11 and up) are indexed by payment address in addition to payment + // hash, but pre 0.8 invoices do not have one at all. When this value is + // known it will be used as the primary identifier, falling back to + // payHash if no value is known. + payAddr *[32]byte + + // setID is the optional set id for an AMP payment. This can be used to + // lookup or update the invoice knowing only this value. Queries by set + // id are only used to facilitate user-facing requests, e.g. lookup, + // settle or cancel an AMP invoice. The regular update flow from the + // invoice registry will always query for the invoice by + // payHash+payAddr. + setID *[32]byte + + // refModifier allows an invoice ref to include or exclude specific + // HTLC sets based on the payAddr or setId. + refModifier RefModifier +} + +// InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by +// its payment hash. +func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef { + return InvoiceRef{ + payHash: &payHash, + } +} + +// InvoiceRefByHashAndAddr creates an InvoiceRef that first queries for an +// invoice by the provided payment address, falling back to the payment hash if +// the payment address is unknown. +func InvoiceRefByHashAndAddr(payHash lntypes.Hash, + payAddr [32]byte) InvoiceRef { + + return InvoiceRef{ + payHash: &payHash, + payAddr: &payAddr, + } +} + +// InvoiceRefByAddr creates an InvoiceRef that queries the payment addr index +// for an invoice with the provided payment address. +func InvoiceRefByAddr(addr [32]byte) InvoiceRef { + return InvoiceRef{ + payAddr: &addr, + } +} + +// InvoiceRefByAddrBlankHtlc creates an InvoiceRef that queries the payment addr +// index for an invoice with the provided payment address, but excludes any of +// the core HTLC information. +func InvoiceRefByAddrBlankHtlc(addr [32]byte) InvoiceRef { + return InvoiceRef{ + payAddr: &addr, + refModifier: HtlcSetBlankModifier, + } +} + +// InvoiceRefBySetID creates an InvoiceRef that queries the set id index for an +// invoice with the provided setID. If the invoice is not found, the query will +// not fallback to payHash or payAddr. +func InvoiceRefBySetID(setID [32]byte) InvoiceRef { + return InvoiceRef{ + setID: &setID, + } +} + +// InvoiceRefBySetIDFiltered is similar to the InvoiceRefBySetID identifier, +// but it specifies that the returned set of HTLCs should be filtered to only +// include HTLCs that are part of that set. +func InvoiceRefBySetIDFiltered(setID [32]byte) InvoiceRef { + return InvoiceRef{ + setID: &setID, + refModifier: HtlcSetOnlyModifier, + } +} + +// PayHash returns the optional payment hash of the target invoice. +// +// NOTE: This value may be nil. +func (r InvoiceRef) PayHash() *lntypes.Hash { + if r.payHash != nil { + hash := *r.payHash + return &hash + } + + return nil +} + +// PayAddr returns the optional payment address of the target invoice. +// +// NOTE: This value may be nil. +func (r InvoiceRef) PayAddr() *[32]byte { + if r.payAddr != nil { + addr := *r.payAddr + return &addr + } + + return nil +} + +// SetID returns the optional set id of the target invoice. +// +// NOTE: This value may be nil. +func (r InvoiceRef) SetID() *[32]byte { + if r.setID != nil { + id := *r.setID + return &id + } + + return nil +} + +// Modifier defines the set of available modifications to the base invoice ref +// look up that are available. +func (r InvoiceRef) Modifier() RefModifier { + return r.refModifier +} + +// String returns a human-readable representation of an InvoiceRef. +func (r InvoiceRef) String() string { + var ids []string + if r.payHash != nil { + ids = append(ids, fmt.Sprintf("pay_hash=%v", *r.payHash)) + } + if r.payAddr != nil { + ids = append(ids, fmt.Sprintf("pay_addr=%x", *r.payAddr)) + } + if r.setID != nil { + ids = append(ids, fmt.Sprintf("set_id=%x", *r.setID)) + } + + return fmt.Sprintf("(%s)", strings.Join(ids, ", ")) +} + +// ContractState describes the state the invoice is in. +type ContractState uint8 + +const ( + // ContractOpen means the invoice has only been created. + ContractOpen ContractState = 0 + + // ContractSettled means the htlc is settled and the invoice has been + // paid. + ContractSettled ContractState = 1 + + // ContractCanceled means the invoice has been canceled. + ContractCanceled ContractState = 2 + + // ContractAccepted means the HTLC has been accepted but not settled + // yet. + ContractAccepted ContractState = 3 +) + +// String returns a human readable identifier for the ContractState type. +func (c ContractState) String() string { + switch c { + case ContractOpen: + return "Open" + + case ContractSettled: + return "Settled" + + case ContractCanceled: + return "Canceled" + + case ContractAccepted: + return "Accepted" + } + + return "Unknown" +} + +// IsFinal returns a boolean indicating whether an invoice state is final. +func (c ContractState) IsFinal() bool { + return c == ContractSettled || c == ContractCanceled +} + +// ContractTerm is a companion struct to the Invoice struct. This struct houses +// the necessary conditions required before the invoice can be considered fully +// settled by the payee. +type ContractTerm struct { + // FinalCltvDelta is the minimum required number of blocks before htlc + // expiry when the invoice is accepted. + FinalCltvDelta int32 + + // Expiry defines how long after creation this invoice should expire. + Expiry time.Duration + + // PaymentPreimage is the preimage which is to be revealed in the + // occasion that an HTLC paying to the hash of this preimage is + // extended. Set to nil if the preimage isn't known yet. + PaymentPreimage *lntypes.Preimage + + // Value is the expected amount of milli-satoshis to be paid to an HTLC + // which can be satisfied by the above preimage. + Value lnwire.MilliSatoshi + + // PaymentAddr is a randomly generated value include in the MPP record + // by the sender to prevent probing of the receiver. + PaymentAddr [32]byte + + // Features is the feature vectors advertised on the payment request. + Features *lnwire.FeatureVector +} + +// String returns a human-readable description of the prominent contract terms. +func (c ContractTerm) String() string { + return fmt.Sprintf("amt=%v, expiry=%v, final_cltv_delta=%v", c.Value, + c.Expiry, c.FinalCltvDelta) +} + +// SetID is the extra unique tuple item for AMP invoices. In addition to +// setting a payment address, each repeated payment to an AMP invoice will also +// contain a set ID as well. +type SetID [32]byte + +// InvoiceStateAMP is a struct that associates the current state of an AMP +// invoice identified by its set ID along with the set of invoices identified +// by the circuit key. This allows callers to easily look up the latest state +// of an AMP "sub-invoice" and also look up the invoice HLTCs themselves in the +// greater HTLC map index. +type InvoiceStateAMP struct { + // State is the state of this sub-AMP invoice. + State HtlcState + + // SettleIndex indicates the location in the settle index that + // references this instance of InvoiceStateAMP, but only if + // this value is set (non-zero), and State is HtlcStateSettled. + SettleIndex uint64 + + // SettleDate is the date that the setID was settled. + SettleDate time.Time + + // InvoiceKeys is the set of circuit keys that can be used to locate + // the invoices for a given set ID. + InvoiceKeys map[CircuitKey]struct{} + + // AmtPaid is the total amount that was paid in the AMP sub-invoice. + // Fetching the full HTLC/invoice state allows one to extract the + // custom records as well as the break down of the payment splits used + // when paying. + AmtPaid lnwire.MilliSatoshi +} + +// copy makes a deep copy of the underlying InvoiceStateAMP. +func (i *InvoiceStateAMP) copy() (InvoiceStateAMP, error) { + result := *i + + // Make a copy of the InvoiceKeys map. + result.InvoiceKeys = make(map[CircuitKey]struct{}) + for k := range i.InvoiceKeys { + result.InvoiceKeys[k] = struct{}{} + } + + // As a safety measure, copy SettleDate. time.Time is concurrency safe + // except when using any of the (un)marshalling methods. + settleDateBytes, err := i.SettleDate.MarshalBinary() + if err != nil { + return InvoiceStateAMP{}, err + } + + err = result.SettleDate.UnmarshalBinary(settleDateBytes) + if err != nil { + return InvoiceStateAMP{}, err + } + + return result, nil +} + +// AMPInvoiceState represents a type that stores metadata related to the set of +// settled AMP "sub-invoices". +type AMPInvoiceState map[SetID]InvoiceStateAMP + +// Invoice is a payment invoice generated by a payee in order to request +// payment for some good or service. The inclusion of invoices within Lightning +// creates a payment work flow for merchants very similar to that of the +// existing financial system within PayPal, etc. Invoices are added to the +// database when a payment is requested, then can be settled manually once the +// payment is received at the upper layer. For record keeping purposes, +// invoices are never deleted from the database, instead a bit is toggled +// denoting the invoice has been fully settled. Within the database, all +// invoices must have a unique payment hash which is generated by taking the +// sha256 of the payment preimage. +type Invoice struct { + // Memo is an optional memo to be stored along side an invoice. The + // memo may contain further details pertaining to the invoice itself, + // or any other message which fits within the size constraints. + Memo []byte + + // PaymentRequest is the encoded payment request for this invoice. For + // spontaneous (keysend) payments, this field will be empty. + PaymentRequest []byte + + // CreationDate is the exact time the invoice was created. + CreationDate time.Time + + // SettleDate is the exact time the invoice was settled. + SettleDate time.Time + + // Terms are the contractual payment terms of the invoice. Once all the + // terms have been satisfied by the payer, then the invoice can be + // considered fully fulfilled. + // + // TODO(roasbeef): later allow for multiple terms to fulfill the final + // invoice: payment fragmentation, etc. + Terms ContractTerm + + // AddIndex is an auto-incrementing integer that acts as a + // monotonically increasing sequence number for all invoices created. + // Clients can then use this field as a "checkpoint" of sorts when + // implementing a streaming RPC to notify consumers of instances where + // an invoice has been added before they re-connected. + // + // NOTE: This index starts at 1. + AddIndex uint64 + + // SettleIndex is an auto-incrementing integer that acts as a + // monotonically increasing sequence number for all settled invoices. + // Clients can then use this field as a "checkpoint" of sorts when + // implementing a streaming RPC to notify consumers of instances where + // an invoice has been settled before they re-connected. + // + // NOTE: This index starts at 1. + SettleIndex uint64 + + // State describes the state the invoice is in. This is the global + // state of the invoice which may remain open even when a series of + // sub-invoices for this invoice has been settled. + State ContractState + + // AmtPaid is the final amount that we ultimately accepted for pay for + // this invoice. We specify this value independently as it's possible + // that the invoice originally didn't specify an amount, or the sender + // overpaid. + AmtPaid lnwire.MilliSatoshi + + // Htlcs records all htlcs that paid to this invoice. Some of these + // htlcs may have been marked as canceled. + Htlcs map[CircuitKey]*InvoiceHTLC + + // AMPState describes the state of any related sub-invoices AMP to this + // greater invoice. A sub-invoice is defined by a set of HTLCs with the + // same set ID that attempt to make one time or recurring payments to + // this greater invoice. It's possible for a sub-invoice to be canceled + // or settled, but the greater invoice still open. + AMPState AMPInvoiceState + + // HodlInvoice indicates whether the invoice should be held in the + // Accepted state or be settled right away. + HodlInvoice bool +} + +// HTLCSet returns the set of HTLCs belonging to setID and in the provided +// state. Passing a nil setID will return all HTLCs in the provided state in the +// case of legacy or MPP, and no HTLCs in the case of AMP. Otherwise, the +// returned set will be filtered by the populated setID which is used to +// retrieve AMP HTLC sets. +func (i *Invoice) HTLCSet(setID *[32]byte, + state HtlcState) map[CircuitKey]*InvoiceHTLC { + + htlcSet := make(map[CircuitKey]*InvoiceHTLC) + for key, htlc := range i.Htlcs { + // Only add HTLCs that are in the requested HtlcState. + if htlc.State != state { + continue + } + + if !htlc.IsInHTLCSet(setID) { + continue + } + + htlcSet[key] = htlc + } + + return htlcSet +} + +// HTLCSetCompliment returns the set of all HTLCs not belonging to setID that +// are in the target state. Passing a nil setID will return no invoices, since +// all MPP HTLCs are part of the same HTLC set. +func (i *Invoice) HTLCSetCompliment(setID *[32]byte, + state HtlcState) map[CircuitKey]*InvoiceHTLC { + + htlcSet := make(map[CircuitKey]*InvoiceHTLC) + for key, htlc := range i.Htlcs { + // Only add HTLCs that are in the requested HtlcState. + if htlc.State != state { + continue + } + + // We are constructing the compliment, so filter anything that + // matches this set id. + if htlc.IsInHTLCSet(setID) { + continue + } + + htlcSet[key] = htlc + } + + return htlcSet +} + +// HtlcState defines the states an htlc paying to an invoice can be in. +type HtlcState uint8 + +const ( + // HtlcStateAccepted indicates the htlc is locked-in, but not resolved. + HtlcStateAccepted HtlcState = iota + + // HtlcStateCanceled indicates the htlc is canceled back to the + // sender. + HtlcStateCanceled + + // HtlcStateSettled indicates the htlc is settled. + HtlcStateSettled +) + +// InvoiceHTLC contains details about an htlc paying to this invoice. +type InvoiceHTLC struct { + // Amt is the amount that is carried by this htlc. + Amt lnwire.MilliSatoshi + + // MppTotalAmt is a field for mpp that indicates the expected total + // amount. + MppTotalAmt lnwire.MilliSatoshi + + // AcceptHeight is the block height at which the invoice registry + // decided to accept this htlc as a payment to the invoice. At this + // height, the invoice cltv delay must have been met. + AcceptHeight uint32 + + // AcceptTime is the wall clock time at which the invoice registry + // decided to accept the htlc. + AcceptTime time.Time + + // ResolveTime is the wall clock time at which the invoice registry + // decided to settle the htlc. + ResolveTime time.Time + + // Expiry is the expiry height of this htlc. + Expiry uint32 + + // State indicates the state the invoice htlc is currently in. A + // canceled htlc isn't just removed from the invoice htlcs map, because + // we need AcceptHeight to properly cancel the htlc back. + State HtlcState + + // CustomRecords contains the custom key/value pairs that accompanied + // the htlc. + CustomRecords record.CustomSet + + // AMP encapsulates additional data relevant to AMP HTLCs. This includes + // the AMP onion record, in addition to the HTLC's payment hash and + // preimage since these are unique to each AMP HTLC, and not the invoice + // as a whole. + // + // NOTE: This value will only be set for AMP HTLCs. + AMP *InvoiceHtlcAMPData +} + +// Copy makes a deep copy of the target InvoiceHTLC. +func (h *InvoiceHTLC) Copy() *InvoiceHTLC { + result := *h + + // Make a copy of the CustomSet map. + result.CustomRecords = make(record.CustomSet) + for k, v := range h.CustomRecords { + result.CustomRecords[k] = v + } + + result.AMP = h.AMP.Copy() + + return &result +} + +// IsInHTLCSet returns true if this HTLC is part an HTLC set. If nil is passed, +// this method returns true if this is an MPP HTLC. Otherwise, it only returns +// true if the AMP HTLC's set id matches the populated setID. +func (h *InvoiceHTLC) IsInHTLCSet(setID *[32]byte) bool { + wantAMPSet := setID != nil + isAMPHtlc := h.AMP != nil + + // Non-AMP HTLCs cannot be part of AMP HTLC sets, and vice versa. + if wantAMPSet != isAMPHtlc { + return false + } + + // Skip AMP HTLCs that have differing set ids. + if isAMPHtlc && *setID != h.AMP.Record.SetID() { + return false + } + + return true +} + +// InvoiceHtlcAMPData is a struct hodling the additional metadata stored for +// each received AMP HTLC. This includes the AMP onion record, in addition to +// the HTLC's payment hash and preimage. +type InvoiceHtlcAMPData struct { + // AMP is a copy of the AMP record presented in the onion payload + // containing the information necessary to correlate and settle a + // spontaneous HTLC set. Newly accepted legacy keysend payments will + // also have this field set as we automatically promote them into an AMP + // payment for internal processing. + Record record.AMP + + // Hash is an HTLC-level payment hash that is stored only for AMP + // payments. This is done because an AMP HTLC will carry a different + // payment hash from the invoice it might be satisfying, so we track the + // payment hashes individually to able to compute whether or not the + // reconstructed preimage correctly matches the HTLC's hash. + Hash lntypes.Hash + + // Preimage is an HTLC-level preimage that satisfies the AMP HTLC's + // Hash. The preimage will be derived either from secret share + // reconstruction of the shares in the AMP payload. + // + // NOTE: Preimage will only be present once the HTLC is in + // HtlcStateSettled. + Preimage *lntypes.Preimage +} + +// Copy returns a deep copy of the InvoiceHtlcAMPData. +func (d *InvoiceHtlcAMPData) Copy() *InvoiceHtlcAMPData { + if d == nil { + return nil + } + + var preimage *lntypes.Preimage + if d.Preimage != nil { + pimg := *d.Preimage + preimage = &pimg + } + + return &InvoiceHtlcAMPData{ + Record: d.Record, + Hash: d.Hash, + Preimage: preimage, + } +} + +// HtlcAcceptDesc describes the details of a newly accepted htlc. +type HtlcAcceptDesc struct { + // AcceptHeight is the block height at which this htlc was accepted. + AcceptHeight int32 + + // Amt is the amount that is carried by this htlc. + Amt lnwire.MilliSatoshi + + // MppTotalAmt is a field for mpp that indicates the expected total + // amount. + MppTotalAmt lnwire.MilliSatoshi + + // Expiry is the expiry height of this htlc. + Expiry uint32 + + // CustomRecords contains the custom key/value pairs that accompanied + // the htlc. + CustomRecords record.CustomSet + + // AMP encapsulates additional data relevant to AMP HTLCs. This includes + // the AMP onion record, in addition to the HTLC's payment hash and + // preimage since these are unique to each AMP HTLC, and not the invoice + // as a whole. + // + // NOTE: This value will only be set for AMP HTLCs. + AMP *InvoiceHtlcAMPData +} + +// InvoiceUpdateDesc describes the changes that should be applied to the +// invoice. +type InvoiceUpdateDesc struct { + // State is the new state that this invoice should progress to. If nil, + // the state is left unchanged. + State *InvoiceStateUpdateDesc + + // CancelHtlcs describes the htlcs that need to be canceled. + CancelHtlcs map[CircuitKey]struct{} + + // AddHtlcs describes the newly accepted htlcs that need to be added to + // the invoice. + AddHtlcs map[CircuitKey]*HtlcAcceptDesc + + // SetID is an optional set ID for AMP invoices that allows operations + // to be more efficient by ensuring we don't need to read out the + // entire HTLC set each timee an HTLC is to be cancelled. + SetID *SetID +} + +// InvoiceStateUpdateDesc describes an invoice-level state transition. +type InvoiceStateUpdateDesc struct { + // NewState is the new state that this invoice should progress to. + NewState ContractState + + // Preimage must be set to the preimage when NewState is settled. + Preimage *lntypes.Preimage + + // HTLCPreimages set the HTLC-level preimages stored for AMP HTLCs. + // These are only learned when settling the invoice as a whole. Must be + // set when settling an invoice with non-nil SetID. + HTLCPreimages map[CircuitKey]lntypes.Preimage + + // SetID identifies a specific set of HTLCs destined for the same + // invoice as part of a larger AMP payment. This value will be nil for + // legacy or MPP payments. + SetID *[32]byte +} + +// InvoiceUpdateCallback is a callback used in the db transaction to update the +// invoice. +type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) + +func ValidateInvoice(i *Invoice, paymentHash lntypes.Hash) error { + // Avoid conflicts with all-zeroes magic value in the database. + if paymentHash == UnknownPreimage.Hash() { + return fmt.Errorf("cannot use hash of all-zeroes preimage") + } + + if len(i.Memo) > MaxMemoSize { + return fmt.Errorf("max length a memo is %v, and invoice "+ + "of length %v was provided", MaxMemoSize, len(i.Memo)) + } + if len(i.PaymentRequest) > MaxPaymentRequestSize { + return fmt.Errorf("max length of payment request is %v, "+ + "length provided was %v", MaxPaymentRequestSize, + len(i.PaymentRequest)) + } + if i.Terms.Features == nil { + return errors.New("invoice must have a feature vector") + } + + err := feature.ValidateDeps(i.Terms.Features) + if err != nil { + return err + } + + // AMP invoices and hodl invoices are allowed to have no preimage + // specified. + isAMP := i.Terms.Features.HasFeature( + lnwire.AMPOptional, + ) + if i.Terms.PaymentPreimage == nil && !(i.HodlInvoice || isAMP) { + return errors.New("non-hodl invoices must have a preimage") + } + + if len(i.Htlcs) > 0 { + return ErrInvoiceHasHtlcs + } + + return nil +} + +// IsPending returns true if the invoice is in ContractOpen state. +func (i *Invoice) IsPending() bool { + return i.State == ContractOpen || i.State == ContractAccepted +} + +// copySlice allocates a new slice and copies the source into it. +func copySlice(src []byte) []byte { + dest := make([]byte, len(src)) + copy(dest, src) + return dest +} + +// CopyInvoice makes a deep copy of the supplied invoice. +func CopyInvoice(src *Invoice) (*Invoice, error) { + dest := Invoice{ + Memo: copySlice(src.Memo), + PaymentRequest: copySlice(src.PaymentRequest), + CreationDate: src.CreationDate, + SettleDate: src.SettleDate, + Terms: src.Terms, + AddIndex: src.AddIndex, + SettleIndex: src.SettleIndex, + State: src.State, + AmtPaid: src.AmtPaid, + Htlcs: make( + map[CircuitKey]*InvoiceHTLC, len(src.Htlcs), + ), + AMPState: make(map[SetID]InvoiceStateAMP), + HodlInvoice: src.HodlInvoice, + } + + dest.Terms.Features = src.Terms.Features.Clone() + + if src.Terms.PaymentPreimage != nil { + preimage := *src.Terms.PaymentPreimage + dest.Terms.PaymentPreimage = &preimage + } + + for k, v := range src.Htlcs { + dest.Htlcs[k] = v.Copy() + } + + // Lastly, copy the amp invoice state. + for k, v := range src.AMPState { + ampInvState, err := v.copy() + if err != nil { + return nil, err + } + + dest.AMPState[k] = ampInvState + } + + return &dest, nil +} + +// InvoiceDeleteRef holds a reference to an invoice to be deleted. +type InvoiceDeleteRef struct { + // PayHash is the payment hash of the target invoice. All invoices are + // currently indexed by payment hash. + PayHash lntypes.Hash + + // PayAddr is the payment addr of the target invoice. Newer invoices + // (0.11 and up) are indexed by payment address in addition to payment + // hash, but pre 0.8 invoices do not have one at all. + PayAddr *[32]byte + + // AddIndex is the add index of the invoice. + AddIndex uint64 + + // SettleIndex is the settle index of the invoice. + SettleIndex uint64 +} diff --git a/invoices/mock.go b/invoices/mock.go new file mode 100644 index 000000000..8410208bf --- /dev/null +++ b/invoices/mock.go @@ -0,0 +1,78 @@ +package invoices + +import ( + "github.com/lightningnetwork/lnd/lntypes" + "github.com/stretchr/testify/mock" +) + +type MockInvoiceDB struct { + mock.Mock +} + +func NewInvoicesDBMock() *MockInvoiceDB { + return &MockInvoiceDB{} +} + +func (m *MockInvoiceDB) AddInvoice(invoice *Invoice, + paymentHash lntypes.Hash) (uint64, error) { + + args := m.Called(invoice, paymentHash) + + addIndex, _ := args.Get(0).(uint64) + + // NOTE: this is a side effect of the AddInvoice method. + invoice.AddIndex = addIndex + + return addIndex, args.Error(1) +} + +func (m *MockInvoiceDB) InvoicesAddedSince(idx uint64) ([]Invoice, error) { + args := m.Called(idx) + invoices, _ := args.Get(0).([]Invoice) + + return invoices, args.Error(1) +} + +func (m *MockInvoiceDB) InvoicesSettledSince(idx uint64) ([]Invoice, error) { + args := m.Called(idx) + invoices, _ := args.Get(0).([]Invoice) + + return invoices, args.Error(1) +} + +func (m *MockInvoiceDB) LookupInvoice(ref InvoiceRef) (Invoice, error) { + args := m.Called(ref) + invoice, _ := args.Get(0).(Invoice) + + return invoice, args.Error(1) +} + +func (m *MockInvoiceDB) ScanInvoices(scanFunc InvScanFunc, + reset func()) error { + + args := m.Called(scanFunc, reset) + + return args.Error(0) +} + +func (m *MockInvoiceDB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { + args := m.Called(q) + invoiceSlice, _ := args.Get(0).(InvoiceSlice) + + return invoiceSlice, args.Error(1) +} + +func (m *MockInvoiceDB) UpdateInvoice(ref InvoiceRef, setIDHint *SetID, + callback InvoiceUpdateCallback) (*Invoice, error) { + + args := m.Called(ref, setIDHint, callback) + invoice, _ := args.Get(0).(*Invoice) + + return invoice, args.Error(1) +} + +func (m *MockInvoiceDB) DeleteInvoice(invoices []InvoiceDeleteRef) error { + args := m.Called(invoices) + + return args.Error(0) +} diff --git a/invoices/resolution.go b/invoices/resolution.go index e878cda6d..c2868857d 100644 --- a/invoices/resolution.go +++ b/invoices/resolution.go @@ -3,7 +3,6 @@ package invoices import ( "time" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lntypes" ) @@ -11,14 +10,14 @@ import ( type HtlcResolution interface { // CircuitKey returns the circuit key for the htlc that we have a // resolution for. - CircuitKey() models.CircuitKey + CircuitKey() CircuitKey } // HtlcFailResolution is an implementation of the HtlcResolution interface // which is returned when a htlc is failed. type HtlcFailResolution struct { // circuitKey is the key of the htlc for which we have a resolution. - circuitKey models.CircuitKey + circuitKey CircuitKey // AcceptHeight is the original height at which the htlc was accepted. AcceptHeight int32 @@ -28,7 +27,7 @@ type HtlcFailResolution struct { } // NewFailResolution returns a htlc failure resolution. -func NewFailResolution(key models.CircuitKey, acceptHeight int32, +func NewFailResolution(key CircuitKey, acceptHeight int32, outcome FailResolutionResult) *HtlcFailResolution { return &HtlcFailResolution{ @@ -42,7 +41,7 @@ func NewFailResolution(key models.CircuitKey, acceptHeight int32, // resolution for. // // Note: it is part of the HtlcResolution interface. -func (f *HtlcFailResolution) CircuitKey() models.CircuitKey { +func (f *HtlcFailResolution) CircuitKey() CircuitKey { return f.circuitKey } @@ -53,7 +52,7 @@ type HtlcSettleResolution struct { Preimage lntypes.Preimage // circuitKey is the key of the htlc for which we have a resolution. - circuitKey models.CircuitKey + circuitKey CircuitKey // acceptHeight is the original height at which the htlc was accepted. AcceptHeight int32 @@ -64,8 +63,8 @@ type HtlcSettleResolution struct { // NewSettleResolution returns a htlc resolution which is associated with a // settle. -func NewSettleResolution(preimage lntypes.Preimage, - key models.CircuitKey, acceptHeight int32, +func NewSettleResolution(preimage lntypes.Preimage, key CircuitKey, + acceptHeight int32, outcome SettleResolutionResult) *HtlcSettleResolution { return &HtlcSettleResolution{ @@ -80,7 +79,7 @@ func NewSettleResolution(preimage lntypes.Preimage, // resolution for. // // Note: it is part of the HtlcResolution interface. -func (s *HtlcSettleResolution) CircuitKey() models.CircuitKey { +func (s *HtlcSettleResolution) CircuitKey() CircuitKey { return s.circuitKey } @@ -92,7 +91,7 @@ func (s *HtlcSettleResolution) CircuitKey() models.CircuitKey { // acceptResolution, a nil resolution should be surfaced. type htlcAcceptResolution struct { // circuitKey is the key of the htlc for which we have a resolution. - circuitKey models.CircuitKey + circuitKey CircuitKey // autoRelease signals that the htlc should be automatically released // after a timeout. @@ -107,7 +106,7 @@ type htlcAcceptResolution struct { // newAcceptResolution returns a htlc resolution which is associated with a // htlc accept. -func newAcceptResolution(key models.CircuitKey, +func newAcceptResolution(key CircuitKey, outcome acceptResolutionResult) *htlcAcceptResolution { return &htlcAcceptResolution{ @@ -120,6 +119,6 @@ func newAcceptResolution(key models.CircuitKey, // resolution for. // // Note: it is part of the HtlcResolution interface. -func (a *htlcAcceptResolution) CircuitKey() models.CircuitKey { +func (a *htlcAcceptResolution) CircuitKey() CircuitKey { return a.circuitKey } diff --git a/invoices/test_utils.go b/invoices/test_utils.go new file mode 100644 index 000000000..dc83fa253 --- /dev/null +++ b/invoices/test_utils.go @@ -0,0 +1,287 @@ +package invoices + +import ( + "crypto/rand" + "encoding/binary" + "encoding/hex" + "fmt" + "sync" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/zpay32" + "github.com/stretchr/testify/require" +) + +type mockChainNotifier struct { + chainntnfs.ChainNotifier + + blockChan chan *chainntnfs.BlockEpoch +} + +func newMockNotifier() *mockChainNotifier { + return &mockChainNotifier{ + blockChan: make(chan *chainntnfs.BlockEpoch), + } +} + +// RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's +// block channel to deliver blocks to the client. +func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( + *chainntnfs.BlockEpochEvent, error) { + + return &chainntnfs.BlockEpochEvent{ + Epochs: m.blockChan, + Cancel: func() {}, + }, nil +} + +const ( + testCurrentHeight = int32(1) +) + +var ( + testTimeout = 5 * time.Second + + testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC) + + testInvoicePreimage = lntypes.Preimage{1} + + testPrivKeyBytes, _ = hex.DecodeString( + "e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2d" + + "b734", + ) + + testPrivKey, _ = btcec.PrivKeyFromBytes(testPrivKeyBytes) + + testInvoiceDescription = "coffee" + + testInvoiceAmount = lnwire.MilliSatoshi(100000) //nolint:gomnd + + testNetParams = &chaincfg.MainNetParams + + testMessageSigner = zpay32.MessageSigner{ + SignCompact: func(msg []byte) ([]byte, error) { + hash := chainhash.HashB(msg) + sig, err := ecdsa.SignCompact(testPrivKey, hash, true) + if err != nil { + return nil, fmt.Errorf("can't sign the "+ + "message: %v", err) + } + + return sig, nil + }, + } + + testFeatures = lnwire.NewFeatureVector( + nil, lnwire.Features, + ) +) + +func newTestInvoice(t *testing.T, preimage lntypes.Preimage, + timestamp time.Time, expiry time.Duration) *Invoice { + + t.Helper() + + if expiry == 0 { + expiry = time.Hour + } + + var payAddr [32]byte + if _, err := rand.Read(payAddr[:]); err != nil { + t.Fatalf("unable to generate payment addr: %v", err) + } + + rawInvoice, err := zpay32.NewInvoice( + testNetParams, + preimage.Hash(), + timestamp, + zpay32.Amount(testInvoiceAmount), + zpay32.Description(testInvoiceDescription), + zpay32.Expiry(expiry), + zpay32.PaymentAddr(payAddr), + ) + require.NoError(t, err, "Error while creating new invoice") + + paymentRequest, err := rawInvoice.Encode(testMessageSigner) + + require.NoError(t, err, "Error while encoding payment request") + + return &Invoice{ + Terms: ContractTerm{ + PaymentPreimage: &preimage, + PaymentAddr: payAddr, + Value: testInvoiceAmount, + Expiry: expiry, + Features: testFeatures, + }, + PaymentRequest: []byte(paymentRequest), + CreationDate: timestamp, + } +} + +// invoiceExpiryTestData simply holds generated expired and pending invoices. +type invoiceExpiryTestData struct { + expiredInvoices map[lntypes.Hash]*Invoice + pendingInvoices map[lntypes.Hash]*Invoice +} + +// generateInvoiceExpiryTestData generates the specified number of fake expired +// and pending invoices anchored to the passed now timestamp. +func generateInvoiceExpiryTestData( + t *testing.T, now time.Time, + offset, numExpired, numPending int) invoiceExpiryTestData { + + t.Helper() + + var testData invoiceExpiryTestData + + testData.expiredInvoices = make(map[lntypes.Hash]*Invoice) + testData.pendingInvoices = make(map[lntypes.Hash]*Invoice) + + expiredCreationDate := now.Add(-24 * time.Hour) + + for i := 1; i <= numExpired; i++ { + var preimage lntypes.Preimage + binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i)) + duration := (i + offset) % 24 //nolint:gomnd + expiry := time.Duration(duration) * time.Hour + invoice := newTestInvoice( + t, preimage, expiredCreationDate, expiry, + ) + testData.expiredInvoices[preimage.Hash()] = invoice + } + + for i := 1; i <= numPending; i++ { + var preimage lntypes.Preimage + binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i)) + duration := (i + offset) % 24 //nolint:gomnd + expiry := time.Duration(duration) * time.Hour + invoice := newTestInvoice(t, preimage, now, expiry) + testData.pendingInvoices[preimage.Hash()] = invoice + } + + return testData +} + +type hodlExpiryTest struct { + hash lntypes.Hash + state ContractState + stateLock sync.Mutex + mockNotifier *mockChainNotifier + mockClock *clock.TestClock + cancelChan chan lntypes.Hash + watcher *InvoiceExpiryWatcher +} + +func (h *hodlExpiryTest) setState(state ContractState) { + h.stateLock.Lock() + defer h.stateLock.Unlock() + + h.state = state +} + +func (h *hodlExpiryTest) announceBlock(t *testing.T, height uint32) { + t.Helper() + + select { + case h.mockNotifier.blockChan <- &chainntnfs.BlockEpoch{ + Height: int32(height), + }: + + case <-time.After(testTimeout): + t.Fatalf("block %v not consumed", height) + } +} + +func (h *hodlExpiryTest) assertCanceled(t *testing.T, expected lntypes.Hash) { + t.Helper() + + select { + case actual := <-h.cancelChan: + require.Equal(t, expected, actual) + + case <-time.After(testTimeout): + t.Fatalf("invoice: %v not canceled", h.hash) + } +} + +// setupHodlExpiry creates a hodl invoice in our expiry watcher and runs an +// arbitrary update function which advances the invoices's state. +func setupHodlExpiry(t *testing.T, creationDate time.Time, + expiry time.Duration, heightDelta uint32, + startState ContractState, + startHtlcs []*InvoiceHTLC) *hodlExpiryTest { + + t.Helper() + + mockNotifier := newMockNotifier() + mockClock := clock.NewTestClock(testTime) + + test := &hodlExpiryTest{ + state: startState, + watcher: NewInvoiceExpiryWatcher( + mockClock, heightDelta, uint32(testCurrentHeight), nil, + mockNotifier, + ), + cancelChan: make(chan lntypes.Hash), + mockNotifier: mockNotifier, + mockClock: mockClock, + } + + // Use an unbuffered channel to block on cancel calls so that the test + // does not exit before we've processed all the invoices we expect. + cancelImpl := func(paymentHash lntypes.Hash, force bool) error { + test.stateLock.Lock() + currentState := test.state + test.stateLock.Unlock() + + if currentState != ContractOpen && !force { + return nil + } + + select { + case test.cancelChan <- paymentHash: + case <-time.After(testTimeout): + } + + return nil + } + + require.NoError(t, test.watcher.Start(cancelImpl)) + + // We set preimage and hash so that we can use our existing test + // helpers. In practice we would only have the hash, but this does not + // affect what we're testing at all. + preimage := lntypes.Preimage{1} + test.hash = preimage.Hash() + + invoice := newTestInvoice(t, preimage, creationDate, expiry) + invoice.State = startState + invoice.HodlInvoice = true + invoice.Htlcs = make(map[CircuitKey]*InvoiceHTLC) + + // If we have any htlcs, add them with unique circult keys. + for i, htlc := range startHtlcs { + key := CircuitKey{ + HtlcID: uint64(i), + } + + invoice.Htlcs[key] = htlc + } + + // Create an expiry entry for our invoice in its starting state. This + // mimics adding invoices to the watcher on start. + entry := makeInvoiceExpiry(test.hash, invoice) + test.watcher.AddInvoices(entry) + + return test +} diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index 6f8da2571..e16b3ecc3 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -1,4 +1,4 @@ -package invoices +package invoices_test import ( "crypto/rand" @@ -17,8 +17,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" + invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -55,6 +55,29 @@ func (p *mockPayload) Metadata() []byte { return p.metadata } +type mockChainNotifier struct { + chainntnfs.ChainNotifier + + blockChan chan *chainntnfs.BlockEpoch +} + +func newMockNotifier() *mockChainNotifier { + return &mockChainNotifier{ + blockChan: make(chan *chainntnfs.BlockEpoch), + } +} + +// RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's +// block channel to deliver blocks to the client. +func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( + *chainntnfs.BlockEpochEvent, error) { + + return &chainntnfs.BlockEpochEvent{ + Epochs: m.blockChan, + Cancel: func() {}, + }, nil +} + const ( testHtlcExpiry = uint32(5) @@ -92,7 +115,8 @@ var ( hash := chainhash.HashB(msg) sig, err := ecdsa.SignCompact(testPrivKey, hash, true) if err != nil { - return nil, fmt.Errorf("can't sign the message: %v", err) + return nil, fmt.Errorf("can't sign the "+ + "message: %v", err) } return sig, nil }, @@ -109,8 +133,8 @@ var ( var ( testInvoiceAmt = lnwire.MilliSatoshi(100000) - testInvoice = &channeldb.Invoice{ - Terms: channeldb.ContractTerm{ + testInvoice = &invpkg.Invoice{ + Terms: invpkg.ContractTerm{ PaymentPreimage: &testInvoicePreimage, Value: testInvoiceAmt, Expiry: time.Hour, @@ -119,8 +143,8 @@ var ( CreationDate: testInvoiceCreationDate, } - testPayAddrReqInvoice = &channeldb.Invoice{ - Terms: channeldb.ContractTerm{ + testPayAddrReqInvoice = &invpkg.Invoice{ + Terms: invpkg.ContractTerm{ PaymentPreimage: &testInvoicePreimage, Value: testInvoiceAmt, Expiry: time.Hour, @@ -135,8 +159,8 @@ var ( CreationDate: testInvoiceCreationDate, } - testPayAddrOptionalInvoice = &channeldb.Invoice{ - Terms: channeldb.ContractTerm{ + testPayAddrOptionalInvoice = &invpkg.Invoice{ + Terms: invpkg.ContractTerm{ PaymentPreimage: &testInvoicePreimage, Value: testInvoiceAmt, Expiry: time.Hour, @@ -151,8 +175,8 @@ var ( CreationDate: testInvoiceCreationDate, } - testHodlInvoice = &channeldb.Invoice{ - Terms: channeldb.ContractTerm{ + testHodlInvoice = &invpkg.Invoice{ + Terms: invpkg.ContractTerm{ Value: testInvoiceAmt, Expiry: time.Hour, Features: testFeatures, @@ -163,6 +187,8 @@ var ( ) func newTestChannelDB(t *testing.T, clock clock.Clock) (*channeldb.DB, error) { + t.Helper() + // Create channeldb for the first time. cdb, err := channeldb.Open( t.TempDir(), channeldb.OptionClock(clock), @@ -179,35 +205,47 @@ func newTestChannelDB(t *testing.T, clock clock.Clock) (*channeldb.DB, error) { } type testContext struct { - cdb *channeldb.DB - registry *InvoiceRegistry + idb *channeldb.DB + registry *invpkg.InvoiceRegistry notifier *mockChainNotifier clock *clock.TestClock t *testing.T } -func newTestContext(t *testing.T) *testContext { +func defaultRegistryConfig() invpkg.RegistryConfig { + return invpkg.RegistryConfig{ + FinalCltvRejectDelta: testFinalCltvRejectDelta, + HtlcHoldDuration: 30 * time.Second, + } +} + +func newTestContext(t *testing.T, + registryCfg *invpkg.RegistryConfig) *testContext { + + t.Helper() + clock := clock.NewTestClock(testTime) - cdb, err := newTestChannelDB(t, clock) + idb, err := newTestChannelDB(t, clock) if err != nil { t.Fatal(err) } notifier := newMockNotifier() - expiryWatcher := NewInvoiceExpiryWatcher( + expiryWatcher := invpkg.NewInvoiceExpiryWatcher( clock, 0, uint32(testCurrentHeight), nil, notifier, ) - // Instantiate and start the invoice ctx.registry. - cfg := RegistryConfig{ - FinalCltvRejectDelta: testFinalCltvRejectDelta, - HtlcHoldDuration: 30 * time.Second, - Clock: clock, + cfg := defaultRegistryConfig() + if registryCfg != nil { + cfg = *registryCfg } - registry := NewRegistry(cdb, expiryWatcher, &cfg) + cfg.Clock = clock + + // Instantiate and start the invoice ctx.registry. + registry := invpkg.NewRegistry(idb, expiryWatcher, &cfg) err = registry.Start() if err != nil { @@ -218,7 +256,7 @@ func newTestContext(t *testing.T) *testContext { }) ctx := testContext{ - cdb: cdb, + idb: idb, registry: registry, notifier: notifier, clock: clock, @@ -228,8 +266,8 @@ func newTestContext(t *testing.T) *testContext { return &ctx } -func getCircuitKey(htlcID uint64) models.CircuitKey { - return models.CircuitKey{ +func getCircuitKey(htlcID uint64) invpkg.CircuitKey { + return invpkg.CircuitKey{ ChanID: lnwire.ShortChannelID{ BlockHeight: 1, TxIndex: 2, TxPosition: 3, }, @@ -238,7 +276,7 @@ func getCircuitKey(htlcID uint64) models.CircuitKey { } func newTestInvoice(t *testing.T, preimage lntypes.Preimage, - timestamp time.Time, expiry time.Duration) *channeldb.Invoice { + timestamp time.Time, expiry time.Duration) *invpkg.Invoice { if expiry == 0 { expiry = time.Hour @@ -264,8 +302,8 @@ func newTestInvoice(t *testing.T, preimage lntypes.Preimage, require.NoError(t, err, "Error while encoding payment request") - return &channeldb.Invoice{ - Terms: channeldb.ContractTerm{ + return &invpkg.Invoice{ + Terms: invpkg.ContractTerm{ PaymentPreimage: &preimage, PaymentAddr: payAddr, Value: testInvoiceAmount, @@ -286,7 +324,8 @@ func timeout() func() { case <-time.After(5 * time.Second): err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) if err != nil { - panic(fmt.Sprintf("error writing to std out after timeout: %v", err)) + panic(fmt.Sprintf("error writing to std out "+ + "after timeout: %v", err)) } panic("timeout") case <-done: @@ -300,8 +339,8 @@ func timeout() func() { // invoiceExpiryTestData simply holds generated expired and pending invoices. type invoiceExpiryTestData struct { - expiredInvoices map[lntypes.Hash]*channeldb.Invoice - pendingInvoices map[lntypes.Hash]*channeldb.Invoice + expiredInvoices map[lntypes.Hash]*invpkg.Invoice + pendingInvoices map[lntypes.Hash]*invpkg.Invoice } // generateInvoiceExpiryTestData generates the specified number of fake expired @@ -312,8 +351,8 @@ func generateInvoiceExpiryTestData( var testData invoiceExpiryTestData - testData.expiredInvoices = make(map[lntypes.Hash]*channeldb.Invoice) - testData.pendingInvoices = make(map[lntypes.Hash]*channeldb.Invoice) + testData.expiredInvoices = make(map[lntypes.Hash]*invpkg.Invoice) + testData.pendingInvoices = make(map[lntypes.Hash]*invpkg.Invoice) expiredCreationDate := now.Add(-24 * time.Hour) @@ -321,7 +360,9 @@ func generateInvoiceExpiryTestData( var preimage lntypes.Preimage binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i)) expiry := time.Duration((i+offset)%24) * time.Hour - invoice := newTestInvoice(t, preimage, expiredCreationDate, expiry) + invoice := newTestInvoice( + t, preimage, expiredCreationDate, expiry, + ) testData.expiredInvoices[preimage.Hash()] = invoice } @@ -339,12 +380,12 @@ func generateInvoiceExpiryTestData( // checkSettleResolution asserts the resolution is a settle with the correct // preimage. If successful, the HtlcSettleResolution is returned in case further // checks are desired. -func checkSettleResolution(t *testing.T, res HtlcResolution, - expPreimage lntypes.Preimage) *HtlcSettleResolution { +func checkSettleResolution(t *testing.T, res invpkg.HtlcResolution, + expPreimage lntypes.Preimage) *invpkg.HtlcSettleResolution { t.Helper() - settleResolution, ok := res.(*HtlcSettleResolution) + settleResolution, ok := res.(*invpkg.HtlcSettleResolution) require.True(t, ok) require.Equal(t, expPreimage, settleResolution.Preimage) @@ -354,11 +395,11 @@ func checkSettleResolution(t *testing.T, res HtlcResolution, // checkFailResolution asserts the resolution is a fail with the correct reason. // If successful, the HtlcFailResolution is returned in case further checks are // desired. -func checkFailResolution(t *testing.T, res HtlcResolution, - expOutcome FailResolutionResult) *HtlcFailResolution { +func checkFailResolution(t *testing.T, res invpkg.HtlcResolution, + expOutcome invpkg.FailResolutionResult) *invpkg.HtlcFailResolution { t.Helper() - failResolution, ok := res.(*HtlcFailResolution) + failResolution, ok := res.(*invpkg.HtlcFailResolution) require.True(t, ok) require.Equal(t, expOutcome, failResolution.Outcome) @@ -367,22 +408,17 @@ func checkFailResolution(t *testing.T, res HtlcResolution, type hodlExpiryTest struct { hash lntypes.Hash - state channeldb.ContractState + state invpkg.ContractState stateLock sync.Mutex mockNotifier *mockChainNotifier mockClock *clock.TestClock cancelChan chan lntypes.Hash - watcher *InvoiceExpiryWatcher -} - -func (h *hodlExpiryTest) setState(state channeldb.ContractState) { - h.stateLock.Lock() - defer h.stateLock.Unlock() - - h.state = state + watcher *invpkg.InvoiceExpiryWatcher } func (h *hodlExpiryTest) announceBlock(t *testing.T, height uint32) { + t.Helper() + select { case h.mockNotifier.blockChan <- &chainntnfs.BlockEpoch{ Height: int32(height), @@ -402,73 +438,3 @@ func (h *hodlExpiryTest) assertCanceled(t *testing.T, expected lntypes.Hash) { t.Fatalf("invoice: %v not canceled", h.hash) } } - -// setupHodlExpiry creates a hodl invoice in our expiry watcher and runs an -// arbitrary update function which advances the invoices's state. -func setupHodlExpiry(t *testing.T, creationDate time.Time, - expiry time.Duration, heightDelta uint32, - startState channeldb.ContractState, - startHtlcs []*channeldb.InvoiceHTLC) *hodlExpiryTest { - - mockNotifier := newMockNotifier() - mockClock := clock.NewTestClock(testTime) - - test := &hodlExpiryTest{ - state: startState, - watcher: NewInvoiceExpiryWatcher( - mockClock, heightDelta, uint32(testCurrentHeight), nil, - mockNotifier, - ), - cancelChan: make(chan lntypes.Hash), - mockNotifier: mockNotifier, - mockClock: mockClock, - } - - // Use an unbuffered channel to block on cancel calls so that the test - // does not exit before we've processed all the invoices we expect. - cancelImpl := func(paymentHash lntypes.Hash, force bool) error { - test.stateLock.Lock() - currentState := test.state - test.stateLock.Unlock() - - if currentState != channeldb.ContractOpen && !force { - return nil - } - - select { - case test.cancelChan <- paymentHash: - case <-time.After(testTimeout): - } - - return nil - } - - require.NoError(t, test.watcher.Start(cancelImpl)) - - // We set preimage and hash so that we can use our existing test - // helpers. In practice we would only have the hash, but this does not - // affect what we're testing at all. - preimage := lntypes.Preimage{1} - test.hash = preimage.Hash() - - invoice := newTestInvoice(t, preimage, creationDate, expiry) - invoice.State = startState - invoice.HodlInvoice = true - invoice.Htlcs = make(map[models.CircuitKey]*channeldb.InvoiceHTLC) - - // If we have any htlcs, add them with unique circult keys. - for i, htlc := range startHtlcs { - key := models.CircuitKey{ - HtlcID: uint64(i), - } - - invoice.Htlcs[key] = htlc - } - - // Create an expiry entry for our invoice in its starting state. This - // mimics adding invoices to the watcher on start. - entry := makeInvoiceExpiry(test.hash, invoice) - test.watcher.AddInvoices(entry) - - return test -} diff --git a/invoices/update.go b/invoices/update.go index 527cb8c55..e9e7be6ed 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -5,8 +5,6 @@ import ( "errors" "github.com/lightningnetwork/lnd/amp" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -16,7 +14,7 @@ import ( // update to be carried out. type invoiceUpdateCtx struct { hash lntypes.Hash - circuitKey models.CircuitKey + circuitKey CircuitKey amtPaid lnwire.MilliSatoshi expiry uint32 currentHeight int32 @@ -29,16 +27,18 @@ type invoiceUpdateCtx struct { // invoiceRef returns an identifier that can be used to lookup or update the // invoice this HTLC is targeting. -func (i *invoiceUpdateCtx) invoiceRef() channeldb.InvoiceRef { +func (i *invoiceUpdateCtx) invoiceRef() InvoiceRef { switch { case i.amp != nil && i.mpp != nil: payAddr := i.mpp.PaymentAddr() - return channeldb.InvoiceRefByAddr(payAddr) + return InvoiceRefByAddr(payAddr) + case i.mpp != nil: payAddr := i.mpp.PaymentAddr() - return channeldb.InvoiceRefByHashAndAddr(i.hash, payAddr) + return InvoiceRefByHashAndAddr(i.hash, payAddr) + default: - return channeldb.InvoiceRefByHash(i.hash) + return InvoiceRefByHash(i.hash) } } @@ -95,20 +95,20 @@ func (i invoiceUpdateCtx) acceptRes(outcome acceptResolutionResult) *htlcAcceptR // updateInvoice is a callback for DB.UpdateInvoice that contains the invoice // settlement logic. It returns a hltc resolution that indicates what the // outcome of the update was. -func updateInvoice(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) ( - *channeldb.InvoiceUpdateDesc, HtlcResolution, error) { +func updateInvoice(ctx *invoiceUpdateCtx, inv *Invoice) ( + *InvoiceUpdateDesc, HtlcResolution, error) { // Don't update the invoice when this is a replayed htlc. htlc, ok := inv.Htlcs[ctx.circuitKey] if ok { switch htlc.State { - case channeldb.HtlcStateCanceled: + case HtlcStateCanceled: return nil, ctx.failRes(ResultReplayToCanceled), nil - case channeldb.HtlcStateAccepted: + case HtlcStateAccepted: return nil, ctx.acceptRes(resultReplayToAccepted), nil - case channeldb.HtlcStateSettled: + case HtlcStateSettled: pre := inv.Terms.PaymentPreimage // Terms.PaymentPreimage will be nil for AMP invoices. @@ -139,8 +139,7 @@ func updateInvoice(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) ( // updateMpp is a callback for DB.UpdateInvoice that contains the invoice // settlement logic for mpp payments. -func updateMpp(ctx *invoiceUpdateCtx, - inv *channeldb.Invoice) (*channeldb.InvoiceUpdateDesc, +func updateMpp(ctx *invoiceUpdateCtx, inv *Invoice) (*InvoiceUpdateDesc, HtlcResolution, error) { // Reject HTLCs to AMP invoices if they are missing an AMP payload, and @@ -160,7 +159,7 @@ func updateMpp(ctx *invoiceUpdateCtx, setID := ctx.setID() // Start building the accept descriptor. - acceptDesc := &channeldb.HtlcAcceptDesc{ + acceptDesc := &HtlcAcceptDesc{ Amt: ctx.amtPaid, Expiry: ctx.expiry, AcceptHeight: ctx.currentHeight, @@ -169,7 +168,7 @@ func updateMpp(ctx *invoiceUpdateCtx, } if ctx.amp != nil { - acceptDesc.AMP = &channeldb.InvoiceHtlcAMPData{ + acceptDesc.AMP = &InvoiceHtlcAMPData{ Record: *ctx.amp, Hash: ctx.hash, Preimage: nil, @@ -180,7 +179,7 @@ func updateMpp(ctx *invoiceUpdateCtx, // non-mpp payments that are accepted even after the invoice is settled. // Because non-mpp payments don't have a payment address, this is needed // to thwart probing. - if inv.State != channeldb.ContractOpen { + if inv.State != ContractOpen { return nil, ctx.failRes(ResultInvoiceNotOpen), nil } @@ -200,7 +199,7 @@ func updateMpp(ctx *invoiceUpdateCtx, return nil, ctx.failRes(ResultHtlcSetTotalTooLow), nil } - htlcSet := inv.HTLCSet(setID, channeldb.HtlcStateAccepted) + htlcSet := inv.HTLCSet(setID, HtlcStateAccepted) // Check whether total amt matches other htlcs in the set. var newSetTotal lnwire.MilliSatoshi @@ -229,16 +228,16 @@ func updateMpp(ctx *invoiceUpdateCtx, return nil, ctx.failRes(ResultExpiryTooSoon), nil } - if setID != nil && *setID == channeldb.BlankPayAddr { + if setID != nil && *setID == BlankPayAddr { return nil, ctx.failRes(ResultAmpError), nil } // Record HTLC in the invoice database. - newHtlcs := map[models.CircuitKey]*channeldb.HtlcAcceptDesc{ + newHtlcs := map[CircuitKey]*HtlcAcceptDesc{ ctx.circuitKey: acceptDesc, } - update := channeldb.InvoiceUpdateDesc{ + update := InvoiceUpdateDesc{ AddHtlcs: newHtlcs, } @@ -251,23 +250,23 @@ func updateMpp(ctx *invoiceUpdateCtx, // Check to see if we can settle or this is an hold invoice and // we need to wait for the preimage. if inv.HodlInvoice { - update.State = &channeldb.InvoiceStateUpdateDesc{ - NewState: channeldb.ContractAccepted, + update.State = &InvoiceStateUpdateDesc{ + NewState: ContractAccepted, SetID: setID, } return &update, ctx.acceptRes(resultAccepted), nil } var ( - htlcPreimages map[models.CircuitKey]lntypes.Preimage + htlcPreimages map[CircuitKey]lntypes.Preimage htlcPreimage lntypes.Preimage ) if ctx.amp != nil { var failRes *HtlcFailResolution htlcPreimages, failRes = reconstructAMPPreimages(ctx, htlcSet) if failRes != nil { - update.State = &channeldb.InvoiceStateUpdateDesc{ - NewState: channeldb.ContractCanceled, + update.State = &InvoiceStateUpdateDesc{ + NewState: ContractCanceled, SetID: setID, } return &update, failRes, nil @@ -280,8 +279,8 @@ func updateMpp(ctx *invoiceUpdateCtx, htlcPreimage = *inv.Terms.PaymentPreimage } - update.State = &channeldb.InvoiceStateUpdateDesc{ - NewState: channeldb.ContractSettled, + update.State = &InvoiceStateUpdateDesc{ + NewState: ContractSettled, Preimage: inv.Terms.PaymentPreimage, HTLCPreimages: htlcPreimages, SetID: setID, @@ -291,10 +290,10 @@ func updateMpp(ctx *invoiceUpdateCtx, } // HTLCSet is a map of CircuitKey to InvoiceHTLC. -type HTLCSet = map[models.CircuitKey]*channeldb.InvoiceHTLC +type HTLCSet = map[CircuitKey]*InvoiceHTLC // HTLCPreimages is a map of CircuitKey to preimage. -type HTLCPreimages = map[models.CircuitKey]lntypes.Preimage +type HTLCPreimages = map[CircuitKey]lntypes.Preimage // reconstructAMPPreimages reconstructs the root seed for an AMP HTLC set and // verifies that all derived child hashes match the payment hashes of the HTLCs @@ -317,7 +316,7 @@ func reconstructAMPPreimages(ctx *invoiceUpdateCtx, // Next, construct an index mapping the position in childDescs to a // circuit key for all preexisting HTLCs. - indexToCircuitKey := make(map[int]models.CircuitKey) + indexToCircuitKey := make(map[int]CircuitKey) // Add the child descriptor for each HTLC in the HTLC set, recording // it's position within the slice. @@ -351,7 +350,7 @@ func reconstructAMPPreimages(ctx *invoiceUpdateCtx, // Finally, construct the map of learned preimages indexed by circuit // key, so that they can be persisted along with each HTLC when updating // the invoice. - htlcPreimages := make(map[models.CircuitKey]lntypes.Preimage) + htlcPreimages := make(map[CircuitKey]lntypes.Preimage) htlcPreimages[ctx.circuitKey] = children[0].Preimage for idx, child := range children[1:] { circuitKey := indexToCircuitKey[idx] @@ -368,11 +367,11 @@ func reconstructAMPPreimages(ctx *invoiceUpdateCtx, // send payments and any invoices we created in the past that are valid and // still had the optional mpp bit set. func updateLegacy(ctx *invoiceUpdateCtx, - inv *channeldb.Invoice) (*channeldb.InvoiceUpdateDesc, HtlcResolution, error) { + inv *Invoice) (*InvoiceUpdateDesc, HtlcResolution, error) { // If the invoice is already canceled, there is no further // checking to do. - if inv.State == channeldb.ContractCanceled { + if inv.State == ContractCanceled { return nil, ctx.failRes(ResultInvoiceAlreadyCanceled), nil } @@ -402,7 +401,7 @@ func updateLegacy(ctx *invoiceUpdateCtx, // Don't allow settling the invoice with an old style // htlc if we are already in the process of gathering an // mpp set. - for _, htlc := range inv.HTLCSet(nil, channeldb.HtlcStateAccepted) { + for _, htlc := range inv.HTLCSet(nil, HtlcStateAccepted) { if htlc.MppTotalAmt > 0 { return nil, ctx.failRes(ResultMppInProgress), nil } @@ -418,7 +417,7 @@ func updateLegacy(ctx *invoiceUpdateCtx, } // Record HTLC in the invoice database. - newHtlcs := map[models.CircuitKey]*channeldb.HtlcAcceptDesc{ + newHtlcs := map[CircuitKey]*HtlcAcceptDesc{ ctx.circuitKey: { Amt: ctx.amtPaid, Expiry: ctx.expiry, @@ -427,17 +426,17 @@ func updateLegacy(ctx *invoiceUpdateCtx, }, } - update := channeldb.InvoiceUpdateDesc{ + update := InvoiceUpdateDesc{ AddHtlcs: newHtlcs, } // Don't update invoice state if we are accepting a duplicate payment. // We do accept or settle the HTLC. switch inv.State { - case channeldb.ContractAccepted: + case ContractAccepted: return &update, ctx.acceptRes(resultDuplicateToAccepted), nil - case channeldb.ContractSettled: + case ContractSettled: return &update, ctx.settleRes( *inv.Terms.PaymentPreimage, ResultDuplicateToSettled, ), nil @@ -446,15 +445,15 @@ func updateLegacy(ctx *invoiceUpdateCtx, // Check to see if we can settle or this is an hold invoice and we need // to wait for the preimage. if inv.HodlInvoice { - update.State = &channeldb.InvoiceStateUpdateDesc{ - NewState: channeldb.ContractAccepted, + update.State = &InvoiceStateUpdateDesc{ + NewState: ContractAccepted, } return &update, ctx.acceptRes(resultAccepted), nil } - update.State = &channeldb.InvoiceStateUpdateDesc{ - NewState: channeldb.ContractSettled, + update.State = &InvoiceStateUpdateDesc{ + NewState: ContractSettled, Preimage: inv.Terms.PaymentPreimage, } diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 1c6f33cc0..00e88ecc1 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -17,6 +17,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" @@ -46,7 +47,7 @@ const ( // AddInvoiceConfig contains dependencies for invoice creation. type AddInvoiceConfig struct { // AddInvoice is called to add the invoice to the registry. - AddInvoice func(invoice *channeldb.Invoice, paymentHash lntypes.Hash) ( + AddInvoice func(invoice *invoices.Invoice, paymentHash lntypes.Hash) ( uint64, error) // IsChannelActive is used to generate valid hop hints. @@ -234,7 +235,7 @@ func (d *AddInvoiceData) mppPaymentHashAndPreimage() (*lntypes.Preimage, // duplicated invoices are rejected, therefore all invoices *must* have a // unique payment preimage. func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, - invoice *AddInvoiceData) (*lntypes.Hash, *channeldb.Invoice, error) { + invoice *AddInvoiceData) (*lntypes.Hash, *invoices.Invoice, error) { paymentPreimage, paymentHash, err := invoice.paymentHashAndPreimage() if err != nil { @@ -243,10 +244,10 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // The size of the memo, receipt and description hash attached must not // exceed the maximum values for either of the fields. - if len(invoice.Memo) > channeldb.MaxMemoSize { + if len(invoice.Memo) > invoices.MaxMemoSize { return nil, nil, fmt.Errorf("memo too large: %v bytes "+ "(maxsize=%v)", len(invoice.Memo), - channeldb.MaxMemoSize) + invoices.MaxMemoSize) } if len(invoice.DescriptionHash) > 0 && len(invoice.DescriptionHash) != 32 { @@ -448,11 +449,11 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, return nil, nil, err } - newInvoice := &channeldb.Invoice{ + newInvoice := &invoices.Invoice{ CreationDate: creationDate, Memo: []byte(invoice.Memo), PaymentRequest: []byte(payReqString), - Terms: channeldb.ContractTerm{ + Terms: invoices.ContractTerm{ FinalCltvDelta: int32(payReq.MinFinalCLTVExpiry()), Expiry: payReq.Expiry(), Value: amtMSat, diff --git a/lnrpc/invoicesrpc/invoices_server.go b/lnrpc/invoicesrpc/invoices_server.go index 60459682d..0c9ddf042 100644 --- a/lnrpc/invoicesrpc/invoices_server.go +++ b/lnrpc/invoicesrpc/invoices_server.go @@ -5,13 +5,14 @@ package invoicesrpc import ( "context" + "errors" "fmt" "io/ioutil" "os" "path/filepath" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/macaroons" @@ -287,7 +288,7 @@ func (s *Server) SettleInvoice(ctx context.Context, } err = s.cfg.InvoiceRegistry.SettleHodlInvoice(preimage) - if err != nil && err != channeldb.ErrInvoiceAlreadySettled { + if err != nil && !errors.Is(err, invoices.ErrInvoiceAlreadySettled) { return nil, err } @@ -380,7 +381,7 @@ func (s *Server) AddHoldInvoice(ctx context.Context, func (s *Server) LookupInvoiceV2(ctx context.Context, req *LookupInvoiceMsg) (*lnrpc.Invoice, error) { - var invoiceRef channeldb.InvoiceRef + var invoiceRef invoices.InvoiceRef // First, we'll attempt to parse out the invoice ref from the proto // oneof. If none of the three currently supported types was @@ -395,7 +396,7 @@ func (s *Server) LookupInvoiceV2(ctx context.Context, ) } - invoiceRef = channeldb.InvoiceRefByHash(payHash) + invoiceRef = invoices.InvoiceRefByHash(payHash) case req.GetPaymentAddr() != nil && req.LookupModifier == LookupModifier_HTLC_SET_BLANK: @@ -403,13 +404,13 @@ func (s *Server) LookupInvoiceV2(ctx context.Context, var payAddr [32]byte copy(payAddr[:], req.GetPaymentAddr()) - invoiceRef = channeldb.InvoiceRefByAddrBlankHtlc(payAddr) + invoiceRef = invoices.InvoiceRefByAddrBlankHtlc(payAddr) case req.GetPaymentAddr() != nil: var payAddr [32]byte copy(payAddr[:], req.GetPaymentAddr()) - invoiceRef = channeldb.InvoiceRefByAddr(payAddr) + invoiceRef = invoices.InvoiceRefByAddr(payAddr) case req.GetSetId() != nil && req.LookupModifier == LookupModifier_HTLC_SET_ONLY: @@ -417,13 +418,13 @@ func (s *Server) LookupInvoiceV2(ctx context.Context, var setID [32]byte copy(setID[:], req.GetSetId()) - invoiceRef = channeldb.InvoiceRefBySetIDFiltered(setID) + invoiceRef = invoices.InvoiceRefBySetIDFiltered(setID) case req.GetSetId() != nil: var setID [32]byte copy(setID[:], req.GetSetId()) - invoiceRef = channeldb.InvoiceRefBySetID(setID) + invoiceRef = invoices.InvoiceRefBySetID(setID) default: return nil, status.Error(codes.InvalidArgument, @@ -434,7 +435,7 @@ func (s *Server) LookupInvoiceV2(ctx context.Context, // we can't find it in the database. invoice, err := s.cfg.InvoiceRegistry.LookupInvoiceByRef(invoiceRef) switch { - case err == channeldb.ErrInvoiceNotFound: + case errors.Is(err, invoices.ErrInvoiceNotFound): return nil, status.Error(codes.NotFound, err.Error()) case err != nil: return nil, err diff --git a/lnrpc/invoicesrpc/utils.go b/lnrpc/invoicesrpc/utils.go index 283e72ce4..d610f9255 100644 --- a/lnrpc/invoicesrpc/utils.go +++ b/lnrpc/invoicesrpc/utils.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/chaincfg" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/zpay32" @@ -16,7 +16,7 @@ import ( // because not all information is stored in dedicated invoice fields. If there // is no payment request present, a dummy request will be returned. This can // happen with just-in-time inserted keysend invoices. -func decodePayReq(invoice *channeldb.Invoice, +func decodePayReq(invoice *invoices.Invoice, activeNetParams *chaincfg.Params) (*zpay32.Invoice, error) { paymentRequest := string(invoice.PaymentRequest) @@ -40,8 +40,8 @@ func decodePayReq(invoice *channeldb.Invoice, return decoded, nil } -// CreateRPCInvoice creates an *lnrpc.Invoice from the *channeldb.Invoice. -func CreateRPCInvoice(invoice *channeldb.Invoice, +// CreateRPCInvoice creates an *lnrpc.Invoice from the *invoices.Invoice. +func CreateRPCInvoice(invoice *invoices.Invoice, activeNetParams *chaincfg.Params) (*lnrpc.Invoice, error) { decoded, err := decodePayReq(invoice, activeNetParams) @@ -76,18 +76,22 @@ func CreateRPCInvoice(invoice *channeldb.Invoice, satAmt := invoice.Terms.Value.ToSatoshis() satAmtPaid := invoice.AmtPaid.ToSatoshis() - isSettled := invoice.State == channeldb.ContractSettled + isSettled := invoice.State == invoices.ContractSettled var state lnrpc.Invoice_InvoiceState switch invoice.State { - case channeldb.ContractOpen: + case invoices.ContractOpen: state = lnrpc.Invoice_OPEN - case channeldb.ContractSettled: + + case invoices.ContractSettled: state = lnrpc.Invoice_SETTLED - case channeldb.ContractCanceled: + + case invoices.ContractCanceled: state = lnrpc.Invoice_CANCELED - case channeldb.ContractAccepted: + + case invoices.ContractAccepted: state = lnrpc.Invoice_ACCEPTED + default: return nil, fmt.Errorf("unknown invoice state %v", invoice.State) @@ -97,11 +101,11 @@ func CreateRPCInvoice(invoice *channeldb.Invoice, for key, htlc := range invoice.Htlcs { var state lnrpc.InvoiceHTLCState switch htlc.State { - case channeldb.HtlcStateAccepted: + case invoices.HtlcStateAccepted: state = lnrpc.InvoiceHTLCState_ACCEPTED - case channeldb.HtlcStateSettled: + case invoices.HtlcStateSettled: state = lnrpc.InvoiceHTLCState_SETTLED - case channeldb.HtlcStateCanceled: + case invoices.HtlcStateCanceled: state = lnrpc.InvoiceHTLCState_CANCELED default: return nil, fmt.Errorf("unknown state %v", htlc.State) @@ -139,7 +143,7 @@ func CreateRPCInvoice(invoice *channeldb.Invoice, } // Only report resolved times if htlc is resolved. - if htlc.State != channeldb.HtlcStateAccepted { + if htlc.State != invoices.HtlcStateAccepted { rpcHtlc.ResolveTime = htlc.ResolveTime.Unix() } @@ -182,11 +186,11 @@ func CreateRPCInvoice(invoice *channeldb.Invoice, var state lnrpc.InvoiceHTLCState switch ampState.State { - case channeldb.HtlcStateAccepted: + case invoices.HtlcStateAccepted: state = lnrpc.InvoiceHTLCState_ACCEPTED - case channeldb.HtlcStateSettled: + case invoices.HtlcStateSettled: state = lnrpc.InvoiceHTLCState_SETTLED - case channeldb.HtlcStateCanceled: + case invoices.HtlcStateCanceled: state = lnrpc.InvoiceHTLCState_CANCELED default: return nil, fmt.Errorf("unknown state %v", ampState.State) @@ -202,7 +206,7 @@ func CreateRPCInvoice(invoice *channeldb.Invoice, // If at least one of the present HTLC sets show up as being // settled, then we'll mark the invoice itself as being // settled. - if ampState.State == channeldb.HtlcStateSettled { + if ampState.State == invoices.HtlcStateSettled { rpcInvoice.Settled = true // nolint:staticcheck rpcInvoice.State = lnrpc.Invoice_SETTLED } diff --git a/rpcserver.go b/rpcserver.go index 95e001022..a5bd31686 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -5507,7 +5507,7 @@ func (r *rpcServer) LookupInvoice(ctx context.Context, invoice, err := r.server.invoices.LookupInvoice(payHash) switch { - case err == channeldb.ErrInvoiceNotFound: + case errors.Is(err, invoices.ErrInvoiceNotFound): return nil, status.Error(codes.NotFound, err.Error()) case err != nil: return nil, err @@ -5551,7 +5551,7 @@ func (r *rpcServer) ListInvoices(ctx context.Context, // Next, we'll map the proto request into a format that is understood by // the database. - q := channeldb.InvoiceQuery{ + q := invoices.InvoiceQuery{ IndexOffset: req.IndexOffset, NumMaxInvoices: req.NumMaxInvoices, PendingOnly: req.PendingOnly,