diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 580fe3d3f..7ea19afd4 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -123,9 +123,7 @@ func TestInvoiceWorkflow(t *testing.T) { // now have the settled bit toggle to true and a non-default // SettledDate payAmt := fakeInvoice.Terms.Value * 2 - _, err = db.AcceptOrSettleInvoice( - paymentHash, payAmt, checkHtlcParameters, - ) + _, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt)) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } @@ -288,8 +286,8 @@ func TestInvoiceAddTimeSeries(t *testing.T) { paymentHash := invoice.Terms.PaymentPreimage.Hash() - _, err := db.AcceptOrSettleInvoice( - paymentHash, 0, checkHtlcParameters, + _, err := db.UpdateInvoice( + paymentHash, getUpdateInvoice(0), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) @@ -371,8 +369,8 @@ func TestDuplicateSettleInvoice(t *testing.T) { } // With the invoice in the DB, we'll now attempt to settle the invoice. - dbInvoice, err := db.AcceptOrSettleInvoice( - payHash, amt, checkHtlcParameters, + dbInvoice, err := db.UpdateInvoice( + payHash, getUpdateInvoice(amt), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) @@ -393,8 +391,8 @@ func TestDuplicateSettleInvoice(t *testing.T) { // If we try to settle the invoice again, then we should get the very // same invoice back, but with an error this time. - dbInvoice, err = db.AcceptOrSettleInvoice( - payHash, amt, checkHtlcParameters, + dbInvoice, err = db.UpdateInvoice( + payHash, getUpdateInvoice(amt), ) if err != ErrInvoiceAlreadySettled { t.Fatalf("expected ErrInvoiceAlreadySettled") @@ -440,8 +438,8 @@ func TestQueryInvoices(t *testing.T) { // We'll only settle half of all invoices created. if i%2 == 0 { - _, err := db.AcceptOrSettleInvoice( - paymentHash, i, checkHtlcParameters, + _, err := db.UpdateInvoice( + paymentHash, getUpdateInvoice(i), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) @@ -685,10 +683,19 @@ func TestQueryInvoices(t *testing.T) { } } -func checkHtlcParameters(invoice *Invoice) error { - if invoice.Terms.State == ContractSettled { - return ErrInvoiceAlreadySettled - } +// getUpdateInvoice returns an invoice update callback that, when called, +// settles the invoice with the given amount. +func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { + return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { + if invoice.Terms.State == ContractSettled { + return nil, ErrInvoiceAlreadySettled + } - return nil + update := &InvoiceUpdateDesc{ + State: ContractSettled, + AmtPaid: amt, + } + + return update, nil + } } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 18dbf75fa..195d82ffe 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -276,6 +276,20 @@ type InvoiceHTLC struct { State HtlcState } +// InvoiceUpdateDesc describes the changes that should be applied to the +// invoice. +type InvoiceUpdateDesc struct { + // State is the new state that this invoice should progress to. + State ContractState + + // AmtPaid is the updated amount that has been paid to this invoice. + AmtPaid lnwire.MilliSatoshi +} + +// InvoiceUpdateCallback is a callback used in the db transaction to update the +// invoice. +type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) + func validateInvoice(i *Invoice) error { if len(i.Memo) > MaxMemoSize { return fmt.Errorf("max length a memo is %v, and invoice "+ @@ -689,21 +703,17 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { return resp, nil } -// AcceptOrSettleInvoice attempts to mark an invoice corresponding to the passed -// payment hash as settled. If an invoice matching the passed payment hash -// doesn't existing within the database, then the action will fail with a "not -// found" error. +// UpdateInvoice attempts to update an invoice corresponding to the passed +// payment hash. If an invoice matching the passed payment hash doesn't exist +// within the database, then the action will fail with a "not found" error. // -// When the preimage for the invoice is unknown (hold invoice), the invoice is -// marked as accepted. -// -// TODO: Store invoice cltv as separate field in database so that it doesn't -// need to be decoded from the payment request. -func (d *DB) AcceptOrSettleInvoice(paymentHash [32]byte, - amtPaid lnwire.MilliSatoshi, - checkHtlcParameters func(invoice *Invoice) error) (*Invoice, error) { +// The update is performed inside the same database transaction that fetches the +// invoice and is therefore atomic. The fields to update are controlled by the +// supplied callback. +func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, + callback InvoiceUpdateCallback) (*Invoice, error) { - var settledInvoice *Invoice + var updatedInvoice *Invoice err := d.Update(func(tx *bbolt.Tx) error { invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) if err != nil { @@ -729,15 +739,14 @@ func (d *DB) AcceptOrSettleInvoice(paymentHash [32]byte, return ErrInvoiceNotFound } - settledInvoice, err = acceptOrSettleInvoice( - invoices, settleIndex, invoiceNum, amtPaid, - checkHtlcParameters, + updatedInvoice, err = updateInvoice( + invoices, settleIndex, invoiceNum, callback, ) return err }) - return settledInvoice, err + return updatedInvoice, err } // SettleHoldInvoice sets the preimage of a hodl invoice and marks the invoice @@ -1200,35 +1209,75 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) { return htlcs, nil } -func acceptOrSettleInvoice(invoices, settleIndex *bbolt.Bucket, - invoiceNum []byte, amtPaid lnwire.MilliSatoshi, - checkHtlcParameters func(invoice *Invoice) error) ( - *Invoice, error) { +// copySlice allocates a new slice and copies the source into it. +func copySlice(src []byte) []byte { + dest := make([]byte, len(src)) + copy(dest, src) + return dest +} + +// copyInvoice makes a deep copy of the supplied invoice. +func copyInvoice(src *Invoice) *Invoice { + dest := Invoice{ + Memo: copySlice(src.Memo), + Receipt: copySlice(src.Receipt), + PaymentRequest: copySlice(src.PaymentRequest), + FinalCltvDelta: src.FinalCltvDelta, + CreationDate: src.CreationDate, + SettleDate: src.SettleDate, + Terms: src.Terms, + AddIndex: src.AddIndex, + SettleIndex: src.SettleIndex, + AmtPaid: src.AmtPaid, + Htlcs: make( + map[CircuitKey]*InvoiceHTLC, len(src.Htlcs), + ), + } + + for k, v := range src.Htlcs { + dest.Htlcs[k] = v + } + + return &dest +} + +// updateInvoice fetches the invoice, obtains the update descriptor from the +// callback and applies the updates in a single db transaction. +func updateInvoice(invoices, settleIndex *bbolt.Bucket, invoiceNum []byte, + callback InvoiceUpdateCallback) (*Invoice, error) { invoice, err := fetchInvoice(invoiceNum, invoices) if err != nil { return nil, err } - // If the invoice is still open, check the htlc parameters. - if err := checkHtlcParameters(&invoice); err != nil { + preUpdateState := invoice.Terms.State + + // Create deep copy to prevent any accidental modification in the + // callback. + copy := copyInvoice(&invoice) + + // Call the callback and obtain the update descriptor. + update, err := callback(copy) + if err != nil { return &invoice, err } - // Check to see if we can settle or this is an hold invoice and we need - // to wait for the preimage. - holdInvoice := invoice.Terms.PaymentPreimage == UnknownPreimage - if holdInvoice { - invoice.Terms.State = ContractAccepted - } else { + // Update invoice state and amount. + invoice.Terms.State = update.State + invoice.AmtPaid = update.AmtPaid + + // If invoice moved to the settled state, update settle index and settle + // time. + if preUpdateState != invoice.Terms.State && + invoice.Terms.State == ContractSettled { + err := setSettleFields(settleIndex, invoiceNum, &invoice) if err != nil { return nil, err } } - invoice.AmtPaid = amtPaid - var buf bytes.Buffer if err := serializeInvoice(&buf, &invoice); err != nil { return nil, err diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index cfe8cb60b..d2729d154 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -409,22 +409,23 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice, return i.cdb.LookupInvoice(rHash) } -// checkHtlcParameters is a callback used inside invoice db transactions to +// updateInvoice is a callback used inside invoice db transactions to // atomically check-and-update an invoice. -func (i *InvoiceRegistry) checkHtlcParameters(invoice *channeldb.Invoice, - amtPaid lnwire.MilliSatoshi, htlcExpiry uint32, currentHeight int32) error { +func (i *InvoiceRegistry) updateInvoice(invoice *channeldb.Invoice, + amtPaid lnwire.MilliSatoshi, htlcExpiry uint32, currentHeight int32) ( + *channeldb.InvoiceUpdateDesc, error) { // If the invoice is already canceled, there is no further checking to // do. if invoice.Terms.State == channeldb.ContractCanceled { - return channeldb.ErrInvoiceAlreadyCanceled + return nil, channeldb.ErrInvoiceAlreadyCanceled } // If an invoice amount is specified, check that enough is paid. Also // check this for duplicate payments if the invoice is already settled // or accepted. if invoice.Terms.Value > 0 && amtPaid < invoice.Terms.Value { - return ErrInvoiceAmountTooLow + return nil, ErrInvoiceAmountTooLow } // Return early in case the invoice was already accepted or settled. We @@ -432,20 +433,32 @@ func (i *InvoiceRegistry) checkHtlcParameters(invoice *channeldb.Invoice, // just restarting. switch invoice.Terms.State { case channeldb.ContractAccepted: - return channeldb.ErrInvoiceAlreadyAccepted + return nil, channeldb.ErrInvoiceAlreadyAccepted case channeldb.ContractSettled: - return channeldb.ErrInvoiceAlreadySettled + return nil, channeldb.ErrInvoiceAlreadySettled } if htlcExpiry < uint32(currentHeight+i.finalCltvRejectDelta) { - return ErrInvoiceExpiryTooSoon + return nil, ErrInvoiceExpiryTooSoon } if htlcExpiry < uint32(currentHeight+invoice.FinalCltvDelta) { - return ErrInvoiceExpiryTooSoon + return nil, ErrInvoiceExpiryTooSoon } - return nil + update := channeldb.InvoiceUpdateDesc{ + AmtPaid: amtPaid, + } + + // Check to see if we can settle or this is an hold invoice and we need + // to wait for the preimage. + holdInvoice := invoice.Terms.PaymentPreimage == channeldb.UnknownPreimage + if holdInvoice { + update.State = channeldb.ContractAccepted + } else { + update.State = channeldb.ContractSettled + } + return &update, nil } // NotifyExitHopHtlc attempts to mark an invoice as settled. If the invoice is a @@ -474,10 +487,12 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, // If this isn't a debug invoice, then we'll attempt to settle an // invoice matching this rHash on disk (if one exists). - invoice, err := i.cdb.AcceptOrSettleInvoice( - rHash, amtPaid, - func(inv *channeldb.Invoice) error { - return i.checkHtlcParameters( + invoice, err := i.cdb.UpdateInvoice( + rHash, + func(inv *channeldb.Invoice) (*channeldb.InvoiceUpdateDesc, + error) { + + return i.updateInvoice( inv, amtPaid, expiry, currentHeight, ) },