diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 2b26d5800..41ba4362d 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -1439,6 +1439,7 @@ func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket, if err != nil { return 0, err } + // Add the invoice to the payment address index, but only if the invoice // has a non-zero payment address. The all-zero payment address is still // in use by legacy keysend, so we special-case here to avoid @@ -1574,6 +1575,15 @@ func serializeInvoice(w io.Writer, i *Invoice) error { return err } + // Only if this is a _non_ AMP invoice do we serialize the HTLCs + // in-line with the rest of the invoice. + ampInvoice := i.Terms.Features.HasFeature( + lnwire.AMPOptional, + ) + if ampInvoice { + return nil + } + return serializeHtlcs(w, i.Htlcs) } @@ -1679,7 +1689,114 @@ func getNanoTime(ns uint64) time.Time { return time.Unix(0, int64(ns)) } -func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket) (Invoice, error) { +// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices +// identified by the setID value. +func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, + invoiceNum []byte, setIDs ...*SetID) (map[CircuitKey]*InvoiceHTLC, error) { + + htlcs := make(map[CircuitKey]*InvoiceHTLC) + for _, setID := range setIDs { + invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:]) + + htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:]) + if htlcSetBytes == nil { + // A set ID was passed in, but we don't have this + // stored yet, meaning that the setID is being added + // for the frist time. + return htlcs, ErrInvoiceNotFound + } + + htlcSetReader := bytes.NewReader(htlcSetBytes) + htlcsBySetID, err := deserializeHtlcs(htlcSetReader) + if err != nil { + return nil, err + } + + for key, htlc := range htlcsBySetID { + htlcs[key] = htlc + } + } + + return htlcs, nil +} + +// forEachAMPInvoice is a helper function that attempts to iterate over each of +// the HTLC sets (based on their set ID) for the given AMP invoice identified +// by its invoiceNum. The callback closure is called for each key within the +// prefix range. +func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte, + callback func(key, htlcSet []byte) error) error { + + invoiceCursor := invoiceBucket.ReadCursor() + + // Seek to the first key that includes the invoice data itself. + invoiceCursor.Seek(invoiceNum) + + // Advance to the very first key _after_ the invoice data, as this is + // where we'll encounter our first HTLC (if any are present). + cursorKey, htlcSet := invoiceCursor.Next() + + // If at this point, the cursor key doesn't match the invoice num + // prefix, then we know that this HTLC doesn't have any set ID HTLCs + // associated with it. + if !bytes.HasPrefix(cursorKey, invoiceNum) { + return nil + } + + // Otherwise continue to iterate until we no longer match the prefix, + // executing the call back at each step. + for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() { + err := callback(cursorKey, htlcSet) + if err != nil { + return err + } + } + + return nil +} + +// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix within the +// AMP bucket to find all the individual HTLCs (by setID) associated with a +// given invoice. If a list of set IDs are specified, then only HTLCs +// associated with that setID will be retrieved. +func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, + invoiceNum []byte, setIDs ...*SetID) (map[CircuitKey]*InvoiceHTLC, error) { + + // If a set of setIDs was specified, then we can skip the cursor and + // just read out exactly what we need. + if len(setIDs) != 0 && setIDs[0] != nil { + return fetchFilteredAmpInvoices( + invoiceBucket, invoiceNum, setIDs..., + ) + } + + // Otherwise, iterate over all the htlc sets that are prefixed beside + // this invoice in the main invoice bucket. + htlcs := make(map[CircuitKey]*InvoiceHTLC) + err := forEachAMPInvoice(invoiceBucket, invoiceNum, func(key, htlcSet []byte) error { + htlcSetReader := bytes.NewReader(htlcSet) + htlcsBySetID, err := deserializeHtlcs(htlcSetReader) + if err != nil { + return err + } + + for key, htlc := range htlcsBySetID { + htlcs[key] = htlc + } + + return nil + }) + if err != nil { + return nil, err + } + + return htlcs, nil +} + +// fetchInvoice attempts to read out the relevant state for the invoice as +// specified by the invoice number. If the setID fields are set, then only the +// HTLC information pertaining to those set IDs is returned. +func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, setIDs ...*SetID) (Invoice, error) { invoiceBytes := invoices.Get(invoiceNum) if invoiceBytes == nil { return Invoice{}, ErrInvoiceNotFound @@ -1688,6 +1805,43 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket) (Invoice, error) { invoiceReader := bytes.NewReader(invoiceBytes) return deserializeInvoice(invoiceReader) + invoice, err := deserializeInvoice(invoiceReader) + if err != nil { + return Invoice{}, err + } + + // If this is an AMP invoice, then we'll also attempt to read out the + // set of HTLCs that were paid to prior set IDs. However, we'll only do + // this is the invoice didn't already have HTLCs stored in-line. + invoiceIsAMP := invoice.Terms.Features.HasFeature( + lnwire.AMPOptional, + ) + switch { + case !invoiceIsAMP: + return invoice, nil + + // For AMP invoice that already have HTLCs populated (created before + // recurring invoices), then we don't need to read from the prefix + // keyed section of the bucket. + case invoiceIsAMP && len(invoice.Htlcs) != 0: + return invoice, nil + + // If the "zero" setID was specified, then this means that no HTLC data + // should be returned alongside of it. + case invoiceIsAMP && len(setIDs) != 0 && setIDs[0] != nil && + *setIDs[0] == BlankPayAddr: + + return invoice, nil + } + + invoice.Htlcs, err = fetchAmpSubInvoices( + invoices, invoiceNum, setIDs..., + ) + if err != nil { + return invoice, nil + } + + return invoice, nil } func deserializeInvoice(r io.Reader) (Invoice, error) { @@ -2167,6 +2321,54 @@ func copyInvoice(src *Invoice) *Invoice { return &dest } +// invoiceSetIDKeyLen is the length of the key that's used to store the +// individual HTLCs prefixed by their ID along side the main invoice within the +// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the +// set ID. +const invoiceSetIDKeyLen = 4 + 32 + +// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice +// number where the HTLCs for this setID will be stored udner. +func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte { + // Construct the prefix key we need to obtain the invoice information: + // invoiceNum || setID. + var invoiceSetIDKey [invoiceSetIDKeyLen]byte + copy(invoiceSetIDKey[:], invoiceNum) + copy(invoiceSetIDKey[len(invoiceNum):], setID) + + return invoiceSetIDKey +} + +// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather +// then continually write the invoices to the end of the invoice value, we +// instead write the invoices into a new key preifx that follows the main +// invoice number. This ensures that we don't need to continually decode a +// potentially massive HTLC set, and also allows us to quickly find the HLTCs +// associated with a particular HTLC set. +func updateAMPInvoices(invoiceBucket kvdb.RwBucket, invoiceNum []byte, + htlcsToUpdate map[SetID]map[CircuitKey]*InvoiceHTLC) error { + + for setID, htlcSet := range htlcsToUpdate { + // First write out the set of HTLCs including all the relevant TLV + // values. + var b bytes.Buffer + if err := serializeHtlcs(&b, htlcSet); err != nil { + return err + } + + // Next store each HTLC in-line, using a prefix based off the + // invoice number. + invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:]) + + err := invoiceBucket.Put(invoiceSetIDKey[:], b.Bytes()) + if err != nil { + return err + } + } + + return nil +} + // updateInvoice fetches the invoice, obtains the update descriptor from the // callback and applies the updates in a single db transaction. func (d *DB) updateInvoice(hash *lntypes.Hash, invoices, @@ -2586,7 +2788,32 @@ func setSettleMetaFields(settleIndex kvdb.RwBucket, invoiceNum []byte, return nil } -// InvoiceDeleteRef holds a refererence to an invoice to be deleted. +// delAMPInvoices attempts to delete all the "sub" invoices associated with a +// greater AMP invoices. We do this by deleting the set of keys that share the +// invoice number as a prefix. +func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error { + // Since it isn't safe to delete using an active cursor, we'll use the + // cursor simply to collect the set of keys we need to delete, _then_ + // delete them in another pass. + var keysToDel [][]byte + err := forEachAMPInvoice(invoiceBucket, invoiceNum, func(cursorKey, v []byte) error { + keysToDel = append(keysToDel, cursorKey) + return nil + }) + if err != nil { + return err + } + + // In this next phase, we'll then delete all the relevant invoices. + for _, keyToDel := range keysToDel { + if err := invoiceBucket.Delete(keyToDel); err != nil { + return err + } + } + + return nil +} + type InvoiceDeleteRef struct { // PayHash is the payment hash of the target invoice. All invoices are // currently indexed by payment hash. @@ -2627,6 +2854,7 @@ func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error { if invoiceAddIndex == nil { return ErrNoInvoicesCreated } + // settleIndex can be nil, as the bucket is created lazily // when the first invoice is settled. settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket) @@ -2708,6 +2936,16 @@ func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error { if err != nil { return err } + + } + + // In addition to deleting the main invoice state, if + // this is an AMP invoice, then we'll also need to + // delete the set HTLC set stored as a key prefix. For + // non-AMP invoices, this'll be a noop. + err = delAMPInvoices(invoiceKey, invoices) + if err != nil { + return err } // Finally remove the serialized invoice from the