diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 268078944..a35c9fae4 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/stretchr/testify/assert" ) var ( @@ -305,6 +306,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) { t.Fatalf("unable to make test db: %v", err) } + _, err = db.InvoicesAddedSince(0) + assert.Nil(t, err) + // We'll start off by creating 20 random invoices, and inserting them // into the database. const numInvoices = 20 @@ -372,6 +376,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } + _, err = db.InvoicesSettledSince(0) + assert.Nil(t, err) + var settledInvoices []Invoice var settleIndex uint64 = 1 // We'll now only settle the latter half of each of those invoices. diff --git a/channeldb/invoices.go b/channeldb/invoices.go index aea8ae306..6c3516621 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -488,12 +488,12 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { err := kvdb.View(d, func(tx kvdb.RTx) error { invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { - return ErrNoInvoicesCreated + return nil } addIndex := invoices.NestedReadBucket(addIndexBucket) if addIndex == nil { - return ErrNoInvoicesCreated + return nil } // We'll now run through each entry in the add index starting @@ -520,12 +520,7 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { return nil }) - switch { - // If no invoices have been created, then we'll return the empty set of - // invoices. - case err == ErrNoInvoicesCreated: - - case err != nil: + if err != nil { return nil, err } @@ -886,12 +881,12 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { err := kvdb.View(d, func(tx kvdb.RTx) error { invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { - return ErrNoInvoicesCreated + return nil } settleIndex := invoices.NestedReadBucket(settleIndexBucket) if settleIndex == nil { - return ErrNoInvoicesCreated + return nil } // We'll now run through each entry in the add index starting diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 662a8a82e..520e75823 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -234,15 +234,6 @@ func (i *InvoiceRegistry) invoiceEventLoop() { // We'll query for any backlog notifications, then add it to the // set of clients. case newClient := <-i.newSubscriptions: - // Before we add the client to our set of active - // clients, we'll first attempt to deliver any backlog - // invoice events. - err := i.deliverBacklogEvents(newClient) - if err != nil { - log.Errorf("unable to deliver backlog invoice "+ - "notifications: %v", err) - } - log.Infof("New invoice subscription "+ "client: id=%v", newClient.id) @@ -410,9 +401,6 @@ func (i *InvoiceRegistry) dispatchToClients(event *invoiceEvent) { // deliverBacklogEvents will attempts to query the invoice database for any // notifications that the client has missed since it reconnected last. func (i *InvoiceRegistry) deliverBacklogEvents(client *InvoiceSubscription) error { - // First, we'll query the database to see if based on the provided - // addIndex and settledIndex we need to deliver any backlog - // notifications. addEvents, err := i.cdb.InvoicesAddedSince(client.addIndex) if err != nil { return err @@ -1182,7 +1170,9 @@ func (i *invoiceSubscriptionKit) notify(event *invoiceEvent) error { // added. The invoiceIndex parameter is a streaming "checkpoint". We'll start // by first sending out all new events with an invoice index _greater_ than // this value. Afterwards, we'll send out real-time notifications. -func (i *InvoiceRegistry) SubscribeNotifications(addIndex, settleIndex uint64) *InvoiceSubscription { +func (i *InvoiceRegistry) SubscribeNotifications( + addIndex, settleIndex uint64) (*InvoiceSubscription, error) { + client := &InvoiceSubscription{ NewInvoices: make(chan *channeldb.Invoice), SettledInvoices: make(chan *channeldb.Invoice), @@ -1251,12 +1241,23 @@ func (i *InvoiceRegistry) SubscribeNotifications(addIndex, settleIndex uint64) * } }() + i.Lock() + defer i.Unlock() + + // Query the database to see if based on the provided addIndex and + // settledIndex we need to deliver any backlog notifications. + err := i.deliverBacklogEvents(client) + if err != nil { + return nil, err + } + select { case i.newSubscriptions <- client: case <-i.quit: + return nil, ErrShuttingDown } - return client + return client, nil } // SubscribeSingleInvoice returns an SingleInvoiceSubscription which allows the diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 319c30cf0..aeb521250 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -9,6 +9,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/stretchr/testify/assert" ) // TestSettleInvoice tests settling of an invoice and related notifications. @@ -16,7 +17,8 @@ func TestSettleInvoice(t *testing.T) { ctx := newTestContext(t) defer ctx.cleanup() - allSubscriptions := ctx.registry.SubscribeNotifications(0, 0) + allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) + assert.Nil(t, err) defer allSubscriptions.Cancel() // Subscribe to the not yet existing invoice. @@ -221,11 +223,12 @@ func TestCancelInvoice(t *testing.T) { ctx := newTestContext(t) defer ctx.cleanup() - allSubscriptions := ctx.registry.SubscribeNotifications(0, 0) + allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) + assert.Nil(t, err) defer allSubscriptions.Cancel() // Try to cancel the not yet existing invoice. This should fail. - err := ctx.registry.CancelInvoice(testInvoicePaymentHash) + err = ctx.registry.CancelInvoice(testInvoicePaymentHash) if err != channeldb.ErrInvoiceNotFound { t.Fatalf("expected ErrInvoiceNotFound, but got %v", err) } @@ -352,7 +355,8 @@ func TestSettleHoldInvoice(t *testing.T) { } defer registry.Stop() - allSubscriptions := registry.SubscribeNotifications(0, 0) + allSubscriptions, err := registry.SubscribeNotifications(0, 0) + assert.Nil(t, err) defer allSubscriptions.Cancel() // Subscribe to the not yet existing invoice. @@ -651,7 +655,8 @@ func testKeySend(t *testing.T, keySendEnabled bool) { ctx.registry.cfg.AcceptKeySend = keySendEnabled - allSubscriptions := ctx.registry.SubscribeNotifications(0, 0) + allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) + assert.Nil(t, err) defer allSubscriptions.Cancel() hodlChan := make(chan interface{}, 1) diff --git a/rpcserver.go b/rpcserver.go index 69abbc17e..178595bd1 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -4535,9 +4535,12 @@ func (r *rpcServer) ListInvoices(ctx context.Context, func (r *rpcServer) SubscribeInvoices(req *lnrpc.InvoiceSubscription, updateStream lnrpc.Lightning_SubscribeInvoicesServer) error { - invoiceClient := r.server.invoices.SubscribeNotifications( + invoiceClient, err := r.server.invoices.SubscribeNotifications( req.AddIndex, req.SettleIndex, ) + if err != nil { + return err + } defer invoiceClient.Cancel() for {