From da8f1c084af39c19c622fa699a302d633764c925 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 13 Feb 2023 13:58:52 +0800 Subject: [PATCH] channeldb+routing: add new interface method `TerminalInfo` This commit adds a new interface method `TerminalInfo` and changes its implementation to return an `*HTLCAttempt` so it includes the route for a successful payment. Method `GetFailureReason` is now removed as its returned value can be found in the above method. --- channeldb/mp_payment.go | 9 ++------- routing/control_tower.go | 7 ++++--- routing/control_tower_test.go | 14 ++++++-------- routing/mock_test.go | 34 +++++++++++++++++++++++----------- routing/payment_lifecycle.go | 18 ++++++++---------- routing/router.go | 3 ++- routing/router_test.go | 34 ++++++++++++++-------------------- 7 files changed, 59 insertions(+), 60 deletions(-) diff --git a/channeldb/mp_payment.go b/channeldb/mp_payment.go index a4b323349..cf5669a50 100644 --- a/channeldb/mp_payment.go +++ b/channeldb/mp_payment.go @@ -219,10 +219,10 @@ func (m *MPPayment) Terminated() bool { // TerminalInfo returns any HTLC settle info recorded. If no settle info is // recorded, any payment level failure will be returned. If neither a settle // nor a failure is recorded, both return values will be nil. -func (m *MPPayment) TerminalInfo() (*HTLCSettleInfo, *FailureReason) { +func (m *MPPayment) TerminalInfo() (*HTLCAttempt, *FailureReason) { for _, h := range m.HTLCs { if h.Settle != nil { - return h.Settle, nil + return &h, nil } } @@ -464,11 +464,6 @@ func (m *MPPayment) GetHTLCs() []HTLCAttempt { return m.HTLCs } -// GetFailureReason returns the failure reason. -func (m *MPPayment) GetFailureReason() *FailureReason { - return m.FailureReason -} - // AllowMoreAttempts is used to decide whether we can safely attempt more HTLCs // for a given payment state. Return an error if the payment is in an // unexpected state. diff --git a/routing/control_tower.go b/routing/control_tower.go index 80ffbe5d9..c064a5b4f 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -31,13 +31,14 @@ type dbMPPayment interface { // InFlightHTLCs returns all HTLCs that are in flight. InFlightHTLCs() []channeldb.HTLCAttempt - // GetFailureReason returns the reason the payment failed. - GetFailureReason() *channeldb.FailureReason - // AllowMoreAttempts is used to decide whether we can safely attempt // more HTLCs for a given payment state. Return an error if the payment // is in an unexpected state. AllowMoreAttempts() (bool, error) + + // TerminalInfo returns the settled HTLC attempt or the payment's + // failure reason. + TerminalInfo() (*channeldb.HTLCAttempt, *channeldb.FailureReason) } // ControlTower tracks all outgoing payments made, whose primary purpose is to diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index f14c18b81..42303dc55 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -134,8 +134,8 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { "subscriber %v failed, want %s, got %s", i, channeldb.StatusSucceeded, result.GetStatus()) - settle, _ := result.TerminalInfo() - if settle.Preimage != preimg { + attempt, _ := result.TerminalInfo() + if attempt.Settle.Preimage != preimg { t.Fatal("unexpected preimage") } if len(result.HTLCs) != 1 { @@ -264,9 +264,8 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { ) settle1, _ := result1.TerminalInfo() - require.Equal( - t, preimg1, settle1.Preimage, "unexpected preimage payment 1", - ) + require.Equal(t, preimg1, settle1.Settle.Preimage, + "unexpected preimage payment 1") require.Len( t, result1.HTLCs, 1, "expect 1 htlc for payment 1, got %d", @@ -283,9 +282,8 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { ) settle2, _ := result2.TerminalInfo() - require.Equal( - t, preimg2, settle2.Preimage, "unexpected preimage payment 2", - ) + require.Equal(t, preimg2, settle2.Settle.Preimage, + "unexpected preimage payment 2") require.Len( t, result2.HTLCs, 1, "expect 1 htlc for payment 2, got %d", len(result2.HTLCs), diff --git a/routing/mock_test.go b/routing/mock_test.go index 5a2ee4206..2458f538b 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -828,22 +828,34 @@ func (m *mockMPPayment) InFlightHTLCs() []channeldb.HTLCAttempt { return args.Get(0).([]channeldb.HTLCAttempt) } -func (m *mockMPPayment) GetFailureReason() *channeldb.FailureReason { - args := m.Called() - - reason := args.Get(0) - if reason == nil { - return nil - } - - return reason.(*channeldb.FailureReason) -} - func (m *mockMPPayment) AllowMoreAttempts() (bool, error) { args := m.Called() return args.Bool(0), args.Error(1) } +func (m *mockMPPayment) TerminalInfo() (*channeldb.HTLCAttempt, + *channeldb.FailureReason) { + + args := m.Called() + + var ( + settleInfo *channeldb.HTLCAttempt + failureInfo *channeldb.FailureReason + ) + + settle := args.Get(0) + if settle != nil { + settleInfo = settle.(*channeldb.HTLCAttempt) + } + + reason := args.Get(1) + if reason != nil { + failureInfo = reason.(*channeldb.FailureReason) + } + + return settleInfo, failureInfo +} + type mockLink struct { htlcswitch.ChannelLink bandwidth lnwire.MilliSatoshi diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index d3e4289e6..511ef3596 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -201,10 +201,10 @@ lifecycle: ps := payment.GetState() remainingFees := p.calcFeeBudget(ps.FeesPaid) - log.Debugf("Payment %v in state terminate=%v, "+ - "active_shards=%v, rem_value=%v, fee_limit=%v", - p.identifier, payment.Terminated(), - ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees) + log.Debugf("Payment %v: status=%v, active_shards=%v, "+ + "rem_value=%v, fee_limit=%v", p.identifier, + payment.GetStatus(), ps.NumAttemptsInFlight, + ps.RemainingAmt, remainingFees) // We now proceed our lifecycle with the following tasks in // order, @@ -291,15 +291,13 @@ lifecycle: "%v: %v", p.identifier, err) } - // Find the first successful shard and return the preimage and route. - for _, a := range payment.GetHTLCs() { - if a.Settle != nil { - return a.Settle.Preimage, &a.Route, nil - } + htlc, failure := payment.TerminalInfo() + if htlc != nil { + return htlc.Settle.Preimage, &htlc.Route, nil } // Otherwise return the payment failure reason. - return [32]byte{}, nil, *payment.GetFailureReason() + return [32]byte{}, nil, *failure } // checkTimeout checks whether the payment has reached its timeout. diff --git a/routing/router.go b/routing/router.go index 0b442da15..f7b20a376 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2532,7 +2532,8 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // Exit if the above error has caused the payment to be failed, we also // return the error from sending attempt to mimic the old behavior of // this method. - if payment.GetFailureReason() != nil { + _, failedReason := payment.TerminalInfo() + if failedReason != nil { return result.attempt, result.err } diff --git a/routing/router_test.go b/routing/router_test.go index 6073ee2be..eda49dd9d 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3482,7 +3482,7 @@ func TestSendMPPaymentSucceed(t *testing.T) { payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs). On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("Terminated").Return(false) + On("GetStatus").Return(channeldb.StatusInFlight) controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock FetchPayment to return the payment. @@ -3518,9 +3518,6 @@ func TestSendMPPaymentSucceed(t *testing.T) { controlTower.On("SettleAttempt", identifier, mock.Anything, mock.Anything, ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { - payment.On("GetHTLCs").Return( - []channeldb.HTLCAttempt{settledAttempt}, - ) // We want to at least wait for one settlement. if numAttempts.Load() > 1 { settled.Store(true) @@ -3566,6 +3563,8 @@ func TestSendMPPaymentSucceed(t *testing.T) { controlTower.On("DeleteFailedAttempts", identifier).Return(nil) + payment.On("TerminalInfo").Return(&settledAttempt, nil) + // Call the actual method SendPayment on router. This is place inside a // goroutine so we can set a timeout for the whole test, in case // anything goes wrong and the test never finishes. @@ -3683,7 +3682,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs). On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("Terminated").Return(false) + On("GetStatus").Return(channeldb.StatusInFlight) controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock FetchPayment to return the payment. @@ -3787,12 +3786,6 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { controlTower.On("SettleAttempt", identifier, mock.Anything, mock.Anything, ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mock the payment's - // GetHTLCs() to return the settled htlc. - payment.On("GetHTLCs").Return( - []channeldb.HTLCAttempt{settledAttempt}, - ) - if numAttempts.Load() > 1 { settled.Store(true) } @@ -3800,6 +3793,8 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { controlTower.On("DeleteFailedAttempts", identifier).Return(nil) + payment.On("TerminalInfo").Return(&settledAttempt, nil) + // Call the actual method SendPayment on router. This is place inside a // goroutine so we can set a timeout for the whole test, in case // anything goes wrong and the test never finishes. @@ -3913,8 +3908,8 @@ func TestSendMPPaymentFailed(t *testing.T) { // Make a mock MPPayment. payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs).Once() - payment.On("GetState").Return(&channeldb.MPPaymentState{}) - payment.On("Terminated").Return(false) + payment.On("GetState").Return(&channeldb.MPPaymentState{}). + On("GetStatus").Return(channeldb.StatusInFlight) controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock the sequential FetchPayment to return the payment. @@ -3935,7 +3930,6 @@ func TestSendMPPaymentFailed(t *testing.T) { } payment.On("AllowMoreAttempts").Return(false, nil). - On("GetHTLCs").Return(htlcs).Once(). On("NeedWaitAttempts").Return(false, nil).Once() }) @@ -4011,12 +4005,12 @@ func TestSendMPPaymentFailed(t *testing.T) { }) // Mock the payment to return the failure reason. - payment.On("GetFailureReason").Return(&failureReason) - payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, ).Return(nil) + payment.On("TerminalInfo").Return(nil, &failureReason) + controlTower.On("DeleteFailedAttempts", identifier).Return(nil) // Call the actual method SendPayment on router. This is place inside a @@ -4194,7 +4188,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() // Mock the payment to return nil failrue reason. - payment.On("GetFailureReason").Return(nil) + payment.On("TerminalInfo").Return(nil, nil) // Expect a successful send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) @@ -4289,7 +4283,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { ).Return(nil, nil) // Mock the payment to return nil failrue reason. - payment.On("GetFailureReason").Return(nil) + payment.On("TerminalInfo").Return(nil, nil) // Expect a failed send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) @@ -4374,7 +4368,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() // Mock the payment to return a failrue reason. - payment.On("GetFailureReason").Return(&failureReason) + payment.On("TerminalInfo").Return(nil, &failureReason) // Expect a failed send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) @@ -4452,7 +4446,7 @@ func TestSendToRouteTempFailure(t *testing.T) { controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() // Mock the payment to return nil failrue reason. - payment.On("GetFailureReason").Return(nil) + payment.On("TerminalInfo").Return(nil, nil) // Return a nil reason to mock a temporary failure. missionControl.On("ReportPaymentFail",