From 2e192a80004f7b90fb36363b7f49103bfc7e7e42 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 17 Nov 2022 10:28:22 +0800 Subject: [PATCH] channeldb+routing: apply method `Terminated` to decide a payment's terminal state This commit applies the new method `Terminated`. A side effect from using this method is, we can now save one query `fetchPayment` inside `FetchInFlightPayments`. --- channeldb/mp_payment.go | 7 +++++++ channeldb/payment_control.go | 15 +++++---------- routing/control_tower.go | 4 ++-- routing/control_tower_test.go | 19 ++++++++++++------- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/channeldb/mp_payment.go b/channeldb/mp_payment.go index 6b1a6484e..0c1943d66 100644 --- a/channeldb/mp_payment.go +++ b/channeldb/mp_payment.go @@ -177,6 +177,13 @@ type MPPayment struct { Status PaymentStatus } +// Terminated returns a bool to specify whether the payment is in a terminal +// state. +func (m *MPPayment) Terminated() bool { + // If the payment is in terminal state, it cannot be updated. + return m.Status.updatable() != nil +} + // TerminalInfo returns any HTLC settle info recorded. If no settle info is // recorded, any payment level failure will be returned. If neither a settle // nor a failure is recorded, both return values will be nil. diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index e0d7fbcec..46c8a9933 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -722,21 +722,16 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*MPPayment, error) { return fmt.Errorf("non bucket element") } - // If the status is not InFlight, we can return early. - paymentStatus, err := fetchPaymentStatus(bucket) - if err != nil { - return err - } - - if paymentStatus != StatusInFlight { - return nil - } - p, err := fetchPayment(bucket) if err != nil { return err } + // Skip the payment if it's terminated. + if p.Terminated() { + return nil + } + inFlights = append(inFlights, p) return nil }) diff --git a/routing/control_tower.go b/routing/control_tower.go index d2cbc6bbf..0590debee 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -280,7 +280,7 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) ( // updates. Otherwise this update is the final update and the incoming // channel can be closed. This will close the queue's outgoing channel // when all updates have been written. - if payment.Status == channeldb.StatusInFlight { + if !payment.Terminated() { p.subscribersMtx.Lock() p.subscribers[paymentHash] = append( p.subscribers[paymentHash], subscriber, @@ -344,7 +344,7 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash, // If the payment reached a terminal state, the subscriber list can be // cleared. There won't be any more updates. - terminal := event.Status != channeldb.StatusInFlight + terminal := event.Terminated() if terminal { delete(p.subscribers, paymentHash) } diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 05f049fa2..3681b647d 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -119,9 +119,9 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { subscriber1, subscriber2, subscriber3, } - for _, s := range subscribers { + for i, s := range subscribers { var result *channeldb.MPPayment - for result == nil || result.Status == channeldb.StatusInFlight { + for result == nil || !result.Terminated() { select { case item := <-s.Updates(): result = item.(*channeldb.MPPayment) @@ -130,9 +130,10 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { } } - if result.Status != channeldb.StatusSucceeded { - t.Fatal("unexpected payment state") - } + require.Equalf(t, channeldb.StatusSucceeded, result.Status, + "subscriber %v failed, want %s, got %s", i, + channeldb.StatusSucceeded, result.Status) + settle, _ := result.TerminalInfo() if settle.Preimage != preimg { t.Fatal("unexpected preimage") @@ -474,9 +475,9 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, subscriber1, subscriber2, } - for _, s := range subscribers { + for i, s := range subscribers { var result *channeldb.MPPayment - for result == nil || result.Status == channeldb.StatusInFlight { + for result == nil || !result.Terminated() { select { case item := <-s.Updates(): result = item.(*channeldb.MPPayment) @@ -510,6 +511,10 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, len(result.HTLCs)) } + require.Equalf(t, channeldb.StatusFailed, result.Status, + "subscriber %v failed, want %s, got %s", i, + channeldb.StatusFailed, result.Status) + if *result.FailureReason != channeldb.FailureReasonTimeout { t.Fatal("unexpected failure reason") }