diff --git a/routing/control_tower.go b/routing/control_tower.go index c064a5b4f..274c4190b 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -181,12 +181,29 @@ func NewControlTower(db *channeldb.PaymentControl) ControlTower { // InitPayment checks or records the given PaymentCreationInfo with the DB, // making sure it does not already exist as an in-flight payment. Then this -// method returns successfully, the payment is guaranteed to be in the InFlight -// state. +// method returns successfully, the payment is guaranteed to be in the +// Initiated state. func (p *controlTower) InitPayment(paymentHash lntypes.Hash, info *channeldb.PaymentCreationInfo) error { - return p.db.InitPayment(paymentHash, info) + err := p.db.InitPayment(paymentHash, info) + if err != nil { + return err + } + + // Take lock before querying the db to prevent missing or duplicating + // an update. + p.paymentsMtx.Lock(paymentHash) + defer p.paymentsMtx.Unlock(paymentHash) + + payment, err := p.db.FetchPayment(paymentHash) + if err != nil { + return err + } + + p.notifySubscribers(paymentHash, payment) + + return nil } // DeleteFailedAttempts deletes all failed htlcs if the payment was diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 42303dc55..2baad92f1 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -196,7 +196,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { err = pControl.InitPayment(info1.PaymentIdentifier, info1) require.NoError(t, err) - // Subscription should succeed and immediately report the InFlight + // Subscription should succeed and immediately report the Initiated // status. subscription, err := pControl.SubscribeAllPayments() require.NoError(t, err, "expected subscribe to succeed, but got: %v") @@ -246,8 +246,8 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { // for each payment. results := make(map[lntypes.Hash]*channeldb.MPPayment) - // After exactly 5 updates both payments will/should have completed. - for i := 0; i < 5; i++ { + // After exactly 6 updates both payments will/should have completed. + for i := 0; i < 6; i++ { select { case item := <-subscription.Updates(): id := item.(*channeldb.MPPayment).Info.PaymentIdentifier @@ -354,10 +354,6 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) { err = pControl.InitPayment(info.PaymentIdentifier, info) require.NoError(t, err) - // Register a payment update. - err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) - require.NoError(t, err) - // Assert all subscriptions receive the update. select { case update1 := <-subscription1.Updates(): @@ -376,14 +372,9 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) { // Close the first subscription. subscription1.Close() - // Register another update. - failInfo := channeldb.HTLCFailInfo{ - Reason: channeldb.HTLCFailInternal, - } - _, err = pControl.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, &failInfo, - ) - require.NoError(t, err, "unable to fail htlc") + // Register a payment update. + err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err) // Assert only subscription 2 receives the update. select { @@ -398,9 +389,14 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) { // Close the second subscription. subscription2.Close() - // Register a last update. - err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) - require.NoError(t, err) + // Register another update. + failInfo := channeldb.HTLCFailInfo{ + Reason: channeldb.HTLCFailInternal, + } + _, err = pControl.FailAttempt( + info.PaymentIdentifier, attempt.AttemptID, &failInfo, + ) + require.NoError(t, err, "unable to fail htlc") // Assert no subscriptions receive the update. require.Len(t, subscription1.Updates(), 0)