From e65f02034841307e32316434736c8715f7a0a11e Mon Sep 17 00:00:00 2001 From: Jesse de Wit Date: Tue, 15 Mar 2022 12:11:11 +0100 Subject: [PATCH] routing: add SubscribeAllPayments to control tower Add a method 'SubscribeAllPayments' to the control tower, in order to be able to subscribe to any payment, rather than subscribing to a specific payment hash. --- routing/control_tower.go | 82 ++++++++++-- routing/control_tower_test.go | 230 ++++++++++++++++++++++++++++++++++ routing/mock_test.go | 13 ++ 3 files changed, 318 insertions(+), 7 deletions(-) diff --git a/routing/control_tower.go b/routing/control_tower.go index be6e61b4e..5b9577e40 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -62,6 +62,11 @@ type ControlTower interface { // sent out immediately. SubscribePayment(paymentHash lntypes.Hash) (*ControlTowerSubscriber, error) + + // SubscribeAllPayments subscribes to updates for all payments. A first + // update with the current state of every inflight payment is always + // sent out immediately. + SubscribeAllPayments() (*ControlTowerSubscriber, error) } // ControlTowerSubscriber contains the state for a payment update subscriber. @@ -102,8 +107,13 @@ func (s *ControlTowerSubscriber) Close() { type controlTower struct { db *channeldb.PaymentControl - subscribers map[lntypes.Hash][]*ControlTowerSubscriber - subscribersMtx sync.Mutex + // subscriberIndex is used to provide a unique id for each subscriber + // to all payments. This is used to easily remove the subscriber when + // necessary. + subscriberIndex uint64 + subscribersAllPayments map[uint64]*ControlTowerSubscriber + subscribers map[lntypes.Hash][]*ControlTowerSubscriber + subscribersMtx sync.Mutex // paymentsMtx provides synchronization on the payment level to ensure // that no race conditions occur in between updating the database and @@ -114,7 +124,10 @@ type controlTower struct { // NewControlTower creates a new instance of the controlTower. func NewControlTower(db *channeldb.PaymentControl) ControlTower { return &controlTower{ - db: db, + db: db, + subscribersAllPayments: make( + map[uint64]*ControlTowerSubscriber, + ), subscribers: make(map[lntypes.Hash][]*ControlTowerSubscriber), paymentsMtx: multimutex.NewHashMutex(), } @@ -266,6 +279,39 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) ( return subscriber, nil } +// SubscribeAllPayments subscribes to updates for all inflight payments. A first +// update with the current state of every inflight payment is always sent out +// immediately. +// Note: If payments are in-flight while starting a new subscription, the start +// of the payment stream could produce out-of-order and/or duplicate events. In +// order to get updates for every in-flight payment attempt make sure to +// subscribe to this method before initiating any payments. +func (p *controlTower) SubscribeAllPayments() (*ControlTowerSubscriber, error) { + subscriber := newControlTowerSubscriber() + + // Add the subscriber to the list before fetching in-flight payments, so + // no events are missed. If a payment attempt update occurs after + // appending and before fetching in-flight payments, an out-of-order + // duplicate may be produced, because it is then fetched in below call + // and notified through the subscription. + p.subscribersMtx.Lock() + p.subscribersAllPayments[p.subscriberIndex] = subscriber + p.subscriberIndex++ + p.subscribersMtx.Unlock() + + inflightPayments, err := p.db.FetchInFlightPayments() + if err != nil { + return nil, err + } + + for index := range inflightPayments { + // Always write current payment state to the channel. + subscriber.queue.ChanIn() <- inflightPayments[index] + } + + return subscriber, nil +} + // notifySubscribers sends a final payment event to all subscribers of this // payment. The channel will be closed after this. Note that this function must // be executed atomically (by means of a lock) with the database update to @@ -275,8 +321,9 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash, // Get all subscribers for this payment. p.subscribersMtx.Lock() - list, ok := p.subscribers[paymentHash] - if !ok { + + subscribersPaymentHash, ok := p.subscribers[paymentHash] + if !ok && len(p.subscribersAllPayments) == 0 { p.subscribersMtx.Unlock() return } @@ -287,10 +334,17 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash, if terminal { delete(p.subscribers, paymentHash) } + + // Copy subscribers to all payments locally while holding the lock in + // order to avoid concurrency issues while reading/writing the map. + subscribersAllPayments := make(map[uint64]*ControlTowerSubscriber) + for k, v := range p.subscribersAllPayments { + subscribersAllPayments[k] = v + } p.subscribersMtx.Unlock() - // Notify all subscribers of the event. - for _, subscriber := range list { + // Notify all subscribers that subscribed to the current payment hash. + for _, subscriber := range subscribersPaymentHash { select { case subscriber.queue.ChanIn() <- event: // If this event is the last, close the incoming channel @@ -305,4 +359,18 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash, case <-subscriber.quit: } } + + // Notify all subscribers that subscribed to all payments. + for key, subscriber := range subscribersAllPayments { + select { + case subscriber.queue.ChanIn() <- event: + + // If subscriber disappeared, remove it from the subscribers + // list. + case <-subscriber.quit: + p.subscribersMtx.Lock() + delete(p.subscribersAllPayments, key) + p.subscribersMtx.Unlock() + } + } } diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 80a8fbc5c..19e4465c2 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -178,6 +178,236 @@ func TestPaymentControlSubscribeFail(t *testing.T) { }) } +// TestPaymentControlSubscribeAllSuccess tests that multiple payments are +// properly sent to subscribers of TrackPayments. +func TestPaymentControlSubscribeAllSuccess(t *testing.T) { + t.Parallel() + + db, err := initDB(t, true) + require.NoError(t, err, "unable to init db: %v") + + pControl := NewControlTower(channeldb.NewPaymentControl(db)) + + // Initiate a payment. + info1, attempt1, preimg1, err := genInfo() + require.NoError(t, err) + + err = pControl.InitPayment(info1.PaymentIdentifier, info1) + require.NoError(t, err) + + // Subscription should succeed and immediately report the InFlight + // status. + subscription, err := pControl.SubscribeAllPayments() + require.NoError(t, err, "expected subscribe to succeed, but got: %v") + + // Register an attempt. + err = pControl.RegisterAttempt(info1.PaymentIdentifier, attempt1) + require.NoError(t, err) + + // Initiate a second payment after the subscription is already active. + info2, attempt2, preimg2, err := genInfo() + require.NoError(t, err) + + err = pControl.InitPayment(info2.PaymentIdentifier, info2) + require.NoError(t, err) + + // Register an attempt on the second payment. + err = pControl.RegisterAttempt(info2.PaymentIdentifier, attempt2) + require.NoError(t, err) + + // Mark the first payment as successful. + settleInfo1 := channeldb.HTLCSettleInfo{ + Preimage: preimg1, + } + htlcAttempt1, err := pControl.SettleAttempt( + info1.PaymentIdentifier, attempt1.AttemptID, &settleInfo1, + ) + require.NoError(t, err) + require.Equal( + t, settleInfo1, *htlcAttempt1.Settle, + "unexpected settle info returned", + ) + + // Mark the second payment as successful. + settleInfo2 := channeldb.HTLCSettleInfo{ + Preimage: preimg2, + } + htlcAttempt2, err := pControl.SettleAttempt( + info2.PaymentIdentifier, attempt2.AttemptID, &settleInfo2, + ) + require.NoError(t, err) + require.Equal( + t, settleInfo2, *htlcAttempt2.Settle, + "unexpected fail info returned", + ) + + // The two payments will be asserted individually, store the last update + // for each payment. + results := make(map[lntypes.Hash]*channeldb.MPPayment) + + // After exactly 5 updates both payments will/should have completed. + for i := 0; i < 5; i++ { + select { + case item := <-subscription.Updates: + id := item.(*channeldb.MPPayment).Info.PaymentIdentifier + results[id] = item.(*channeldb.MPPayment) + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for payment result") + } + } + + result1 := results[info1.PaymentIdentifier] + require.Equal( + t, channeldb.StatusSucceeded, result1.Status, + "unexpected payment state payment 1", + ) + + settle1, _ := result1.TerminalInfo() + require.Equal( + t, preimg1, settle1.Preimage, "unexpected preimage payment 1", + ) + + require.Len( + t, result1.HTLCs, 1, "expect 1 htlc for payment 1, got %d", + len(result1.HTLCs), + ) + + htlc1 := result1.HTLCs[0] + require.Equal(t, attempt1.Route, htlc1.Route, "unexpected htlc route.") + + result2 := results[info2.PaymentIdentifier] + require.Equal( + t, channeldb.StatusSucceeded, result2.Status, + "unexpected payment state payment 2", + ) + + settle2, _ := result2.TerminalInfo() + require.Equal( + t, preimg2, settle2.Preimage, "unexpected preimage payment 2", + ) + require.Len( + t, result2.HTLCs, 1, "expect 1 htlc for payment 2, got %d", + len(result2.HTLCs), + ) + + htlc2 := result2.HTLCs[0] + require.Equal(t, attempt2.Route, htlc2.Route, "unexpected htlc route.") +} + +// TestPaymentControlSubscribeAllImmediate tests whether already inflight +// payments are reported at the start of the SubscribeAllPayments subscription. +func TestPaymentControlSubscribeAllImmediate(t *testing.T) { + t.Parallel() + + db, err := initDB(t, true) + require.NoError(t, err, "unable to init db: %v") + + pControl := NewControlTower(channeldb.NewPaymentControl(db)) + + // Initiate a payment. + info, attempt, _, err := genInfo() + require.NoError(t, err) + + err = pControl.InitPayment(info.PaymentIdentifier, info) + require.NoError(t, err) + + // Register a payment update. + err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err) + + subscription, err := pControl.SubscribeAllPayments() + require.NoError(t, err, "expected subscribe to succeed, but got: %v") + + // Assert the new subscription receives the old update. + select { + case update := <-subscription.Updates: + require.NotNil(t, update) + require.Equal( + t, info.PaymentIdentifier, + update.(*channeldb.MPPayment).Info.PaymentIdentifier, + ) + require.Len(t, subscription.Updates, 0) + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for payment result") + } +} + +// TestPaymentControlUnsubscribeSuccess tests that when unsubscribed, there are +// no more notifications to that specific subscription. +func TestPaymentControlUnsubscribeSuccess(t *testing.T) { + t.Parallel() + + db, err := initDB(t, true) + require.NoError(t, err, "unable to init db: %v") + + pControl := NewControlTower(channeldb.NewPaymentControl(db)) + + subscription1, err := pControl.SubscribeAllPayments() + require.NoError(t, err, "expected subscribe to succeed, but got: %v") + + subscription2, err := pControl.SubscribeAllPayments() + require.NoError(t, err, "expected subscribe to succeed, but got: %v") + + // Initiate a payment. + info, attempt, _, err := genInfo() + require.NoError(t, err) + + err = pControl.InitPayment(info.PaymentIdentifier, info) + require.NoError(t, err) + + // Register a payment update. + err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err) + + // Assert all subscriptions receive the update. + select { + case update1 := <-subscription1.Updates: + require.NotNil(t, update1) + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for payment result") + } + + select { + case update2 := <-subscription2.Updates: + require.NotNil(t, update2) + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for payment result") + } + + // Close the first subscription. + subscription1.Close() + + // Register another update. + failInfo := channeldb.HTLCFailInfo{ + Reason: channeldb.HTLCFailInternal, + } + _, err = pControl.FailAttempt( + info.PaymentIdentifier, attempt.AttemptID, &failInfo, + ) + require.NoError(t, err, "unable to fail htlc") + + // Assert only subscription 2 receives the update. + select { + case update2 := <-subscription2.Updates: + require.NotNil(t, update2) + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for payment result") + } + + require.Len(t, subscription1.Updates, 0) + + // Close the second subscription. + subscription2.Close() + + // Register a last update. + err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err) + + // Assert no subscriptions receive the update. + require.Len(t, subscription1.Updates, 0) + require.Len(t, subscription2.Updates, 0) +} + func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, keepFailedPaymentAttempts bool) { diff --git a/routing/mock_test.go b/routing/mock_test.go index 67f6fa432..ad7881cd3 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -557,6 +557,12 @@ func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) ( return nil, errors.New("not implemented") } +func (m *mockControlTowerOld) SubscribeAllPayments() ( + *ControlTowerSubscriber, error) { + + return nil, errors.New("not implemented") +} + type mockPaymentAttemptDispatcher struct { mock.Mock @@ -774,6 +780,13 @@ func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) ( return args.Get(0).(*ControlTowerSubscriber), args.Error(1) } +func (m *mockControlTower) SubscribeAllPayments() ( + *ControlTowerSubscriber, error) { + + args := m.Called() + return args.Get(0).(*ControlTowerSubscriber), args.Error(1) +} + type mockLink struct { htlcswitch.ChannelLink bandwidth lnwire.MilliSatoshi