From 4bf6b521580e9fbe8c2e0d16b1df953a3752c3d1 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 17 Oct 2023 15:42:02 +0200 Subject: [PATCH] channeldb: fetch the invoice before calling into updateInvoice --- channeldb/invoices.go | 53 +++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 9793b2c85..5c3891cd7 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -641,10 +641,25 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, return err } + // If the set ID hint is non-nil, then we'll use that to filter + // out the HTLCs for AMP invoice so we don't need to read them + // all out to satisfy the invoice callback below. If it's nil, + // then we pass in the zero set ID which means no HTLCs will be + // read out. + var invSetID invpkg.SetID + + if setIDHint != nil { + invSetID = *setIDHint + } + invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID) + if err != nil { + return err + } + payHash := ref.PayHash() updatedInvoice, err = d.updateInvoice( - payHash, setIDHint, invoices, settleIndex, setIDIndex, - invoiceNum, callback, + payHash, invoices, settleIndex, setIDIndex, + &invoice, invoiceNum, callback, ) return err @@ -1872,26 +1887,14 @@ func settleHtlcsAmp(invoice *invpkg.Invoice, // updateInvoice fetches the invoice, obtains the update descriptor from the // callback and applies the updates in a single db transaction. -func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices, - settleIndex, setIDIndex kvdb.RwBucket, invoiceNum []byte, - callback invpkg.InvoiceUpdateCallback) (*invpkg.Invoice, error) { - - // If the set ID is non-nil, then we'll use that to filter out the - // HTLCs for AMP invoice so we don't need to read them all out to - // satisfy the invoice callback below. If it's nil, then we pass in the - // zero set ID which means no HTLCs will be read out. - var invSetID invpkg.SetID - if refSetID != nil { - invSetID = *refSetID - } - invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID) - if err != nil { - return nil, err - } +func (d *DB) updateInvoice(hash *lntypes.Hash, invoices, + settleIndex, setIDIndex kvdb.RwBucket, invoice *invpkg.Invoice, + invoiceNum []byte, callback invpkg.InvoiceUpdateCallback) ( + *invpkg.Invoice, error) { // Create deep copy to prevent any accidental modification in the // callback. - invoiceCopy, err := invpkg.CopyInvoice(&invoice) + invoiceCopy, err := invpkg.CopyInvoice(invoice) if err != nil { return nil, err } @@ -1899,33 +1902,33 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices, // Call the callback and obtain the update descriptor. update, err := callback(invoiceCopy) if err != nil { - return &invoice, err + return invoice, err } // If there is nothing to update, return early. if update == nil { - return &invoice, nil + return invoice, nil } switch update.UpdateType { case invpkg.CancelHTLCsUpdate: - return d.cancelHTLCs(invoices, invoiceNum, &invoice, update) + return d.cancelHTLCs(invoices, invoiceNum, invoice, update) case invpkg.AddHTLCsUpdate: return d.addHTLCs( - invoices, settleIndex, setIDIndex, invoiceNum, &invoice, + invoices, settleIndex, setIDIndex, invoiceNum, invoice, hash, update, ) case invpkg.SettleHodlInvoiceUpdate: return d.settleHodlInvoice( - invoices, settleIndex, invoiceNum, &invoice, hash, + invoices, settleIndex, invoiceNum, invoice, hash, update.State, ) case invpkg.CancelInvoiceUpdate: return d.cancelInvoice( - invoices, invoiceNum, &invoice, hash, update.State, + invoices, invoiceNum, invoice, hash, update.State, ) default: