channeldb: fetch the invoice before calling into updateInvoice

This commit is contained in:
Andras Banki-Horvath
2023-10-17 15:42:02 +02:00
parent eb4198b970
commit 4bf6b52158

View File

@@ -641,10 +641,25 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
return err
}
// If the set ID hint is non-nil, then we'll use that to filter
// out the HTLCs for AMP invoice so we don't need to read them
// all out to satisfy the invoice callback below. If it's nil,
// then we pass in the zero set ID which means no HTLCs will be
// read out.
var invSetID invpkg.SetID
if setIDHint != nil {
invSetID = *setIDHint
}
invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID)
if err != nil {
return err
}
payHash := ref.PayHash()
updatedInvoice, err = d.updateInvoice(
payHash, setIDHint, invoices, settleIndex, setIDIndex,
invoiceNum, callback,
payHash, invoices, settleIndex, setIDIndex,
&invoice, invoiceNum, callback,
)
return err
@@ -1872,26 +1887,14 @@ func settleHtlcsAmp(invoice *invpkg.Invoice,
// updateInvoice fetches the invoice, obtains the update descriptor from the
// callback and applies the updates in a single db transaction.
func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices,
settleIndex, setIDIndex kvdb.RwBucket, invoiceNum []byte,
callback invpkg.InvoiceUpdateCallback) (*invpkg.Invoice, error) {
// If the set ID is non-nil, then we'll use that to filter out the
// HTLCs for AMP invoice so we don't need to read them all out to
// satisfy the invoice callback below. If it's nil, then we pass in the
// zero set ID which means no HTLCs will be read out.
var invSetID invpkg.SetID
if refSetID != nil {
invSetID = *refSetID
}
invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID)
if err != nil {
return nil, err
}
func (d *DB) updateInvoice(hash *lntypes.Hash, invoices,
settleIndex, setIDIndex kvdb.RwBucket, invoice *invpkg.Invoice,
invoiceNum []byte, callback invpkg.InvoiceUpdateCallback) (
*invpkg.Invoice, error) {
// Create deep copy to prevent any accidental modification in the
// callback.
invoiceCopy, err := invpkg.CopyInvoice(&invoice)
invoiceCopy, err := invpkg.CopyInvoice(invoice)
if err != nil {
return nil, err
}
@@ -1899,33 +1902,33 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices,
// Call the callback and obtain the update descriptor.
update, err := callback(invoiceCopy)
if err != nil {
return &invoice, err
return invoice, err
}
// If there is nothing to update, return early.
if update == nil {
return &invoice, nil
return invoice, nil
}
switch update.UpdateType {
case invpkg.CancelHTLCsUpdate:
return d.cancelHTLCs(invoices, invoiceNum, &invoice, update)
return d.cancelHTLCs(invoices, invoiceNum, invoice, update)
case invpkg.AddHTLCsUpdate:
return d.addHTLCs(
invoices, settleIndex, setIDIndex, invoiceNum, &invoice,
invoices, settleIndex, setIDIndex, invoiceNum, invoice,
hash, update,
)
case invpkg.SettleHodlInvoiceUpdate:
return d.settleHodlInvoice(
invoices, settleIndex, invoiceNum, &invoice, hash,
invoices, settleIndex, invoiceNum, invoice, hash,
update.State,
)
case invpkg.CancelInvoiceUpdate:
return d.cancelInvoice(
invoices, invoiceNum, &invoice, hash, update.State,
invoices, invoiceNum, invoice, hash, update.State,
)
default: