From 3c5c37b6937321f9b7218b6611e9dc5c329a201d Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 8 Mar 2023 00:58:14 +0800 Subject: [PATCH] routing: introduce `stateStep` to manage payment lifecycle This commit adds a new struct, `stateStep`, to decide the workflow inside `resumePayment`. It also refactors `collectResultAsync` introducing a new channel `resultCollected`. This channel is used to signal the payment lifecycle that an HTLC attempt result is ready to be processed. --- routing/payment_lifecycle.go | 293 ++++++++++++++++-------------- routing/payment_lifecycle_test.go | 120 ++++++++++++ routing/router_test.go | 107 +++++++---- 3 files changed, 343 insertions(+), 177 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 20497c62a..d3e4289e6 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -3,7 +3,6 @@ package routing import ( "errors" "fmt" - "sync" "time" "github.com/btcsuite/btcd/btcec/v2" @@ -18,9 +17,6 @@ import ( "github.com/lightningnetwork/lnd/routing/shards" ) -// errShardHandlerExiting is returned from the shardHandler when it exits. -var errShardHandlerExiting = errors.New("shard handler exiting") - // paymentLifecycle holds all information about the current state of a payment // needed to resume if from any point. type paymentLifecycle struct { @@ -32,18 +28,15 @@ type paymentLifecycle struct { timeoutChan <-chan time.Time currentHeight int32 - // shardErrors is a channel where errors collected by calling - // collectResultAsync will be delivered. These results are meant to be - // inspected by calling waitForShard or checkShards, and the channel - // doesn't need to be initiated if the caller is using the sync - // collectResult directly. - // TODO(yy): delete. - shardErrors chan error - // quit is closed to signal the sub goroutines of the payment lifecycle // to stop. quit chan struct{} - wg sync.WaitGroup + + // resultCollected is used to signal that the result of an attempt has + // been collected. A nil error means the attempt is either successful + // or failed with temporary error. Otherwise, we should exit the + // lifecycle loop as a terminal error has occurred. + resultCollected chan error } // newPaymentLifecycle initiates a new payment lifecycle and returns it. @@ -53,14 +46,14 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, currentHeight int32) *paymentLifecycle { p := &paymentLifecycle{ - router: r, - feeLimit: feeLimit, - identifier: identifier, - paySession: paySession, - shardTracker: shardTracker, - currentHeight: currentHeight, - shardErrors: make(chan error), - quit: make(chan struct{}), + router: r, + feeLimit: feeLimit, + identifier: identifier, + paySession: paySession, + shardTracker: shardTracker, + currentHeight: currentHeight, + quit: make(chan struct{}), + resultCollected: make(chan error, 1), } // If a timeout is specified, create a timeout channel. If no timeout is @@ -92,6 +85,74 @@ func (p *paymentLifecycle) calcFeeBudget( return budget } +// stateStep defines an action to be taken in our payment lifecycle. We either +// quit, continue, or exit the lifecycle, see details below. +type stateStep uint8 + +const ( + // stepSkip is used when we need to skip the current lifecycle and jump + // to the next one. + stepSkip stateStep = iota + + // stepProceed is used when we can proceed the current lifecycle. + stepProceed + + // stepExit is used when we need to quit the current lifecycle. + stepExit +) + +// decideNextStep is used to determine the next step in the payment lifecycle. +func (p *paymentLifecycle) decideNextStep( + payment dbMPPayment) (stateStep, error) { + + // Check whether we could make new HTLC attempts. + allow, err := payment.AllowMoreAttempts() + if err != nil { + return stepExit, err + } + + if !allow { + // Check whether we need to wait for results. + wait, err := payment.NeedWaitAttempts() + if err != nil { + return stepExit, err + } + + // If we are not allowed to make new HTLC attempts and there's + // no need to wait, the lifecycle is done and we can exit. + if !wait { + return stepExit, nil + } + + log.Tracef("Waiting for attempt results for payment %v", + p.identifier) + + // Otherwise we wait for one HTLC attempt then continue + // the lifecycle. + // + // NOTE: we don't check `p.quit` since `decideNextStep` is + // running in the same goroutine as `resumePayment`. + select { + case err := <-p.resultCollected: + // If an error is returned, exit with it. + if err != nil { + return stepExit, err + } + + log.Tracef("Received attempt result for payment %v", + p.identifier) + + case <-p.router.quit: + return stepExit, ErrRouterShuttingDown + } + + return stepSkip, nil + } + + // Otherwise we need to make more attempts. + return stepProceed, nil +} + // resumePayment resumes the paymentLifecycle from the current state. func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { // When the payment lifecycle loop exits, we make sure to signal any @@ -127,20 +188,12 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { // critical error during path finding. lifecycle: for { - // Start by quickly checking if there are any outcomes already - // available to handle before we reevaluate our state. - if err := p.checkShards(); err != nil { - return exitWithErr(err) - } - // We update the payment state on every iteration. Since the // payment state is affected by multiple goroutines (ie, // collectResultAsync), it is NOT guaranteed that we always // have the latest state here. This is fine as long as the // state is consistent as a whole. - - // Fetch the latest payment from db. - payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + payment, err = p.router.cfg.Control.FetchPayment(p.identifier) if err != nil { return exitWithErr(err) } @@ -153,53 +206,14 @@ lifecycle: p.identifier, payment.Terminated(), ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees) - // TODO(yy): sanity check all the states to make sure - // everything is expected. - // We have a terminal condition and no active shards, we are - // ready to exit. - if payment.Terminated() { - // Find the first successful shard and return - // the preimage and route. - for _, a := range payment.GetHTLCs() { - if a.Settle == nil { - continue - } - - err := p.router.cfg.Control.DeleteFailedAttempts( - p.identifier, - ) - if err != nil { - log.Errorf("Error deleting failed "+ - "payment attempts for "+ - "payment %v: %v", p.identifier, - err) - } - - return a.Settle.Preimage, &a.Route, nil - } - - // Payment failed. - return exitWithErr(*payment.GetFailureReason()) - } - - // If we either reached a terminal error condition (but had - // active shards still) or there is no remaining value to send, - // we'll wait for a shard outcome. - wait, err := payment.NeedWaitAttempts() - if err != nil { - return exitWithErr(err) - } - - if wait { - // We still have outstanding shards, so wait for a new - // outcome to be available before re-evaluating our - // state. - if err := p.waitForShard(); err != nil { - return exitWithErr(err) - } - continue lifecycle - } - + // We now proceed our lifecycle with the following tasks in + // order, + // 1. check timeout. + // 2. request route. + // 3. create HTLC attempt. + // 4. send HTLC attempt. + // 5. collect HTLC attempt result. + // // Before we attempt any new shard, we'll check to see if // either we've gone past the payment attempt timeout, or the // router is exiting. In either case, we'll stop this payment @@ -209,6 +223,30 @@ lifecycle: return exitWithErr(err) } + // Now decide the next step of the current lifecycle. + step, err := p.decideNextStep(payment) + if err != nil { + return exitWithErr(err) + } + + switch step { + // Exit the for loop and return below. + case stepExit: + break lifecycle + + // Continue the for loop and skip the rest. + case stepSkip: + continue lifecycle + + // Continue the for loop and proceed the rest. + case stepProceed: + + // Unknown step received, exit with an error. + default: + err = fmt.Errorf("unknown step: %v", step) + return exitWithErr(err) + } + // Now request a route to be used to create our HTLC attempt. rt, err := p.requestRoute(ps) if err != nil { @@ -241,6 +279,27 @@ lifecycle: p.collectResultAsync(attempt) } } + + // Once we are out the lifecycle loop, it means we've reached a + // terminal condition. We either return the settled preimage or the + // payment's failure reason. + // + // Optionally delete the failed attempts from the database. + err = p.router.cfg.Control.DeleteFailedAttempts(p.identifier) + if err != nil { + log.Errorf("Error deleting failed htlc attempts for payment "+ + "%v: %v", p.identifier, err) + } + + // Find the first successful shard and return the preimage and route. + for _, a := range payment.GetHTLCs() { + if a.Settle != nil { + return a.Settle.Preimage, &a.Route, nil + } + } + + // Otherwise return the payment failure reason. + return [32]byte{}, nil, *payment.GetFailureReason() } // checkTimeout checks whether the payment has reached its timeout. @@ -332,46 +391,9 @@ func (p *paymentLifecycle) requestRoute( return nil, nil } -// stop signals any active shard goroutine to exit and waits for them to exit. +// stop signals any active shard goroutine to exit. func (p *paymentLifecycle) stop() { close(p.quit) - p.wg.Wait() -} - -// waitForShard blocks until any of the outstanding shards return. -func (p *paymentLifecycle) waitForShard() error { - select { - case err := <-p.shardErrors: - return err - - case <-p.quit: - return errShardHandlerExiting - - case <-p.router.quit: - return ErrRouterShuttingDown - } -} - -// checkShards is a non-blocking method that check if any shards has finished -// their execution. -func (p *paymentLifecycle) checkShards() error { - for { - select { - case err := <-p.shardErrors: - if err != nil { - return err - } - - case <-p.quit: - return errShardHandlerExiting - - case <-p.router.quit: - return ErrRouterShuttingDown - - default: - return nil - } - } } // attemptResult holds the HTLC attempt and a possible error returned from @@ -388,38 +410,33 @@ type attemptResult struct { } // collectResultAsync launches a goroutine that will wait for the result of the -// given HTLC attempt to be available then handle its result. It will fail the -// payment with the control tower if a terminal error is encountered. +// given HTLC attempt to be available then handle its result. Once received, it +// will send a nil error to channel `resultCollected` to indicate there's an +// result. func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { - // errToSend is the error to be sent to sh.shardErrors. - var errToSend error - - // handleResultErr is a function closure must be called using defer. It - // finishes collecting result by updating the payment state and send - // the error (or nil) to sh.shardErrors. - handleResultErr := func() { - // Send the error or quit. - select { - case p.shardErrors <- errToSend: - case <-p.router.quit: - case <-p.quit: - } - - p.wg.Done() - } - - p.wg.Add(1) go func() { - defer handleResultErr() - // Block until the result is available. _, err := p.collectResult(attempt) if err != nil { log.Errorf("Error collecting result for attempt %v "+ "in payment %v: %v", attempt.AttemptID, p.identifier, err) + } - errToSend = err + log.Debugf("Result collected for attempt %v in payment %v", + attempt.AttemptID, p.identifier) + + // Once the result is collected, we signal it by writing the + // error to `resultCollected`. + select { + // Send the signal or quit. + case p.resultCollected <- err: + + case <-p.quit: + log.Debugf("Lifecycle exiting while collecting "+ + "result for payment %v", p.identifier) + + case <-p.router.quit: return } }() diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index eca18305f..8fee4cfb5 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -21,6 +21,10 @@ import ( const stepTimeout = 5 * time.Second +var ( + dummyErr = errors.New("dummy") +) + // createTestRoute builds a route a->b->c paying the given amt to c. func createTestRoute(amt lnwire.MilliSatoshi, aliasMap map[string]route.Vertex) (*route.Route, error) { @@ -1112,3 +1116,119 @@ func TestRequestRouteFailPaymentError(t *testing.T) { // Assert that `FailPayment` is called as expected. ct.AssertExpectations(t) } + +// TestDecideNextStep checks the method `decideNextStep` behaves as expected. +func TestDecideNextStep(t *testing.T) { + t.Parallel() + + // mockReturn is used to hold the return values from AllowMoreAttempts + // or NeedWaitAttempts. + type mockReturn struct { + allowOrWait bool + err error + } + + testCases := []struct { + name string + allowMoreAttempts *mockReturn + needWaitAttempts *mockReturn + + // When the attemptResultChan has returned. + closeResultChan bool + + // Whether the router has quit. + routerQuit bool + + expectedStep stateStep + expectedErr error + }{ + { + name: "allow more attempts", + allowMoreAttempts: &mockReturn{true, nil}, + expectedStep: stepProceed, + expectedErr: nil, + }, + { + name: "error on allow more attempts", + allowMoreAttempts: &mockReturn{false, dummyErr}, + expectedStep: stepExit, + expectedErr: dummyErr, + }, + { + name: "no wait and exit", + allowMoreAttempts: &mockReturn{false, nil}, + needWaitAttempts: &mockReturn{false, nil}, + expectedStep: stepExit, + expectedErr: nil, + }, + { + name: "wait returns an error", + allowMoreAttempts: &mockReturn{false, nil}, + needWaitAttempts: &mockReturn{false, dummyErr}, + expectedStep: stepExit, + expectedErr: dummyErr, + }, + + { + name: "wait and exit on result chan", + allowMoreAttempts: &mockReturn{false, nil}, + needWaitAttempts: &mockReturn{true, nil}, + closeResultChan: true, + expectedStep: stepSkip, + expectedErr: nil, + }, + { + name: "wait and exit on router quit", + allowMoreAttempts: &mockReturn{false, nil}, + needWaitAttempts: &mockReturn{true, nil}, + routerQuit: true, + expectedStep: stepExit, + expectedErr: ErrRouterShuttingDown, + }, + } + + for _, tc := range testCases { + tc := tc + + // Create a test paymentLifecycle. + p := createTestPaymentLifecycle() + + // Make a mock payment. + payment := &mockMPPayment{} + + // Mock the method AllowMoreAttempts. + payment.On("AllowMoreAttempts").Return( + tc.allowMoreAttempts.allowOrWait, + tc.allowMoreAttempts.err, + ).Once() + + // Mock the method NeedWaitAttempts. + if tc.needWaitAttempts != nil { + payment.On("NeedWaitAttempts").Return( + tc.needWaitAttempts.allowOrWait, + tc.needWaitAttempts.err, + ).Once() + } + + // Send a nil error to the attemptResultChan if requested. + if tc.closeResultChan { + p.resultCollected = make(chan error, 1) + p.resultCollected <- nil + } + + // Quit the router if requested. + if tc.routerQuit { + close(p.router.quit) + } + + // Once the setup is finished, run the test cases. + t.Run(tc.name, func(t *testing.T) { + step, err := p.decideNextStep(payment) + require.Equal(t, tc.expectedStep, step) + require.ErrorIs(t, tc.expectedErr, err) + }) + + // Check the payment's methods are called as expected. + payment.AssertExpectations(t) + } +} diff --git a/routing/router_test.go b/routing/router_test.go index 6adefe12a..6073ee2be 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3470,34 +3470,44 @@ func TestSendMPPaymentSucceed(t *testing.T) { session := &mockPaymentSession{} sessionSource.On("NewPaymentSession", req).Return(session, nil) controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - // Mock the InFlightHTLCs. var ( htlcs []channeldb.HTLCAttempt numAttempts atomic.Uint32 + settled atomic.Bool + numParts = uint32(4) ) // Make a mock MPPayment. payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}) + On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). + On("Terminated").Return(false) + controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", - identifier, - ).Return(payment, nil).Run(func(args mock.Arguments) { - // When number of attempts made is less than 4, we will mock - // the payment's methods to allow the lifecycle to continue. - if numAttempts.Load() < 4 { - payment.On("Terminated").Return(false).Times(2). - On("NeedWaitAttempts").Return(false, nil).Once() - return - } + controlTower.On("FetchPayment", identifier).Return(payment, nil). + Run(func(args mock.Arguments) { + // When number of attempts made is less than 4, we will + // mock the payment's methods to allow the lifecycle to + // continue. + if numAttempts.Load() < numParts { + payment.On("AllowMoreAttempts").Return(true, nil).Once() + return + } - // Otherwise, terminate the lifecycle. - payment.On("Terminated").Return(true). - On("NeedWaitAttempts").Return(true, nil) - }) + if !settled.Load() { + fmt.Println("wait") + payment.On("AllowMoreAttempts").Return(false, nil).Once() + payment.On("NeedWaitAttempts").Return(true, nil).Once() + // We add another attempt to the counter to + // unblock next time. + return + } + + payment.On("AllowMoreAttempts").Return(false, nil). + On("NeedWaitAttempts").Return(false, nil) + }) // Mock SettleAttempt. preimage := lntypes.Preimage{1, 2, 3} @@ -3511,6 +3521,10 @@ func TestSendMPPaymentSucceed(t *testing.T) { payment.On("GetHTLCs").Return( []channeldb.HTLCAttempt{settledAttempt}, ) + // We want to at least wait for one settlement. + if numAttempts.Load() > 1 { + settled.Store(true) + } }) // Create a route that can send 1/4 of the total amount. This value @@ -3527,7 +3541,6 @@ func TestSendMPPaymentSucceed(t *testing.T) { controlTower.On("RegisterAttempt", identifier, mock.Anything, ).Return(nil).Run(func(args mock.Arguments) { - // Increase the counter whenever an attempt is made. numAttempts.Add(1) }) @@ -3663,29 +3676,40 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { htlcs []channeldb.HTLCAttempt numAttempts atomic.Uint32 failAttemptCount atomic.Uint32 + settled atomic.Bool ) // Make a mock MPPayment. payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}) + On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). + On("Terminated").Return(false) + controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", - identifier, - ).Return(payment, nil).Run(func(args mock.Arguments) { - // When number of attempts made is less than 6, we will mock - // the payment's methods to allow the lifecycle to continue. - if numAttempts.Load() < 6 { - payment.On("Terminated").Return(false).Times(2). - On("NeedWaitAttempts").Return(false, nil).Once() - return - } + controlTower.On("FetchPayment", identifier).Return(payment, nil). + Run(func(args mock.Arguments) { + // When number of attempts made is less than 4, we will + // mock the payment's methods to allow the lifecycle to + // continue. + attempts := numAttempts.Load() + if attempts < 6 { + payment.On("AllowMoreAttempts").Return(true, nil).Once() + return + } - // Otherwise, terminate the lifecycle. - payment.On("Terminated").Return(true). - On("NeedWaitAttempts").Return(true, nil) - }) + if !settled.Load() { + payment.On("AllowMoreAttempts").Return(false, nil).Once() + payment.On("NeedWaitAttempts").Return(true, nil).Once() + // We add another attempt to the counter to + // unblock next time. + numAttempts.Add(1) + return + } + + payment.On("AllowMoreAttempts").Return(false, nil). + On("NeedWaitAttempts").Return(false, nil) + }) // Create a route that can send 1/4 of the total amount. This value // will be returned by calling RequestRoute. @@ -3768,6 +3792,10 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { payment.On("GetHTLCs").Return( []channeldb.HTLCAttempt{settledAttempt}, ) + + if numAttempts.Load() > 1 { + settled.Store(true) + } }) controlTower.On("DeleteFailedAttempts", identifier).Return(nil) @@ -3885,8 +3913,8 @@ func TestSendMPPaymentFailed(t *testing.T) { // Make a mock MPPayment. payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs).Once() - payment.On("GetStatus").Return(channeldb.StatusInFlight).Once() payment.On("GetState").Return(&channeldb.MPPaymentState{}) + payment.On("Terminated").Return(false) controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock the sequential FetchPayment to return the payment. @@ -3895,21 +3923,20 @@ func TestSendMPPaymentFailed(t *testing.T) { // We want to at least send out all parts in order to // wait for them later. if numAttempts.Load() < numParts { - payment.On("Terminated").Return(false).Times(2). - On("NeedWaitAttempts").Return(false, nil).Once() + payment.On("AllowMoreAttempts").Return(true, nil).Once() return } // Wait if the payment wasn't failed yet. if !failed.Load() { - payment.On("Terminated").Return(false).Times(2). + payment.On("AllowMoreAttempts").Return(false, nil).Once(). On("NeedWaitAttempts").Return(true, nil).Once() - return } - payment.On("Terminated").Return(true). - On("GetHTLCs").Return(htlcs).Once() + payment.On("AllowMoreAttempts").Return(false, nil). + On("GetHTLCs").Return(htlcs).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() }) // Create a route that can send 1/4 of the total amount. This value @@ -3990,6 +4017,8 @@ func TestSendMPPaymentFailed(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, ).Return(nil) + controlTower.On("DeleteFailedAttempts", identifier).Return(nil) + // Call the actual method SendPayment on router. This is place inside a // goroutine so we can set a timeout for the whole test, in case // anything goes wrong and the test never finishes.