diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 14084b55c..b4ce3ba39 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -907,10 +907,7 @@ func serializeInvoice(w io.Writer, i *invpkg.Invoice) error { // Only if this is a _non_ AMP invoice do we serialize the HTLCs // in-line with the rest of the invoice. - ampInvoice := i.Terms.Features.HasFeature( - lnwire.AMPOptional, - ) - if ampInvoice { + if i.IsAMP() { return nil } @@ -1148,40 +1145,51 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, return invpkg.Invoice{}, err } - // If this is an AMP invoice, then we'll also attempt to read out the - // set of HTLCs that were paid to prior set IDs. However, we'll only do - // this is the invoice didn't already have HTLCs stored in-line. - invoiceIsAMP := invoice.Terms.Features.HasFeature( - lnwire.AMPOptional, - ) - switch { - case !invoiceIsAMP: - return invoice, nil - - // For AMP invoice that already have HTLCs populated (created before - // recurring invoices), then we don't need to read from the prefix - // keyed section of the bucket. - case invoiceIsAMP && len(invoice.Htlcs) != 0: - return invoice, nil - - // If the "zero" setID was specified, then this means that no HTLC data - // should be returned alongside of it. - case invoiceIsAMP && len(setIDs) != 0 && setIDs[0] != nil && - *setIDs[0] == invpkg.BlankPayAddr: - + // If this is an AMP invoice we'll also attempt to read out the set of + // HTLCs that were paid to prior set IDs, if needed. + if !invoice.IsAMP() { return invoice, nil } - invoice.Htlcs, err = fetchAmpSubInvoices( - invoices, invoiceNum, setIDs..., - ) - if err != nil { - return invoice, nil + if shouldFetchAMPHTLCs(invoice, setIDs) { + invoice.Htlcs, err = fetchAmpSubInvoices( + invoices, invoiceNum, setIDs..., + ) + // TODO(positiveblue): we should fail when we are not able to + // fetch all the HTLCs for an AMP invoice. Multiple tests in + // the invoice and channeldb package break if we return this + // error. We need to update them when we migrate this logic to + // the sql implementation. + if err != nil { + log.Errorf("unable to fetch amp htlcs for inv "+ + "%v and setIDs %v: %w", invoiceNum, setIDs, err) + } } return invoice, nil } +// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that +// were paid to the relevant set IDs. +func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool { + // For AMP invoice that already have HTLCs populated (created before + // recurring invoices), then we don't need to read from the prefix + // keyed section of the bucket. + if len(invoice.Htlcs) != 0 { + return false + } + + // If the "zero" setID was specified, then this means that no HTLC data + // should be returned alongside of it. + if len(setIDs) != 0 && setIDs[0] != nil && + *setIDs[0] == invpkg.BlankPayAddr { + + return false + } + + return true +} + // fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for // an AMP invoice. This methods only decode the relevant state vs the entire // invoice. @@ -1905,9 +1913,7 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices, now := d.clock.Now() - invoiceIsAMP := invoiceCopy.Terms.Features.HasFeature( - lnwire.AMPOptional, - ) + invoiceIsAMP := invoiceCopy.IsAMP() // Process add actions from update descriptor. htlcsAmpUpdate := make(map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC) //nolint:lll diff --git a/invoices/invoices.go b/invoices/invoices.go index d239897fe..7a99a46c6 100644 --- a/invoices/invoices.go +++ b/invoices/invoices.go @@ -471,6 +471,24 @@ func (i *Invoice) HTLCSetCompliment(setID *[32]byte, return htlcSet } +// IsKeysend returns true if the invoice is a Keysend invoice. +func (i *Invoice) IsKeysend() bool { + // TODO(positiveblue): look for a more reliable way to tests if + // an invoice is keysend. + return len(i.PaymentRequest) == 0 && !i.IsAMP() +} + +// IsAMP returns true if the invoice is an AMP invoice. +func (i *Invoice) IsAMP() bool { + if i.Terms.Features == nil { + return false + } + + return i.Terms.Features.HasFeature( + lnwire.AMPRequired, + ) +} + // HtlcState defines the states an htlc paying to an invoice can be in. type HtlcState uint8 @@ -681,6 +699,8 @@ type InvoiceStateUpdateDesc struct { // invoice. type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) +// ValidateInvoice assures the invoice passes the checks for all the relevant +// constraints. func ValidateInvoice(i *Invoice, paymentHash lntypes.Hash) error { // Avoid conflicts with all-zeroes magic value in the database. if paymentHash == UnknownPreimage.Hash() { @@ -705,13 +725,8 @@ func ValidateInvoice(i *Invoice, paymentHash lntypes.Hash) error { return err } - // AMP invoices and hodl invoices are allowed to have no preimage - // specified. - isAMP := i.Terms.Features.HasFeature( - lnwire.AMPOptional, - ) - if i.Terms.PaymentPreimage == nil && !(i.HodlInvoice || isAMP) { - return errors.New("non-hodl invoices must have a preimage") + if i.requiresPreimage() && i.Terms.PaymentPreimage == nil { + return errors.New("this invoice must have a preimage") } if len(i.Htlcs) > 0 { @@ -721,6 +736,17 @@ func ValidateInvoice(i *Invoice, paymentHash lntypes.Hash) error { return nil } +// requiresPreimage returns true if the invoice requires a preimage to be valid. +func (i *Invoice) requiresPreimage() bool { + // AMP invoices and hodl invoices are allowed to have no preimage + // specified. + if i.HodlInvoice || i.IsAMP() { + return false + } + + return true +} + // IsPending returns true if the invoice is in ContractOpen state. func (i *Invoice) IsPending() bool { return i.State == ContractOpen || i.State == ContractAccepted diff --git a/lnrpc/invoicesrpc/utils.go b/lnrpc/invoicesrpc/utils.go index d610f9255..70d774082 100644 --- a/lnrpc/invoicesrpc/utils.go +++ b/lnrpc/invoicesrpc/utils.go @@ -150,8 +150,6 @@ func CreateRPCInvoice(invoice *invoices.Invoice, rpcHtlcs = append(rpcHtlcs, &rpcHtlc) } - isAmp := invoice.Terms.Features.HasFeature(lnwire.AMPOptional) - rpcInvoice := &lnrpc.Invoice{ Memo: string(invoice.Memo), RHash: rHash, @@ -175,9 +173,9 @@ func CreateRPCInvoice(invoice *invoices.Invoice, State: state, Htlcs: rpcHtlcs, Features: CreateRPCFeatures(invoice.Terms.Features), - IsKeysend: len(invoice.PaymentRequest) == 0 && !isAmp, + IsKeysend: invoice.IsKeysend(), PaymentAddr: invoice.Terms.PaymentAddr[:], - IsAmp: isAmp, + IsAmp: invoice.IsAMP(), } rpcInvoice.AmpInvoiceState = make(map[string]*lnrpc.AMPInvoiceState)