mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-18 05:42:09 +01:00
channeldb: filter AMP state to relevant set IDs
When fetching an AMP invoice we sometimes filter HTLCs to selected set IDs, however we always kept the full AMP state which is irrelevant as it contains state for all AMP payments. This was a side effect of UpdateInvoice needing to serialize the whole invoice when storing after an update but it is an unwanted "feature" as users will need to filter to relevant set when listing an AMP payment or subsribing to an update.
This commit is contained in:
parent
cadce23b47
commit
b3dc3ed5c8
@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
|
||||
|
||||
// For each key found, we'll look up the actual
|
||||
// invoice, then accumulate it into our return value.
|
||||
invoice, err := fetchInvoice(invoiceKey, invoices)
|
||||
invoice, err := fetchInvoice(
|
||||
invoiceKey, invoices, nil, false,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
|
||||
|
||||
// An invoice was found, retrieve the remainder of the invoice
|
||||
// body.
|
||||
i, err := fetchInvoice(invoiceNum, invoices, setID)
|
||||
i, err := fetchInvoice(
|
||||
invoiceNum, invoices, []*invpkg.SetID{setID}, true,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) (
|
||||
return nil
|
||||
}
|
||||
|
||||
invoice, err := fetchInvoice(v, invoices)
|
||||
invoice, err := fetchInvoice(v, invoices, nil, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
|
||||
// characteristics for our query and returns the number of items
|
||||
// we have added to our set of invoices.
|
||||
accumulateInvoices := func(_, indexValue []byte) (bool, error) {
|
||||
invoice, err := fetchInvoice(indexValue, invoices)
|
||||
invoice, err := fetchInvoice(
|
||||
indexValue, invoices, nil, false,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
|
||||
if setIDHint != nil {
|
||||
invSetID = *setIDHint
|
||||
}
|
||||
invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID)
|
||||
invoice, err := fetchInvoice(
|
||||
invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -676,8 +684,17 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
|
||||
updatedInvoice, err = invpkg.UpdateInvoice(
|
||||
payHash, updater.invoice, now, callback, updater,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
// If this is an AMP update, then limit the returned AMP state
|
||||
// to only the requested set ID.
|
||||
if setIDHint != nil {
|
||||
filterInvoiceAMPState(updatedInvoice, &invSetID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {
|
||||
updatedInvoice = nil
|
||||
})
|
||||
@ -685,6 +702,25 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
|
||||
return updatedInvoice, err
|
||||
}
|
||||
|
||||
// filterInvoiceAMPState filters the AMP state of the invoice to only include
|
||||
// state for the specified set IDs.
|
||||
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
|
||||
filteredAMPState := make(invpkg.AMPInvoiceState)
|
||||
|
||||
for _, setID := range setIDs {
|
||||
if setID == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ampState, ok := invoice.AMPState[*setID]
|
||||
if ok {
|
||||
filteredAMPState[*setID] = ampState
|
||||
}
|
||||
}
|
||||
|
||||
invoice.AMPState = filteredAMPState
|
||||
}
|
||||
|
||||
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
|
||||
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
|
||||
|
||||
@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
|
||||
// For each key found, we'll look up the actual
|
||||
// invoice, then accumulate it into our return value.
|
||||
invoice, err := fetchInvoice(
|
||||
invoiceKey[:], invoices, setID,
|
||||
invoiceKey[:], invoices, []*invpkg.SetID{setID},
|
||||
true,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
|
||||
// specified by the invoice number. If the setID fields are set, then only the
|
||||
// HTLC information pertaining to those set IDs is returned.
|
||||
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
|
||||
setIDs ...*invpkg.SetID) (invpkg.Invoice, error) {
|
||||
setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
|
||||
|
||||
invoiceBytes := invoices.Get(invoiceNum)
|
||||
if invoiceBytes == nil {
|
||||
@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
|
||||
log.Errorf("unable to fetch amp htlcs for inv "+
|
||||
"%v and setIDs %v: %w", invoiceNum, setIDs, err)
|
||||
}
|
||||
|
||||
if filterAMPState {
|
||||
filterInvoiceAMPState(&invoice, setIDs...)
|
||||
}
|
||||
}
|
||||
|
||||
return invoice, nil
|
||||
@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
invoice, err := fetchInvoice(v, invoices)
|
||||
invoice, err := fetchInvoice(v, invoices, nil, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -303,9 +303,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
|
||||
// return the "projected" sub-invoice for a given setID.
|
||||
require.Equal(ht, 1, len(invoiceNtfn.Htlcs))
|
||||
|
||||
// However the AMP state index should show that there've been two
|
||||
// repeated payments to this invoice so far.
|
||||
require.Equal(ht, 2, len(invoiceNtfn.AmpInvoiceState))
|
||||
// The AMP state should also be restricted to a single entry for the
|
||||
// "projected" sub-invoice.
|
||||
require.Equal(ht, 1, len(invoiceNtfn.AmpInvoiceState))
|
||||
|
||||
// Now we'll look up the invoice using the new LookupInvoice2 RPC call
|
||||
// by the set ID of each of the invoices.
|
||||
@ -364,7 +364,7 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
|
||||
// through.
|
||||
backlogInv := ht.ReceiveInvoiceUpdate(invSub2)
|
||||
require.Equal(ht, 1, len(backlogInv.Htlcs))
|
||||
require.Equal(ht, 2, len(backlogInv.AmpInvoiceState))
|
||||
require.Equal(ht, 1, len(backlogInv.AmpInvoiceState))
|
||||
require.True(ht, backlogInv.Settled)
|
||||
require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user