diff --git a/routing/control_tower.go b/routing/control_tower.go index 8968ca312..174ff6a7f 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -282,6 +282,9 @@ func (p *controlTower) FetchPayment(paymentHash lntypes.Hash) ( // reason the payment failed. After invoking this method, InitPayment should // return nil on its next call for this payment hash, allowing the switch to // make a subsequent payment. +// +// NOTE: This method will overwrite the failure reason if the payment is already +// failed. func (p *controlTower) FailPayment(paymentHash lntypes.Hash, reason channeldb.FailureReason) error { diff --git a/routing/router.go b/routing/router.go index 323fe7616..3fa259ed6 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1048,6 +1048,29 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, firstHopCustomRecords lnwire.CustomRecords) (*channeldb.HTLCAttempt, error) { + // Helper function to fail a payment. It makes sure the payment is only + // failed once so that the failure reason is not overwritten. + failPayment := func(paymentIdentifier lntypes.Hash, + reason channeldb.FailureReason) error { + + payment, fetchErr := r.cfg.Control.FetchPayment( + paymentIdentifier, + ) + if fetchErr != nil { + return fetchErr + } + + // NOTE: We cannot rely on the payment status to be failed here + // because it can still be in-flight although the payment is + // already failed. + _, failedReason := payment.TerminalInfo() + if failedReason != nil { + return nil + } + + return r.cfg.Control.FailPayment(paymentIdentifier, reason) + } + log.Debugf("SendToRoute for payment %v with skipTempErr=%v", htlcHash, skipTempErr) @@ -1148,20 +1171,6 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, return nil, err } - // We now look up the payment to see if it's already failed. - payment, err := p.router.cfg.Control.FetchPayment(p.identifier) - if err != nil { - return result.attempt, err - } - - // Exit if the above error has caused the payment to be failed, we also - // return the error from sending attempt to mimic the old behavior of - // this method. - _, failedReason := payment.TerminalInfo() - if failedReason != nil { - return result.attempt, result.err - } - // Since for SendToRoute we won't retry in case the shard fails, we'll // mark the payment failed with the control tower immediately if the // skipTempErr is false. @@ -1175,8 +1184,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, return result.attempt, result.err } - // Otherwise we need to fail the payment. - err := r.cfg.Control.FailPayment(paymentIdentifier, reason) + err := failPayment(paymentIdentifier, reason) if err != nil { return nil, err } @@ -1199,7 +1207,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // An error returned from collecting the result, we'll mark the payment // as failed if we don't skip temp error. if !skipTempErr { - err := r.cfg.Control.FailPayment(paymentIdentifier, reason) + err := failPayment(paymentIdentifier, reason) if err != nil { return nil, err } diff --git a/routing/router_test.go b/routing/router_test.go index 6efc75df2..62a68b213 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -2215,13 +2215,6 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { mock.Anything, rt, ).Return(nil) - // Mock the control tower to return the mocked payment. - payment := &mockMPPayment{} - controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - - // Mock the payment to return nil failure reason. - payment.On("TerminalInfo").Return(nil, nil) - // Expect a successful send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt, nil) require.NoError(t, err) @@ -2231,7 +2224,6 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) - payment.AssertExpectations(t) } // TestSendToRouteSkipTempErrNonMPP checks that an error is return when @@ -2352,19 +2344,12 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, ).Return(tempErr) - // Mock the control tower to return the mocked payment. - payment := &mockMPPayment{} - controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - // Mock the mission control to return a nil reason from reporting the // attempt failure. missionControl.On("ReportPaymentFail", mock.Anything, rt, mock.Anything, mock.Anything, ).Return(nil, nil) - // Mock the payment to return nil failure reason. - payment.On("TerminalInfo").Return(nil, nil) - // Expect a failed send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt, nil) require.Equal(t, tempErr, err) @@ -2374,7 +2359,6 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) - payment.AssertExpectations(t) } // TestSendToRouteSkipTempErrPermanentFailure validates a permanent failure @@ -2436,7 +2420,9 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { ).Return(testAttempt, nil) // Expect the payment to be failed. - controlTower.On("FailPayment", payHash, mock.Anything).Return(nil) + controlTower.On( + "FailPayment", payHash, mock.Anything, + ).Return(nil).Once() // Mock an error to be returned from sending the htlc. payer.On("SendHTLC", @@ -2448,13 +2434,6 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { mock.Anything, rt, mock.Anything, mock.Anything, ).Return(&failureReason, nil) - // Mock the control tower to return the mocked payment. - payment := &mockMPPayment{} - controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - - // Mock the payment to return a failure reason. - payment.On("TerminalInfo").Return(nil, &failureReason) - // Expect a failed send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt, nil) require.Equal(t, permErr, err) @@ -2464,7 +2443,6 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) - payment.AssertExpectations(t) } // TestSendToRouteTempFailure validates a temporary failure will cause the @@ -2525,7 +2503,9 @@ func TestSendToRouteTempFailure(t *testing.T) { ).Return(testAttempt, nil) // Expect the payment to be failed. - controlTower.On("FailPayment", payHash, mock.Anything).Return(nil) + controlTower.On( + "FailPayment", payHash, mock.Anything, + ).Return(nil).Once() payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, @@ -2536,7 +2516,7 @@ func TestSendToRouteTempFailure(t *testing.T) { controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() // Mock the payment to return nil failure reason. - payment.On("TerminalInfo").Return(nil, nil) + payment.On("TerminalInfo").Return(nil, nil).Once() // Return a nil reason to mock a temporary failure. missionControl.On("ReportPaymentFail",