channeldb+routing: apply method Terminated to decide a payment's terminal state

This commit applies the new method `Terminated`. A side effect from
using this method is, we can now save one query `fetchPayment` inside
`FetchInFlightPayments`.
This commit is contained in:
yyforyongyu 2022-11-17 10:28:22 +08:00 committed by Olaoluwa Osuntokun
parent fac6044501
commit c175386c4d
4 changed files with 26 additions and 19 deletions

View File

@ -177,6 +177,13 @@ type MPPayment struct {
Status PaymentStatus Status PaymentStatus
} }
// Terminated returns a bool to specify whether the payment is in a terminal
// state.
func (m *MPPayment) Terminated() bool {
// If the payment is in terminal state, it cannot be updated.
return m.Status.updatable() != nil
}
// TerminalInfo returns any HTLC settle info recorded. If no settle info is // TerminalInfo returns any HTLC settle info recorded. If no settle info is
// recorded, any payment level failure will be returned. If neither a settle // recorded, any payment level failure will be returned. If neither a settle
// nor a failure is recorded, both return values will be nil. // nor a failure is recorded, both return values will be nil.

View File

@ -722,21 +722,16 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*MPPayment, error) {
return fmt.Errorf("non bucket element") return fmt.Errorf("non bucket element")
} }
// If the status is not InFlight, we can return early.
paymentStatus, err := fetchPaymentStatus(bucket)
if err != nil {
return err
}
if paymentStatus != StatusInFlight {
return nil
}
p, err := fetchPayment(bucket) p, err := fetchPayment(bucket)
if err != nil { if err != nil {
return err return err
} }
// Skip the payment if it's terminated.
if p.Terminated() {
return nil
}
inFlights = append(inFlights, p) inFlights = append(inFlights, p)
return nil return nil
}) })

View File

@ -280,7 +280,7 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
// updates. Otherwise this update is the final update and the incoming // updates. Otherwise this update is the final update and the incoming
// channel can be closed. This will close the queue's outgoing channel // channel can be closed. This will close the queue's outgoing channel
// when all updates have been written. // when all updates have been written.
if payment.Status == channeldb.StatusInFlight { if !payment.Terminated() {
p.subscribersMtx.Lock() p.subscribersMtx.Lock()
p.subscribers[paymentHash] = append( p.subscribers[paymentHash] = append(
p.subscribers[paymentHash], subscriber, p.subscribers[paymentHash], subscriber,
@ -344,7 +344,7 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
// If the payment reached a terminal state, the subscriber list can be // If the payment reached a terminal state, the subscriber list can be
// cleared. There won't be any more updates. // cleared. There won't be any more updates.
terminal := event.Status != channeldb.StatusInFlight terminal := event.Terminated()
if terminal { if terminal {
delete(p.subscribers, paymentHash) delete(p.subscribers, paymentHash)
} }

View File

@ -119,9 +119,9 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
subscriber1, subscriber2, subscriber3, subscriber1, subscriber2, subscriber3,
} }
for _, s := range subscribers { for i, s := range subscribers {
var result *channeldb.MPPayment var result *channeldb.MPPayment
for result == nil || result.Status == channeldb.StatusInFlight { for result == nil || !result.Terminated() {
select { select {
case item := <-s.Updates(): case item := <-s.Updates():
result = item.(*channeldb.MPPayment) result = item.(*channeldb.MPPayment)
@ -130,9 +130,10 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
} }
} }
if result.Status != channeldb.StatusSucceeded { require.Equalf(t, channeldb.StatusSucceeded, result.Status,
t.Fatal("unexpected payment state") "subscriber %v failed, want %s, got %s", i,
} channeldb.StatusSucceeded, result.Status)
settle, _ := result.TerminalInfo() settle, _ := result.TerminalInfo()
if settle.Preimage != preimg { if settle.Preimage != preimg {
t.Fatal("unexpected preimage") t.Fatal("unexpected preimage")
@ -474,9 +475,9 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
subscriber1, subscriber2, subscriber1, subscriber2,
} }
for _, s := range subscribers { for i, s := range subscribers {
var result *channeldb.MPPayment var result *channeldb.MPPayment
for result == nil || result.Status == channeldb.StatusInFlight { for result == nil || !result.Terminated() {
select { select {
case item := <-s.Updates(): case item := <-s.Updates():
result = item.(*channeldb.MPPayment) result = item.(*channeldb.MPPayment)
@ -510,6 +511,10 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
len(result.HTLCs)) len(result.HTLCs))
} }
require.Equalf(t, channeldb.StatusFailed, result.Status,
"subscriber %v failed, want %s, got %s", i,
channeldb.StatusFailed, result.Status)
if *result.FailureReason != channeldb.FailureReasonTimeout { if *result.FailureReason != channeldb.FailureReasonTimeout {
t.Fatal("unexpected failure reason") t.Fatal("unexpected failure reason")
} }