diff --git a/channeldb/invoices.go b/channeldb/invoices.go index df124b632..9da504a5d 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) ( // For each key found, we'll look up the actual // invoice, then accumulate it into our return value. - invoice, err := fetchInvoice(invoiceKey, invoices) + invoice, err := fetchInvoice( + invoiceKey, invoices, nil, false, + ) if err != nil { return err } @@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) ( // An invoice was found, retrieve the remainder of the invoice // body. - i, err := fetchInvoice(invoiceNum, invoices, setID) + i, err := fetchInvoice( + invoiceNum, invoices, []*invpkg.SetID{setID}, true, + ) if err != nil { return err } @@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) ( return nil } - invoice, err := fetchInvoice(v, invoices) + invoice, err := fetchInvoice(v, invoices, nil, false) if err != nil { return err } @@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) ( // characteristics for our query and returns the number of items // we have added to our set of invoices. accumulateInvoices := func(_, indexValue []byte) (bool, error) { - invoice, err := fetchInvoice(indexValue, invoices) + invoice, err := fetchInvoice( + indexValue, invoices, nil, false, + ) if err != nil { return false, err } @@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, if setIDHint != nil { invSetID = *setIDHint } - invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID) + invoice, err := fetchInvoice( + invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false, + ) if err != nil { return err } @@ -676,8 +684,17 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, updatedInvoice, err = invpkg.UpdateInvoice( payHash, updater.invoice, now, callback, updater, ) + if err != nil { + return err + } - return err + // If this is an AMP update, then limit the returned AMP state + // to only the requested set ID. + if setIDHint != nil { + filterInvoiceAMPState(updatedInvoice, &invSetID) + } + + return nil }, func() { updatedInvoice = nil }) @@ -685,6 +702,25 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, return updatedInvoice, err } +// filterInvoiceAMPState filters the AMP state of the invoice to only include +// state for the specified set IDs. +func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) { + filteredAMPState := make(invpkg.AMPInvoiceState) + + for _, setID := range setIDs { + if setID == nil { + return + } + + ampState, ok := invoice.AMPState[*setID] + if ok { + filteredAMPState[*setID] = ampState + } + } + + invoice.AMPState = filteredAMPState +} + // ampHTLCsMap is a map of AMP HTLCs affected by an invoice update. type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC @@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) ( // For each key found, we'll look up the actual // invoice, then accumulate it into our return value. invoice, err := fetchInvoice( - invoiceKey[:], invoices, setID, + invoiceKey[:], invoices, []*invpkg.SetID{setID}, + true, ) if err != nil { return err @@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte, // specified by the invoice number. If the setID fields are set, then only the // HTLC information pertaining to those set IDs is returned. func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, - setIDs ...*invpkg.SetID) (invpkg.Invoice, error) { + setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) { invoiceBytes := invoices.Get(invoiceNum) if invoiceBytes == nil { @@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, log.Errorf("unable to fetch amp htlcs for inv "+ "%v and setIDs %v: %w", invoiceNum, setIDs, err) } + + if filterAMPState { + filterInvoiceAMPState(&invoice, setIDs...) + } } return invoice, nil @@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error { return nil } - invoice, err := fetchInvoice(v, invoices) + invoice, err := fetchInvoice(v, invoices, nil, false) if err != nil { return err } diff --git a/itest/lnd_amp_test.go b/itest/lnd_amp_test.go index 9232a8b87..4b4cfb5a2 100644 --- a/itest/lnd_amp_test.go +++ b/itest/lnd_amp_test.go @@ -303,9 +303,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { // return the "projected" sub-invoice for a given setID. require.Equal(ht, 1, len(invoiceNtfn.Htlcs)) - // However the AMP state index should show that there've been two - // repeated payments to this invoice so far. - require.Equal(ht, 2, len(invoiceNtfn.AmpInvoiceState)) + // The AMP state should also be restricted to a single entry for the + // "projected" sub-invoice. + require.Equal(ht, 1, len(invoiceNtfn.AmpInvoiceState)) // Now we'll look up the invoice using the new LookupInvoice2 RPC call // by the set ID of each of the invoices. @@ -364,7 +364,7 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { // through. backlogInv := ht.ReceiveInvoiceUpdate(invSub2) require.Equal(ht, 1, len(backlogInv.Htlcs)) - require.Equal(ht, 2, len(backlogInv.AmpInvoiceState)) + require.Equal(ht, 1, len(backlogInv.AmpInvoiceState)) require.True(ht, backlogInv.Settled) require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat)) }