channeldb: modify SettleInvoice to also return the invoice being settled

This commit is contained in:
Olaoluwa Osuntokun
2018-06-28 21:43:18 -07:00
parent 2c08a22ed3
commit e5c579120e
2 changed files with 29 additions and 12 deletions

View File

@@ -99,7 +99,7 @@ func TestInvoiceWorkflow(t *testing.T) {
// now have the settled bit toggle to true and a non-default // now have the settled bit toggle to true and a non-default
// SettledDate // SettledDate
payAmt := fakeInvoice.Terms.Value * 2 payAmt := fakeInvoice.Terms.Value * 2
if err := db.SettleInvoice(paymentHash, payAmt); err != nil { if _, err := db.SettleInvoice(paymentHash, payAmt); err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
} }
dbInvoice2, err := db.LookupInvoice(paymentHash) dbInvoice2, err := db.LookupInvoice(paymentHash)
@@ -260,7 +260,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
invoice.Terms.PaymentPreimage[:], invoice.Terms.PaymentPreimage[:],
) )
err := db.SettleInvoice(paymentHash, 0) _, err := db.SettleInvoice(paymentHash, 0)
if err != nil { if err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
} }

View File

@@ -393,9 +393,11 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) {
// payment hash as fully settled. If an invoice matching the passed payment // payment hash as fully settled. If an invoice matching the passed payment
// hash doesn't existing within the database, then the action will fail with a // hash doesn't existing within the database, then the action will fail with a
// "not found" error. // "not found" error.
func (d *DB) SettleInvoice(paymentHash [32]byte, amtPaid lnwire.MilliSatoshi) error { func (d *DB) SettleInvoice(paymentHash [32]byte,
amtPaid lnwire.MilliSatoshi) (*Invoice, error) {
return d.Update(func(tx *bolt.Tx) error { var settledInvoice *Invoice
err := d.Update(func(tx *bolt.Tx) error {
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
if err != nil { if err != nil {
return err return err
@@ -420,10 +422,21 @@ func (d *DB) SettleInvoice(paymentHash [32]byte, amtPaid lnwire.MilliSatoshi) er
return ErrInvoiceNotFound return ErrInvoiceNotFound
} }
return settleInvoice( invoice, err := settleInvoice(
invoices, settleIndex, invoiceNum, amtPaid, invoices, settleIndex, invoiceNum, amtPaid,
) )
if err != nil {
return err
}
settledInvoice = invoice
return nil
}) })
if err != nil {
return nil, err
}
return settledInvoice, nil
} }
// InvoicesSettledSince can be used by callers to catch up any settled invoices // InvoicesSettledSince can be used by callers to catch up any settled invoices
@@ -670,17 +683,17 @@ func deserializeInvoice(r io.Reader) (Invoice, error) {
} }
func settleInvoice(invoices, settleIndex *bolt.Bucket, invoiceNum []byte, func settleInvoice(invoices, settleIndex *bolt.Bucket, invoiceNum []byte,
amtPaid lnwire.MilliSatoshi) error { amtPaid lnwire.MilliSatoshi) (*Invoice, error) {
invoice, err := fetchInvoice(invoiceNum, invoices) invoice, err := fetchInvoice(invoiceNum, invoices)
if err != nil { if err != nil {
return err return nil, err
} }
// Add idempotency to duplicate settles, return here to avoid // Add idempotency to duplicate settles, return here to avoid
// overwriting the previous info. // overwriting the previous info.
if invoice.Terms.Settled { if invoice.Terms.Settled {
return nil return nil, nil
} }
// Now that we know the invoice hasn't already been settled, we'll // Now that we know the invoice hasn't already been settled, we'll
@@ -688,13 +701,13 @@ func settleInvoice(invoices, settleIndex *bolt.Bucket, invoiceNum []byte,
// proper location within our time series. // proper location within our time series.
nextSettleSeqNo, err := settleIndex.NextSequence() nextSettleSeqNo, err := settleIndex.NextSequence()
if err != nil { if err != nil {
return err return nil, err
} }
var seqNoBytes [8]byte var seqNoBytes [8]byte
byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil { if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil {
return err return nil, err
} }
invoice.AmtPaid = amtPaid invoice.AmtPaid = amtPaid
@@ -704,8 +717,12 @@ func settleInvoice(invoices, settleIndex *bolt.Bucket, invoiceNum []byte,
var buf bytes.Buffer var buf bytes.Buffer
if err := serializeInvoice(&buf, &invoice); err != nil { if err := serializeInvoice(&buf, &invoice); err != nil {
return nil return nil, err
} }
return invoices.Put(invoiceNum[:], buf.Bytes()) if err := invoices.Put(invoiceNum[:], buf.Bytes()); err != nil {
return nil, err
}
return &invoice, nil
} }