diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index f0d667332..fcc959516 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -3146,6 +3146,65 @@ func TestDeleteInvoices(t *testing.T) { assertInvoiceCount(0) } +// TestDeleteCanceledInvoices tests that deleting canceled invoices with the +// specific DeleteCanceledInvoices method works correctly. +func TestDeleteCanceledInvoices(t *testing.T) { + t.Parallel() + + db, err := MakeTestInvoiceDB(t) + require.NoError(t, err, "unable to make test db") + + // Updatefunc is used to cancel an invoice. + updateFunc := func(invoice *invpkg.Invoice) ( + *invpkg.InvoiceUpdateDesc, error) { + + return &invpkg.InvoiceUpdateDesc{ + UpdateType: invpkg.CancelInvoiceUpdate, + State: &invpkg.InvoiceStateUpdateDesc{ + NewState: invpkg.ContractCanceled, + }, + }, nil + } + + // Add some invoices to the test db. + ctxb := context.Background() + var invoices []invpkg.Invoice + for i := 0; i < 10; i++ { + invoice, err := randInvoice(lnwire.MilliSatoshi(i + 1)) + require.NoError(t, err) + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(ctxb, invoice, paymentHash) + require.NoError(t, err) + + // Cancel every second invoice. + if i%2 == 0 { + invoice, err = db.UpdateInvoice( + ctxb, invpkg.InvoiceRefByHash(paymentHash), nil, + updateFunc, + ) + require.NoError(t, err) + } else { + invoices = append(invoices, *invoice) + } + } + + // Delete canceled invoices. + require.NoError(t, db.DeleteCanceledInvoices(ctxb)) + + // Query to collect all invoices. + query := invpkg.InvoiceQuery{ + IndexOffset: 0, + NumMaxInvoices: math.MaxUint64, + } + + dbInvoices, err := db.QueryInvoices(ctxb, query) + require.NoError(t, err) + + // Check that we really have the expected invoices. + require.Equal(t, invoices, dbInvoices.Invoices) +} + // TestAddInvoiceInvalidFeatureDeps asserts that inserting an invoice with // invalid transitive feature dependencies fails with the appropriate error. func TestAddInvoiceInvalidFeatureDeps(t *testing.T) { diff --git a/channeldb/invoices.go b/channeldb/invoices.go index cc665d097..608491e69 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -2761,6 +2761,102 @@ func delAMPSettleIndex(invoiceNum []byte, invoices, return nil } +// DeleteCanceledInvoices deletes all canceled invoices from the database. +func (d *DB) DeleteCanceledInvoices(_ context.Context) error { + return kvdb.Update(d, func(tx kvdb.RwTx) error { + invoices := tx.ReadWriteBucket(invoiceBucket) + if invoices == nil { + return nil + } + + invoiceIndex := invoices.NestedReadWriteBucket( + invoiceIndexBucket, + ) + if invoiceIndex == nil { + return invpkg.ErrNoInvoicesCreated + } + + invoiceAddIndex := invoices.NestedReadWriteBucket( + addIndexBucket, + ) + if invoiceAddIndex == nil { + return invpkg.ErrNoInvoicesCreated + } + + payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket) + + return invoiceIndex.ForEach(func(k, v []byte) error { + // Skip the special numInvoicesKey as that does not + // point to a valid invoice. + if bytes.Equal(k, numInvoicesKey) { + return nil + } + + // Skip sub-buckets. + if v == nil { + return nil + } + + invoice, err := fetchInvoice(v, invoices) + if err != nil { + return err + } + + if invoice.State != invpkg.ContractCanceled { + return nil + } + + // Delete the payment hash from the invoice index. + err = invoiceIndex.Delete(k) + if err != nil { + return err + } + + // Delete payment address index reference if there's a + // valid payment address. + if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr { + // To ensure consistency check that the already + // fetched invoice key matches the one in the + // payment address index. + key := payAddrIndex.Get( + invoice.Terms.PaymentAddr[:], + ) + if bytes.Equal(key, k) { + // Delete from the payment address + // index. + if err := payAddrIndex.Delete( + invoice.Terms.PaymentAddr[:], + ); err != nil { + return err + } + } + } + + // Remove from the add index. + var addIndexKey [8]byte + byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex) + err = invoiceAddIndex.Delete(addIndexKey[:]) + if err != nil { + return err + } + + // Note that we don't need to delete the invoice from + // the settle index as it is not added until the + // invoice is settled. + + // Now remove all sub invoices. + err = delAMPInvoices(k, invoices) + if err != nil { + return err + } + + // Finally remove the serialized invoice from the + // invoice bucket. + return invoices.Delete(k) + }) + }, func() {}) +} + // DeleteInvoice attempts to delete the passed invoices from the database in // one transaction. The passed delete references hold all keys required to // delete the invoices without also needing to deserialze them. diff --git a/invoices/interface.go b/invoices/interface.go index d88a96753..7bb39f846 100644 --- a/invoices/interface.go +++ b/invoices/interface.go @@ -95,6 +95,10 @@ type InvoiceDB interface { // deserialze them. DeleteInvoice(ctx context.Context, invoicesToDelete []InvoiceDeleteRef) error + + // DeleteCanceledInvoices removes all canceled invoices from the + // database. + DeleteCanceledInvoices(ctx context.Context) error } // Payload abstracts access to any additional fields provided in the final hop's diff --git a/invoices/mock.go b/invoices/mock.go index 5c419d0b1..f7f6a0bc2 100644 --- a/invoices/mock.go +++ b/invoices/mock.go @@ -85,3 +85,9 @@ func (m *MockInvoiceDB) DeleteInvoice(invoices []InvoiceDeleteRef) error { return args.Error(0) } + +func (m *MockInvoiceDB) DeleteCanceledInvoices(ctx context.Context) error { + args := m.Called(ctx) + + return args.Error(0) +}