channeldb: filter AMP state to relevant set IDs

When fetching an AMP invoice we sometimes filter HTLCs to selected set
IDs, however we always kept the full AMP state which is irrelevant as it
contains state for all AMP payments. This was a side effect of
UpdateInvoice needing to serialize the whole invoice when storing after
an update but it is an unwanted "feature" as users will need to filter
to relevant set when listing an AMP payment or subsribing to an update.
This commit is contained in:
Andras Banki-Horvath 2024-08-30 11:19:32 +02:00 committed by Olaoluwa Osuntokun
parent cadce23b47
commit b3dc3ed5c8
2 changed files with 54 additions and 13 deletions

View File

@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(invoiceKey, invoices)
invoice, err := fetchInvoice(
invoiceKey, invoices, nil, false,
)
if err != nil {
return err
}
@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
// An invoice was found, retrieve the remainder of the invoice
// body.
i, err := fetchInvoice(invoiceNum, invoices, setID)
i, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{setID}, true,
)
if err != nil {
return err
}
@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) (
return nil
}
invoice, err := fetchInvoice(v, invoices)
invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil {
return err
}
@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
// characteristics for our query and returns the number of items
// we have added to our set of invoices.
accumulateInvoices := func(_, indexValue []byte) (bool, error) {
invoice, err := fetchInvoice(indexValue, invoices)
invoice, err := fetchInvoice(
indexValue, invoices, nil, false,
)
if err != nil {
return false, err
}
@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
if setIDHint != nil {
invSetID = *setIDHint
}
invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID)
invoice, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false,
)
if err != nil {
return err
}
@ -676,8 +684,17 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
updatedInvoice, err = invpkg.UpdateInvoice(
payHash, updater.invoice, now, callback, updater,
)
if err != nil {
return err
}
return err
// If this is an AMP update, then limit the returned AMP state
// to only the requested set ID.
if setIDHint != nil {
filterInvoiceAMPState(updatedInvoice, &invSetID)
}
return nil
}, func() {
updatedInvoice = nil
})
@ -685,6 +702,25 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
return updatedInvoice, err
}
// filterInvoiceAMPState filters the AMP state of the invoice to only include
// state for the specified set IDs.
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
filteredAMPState := make(invpkg.AMPInvoiceState)
for _, setID := range setIDs {
if setID == nil {
return
}
ampState, ok := invoice.AMPState[*setID]
if ok {
filteredAMPState[*setID] = ampState
}
}
invoice.AMPState = filteredAMPState
}
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(
invoiceKey[:], invoices, setID,
invoiceKey[:], invoices, []*invpkg.SetID{setID},
true,
)
if err != nil {
return err
@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
// specified by the invoice number. If the setID fields are set, then only the
// HTLC information pertaining to those set IDs is returned.
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
setIDs ...*invpkg.SetID) (invpkg.Invoice, error) {
setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
invoiceBytes := invoices.Get(invoiceNum)
if invoiceBytes == nil {
@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
log.Errorf("unable to fetch amp htlcs for inv "+
"%v and setIDs %v: %w", invoiceNum, setIDs, err)
}
if filterAMPState {
filterInvoiceAMPState(&invoice, setIDs...)
}
}
return invoice, nil
@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
return nil
}
invoice, err := fetchInvoice(v, invoices)
invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil {
return err
}

View File

@ -303,9 +303,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
// return the "projected" sub-invoice for a given setID.
require.Equal(ht, 1, len(invoiceNtfn.Htlcs))
// However the AMP state index should show that there've been two
// repeated payments to this invoice so far.
require.Equal(ht, 2, len(invoiceNtfn.AmpInvoiceState))
// The AMP state should also be restricted to a single entry for the
// "projected" sub-invoice.
require.Equal(ht, 1, len(invoiceNtfn.AmpInvoiceState))
// Now we'll look up the invoice using the new LookupInvoice2 RPC call
// by the set ID of each of the invoices.
@ -364,7 +364,7 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
// through.
backlogInv := ht.ReceiveInvoiceUpdate(invSub2)
require.Equal(ht, 1, len(backlogInv.Htlcs))
require.Equal(ht, 2, len(backlogInv.AmpInvoiceState))
require.Equal(ht, 1, len(backlogInv.AmpInvoiceState))
require.True(ht, backlogInv.Settled)
require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat))
}