diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 9ffc6aa27..9793b2c85 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -1939,8 +1939,8 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices, // NOTE: cancelHTLCs updates will only use the `CancelHtlcs` field in the // InvoiceUpdateDesc. func (d *DB) cancelHTLCs(invoices kvdb.RwBucket, invoiceNum []byte, - invoice *invpkg.Invoice, - update *invpkg.InvoiceUpdateDesc) (*invpkg.Invoice, error) { + invoice *invpkg.Invoice, update *invpkg.InvoiceUpdateDesc) ( + *invpkg.Invoice, error) { timestamp := d.clock.Now() @@ -1971,9 +1971,7 @@ func (d *DB) cancelHTLCs(invoices kvdb.RwBucket, invoiceNum []byte, // Tally this into the set of HTLCs that need to be updated on // disk, but once again, only if this is an AMP invoice. if invoice.IsAMP() { - cancelHtlcsAmp( - invoice, htlcsAmpUpdate, htlc, key, - ) + cancelHtlcsAmp(invoice, htlcsAmpUpdate, htlc, key) } } @@ -1983,22 +1981,41 @@ func (d *DB) cancelHTLCs(invoices kvdb.RwBucket, invoiceNum []byte, return nil, errors.New("cancel action on non-existent htlc(s)") } - err := d.serializeAndStoreInvoice(invoices, invoiceNum, invoice) + err := d.cancelHTLCsStoreUpdate( + invoices, invoiceNum, invoice, htlcsAmpUpdate, + ) if err != nil { return nil, err } - // If this is an AMP invoice, then we'll actually store the rest of the - // HTLCs in-line with the invoice, using the invoice ID as a prefix, - // and the AMP key as a suffix: invoiceNum || setID. + return invoice, nil +} + +// cancelHTLCsStoreUpdate is a helper function used to store the invoice and +// AMP state after canceling HTLCs. +func (d *DB) cancelHTLCsStoreUpdate(invoices kvdb.RwBucket, invoiceNum []byte, + invoice *invpkg.Invoice, + htlcsAmpUpdate map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC) error { + + err := d.serializeAndStoreInvoice(invoices, invoiceNum, invoice) + if err != nil { + return err + } + + // If this is an AMP invoice, then we'll actually store the rest + // of the HTLCs in-line with the invoice, using the invoice ID + // as a prefix, and the AMP key as a suffix: invoiceNum || + // setID. if invoice.IsAMP() { - err := updateAMPInvoices(invoices, invoiceNum, htlcsAmpUpdate) + err := updateAMPInvoices( + invoices, invoiceNum, htlcsAmpUpdate, + ) if err != nil { - return nil, err + return err } } - return invoice, nil + return nil } // serializeAndStoreInvoice is a helper function used to store invoices. @@ -2039,29 +2056,6 @@ func (d *DB) addHTLCs(invoices, settleIndex, //nolint:funlen return nil, errors.New("nil custom records map") } - if invoiceIsAMP { - if htlcUpdate.AMP == nil { - return nil, fmt.Errorf("unable to add htlc "+ - "without AMP data to AMP invoice(%v)", - invoice.AddIndex) - } - - // Check if this SetID already exist. - htlcSetID := htlcUpdate.AMP.Record.SetID() - setIDInvNum := setIDIndex.Get(htlcSetID[:]) - - if setIDInvNum == nil { - err := setIDIndex.Put(htlcSetID[:], invoiceNum) - if err != nil { - return nil, err - } - } else if !bytes.Equal(setIDInvNum, invoiceNum) { - return nil, invpkg.ErrDuplicateSetID{ - SetID: htlcSetID, - } - } - } - htlc := &invpkg.InvoiceHTLC{ Amt: htlcUpdate.Amt, MppTotalAmt: htlcUpdate.MppTotalAmt, @@ -2070,7 +2064,16 @@ func (d *DB) addHTLCs(invoices, settleIndex, //nolint:funlen AcceptTime: timestamp, State: invpkg.HtlcStateAccepted, CustomRecords: htlcUpdate.CustomRecords, - AMP: htlcUpdate.AMP.Copy(), + } + + if invoiceIsAMP { + if htlcUpdate.AMP == nil { + return nil, fmt.Errorf("unable to add htlc "+ + "without AMP data to AMP invoice(%v)", + invoice.AddIndex) + } + + htlc.AMP = htlcUpdate.AMP.Copy() } invoice.Htlcs[key] = htlc @@ -2099,28 +2102,13 @@ func (d *DB) addHTLCs(invoices, settleIndex, //nolint:funlen } // If this isn't an AMP invoice, then we'll go ahead and update - // the invoice state directly here. For AMP invoices, we - // instead will keep the top-level invoice open, and instead - // update the state of each _htlc set_ instead. However, we'll - // allow the invoice to transition to the cancelled state - // regardless. + // the invoice state directly here. For AMP invoices, we instead + // will keep the top-level invoice open, and update the state of + // each _htlc set_ instead. However, we'll allow the invoice to + // transition to the cancelled state regardless. if !invoiceIsAMP || *newState == invpkg.ContractCanceled { invoice.State = *newState } - - // If this is a non-AMP invoice, then the state can eventually - // go to ContractSettled, so we pass in nil value as part of - // setSettleMetaFields. - isSettled := update.State.NewState == invpkg.ContractSettled - if !invoiceIsAMP && isSettled { - err := setSettleMetaFields( - settleIndex, invoiceNum, invoice, timestamp, - nil, - ) - if err != nil { - return nil, err - } - } } // The set of HTLC pre-images will only be set if we were actually able @@ -2227,6 +2215,55 @@ func (d *DB) addHTLCs(invoices, settleIndex, //nolint:funlen invoice.AmtPaid += amtPaid } + err := d.addHTLCsStoreUpdate( + invoices, settleIndex, setIDIndex, invoiceNum, invoice, + settledSetIDs, htlcsAmpUpdate, timestamp, + ) + if err != nil { + return nil, err + } + + return invoice, nil +} + +// addHTLCsStoreUpdate is a helper function used to store the invoice and +// AMP state after adding HTLCs. +func (d *DB) addHTLCsStoreUpdate(invoices, settleIndex, setIDIndex kvdb.RwBucket, + invoiceNum []byte, invoice *invpkg.Invoice, + settledSetIDs map[invpkg.SetID]struct{}, + htlcsAmpUpdate map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC, + timestamp time.Time) error { + + invoiceIsAMP := invoice.IsAMP() + + for htlcSetID := range htlcsAmpUpdate { + // Check if this SetID already exist. + setIDInvNum := setIDIndex.Get(htlcSetID[:]) + + if setIDInvNum == nil { + err := setIDIndex.Put(htlcSetID[:], invoiceNum) + if err != nil { + return err + } + } else if !bytes.Equal(setIDInvNum, invoiceNum) { + return invpkg.ErrDuplicateSetID{ + SetID: htlcSetID, + } + } + } + + // If this is a non-AMP invoice, then the state can eventually go to + // ContractSettled, so we pass in nil value as part of + // setSettleMetaFields. + if !invoiceIsAMP && invoice.State == invpkg.ContractSettled { + err := setSettleMetaFields( + settleIndex, invoiceNum, invoice, timestamp, nil, + ) + if err != nil { + return err + } + } + // As we don't update the settle index above for AMP invoices, we'll do // it here for each sub-AMP invoice that was settled. for settledSetID := range settledSetIDs { @@ -2236,13 +2273,13 @@ func (d *DB) addHTLCs(invoices, settleIndex, //nolint:funlen &settledSetID, ) if err != nil { - return nil, err + return err } } err := d.serializeAndStoreInvoice(invoices, invoiceNum, invoice) if err != nil { - return nil, err + return err } // If this is an AMP invoice, then we'll actually store the rest of the @@ -2251,11 +2288,11 @@ func (d *DB) addHTLCs(invoices, settleIndex, //nolint:funlen if invoiceIsAMP { err := updateAMPInvoices(invoices, invoiceNum, htlcsAmpUpdate) if err != nil { - return nil, err + return err } } - return invoice, nil + return nil } // settleHodlInvoice marks a hodl invoice as settled. @@ -2299,13 +2336,6 @@ func (d *DB) settleHodlInvoice(invoices, settleIndex kvdb.RwBucket, invoice.State = invpkg.ContractSettled timestamp := d.clock.Now() - err = setSettleMetaFields( - settleIndex, invoiceNum, invoice, timestamp, nil, - ) - if err != nil { - return nil, err - } - // TODO(positiveblue): this logic can be further simplified. var amtPaid lnwire.MilliSatoshi for _, htlc := range invoice.Htlcs { @@ -2323,7 +2353,9 @@ func (d *DB) settleHodlInvoice(invoices, settleIndex kvdb.RwBucket, invoice.AmtPaid = amtPaid - err = d.serializeAndStoreInvoice(invoices, invoiceNum, invoice) + err = d.settleHodlInvoiceStoreUpdate( + invoices, settleIndex, invoiceNum, invoice, timestamp, + ) if err != nil { return nil, err } @@ -2331,6 +2363,26 @@ func (d *DB) settleHodlInvoice(invoices, settleIndex kvdb.RwBucket, return invoice, nil } +// settleHodlInvoiceStoreUpdate is a helper function used to store the settled +// hodl invoice update. +func (d *DB) settleHodlInvoiceStoreUpdate(invoices, settleIndex kvdb.RwBucket, + invoiceNum []byte, invoice *invpkg.Invoice, timestamp time.Time) error { + + err := setSettleMetaFields( + settleIndex, invoiceNum, invoice, timestamp, nil, + ) + if err != nil { + return err + } + + err = d.serializeAndStoreInvoice(invoices, invoiceNum, invoice) + if err != nil { + return err + } + + return nil +} + // cancelInvoice attempts to cancel the given invoice. That includes changing // the invoice state and the state of any relevant HTLC. func (d *DB) cancelInvoice(invoices kvdb.RwBucket, invoiceNum []byte, @@ -2380,7 +2432,7 @@ func (d *DB) cancelInvoice(invoices kvdb.RwBucket, invoiceNum []byte, } } - err = d.serializeAndStoreInvoice(invoices, invoiceNum, invoice) + err = d.cancelInvoiceStoreUpdate(invoices, invoiceNum, invoice) if err != nil { return nil, err } @@ -2388,6 +2440,14 @@ func (d *DB) cancelInvoice(invoices kvdb.RwBucket, invoiceNum []byte, return invoice, nil } +// cancelInvoiceStoreUpdate is a helper function used to store the canceled +// invoice update. +func (d *DB) cancelInvoiceStoreUpdate(invoices kvdb.RwBucket, invoiceNum []byte, + invoice *invpkg.Invoice) error { + + return d.serializeAndStoreInvoice(invoices, invoiceNum, invoice) +} + // updateInvoiceState validates and processes an invoice state update. The new // state to transition to is returned, so the caller is able to select exactly // how the invoice state is updated.