From 8458966f02259ae642024b1ba1e2156638e4cd11 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 23 Jun 2022 01:26:38 +0800 Subject: [PATCH 01/27] routing: remove the abstraction `shardHandler` This commit removes the unclear abstraction `shardHandler` that's used in our payment lifecycle. As we'll see in the following commits, `shardHandler` is an unnecessary layer and everything can be cleanly managed inside `paymentLifecycle`. --- routing/payment_lifecycle.go | 85 +++++++++++++++--------------------- routing/router.go | 20 +++++---- 2 files changed, 46 insertions(+), 59 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index b2c3834f2..a07542f9a 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -30,6 +30,19 @@ type paymentLifecycle struct { shardTracker shards.ShardTracker timeoutChan <-chan time.Time currentHeight int32 + + // shardErrors is a channel where errors collected by calling + // collectResultAsync will be delivered. These results are meant to be + // inspected by calling waitForShard or checkShards, and the channel + // doesn't need to be initiated if the caller is using the sync + // collectResult directly. + // TODO(yy): delete. + shardErrors chan error + + // quit is closed to signal the sub goroutines of the payment lifecycle + // to stop. + quit chan struct{} + wg sync.WaitGroup } // newPaymentLifecycle initiates a new payment lifecycle and returns it. @@ -45,6 +58,8 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, paySession: paySession, shardTracker: shardTracker, currentHeight: currentHeight, + shardErrors: make(chan error), + quit: make(chan struct{}), } // If a timeout is specified, create a timeout channel. If no timeout is @@ -78,19 +93,10 @@ func (p *paymentLifecycle) calcFeeBudget( // resumePayment resumes the paymentLifecycle from the current state. func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { - shardHandler := &shardHandler{ - router: p.router, - identifier: p.identifier, - shardTracker: p.shardTracker, - shardErrors: make(chan error), - quit: make(chan struct{}), - paySession: p.paySession, - } - // When the payment lifecycle loop exits, we make sure to signal any - // sub goroutine of the shardHandler to exit, then wait for them to + // sub goroutine of the HTLC attempt to exit, then wait for them to // return. - defer shardHandler.stop() + defer p.stop() // If we had any existing attempts outstanding, we'll start by spinning // up goroutines that'll collect their results and deliver them to the @@ -106,7 +112,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { log.Infof("Resuming payment shard %v for payment %v", a.AttemptID, p.identifier) - shardHandler.collectResultAsync(&a) + p.collectResultAsync(&a) } // exitWithErr is a helper closure that logs and returns an error. @@ -122,7 +128,7 @@ lifecycle: for { // Start by quickly checking if there are any outcomes already // available to handle before we reevaluate our state. - if err := shardHandler.checkShards(); err != nil { + if err := p.checkShards(); err != nil { return exitWithErr(err) } @@ -187,7 +193,7 @@ lifecycle: // We still have outstanding shards, so wait for a new // outcome to be available before re-evaluating our // state. - if err := shardHandler.waitForShard(); err != nil { + if err := p.waitForShard(); err != nil { return exitWithErr(err) } continue lifecycle @@ -259,7 +265,7 @@ lifecycle: // We still have active shards, we'll wait for an // outcome to be available before retrying. - if err := shardHandler.waitForShard(); err != nil { + if err := p.waitForShard(); err != nil { return exitWithErr(err) } continue lifecycle @@ -272,7 +278,7 @@ lifecycle: lastShard := rt.ReceiverAmt() == ps.RemainingAmt // We found a route to try, launch a new shard. - attempt, outcome, err := shardHandler.launchShard(rt, lastShard) + attempt, outcome, err := p.launchShard(rt, lastShard) switch { // We may get a terminal error if we've processed a shard with // a terminal state (settled or permanent failure), while we @@ -298,7 +304,7 @@ lifecycle: // We must inspect the error to know whether it was // critical or not, to decide whether we should // continue trying. - err := shardHandler.handleSwitchErr( + err := p.handleSwitchErr( attempt, outcome.err, ) if err != nil { @@ -312,39 +318,18 @@ lifecycle: // Now that the shard was successfully sent, launch a go // routine that will handle its result when its back. - shardHandler.collectResultAsync(attempt) + p.collectResultAsync(attempt) } } -// shardHandler holds what is necessary to send and collect the result of -// shards. -type shardHandler struct { - identifier lntypes.Hash - router *ChannelRouter - shardTracker shards.ShardTracker - paySession PaymentSession - - // shardErrors is a channel where errors collected by calling - // collectResultAsync will be delivered. These results are meant to be - // inspected by calling waitForShard or checkShards, and the channel - // doesn't need to be initiated if the caller is using the sync - // collectResult directly. - shardErrors chan error - - // quit is closed to signal the sub goroutines of the payment lifecycle - // to stop. - quit chan struct{} - wg sync.WaitGroup -} - // stop signals any active shard goroutine to exit and waits for them to exit. -func (p *shardHandler) stop() { +func (p *paymentLifecycle) stop() { close(p.quit) p.wg.Wait() } // waitForShard blocks until any of the outstanding shards return. -func (p *shardHandler) waitForShard() error { +func (p *paymentLifecycle) waitForShard() error { select { case err := <-p.shardErrors: return err @@ -359,7 +344,7 @@ func (p *shardHandler) waitForShard() error { // checkShards is a non-blocking method that check if any shards has finished // their execution. -func (p *shardHandler) checkShards() error { +func (p *paymentLifecycle) checkShards() error { for { select { case err := <-p.shardErrors: @@ -400,7 +385,7 @@ type launchOutcome struct { // whether the attempt was successfully sent. If the launchOutcome wraps a // non-nil error, it means that the attempt was not sent onto the network, so // no result will be available in the future for it. -func (p *shardHandler) launchShard(rt *route.Route, +func (p *paymentLifecycle) launchShard(rt *route.Route, lastShard bool) (*channeldb.HTLCAttempt, *launchOutcome, error) { // Using the route received from the payment session, create a new @@ -455,7 +440,7 @@ type shardResult struct { // collectResultAsync launches a goroutine that will wait for the result of the // given HTLC attempt to be available then handle its result. It will fail the // payment with the control tower if a terminal error is encountered. -func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttempt) { +func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { // errToSend is the error to be sent to sh.shardErrors. var errToSend error @@ -510,7 +495,7 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttempt) { // collectResult waits for the result for the given attempt to be available // from the Switch, then records the attempt outcome with the control tower. A // shardResult is returned, indicating the final outcome of this HTLC attempt. -func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttempt) ( +func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( *shardResult, error) { // We'll retrieve the hash specific to this shard from the @@ -632,7 +617,7 @@ func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttempt) ( } // createNewPaymentAttempt creates a new payment attempt from the given route. -func (p *shardHandler) createNewPaymentAttempt(rt *route.Route, +func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, lastShard bool) (*channeldb.HTLCAttempt, error) { // Generate a new key to be used for this attempt. @@ -683,7 +668,7 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route, // sendAttempt attempts to send the current attempt to the switch to complete // the payment. If this attempt fails, then we'll continue on to the next // available route. -func (p *shardHandler) sendAttempt( +func (p *paymentLifecycle) sendAttempt( attempt *channeldb.HTLCAttempt) error { log.Tracef("Attempting to send payment %v (pid=%v), "+ @@ -744,7 +729,7 @@ func (p *shardHandler) sendAttempt( // the error type, the error is either the final outcome of the payment or we // need to continue with an alternative route. A final outcome is indicated by // a non-nil reason value. -func (p *shardHandler) handleSwitchErr(attempt *channeldb.HTLCAttempt, +func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, sendErr error) error { internalErrorReason := channeldb.FailureReasonError @@ -841,7 +826,7 @@ func (p *shardHandler) handleSwitchErr(attempt *channeldb.HTLCAttempt, // handleFailureMessage tries to apply a channel update present in the failure // message if any. -func (p *shardHandler) handleFailureMessage(rt *route.Route, +func (p *paymentLifecycle) handleFailureMessage(rt *route.Route, errorSourceIdx int, failure lnwire.FailureMessage) error { if failure == nil { @@ -913,7 +898,7 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route, } // failAttempt calls control tower to fail the current payment attempt. -func (p *shardHandler) failAttempt(attemptID uint64, +func (p *paymentLifecycle) failAttempt(attemptID uint64, sendError error) (*channeldb.HTLCAttempt, error) { log.Warnf("Attempt %v for payment %v failed: %v", attemptID, diff --git a/routing/router.go b/routing/router.go index fa255544c..2e679b4bb 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2496,15 +2496,17 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // shard we'll now launch. shardTracker := shards.NewSimpleShardTracker(htlcHash, nil) - // Launch a shard along the given route. - sh := &shardHandler{ - router: r, - identifier: paymentIdentifier, - shardTracker: shardTracker, - } + // Create a payment lifecycle using the given route with, + // - zero fee limit as we are not requesting routes. + // - nil payment session (since we already have a route). + // - no payment timeout. + // - no current block height. + p := newPaymentLifecycle( + r, 0, paymentIdentifier, nil, shardTracker, 0, 0, + ) var shardError error - attempt, outcome, err := sh.launchShard(rt, false) + attempt, outcome, err := p.launchShard(rt, false) // With SendToRoute, it can happen that the route exceeds protocol // constraints. Mark the payment as failed with an internal error. @@ -2536,7 +2538,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // Shard successfully launched, wait for the result to be available. default: - result, err := sh.collectResult(attempt) + result, err := p.collectResult(attempt) if err != nil { return nil, err } @@ -2556,7 +2558,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // the error to check if it maps into a terminal error code, if not use // a generic NO_ROUTE error. var failureReason *channeldb.FailureReason - err = sh.handleSwitchErr(attempt, shardError) + err = p.handleSwitchErr(attempt, shardError) switch { // If a non-terminal error is returned and `skipTempErr` is false, then From 4bb8db46dfb37988ef2eed07bb819e5bac14580a Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 7 Mar 2023 18:05:36 +0800 Subject: [PATCH 02/27] routing: fail payment before attempt inside `handleSwitchErr` `handleSwitchErr` is now responsible for failing the given HTLC attempt after deciding to fail the payment or not. This is crucial as previously, we might enter into a state where the payment's HTLC has already been marked as failed, and while we are marking the payment as failed, another HTLC attempt can be made at the same time, leading to potential stuck payments. --- routing/payment_lifecycle.go | 78 ++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index a07542f9a..be223b483 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -721,6 +721,30 @@ func (p *paymentLifecycle) sendAttempt( return nil } +// failAttemptAndPayment fails both the payment and its attempt via the +// router's control tower, which marks the payment as failed in db. +func (p *paymentLifecycle) failPaymentAndAttempt( + attemptID uint64, reason *channeldb.FailureReason, + sendErr error) (*channeldb.HTLCAttempt, error) { + + log.Errorf("Payment %v failed: final_outcome=%v, raw_err=%v", + p.identifier, *reason, sendErr) + + // Fail the payment via control tower. + // + // NOTE: we must fail the payment first before failing the attempt. + // Otherwise, once the attempt is marked as failed, another goroutine + // might make another attempt while we are failing the payment. + err := p.router.cfg.Control.FailPayment(p.identifier, *reason) + if err != nil { + log.Errorf("Unable to fail payment: %v", err) + return nil, err + } + + // Fail the attempt. + return p.failAttempt(attemptID, sendErr) +} + // handleSwitchErr inspects the given error from the Switch and determines // whether we should make another payment attempt, or if it should be // considered a terminal error. Terminal errors will be recorded with the @@ -733,36 +757,18 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, sendErr error) error { internalErrorReason := channeldb.FailureReasonError + attemptID := attempt.AttemptID - // failPayment is a helper closure that fails the payment via the - // router's control tower, which marks the payment as failed in db. - failPayment := func(reason *channeldb.FailureReason, - sendErr error) error { - - log.Infof("Payment %v failed: final_outcome=%v, raw_err=%v", - p.identifier, *reason, sendErr) - - // Fail the payment via control tower. - if err := p.router.cfg.Control.FailPayment( - p.identifier, *reason, - ); err != nil { - log.Errorf("unable to report failure to control "+ - "tower: %v", err) - - return &internalErrorReason - } - - return reason - } - - // reportFail is a helper closure that reports the failure to the + // reportAndFail is a helper closure that reports the failure to the // mission control, which helps us to decide whether we want to retry // the payment or not. If a non nil reason is returned from mission // control, it will further fail the payment via control tower. - reportFail := func(srcIdx *int, msg lnwire.FailureMessage) error { + reportAndFail := func(srcIdx *int, + msg lnwire.FailureMessage) (*channeldb.HTLCAttempt, error) { + // Report outcome to mission control. reason, err := p.router.cfg.MissionControl.ReportPaymentFail( - attempt.AttemptID, &attempt.Route, srcIdx, msg, + attemptID, &attempt.Route, srcIdx, msg, ) if err != nil { log.Errorf("Error reporting payment result to mc: %v", @@ -771,18 +777,21 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, reason = &internalErrorReason } - // Exit early if there's no reason. + // Fail the attempt only if there's no reason. if reason == nil { - return nil + // Fail the attempt. + return p.failAttempt(attemptID, sendErr) } - return failPayment(reason, sendErr) + // Otherwise fail both the payment and the attempt. + return p.failPaymentAndAttempt(attemptID, reason, sendErr) } if sendErr == htlcswitch.ErrUnreadableFailureMessage { log.Tracef("Unreadable failure when sending htlc") - return reportFail(nil, nil) + _, err := reportAndFail(nil, nil) + return err } // If the error is a ClearTextError, we have received a valid wire @@ -792,7 +801,10 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, // occurred. rtErr, ok := sendErr.(htlcswitch.ClearTextError) if !ok { - return failPayment(&internalErrorReason, sendErr) + _, err := p.failPaymentAndAttempt( + attemptID, &internalErrorReason, sendErr, + ) + return err } // failureSourceIdx is the index of the node that the failure occurred @@ -815,13 +827,17 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, &attempt.Route, failureSourceIdx, failureMessage, ) if err != nil { - return failPayment(&internalErrorReason, sendErr) + _, err := p.failPaymentAndAttempt( + attemptID, &internalErrorReason, sendErr, + ) + return err } log.Tracef("Node=%v reported failure when sending htlc", failureSourceIdx) - return reportFail(&failureSourceIdx, failureMessage) + _, err = reportAndFail(&failureSourceIdx, failureMessage) + return err } // handleFailureMessage tries to apply a channel update present in the failure From 071d05e0e3cfd9b6b8e1364a611501419a9a37d2 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 7 Mar 2023 18:18:52 +0800 Subject: [PATCH 03/27] routing: unify `shardResult` and `launchOutcome` to be `attemptResult` This commit removes the `launchOutcome` and `shardResult` and uses `attemptResult` instead. This struct is also used in `failAttempt` so we can future distinguish critical vs non-critical errors when handling HTLC attempts. --- routing/payment_lifecycle.go | 72 ++++++++++++++---------------------- 1 file changed, 28 insertions(+), 44 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index be223b483..6251d9c83 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -364,12 +364,12 @@ func (p *paymentLifecycle) checkShards() error { } } -// launchOutcome is a type returned from launchShard that indicates whether the -// shard was successfully send onto the network. -type launchOutcome struct { +// attemptResult holds the HTLC attempt and a possible error returned from +// sending it. +type attemptResult struct { // err is non-nil if a non-critical error was encountered when trying - // to send the shard, and we successfully updated the control tower to - // reflect this error. This can be errors like not enough local + // to send the attempt, and we successfully updated the control tower + // to reflect this error. This can be errors like not enough local // balance for the given route etc. err error @@ -386,7 +386,7 @@ type launchOutcome struct { // non-nil error, it means that the attempt was not sent onto the network, so // no result will be available in the future for it. func (p *paymentLifecycle) launchShard(rt *route.Route, - lastShard bool) (*channeldb.HTLCAttempt, *launchOutcome, error) { + lastShard bool) (*channeldb.HTLCAttempt, *attemptResult, error) { // Using the route received from the payment session, create a new // shard to send. @@ -419,22 +419,13 @@ func (p *paymentLifecycle) launchShard(rt *route.Route, } // Return a launchOutcome indicating the shard failed. - return attempt, &launchOutcome{ - attempt: htlcAttempt, + return attempt, &attemptResult{ + attempt: htlcAttempt.attempt, err: sendErr, }, nil } - return attempt, &launchOutcome{}, nil -} - -// shardResult holds the resulting outcome of a shard sent. -type shardResult struct { - // attempt is the attempt structure as recorded in the database. - attempt *channeldb.HTLCAttempt - - // err indicates that the shard failed. - err error + return attempt, &attemptResult{}, nil } // collectResultAsync launches a goroutine that will wait for the result of the @@ -493,10 +484,11 @@ func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { } // collectResult waits for the result for the given attempt to be available -// from the Switch, then records the attempt outcome with the control tower. A -// shardResult is returned, indicating the final outcome of this HTLC attempt. +// from the Switch, then records the attempt outcome with the control tower. +// An attemptResult is returned, indicating the final outcome of this HTLC +// attempt. func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( - *shardResult, error) { + *attemptResult, error) { // We'll retrieve the hash specific to this shard from the // shardTracker, since it will be needed to regenerate the circuit @@ -536,15 +528,7 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( "the Switch, retrying.", attempt.AttemptID, p.identifier) - attempt, cErr := p.failAttempt(attempt.AttemptID, err) - if cErr != nil { - return nil, cErr - } - - return &shardResult{ - attempt: attempt, - err: err, - }, nil + return p.failAttempt(attempt.AttemptID, err) // A critical, unexpected error was encountered. case err != nil: @@ -574,15 +558,7 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( // In case of a payment failure, fail the attempt with the control // tower and return. if result.Error != nil { - attempt, err := p.failAttempt(attempt.AttemptID, result.Error) - if err != nil { - return nil, err - } - - return &shardResult{ - attempt: attempt, - err: result.Error, - }, nil + return p.failAttempt(attempt.AttemptID, result.Error) } // We successfully got a payment result back from the switch. @@ -611,7 +587,7 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( return nil, err } - return &shardResult{ + return &attemptResult{ attempt: htlcAttempt, }, nil } @@ -725,7 +701,7 @@ func (p *paymentLifecycle) sendAttempt( // router's control tower, which marks the payment as failed in db. func (p *paymentLifecycle) failPaymentAndAttempt( attemptID uint64, reason *channeldb.FailureReason, - sendErr error) (*channeldb.HTLCAttempt, error) { + sendErr error) (*attemptResult, error) { log.Errorf("Payment %v failed: final_outcome=%v, raw_err=%v", p.identifier, *reason, sendErr) @@ -764,7 +740,7 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, // the payment or not. If a non nil reason is returned from mission // control, it will further fail the payment via control tower. reportAndFail := func(srcIdx *int, - msg lnwire.FailureMessage) (*channeldb.HTLCAttempt, error) { + msg lnwire.FailureMessage) (*attemptResult, error) { // Report outcome to mission control. reason, err := p.router.cfg.MissionControl.ReportPaymentFail( @@ -915,7 +891,7 @@ func (p *paymentLifecycle) handleFailureMessage(rt *route.Route, // failAttempt calls control tower to fail the current payment attempt. func (p *paymentLifecycle) failAttempt(attemptID uint64, - sendError error) (*channeldb.HTLCAttempt, error) { + sendError error) (*attemptResult, error) { log.Warnf("Attempt %v for payment %v failed: %v", attemptID, p.identifier, sendError) @@ -932,9 +908,17 @@ func (p *paymentLifecycle) failAttempt(attemptID uint64, return nil, err } - return p.router.cfg.Control.FailAttempt( + attempt, err := p.router.cfg.Control.FailAttempt( p.identifier, attemptID, failInfo, ) + if err != nil { + return nil, err + } + + return &attemptResult{ + attempt: attempt, + err: sendError, + }, nil } // marshallError marshall an error as received from the switch to a structure From 568b977a1fc659f8f9964555db4af6f6cb925e93 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 15 Feb 2023 03:05:09 +0800 Subject: [PATCH 04/27] routing: add new method `registerAttempt` This commit adds a new method `registerAttempt` to take care of creating and saving an htlc attempt to disk. --- routing/payment_lifecycle.go | 54 ++++++++++++++++++++++-------------- routing/router.go | 2 +- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 6251d9c83..facb9168b 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -273,12 +273,8 @@ lifecycle: log.Tracef("Found route: %s", spew.Sdump(rt.Hops)) - // If this route will consume the last remaining amount to send - // to the receiver, this will be our last shard (for now). - lastShard := rt.ReceiverAmt() == ps.RemainingAmt - // We found a route to try, launch a new shard. - attempt, outcome, err := p.launchShard(rt, lastShard) + attempt, outcome, err := p.launchShard(rt, ps.RemainingAmt) switch { // We may get a terminal error if we've processed a shard with // a terminal state (settled or permanent failure), while we @@ -386,23 +382,10 @@ type attemptResult struct { // non-nil error, it means that the attempt was not sent onto the network, so // no result will be available in the future for it. func (p *paymentLifecycle) launchShard(rt *route.Route, - lastShard bool) (*channeldb.HTLCAttempt, *attemptResult, error) { + remainingAmt lnwire.MilliSatoshi) (*channeldb.HTLCAttempt, + *attemptResult, error) { - // Using the route received from the payment session, create a new - // shard to send. - attempt, err := p.createNewPaymentAttempt(rt, lastShard) - if err != nil { - return nil, nil, err - } - - // Before sending this HTLC to the switch, we checkpoint the fresh - // paymentID and route to the DB. This lets us know on startup the ID - // of the payment that we attempted to send, such that we can query the - // Switch for its whereabouts. The route is needed to handle the result - // when it eventually comes back. - err = p.router.cfg.Control.RegisterAttempt( - p.identifier, &attempt.HTLCAttemptInfo, - ) + attempt, err := p.registerAttempt(rt, remainingAmt) if err != nil { return nil, nil, err } @@ -592,6 +575,35 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( }, nil } +// registerAttempt is responsible for creating and saving an HTLC attempt in db +// by using the route info provided. The `remainingAmt` is used to decide +// whether this is the last attempt. +func (p *paymentLifecycle) registerAttempt(rt *route.Route, + remainingAmt lnwire.MilliSatoshi) (*channeldb.HTLCAttempt, error) { + + // If this route will consume the last remaining amount to send + // to the receiver, this will be our last shard (for now). + isLastAttempt := rt.ReceiverAmt() == remainingAmt + + // Using the route received from the payment session, create a new + // shard to send. + attempt, err := p.createNewPaymentAttempt(rt, isLastAttempt) + if err != nil { + return nil, err + } + + // Before sending this HTLC to the switch, we checkpoint the fresh + // paymentID and route to the DB. This lets us know on startup the ID + // of the payment that we attempted to send, such that we can query the + // Switch for its whereabouts. The route is needed to handle the result + // when it eventually comes back. + err = p.router.cfg.Control.RegisterAttempt( + p.identifier, &attempt.HTLCAttemptInfo, + ) + + return attempt, err +} + // createNewPaymentAttempt creates a new payment attempt from the given route. func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, lastShard bool) (*channeldb.HTLCAttempt, error) { diff --git a/routing/router.go b/routing/router.go index 2e679b4bb..d20932c28 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2506,7 +2506,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, ) var shardError error - attempt, outcome, err := p.launchShard(rt, false) + attempt, outcome, err := p.launchShard(rt, 0) // With SendToRoute, it can happen that the route exceeds protocol // constraints. Mark the payment as failed with an internal error. From 49bafc02078788c11d8db5c80e601b3141136914 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 7 Mar 2023 18:39:26 +0800 Subject: [PATCH 05/27] routing: handle switch error when `sendAttempt` fails This commit starts handling switch error inside `sendAttempt` when an error is returned from sending the HTLC. To make sure the updated `HTLCAttempt` is always returned to the callsite, `handleSwitchErr` now also returns a `attemptResult`. --- routing/payment_lifecycle.go | 41 +++++++++++++++++++++--------------- routing/router.go | 2 +- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index facb9168b..7f64f0eeb 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -300,7 +300,7 @@ lifecycle: // We must inspect the error to know whether it was // critical or not, to decide whether we should // continue trying. - err := p.handleSwitchErr( + _, err := p.handleSwitchErr( attempt, outcome.err, ) if err != nil { @@ -392,7 +392,7 @@ func (p *paymentLifecycle) launchShard(rt *route.Route, // Now that the attempt is created and checkpointed to the DB, we send // it. - sendErr := p.sendAttempt(attempt) + _, sendErr := p.sendAttempt(attempt) if sendErr != nil { // TODO(joostjager): Distinguish unexpected internal errors // from real send errors. @@ -460,7 +460,7 @@ func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { // Overwrite the param errToSend and return so that the // defer function will use the param to proceed. Notice // that the errToSend could be nil here. - errToSend = p.handleSwitchErr(attempt, result.err) + _, errToSend = p.handleSwitchErr(attempt, result.err) return } }() @@ -657,7 +657,7 @@ func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, // the payment. If this attempt fails, then we'll continue on to the next // available route. func (p *paymentLifecycle) sendAttempt( - attempt *channeldb.HTLCAttempt) error { + attempt *channeldb.HTLCAttempt) (*attemptResult, error) { log.Tracef("Attempting to send payment %v (pid=%v), "+ "using route: %v", p.identifier, attempt.AttemptID, @@ -687,8 +687,13 @@ func (p *paymentLifecycle) sendAttempt( &rt, attempt.Hash[:], attempt.SessionKey(), ) if err != nil { - return err + log.Errorf("Failed to create onion blob: attempt=%d in "+ + "payment=%v, err:%v", attempt.AttemptID, + p.identifier, err) + + return p.failAttempt(attempt.AttemptID, err) } + copy(htlcAdd.OnionBlob[:], onionBlob) // Send it to the Switch. When this method returns we assume @@ -700,13 +705,15 @@ func (p *paymentLifecycle) sendAttempt( log.Errorf("Failed sending attempt %d for payment %v to "+ "switch: %v", attempt.AttemptID, p.identifier, err) - return err + return p.handleSwitchErr(attempt, err) } log.Debugf("Payment %v (pid=%v) successfully sent to switch, route: %v", p.identifier, attempt.AttemptID, &attempt.Route) - return nil + return &attemptResult{ + attempt: attempt, + }, nil } // failAttemptAndPayment fails both the payment and its attempt via the @@ -742,7 +749,7 @@ func (p *paymentLifecycle) failPaymentAndAttempt( // need to continue with an alternative route. A final outcome is indicated by // a non-nil reason value. func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, - sendErr error) error { + sendErr error) (*attemptResult, error) { internalErrorReason := channeldb.FailureReasonError attemptID := attempt.AttemptID @@ -776,10 +783,13 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, } if sendErr == htlcswitch.ErrUnreadableFailureMessage { - log.Tracef("Unreadable failure when sending htlc") + log.Warn("Unreadable failure when sending htlc: id=%v, hash=%v", + attempt.AttemptID, attempt.Hash) - _, err := reportAndFail(nil, nil) - return err + // Since this error message cannot be decrypted, we will send a + // nil error message to our mission controller and fail the + // payment. + return reportAndFail(nil, nil) } // If the error is a ClearTextError, we have received a valid wire @@ -789,10 +799,9 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, // occurred. rtErr, ok := sendErr.(htlcswitch.ClearTextError) if !ok { - _, err := p.failPaymentAndAttempt( + return p.failPaymentAndAttempt( attemptID, &internalErrorReason, sendErr, ) - return err } // failureSourceIdx is the index of the node that the failure occurred @@ -815,17 +824,15 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, &attempt.Route, failureSourceIdx, failureMessage, ) if err != nil { - _, err := p.failPaymentAndAttempt( + return p.failPaymentAndAttempt( attemptID, &internalErrorReason, sendErr, ) - return err } log.Tracef("Node=%v reported failure when sending htlc", failureSourceIdx) - _, err = reportAndFail(&failureSourceIdx, failureMessage) - return err + return reportAndFail(&failureSourceIdx, failureMessage) } // handleFailureMessage tries to apply a channel update present in the failure diff --git a/routing/router.go b/routing/router.go index d20932c28..08700526e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2558,7 +2558,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // the error to check if it maps into a terminal error code, if not use // a generic NO_ROUTE error. var failureReason *channeldb.FailureReason - err = p.handleSwitchErr(attempt, shardError) + _, err = p.handleSwitchErr(attempt, shardError) switch { // If a non-terminal error is returned and `skipTempErr` is false, then From 7209c65ccfab46ede9296a1b141982cf06d21097 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 7 Mar 2023 18:43:14 +0800 Subject: [PATCH 06/27] routing: split `launchShard` into registerAttempt and sendAttempt This commit removes the method `launchShard` and splits its original functionality into two steps - first create the attempt, second send the attempt. This enables us to have finer control over "which error is returned from which system and how to handle it". --- channeldb/mp_payment.go | 1 + routing/payment_lifecycle.go | 95 ++++------------------- routing/payment_lifecycle_test.go | 9 +++ routing/router.go | 122 +++++++++++++++--------------- routing/router_test.go | 113 ++++++++++++++------------- 5 files changed, 146 insertions(+), 194 deletions(-) diff --git a/channeldb/mp_payment.go b/channeldb/mp_payment.go index 1765b94ad..30c595049 100644 --- a/channeldb/mp_payment.go +++ b/channeldb/mp_payment.go @@ -264,6 +264,7 @@ func (m *MPPayment) InFlightHTLCs() []HTLCAttempt { // GetAttempt returns the specified htlc attempt on the payment. func (m *MPPayment) GetAttempt(id uint64) (*HTLCAttempt, error) { + // TODO(yy): iteration can be slow, make it into a tree or use BS. for _, htlc := range m.HTLCs { htlc := htlc if htlc.AttemptID == id { diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 7f64f0eeb..6539d1a60 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -273,48 +273,23 @@ lifecycle: log.Tracef("Found route: %s", spew.Sdump(rt.Hops)) - // We found a route to try, launch a new shard. - attempt, outcome, err := p.launchShard(rt, ps.RemainingAmt) - switch { - // We may get a terminal error if we've processed a shard with - // a terminal state (settled or permanent failure), while we - // were pathfinding. We know we're in a terminal state here, - // so we can continue and wait for our last shards to return. - case err == channeldb.ErrPaymentTerminal: - log.Infof("Payment %v in terminal state, abandoning "+ - "shard", p.identifier) - - continue lifecycle - - case err != nil: + // We found a route to try, create a new HTLC attempt to try. + attempt, err := p.registerAttempt(rt, ps.RemainingAmt) + if err != nil { return exitWithErr(err) } - // If we encountered a non-critical error when launching the - // shard, handle it. - if outcome.err != nil { - log.Warnf("Failed to launch shard %v for "+ - "payment %v: %v", attempt.AttemptID, - p.identifier, outcome.err) - - // We must inspect the error to know whether it was - // critical or not, to decide whether we should - // continue trying. - _, err := p.handleSwitchErr( - attempt, outcome.err, - ) - if err != nil { - return exitWithErr(err) - } - - // Error was handled successfully, continue to make a - // new attempt. - continue lifecycle + // Once the attempt is created, send it to the htlcswitch. + result, err := p.sendAttempt(attempt) + if err != nil { + return exitWithErr(err) } // Now that the shard was successfully sent, launch a go // routine that will handle its result when its back. - p.collectResultAsync(attempt) + if result.err == nil { + p.collectResultAsync(attempt) + } } } @@ -373,44 +348,6 @@ type attemptResult struct { attempt *channeldb.HTLCAttempt } -// launchShard creates and sends an HTLC attempt along the given route, -// registering it with the control tower before sending it. The lastShard -// argument should be true if this shard will consume the remainder of the -// amount to send. It returns the HTLCAttemptInfo that was created for the -// shard, along with a launchOutcome. The launchOutcome is used to indicate -// whether the attempt was successfully sent. If the launchOutcome wraps a -// non-nil error, it means that the attempt was not sent onto the network, so -// no result will be available in the future for it. -func (p *paymentLifecycle) launchShard(rt *route.Route, - remainingAmt lnwire.MilliSatoshi) (*channeldb.HTLCAttempt, - *attemptResult, error) { - - attempt, err := p.registerAttempt(rt, remainingAmt) - if err != nil { - return nil, nil, err - } - - // Now that the attempt is created and checkpointed to the DB, we send - // it. - _, sendErr := p.sendAttempt(attempt) - if sendErr != nil { - // TODO(joostjager): Distinguish unexpected internal errors - // from real send errors. - htlcAttempt, err := p.failAttempt(attempt.AttemptID, sendErr) - if err != nil { - return nil, nil, err - } - - // Return a launchOutcome indicating the shard failed. - return attempt, &attemptResult{ - attempt: htlcAttempt.attempt, - err: sendErr, - }, nil - } - - return attempt, &attemptResult{}, nil -} - // collectResultAsync launches a goroutine that will wait for the result of the // given HTLC attempt to be available then handle its result. It will fail the // payment with the control tower if a terminal error is encountered. @@ -659,12 +596,8 @@ func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, func (p *paymentLifecycle) sendAttempt( attempt *channeldb.HTLCAttempt) (*attemptResult, error) { - log.Tracef("Attempting to send payment %v (pid=%v), "+ - "using route: %v", p.identifier, attempt.AttemptID, - newLogClosure(func() string { - return spew.Sdump(attempt.Route) - }), - ) + log.Debugf("Attempting to send payment %v (pid=%v)", p.identifier, + attempt.AttemptID) rt := attempt.Route @@ -708,8 +641,8 @@ func (p *paymentLifecycle) sendAttempt( return p.handleSwitchErr(attempt, err) } - log.Debugf("Payment %v (pid=%v) successfully sent to switch, route: %v", - p.identifier, attempt.AttemptID, &attempt.Route) + log.Debugf("Attempt %v for payment %v successfully sent to switch, "+ + "route: %v", attempt.AttemptID, p.identifier, &attempt.Route) return &attemptResult{ attempt: attempt, diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index c1dfe4e44..d65ba78d5 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -800,6 +800,15 @@ func makeSettledAttempt(total, fee int, } } +func makeFailedAttempt(total, fee int) *channeldb.HTLCAttempt { + return &channeldb.HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + Failure: &channeldb.HTLCFailInfo{ + Reason: channeldb.HTLCFailInternal, + }, + } +} + func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo { hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)} return channeldb.HTLCAttemptInfo{ diff --git a/routing/router.go b/routing/router.go index 08700526e..0b442da15 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2,7 +2,6 @@ package routing import ( "bytes" - goErrors "errors" "fmt" "math" "runtime" @@ -2505,81 +2504,82 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, r, 0, paymentIdentifier, nil, shardTracker, 0, 0, ) - var shardError error - attempt, outcome, err := p.launchShard(rt, 0) - - // With SendToRoute, it can happen that the route exceeds protocol - // constraints. Mark the payment as failed with an internal error. - if err == route.ErrMaxRouteHopsExceeded || - err == sphinx.ErrMaxRoutingInfoSizeExceeded { - - log.Debugf("Invalid route provided for payment %x: %v", - paymentIdentifier, err) - - controlErr := r.cfg.Control.FailPayment( - paymentIdentifier, channeldb.FailureReasonError, - ) - if controlErr != nil { - return nil, controlErr - } - } - - // In any case, don't continue if there is an error. + // We found a route to try, create a new HTLC attempt to try. + // + // NOTE: we use zero `remainingAmt` here to simulate the same effect of + // setting the lastShard to be false, which is used by previous + // implementation. + attempt, err := p.registerAttempt(rt, 0) if err != nil { return nil, err } - var htlcAttempt *channeldb.HTLCAttempt - switch { - // Failed to launch shard. - case outcome.err != nil: - shardError = outcome.err - htlcAttempt = outcome.attempt + // Once the attempt is created, send it to the htlcswitch. Notice that + // the `err` returned here has already been processed by + // `handleSwitchErr`, which means if there's a terminal failure, the + // payment has been failed. + result, err := p.sendAttempt(attempt) + if err != nil { + return nil, err + } - // Shard successfully launched, wait for the result to be available. - default: - result, err := p.collectResult(attempt) + // We now lookup the payment to see if it's already failed. + payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + if err != nil { + return result.attempt, err + } + + // Exit if the above error has caused the payment to be failed, we also + // return the error from sending attempt to mimic the old behavior of + // this method. + if payment.GetFailureReason() != nil { + return result.attempt, result.err + } + + // Since for SendToRoute we won't retry in case the shard fails, we'll + // mark the payment failed with the control tower immediately if the + // skipTempErr is false. + reason := channeldb.FailureReasonError + + // If we failed to send the HTLC, we need to further decide if we want + // to fail the payment. + if result.err != nil { + // If skipTempErr, we'll return the attempt and the temp error. + if skipTempErr { + return result.attempt, result.err + } + + // Otherwise we need to fail the payment. + err := r.cfg.Control.FailPayment(paymentIdentifier, reason) if err != nil { return nil, err } - // We got a successful result. - if result.err == nil { - return result.attempt, nil - } - - // The shard failed, break switch to handle it. - shardError = result.err - htlcAttempt = result.attempt + return result.attempt, result.err } - // Since for SendToRoute we won't retry in case the shard fails, we'll - // mark the payment failed with the control tower immediately. Process - // the error to check if it maps into a terminal error code, if not use - // a generic NO_ROUTE error. - var failureReason *channeldb.FailureReason - _, err = p.handleSwitchErr(attempt, shardError) - - switch { - // If a non-terminal error is returned and `skipTempErr` is false, then - // we'll use the normal no route error. - case err == nil && !skipTempErr: - err = r.cfg.Control.FailPayment( - paymentIdentifier, channeldb.FailureReasonNoRoute, - ) - - // If this is a failure reason, then we'll apply the failure directly - // to the control tower, and return the normal response to the caller. - case goErrors.As(err, &failureReason): - err = r.cfg.Control.FailPayment( - paymentIdentifier, *failureReason, - ) - } + // The attempt was successfully sent, wait for the result to be + // available. + result, err = p.collectResult(attempt) if err != nil { return nil, err } - return htlcAttempt, shardError + // We got a successful result. + if result.err == nil { + return result.attempt, nil + } + + // An error returned from collecting the result, we'll mark the payment + // as failed if we don't skip temp error. + if !skipTempErr { + err := r.cfg.Control.FailPayment(paymentIdentifier, reason) + if err != nil { + return nil, err + } + } + + return result.attempt, result.err } // sendPayment attempts to send a payment to the passed payment hash. This diff --git a/routing/router_test.go b/routing/router_test.go index 14e2ae839..f18e80024 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3025,8 +3025,8 @@ func TestSendToRouteMaxHops(t *testing.T) { // Send off the payment request to the router. We expect an error back // indicating that the route is too long. - var payment lntypes.Hash - _, err = ctx.router.SendToRoute(payment, rt) + var payHash lntypes.Hash + _, err = ctx.router.SendToRoute(payHash, rt) if err != route.ErrMaxRouteHopsExceeded { t.Fatalf("expected ErrMaxRouteHopsExceeded, but got %v", err) } @@ -4272,11 +4272,13 @@ func TestBlockDifferenceFix(t *testing.T) { // TestSendToRouteSkipTempErrSuccess validates a successful payment send. func TestSendToRouteSkipTempErrSuccess(t *testing.T) { var ( - payHash lntypes.Hash - payAmt = lnwire.MilliSatoshi(10000) - testAttempt = &channeldb.HTLCAttempt{} + payHash lntypes.Hash + payAmt = lnwire.MilliSatoshi(10000) ) + preimage := lntypes.Preimage{1} + testAttempt := makeSettledAttempt(int(payAmt), 0, preimage) + node, err := createTestNode() require.NoError(t, err) @@ -4313,7 +4315,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) controlTower.On("SettleAttempt", payHash, mock.Anything, mock.Anything, - ).Return(testAttempt, nil) + ).Return(&testAttempt, nil) payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, @@ -4332,15 +4334,23 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { mock.Anything, rt, ).Return(nil) + // Mock the control tower to return the mocked payment. + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() + + // Mock the payment to return nil failrue reason. + payment.On("GetFailureReason").Return(nil) + // Expect a successful send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) require.NoError(t, err) - require.Equal(t, testAttempt, attempt) + require.Equal(t, &testAttempt, attempt) // Assert the above methods are called as expected. controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestSendToRouteSkipTempErrTempFailure validates a temporary failure won't @@ -4413,11 +4423,19 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { } }) - // Return a nil reason to mock a temporary failure. + // Mock the control tower to return the mocked payment. + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() + + // Mock the mission control to return a nil reason from reporting the + // attempt failure. missionControl.On("ReportPaymentFail", mock.Anything, rt, mock.Anything, mock.Anything, ).Return(nil, nil) + // Mock the payment to return nil failrue reason. + payment.On("GetFailureReason").Return(nil) + // Expect a failed send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) require.Equal(t, tempErr, err) @@ -4427,17 +4445,18 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestSendToRouteSkipTempErrPermanentFailure validates a permanent failure // will fail the payment. func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { var ( - payHash lntypes.Hash - payAmt = lnwire.MilliSatoshi(10000) - testAttempt = &channeldb.HTLCAttempt{} + payHash lntypes.Hash + payAmt = lnwire.MilliSatoshi(10000) ) + testAttempt := makeFailedAttempt(int(payAmt), 0) node, err := createTestNode() require.NoError(t, err) @@ -4469,9 +4488,15 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { }, }} + // Create the error to be returned. + permErr := htlcswitch.NewForwardingError( + &lnwire.FailIncorrectDetails{}, 1, + ) + // Register mockers with the expected method calls. controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) + controlTower.On("FailAttempt", payHash, mock.Anything, mock.Anything, ).Return(testAttempt, nil) @@ -4479,34 +4504,23 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { // Expect the payment to be failed. controlTower.On("FailPayment", payHash, mock.Anything).Return(nil) + // Mock an error to be returned from sending the htlc. payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, - ).Return(nil) + ).Return(permErr) - // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 1) - - // Create the error to be returned. - permErr := htlcswitch.NewForwardingError( - &lnwire.FailIncorrectDetails{}, 1, - ) - - // Mock GetAttemptResult to return a failure. - payer.On("GetAttemptResult", - mock.Anything, mock.Anything, mock.Anything, - ).Run(func(_ mock.Arguments) { - // Send a permanent failure. - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: permErr, - } - }) - - // Return a reason to mock a permanent failure. failureReason := channeldb.FailureReasonPaymentDetails missionControl.On("ReportPaymentFail", mock.Anything, rt, mock.Anything, mock.Anything, ).Return(&failureReason, nil) + // Mock the control tower to return the mocked payment. + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() + + // Mock the payment to return a failrue reason. + payment.On("GetFailureReason").Return(&failureReason) + // Expect a failed send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) require.Equal(t, permErr, err) @@ -4516,17 +4530,18 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestSendToRouteTempFailure validates a temporary failure will cause the // payment to be failed. func TestSendToRouteTempFailure(t *testing.T) { var ( - payHash lntypes.Hash - payAmt = lnwire.MilliSatoshi(10000) - testAttempt = &channeldb.HTLCAttempt{} + payHash lntypes.Hash + payAmt = lnwire.MilliSatoshi(10000) ) + testAttempt := makeFailedAttempt(int(payAmt), 0) node, err := createTestNode() require.NoError(t, err) @@ -4558,6 +4573,11 @@ func TestSendToRouteTempFailure(t *testing.T) { }, }} + // Create the error to be returned. + tempErr := htlcswitch.NewForwardingError( + &lnwire.FailTemporaryChannelFailure{}, 1, + ) + // Register mockers with the expected method calls. controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) @@ -4570,26 +4590,14 @@ func TestSendToRouteTempFailure(t *testing.T) { payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, - ).Return(nil) + ).Return(tempErr) - // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 1) + // Mock the control tower to return the mocked payment. + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - // Create the error to be returned. - tempErr := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, - 1, - ) - - // Mock GetAttemptResult to return a failure. - payer.On("GetAttemptResult", - mock.Anything, mock.Anything, mock.Anything, - ).Run(func(_ mock.Arguments) { - // Send an attempt failure. - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: tempErr, - } - }) + // Mock the payment to return nil failrue reason. + payment.On("GetFailureReason").Return(nil) // Return a nil reason to mock a temporary failure. missionControl.On("ReportPaymentFail", @@ -4605,6 +4613,7 @@ func TestSendToRouteTempFailure(t *testing.T) { controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestNewRouteRequest tests creation of route requests for blinded and From 703ea0831677135057f3fefb2cc1f4f1941e08be Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 27 Jun 2022 05:02:30 +0800 Subject: [PATCH 07/27] routing: add methods `checkTimeout` and `requestRoute` This commit refactors the `resumePayment` method by adding the methods `checkTimeout` and `requestRoute` so it's easier to understand the flow and reason about the error handling. --- routing/mock_test.go | 6 + routing/payment_lifecycle.go | 161 +++++++++------- routing/payment_lifecycle_test.go | 294 ++++++++++++++++++++++++++++++ 3 files changed, 400 insertions(+), 61 deletions(-) diff --git a/routing/mock_test.go b/routing/mock_test.go index 6ab8f7083..6db22797e 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -673,6 +673,12 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, activeShards, height uint32) (*route.Route, error) { args := m.Called(maxAmt, feeLimit, activeShards, height) + + // Type assertion on nil will fail, so we check and return here. + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*route.Route), args.Error(1) } diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 6539d1a60..6cbe5d69c 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -1,12 +1,13 @@ package routing import ( + "errors" + "fmt" "sync" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" - "github.com/go-errors/errors" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -204,70 +205,19 @@ lifecycle: // router is exiting. In either case, we'll stop this payment // attempt short. If a timeout is not applicable, timeoutChan // will be nil. - select { - case <-p.timeoutChan: - log.Warnf("payment attempt not completed before " + - "timeout") - - // By marking the payment failed with the control - // tower, no further shards will be launched and we'll - // return with an error the moment all active shards - // have finished. - saveErr := p.router.cfg.Control.FailPayment( - p.identifier, channeldb.FailureReasonTimeout, - ) - if saveErr != nil { - return exitWithErr(saveErr) - } - - continue lifecycle - - case <-p.router.quit: - return exitWithErr(ErrRouterShuttingDown) - - // Fall through if we haven't hit our time limit. - default: + if err := p.checkTimeout(); err != nil { + return exitWithErr(err) } - // Create a new payment attempt from the given payment session. - rt, err := p.paySession.RequestRoute( - ps.RemainingAmt, remainingFees, - uint32(ps.NumAttemptsInFlight), - uint32(p.currentHeight), - ) + // Now request a route to be used to create our HTLC attempt. + rt, err := p.requestRoute(ps) if err != nil { - log.Warnf("Failed to find route for payment %v: %v", - p.identifier, err) + return exitWithErr(err) + } - routeErr, ok := err.(noRouteError) - if !ok { - return exitWithErr(err) - } - - // There is no route to try, and we have no active - // shards. This means that there is no way for us to - // send the payment, so mark it failed with no route. - if ps.NumAttemptsInFlight == 0 { - failureCode := routeErr.FailureReason() - log.Debugf("Marking payment %v permanently "+ - "failed with no route: %v", - p.identifier, failureCode) - - saveErr := p.router.cfg.Control.FailPayment( - p.identifier, failureCode, - ) - if saveErr != nil { - return exitWithErr(saveErr) - } - - continue lifecycle - } - - // We still have active shards, we'll wait for an - // outcome to be available before retrying. - if err := p.waitForShard(); err != nil { - return exitWithErr(err) - } + // NOTE: might cause an infinite loop, see notes in + // `requestRoute` for details. + if rt == nil { continue lifecycle } @@ -293,6 +243,95 @@ lifecycle: } } +// checkTimeout checks whether the payment has reached its timeout. +func (p *paymentLifecycle) checkTimeout() error { + select { + case <-p.timeoutChan: + log.Warnf("payment attempt not completed before timeout") + + // By marking the payment failed, depending on whether it has + // inflight HTLCs or not, its status will now either be + // `StatusInflight` or `StatusFailed`. In either case, no more + // HTLCs will be attempted. + err := p.router.cfg.Control.FailPayment( + p.identifier, channeldb.FailureReasonTimeout, + ) + if err != nil { + return fmt.Errorf("FailPayment got %w", err) + } + + case <-p.router.quit: + return fmt.Errorf("check payment timeout got: %w", + ErrRouterShuttingDown) + + // Fall through if we haven't hit our time limit. + default: + } + + return nil +} + +// requestRoute is responsible for finding a route to be used to create an HTLC +// attempt. +func (p *paymentLifecycle) requestRoute( + ps *channeldb.MPPaymentState) (*route.Route, error) { + + remainingFees := p.calcFeeBudget(ps.FeesPaid) + + // Query our payment session to construct a route. + rt, err := p.paySession.RequestRoute( + ps.RemainingAmt, remainingFees, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), + ) + + // Exit early if there's no error. + if err == nil { + return rt, nil + } + + // Otherwise we need to handle the error. + log.Warnf("Failed to find route for payment %v: %v", p.identifier, err) + + // If the error belongs to `noRouteError` set, it means a non-critical + // error has happened during path finding and we might be able to find + // another route during next HTLC attempt. Otherwise, we'll return the + // critical error found. + var routeErr noRouteError + if !errors.As(err, &routeErr) { + return nil, fmt.Errorf("requestRoute got: %w", err) + } + + // There is no route to try, and we have no active shards. This means + // that there is no way for us to send the payment, so mark it failed + // with no route. + // + // NOTE: if we have zero `numShardsInFlight`, it means all the HTLC + // attempts have failed. Otherwise, if there are still inflight + // attempts, we might enter an infinite loop in our lifecycle if + // there's still remaining amount since we will keep adding new HTLC + // attempts and they all fail with `noRouteError`. + // + // TODO(yy): further check the error returned here. It's the + // `paymentSession`'s responsibility to find a route for us with best + // effort. When it cannot find a path, we need to treat it as a + // terminal condition and fail the payment no matter it has inflight + // HTLCs or not. + if ps.NumAttemptsInFlight == 0 { + failureCode := routeErr.FailureReason() + log.Debugf("Marking payment %v permanently failed with no "+ + "route: %v", p.identifier, failureCode) + + err := p.router.cfg.Control.FailPayment( + p.identifier, failureCode, + ) + if err != nil { + return nil, fmt.Errorf("FailPayment got: %w", err) + } + } + + return nil, nil +} + // stop signals any active shard goroutine to exit and waits for them to exit. func (p *paymentLifecycle) stop() { close(p.quit) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index d65ba78d5..8e1389687 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -818,3 +819,296 @@ func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo { }, } } + +// TestCheckTimeoutTimedOut checks that when the payment times out, it is +// marked as failed. +func TestCheckTimeoutTimedOut(t *testing.T) { + t.Parallel() + + p := createTestPaymentLifecycle() + + // Mock the control tower's `FailPayment` method. + ct := &mockControlTower{} + ct.On("FailPayment", + p.identifier, channeldb.FailureReasonTimeout).Return(nil) + + // Mount the mocked control tower. + p.router.cfg.Control = ct + + // Make the timeout happens instantly. + p.timeoutChan = time.After(1 * time.Nanosecond) + + // Sleep one millisecond to make sure it timed out. + time.Sleep(1 * time.Millisecond) + + // Call the function and expect no error. + err := p.checkTimeout() + require.NoError(t, err) + + // Assert that `FailPayment` is called as expected. + ct.AssertExpectations(t) + + // We now test that when `FailPayment` returns an error, it's returned + // by the function too. + // + // Mock `FailPayment` to return a dummy error. + dummyErr := errors.New("dummy") + ct = &mockControlTower{} + ct.On("FailPayment", + p.identifier, channeldb.FailureReasonTimeout).Return(dummyErr) + + // Mount the mocked control tower. + p.router.cfg.Control = ct + + // Make the timeout happens instantly. + p.timeoutChan = time.After(1 * time.Nanosecond) + + // Sleep one millisecond to make sure it timed out. + time.Sleep(1 * time.Millisecond) + + // Call the function and expect an error. + err = p.checkTimeout() + require.ErrorIs(t, err, dummyErr) + + // Assert that `FailPayment` is called as expected. + ct.AssertExpectations(t) +} + +// TestCheckTimeoutOnRouterQuit checks that when the router has quit, an error +// is returned from checkTimeout. +func TestCheckTimeoutOnRouterQuit(t *testing.T) { + t.Parallel() + + p := createTestPaymentLifecycle() + + close(p.router.quit) + err := p.checkTimeout() + require.ErrorIs(t, err, ErrRouterShuttingDown) +} + +// createTestPaymentLifecycle creates a `paymentLifecycle` using the mocks. +func createTestPaymentLifecycle() *paymentLifecycle { + paymentHash := lntypes.Hash{1, 2, 3} + quitChan := make(chan struct{}) + rt := &ChannelRouter{ + cfg: &Config{}, + quit: quitChan, + } + + return &paymentLifecycle{ + router: rt, + identifier: paymentHash, + } +} + +// TestRequestRouteSucceed checks that `requestRoute` can successfully request +// a route. +func TestRequestRouteSucceed(t *testing.T) { + t.Parallel() + + p := createTestPaymentLifecycle() + + // Create a mock payment session and a dummy route. + paySession := &mockPaymentSession{} + dummyRoute := &route.Route{} + + // Mount the mocked payment session. + p.paySession = paySession + + // Create a dummy payment state. + ps := &channeldb.MPPaymentState{ + NumAttemptsInFlight: 1, + RemainingAmt: 1, + FeesPaid: 100, + } + + // Mock remainingFees to be 1. + p.feeLimit = ps.FeesPaid + 1 + + // Mock the paySession's `RequestRoute` method to return no error. + paySession.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(dummyRoute, nil) + + result, err := p.requestRoute(ps) + require.NoError(t, err, "expect no error") + require.Equal(t, dummyRoute, result, "returned route not matched") + + // Assert that `RequestRoute` is called as expected. + paySession.AssertExpectations(t) +} + +// TestRequestRouteHandleCriticalErr checks that `requestRoute` can +// successfully handle a critical error returned from payment session. +func TestRequestRouteHandleCriticalErr(t *testing.T) { + t.Parallel() + + p := createTestPaymentLifecycle() + + // Create a mock payment session. + paySession := &mockPaymentSession{} + + // Mount the mocked payment session. + p.paySession = paySession + + // Create a dummy payment state. + ps := &channeldb.MPPaymentState{ + NumAttemptsInFlight: 1, + RemainingAmt: 1, + FeesPaid: 100, + } + + // Mock remainingFees to be 1. + p.feeLimit = ps.FeesPaid + 1 + + // Mock the paySession's `RequestRoute` method to return an error. + dummyErr := errors.New("dummy") + paySession.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(nil, dummyErr) + + result, err := p.requestRoute(ps) + + // Expect an error is returned since it's critical. + require.ErrorIs(t, err, dummyErr, "error not matched") + require.Nil(t, result, "expected no route returned") + + // Assert that `RequestRoute` is called as expected. + paySession.AssertExpectations(t) +} + +// TestRequestRouteHandleNoRouteErr checks that `requestRoute` can successfully +// handle the `noRouteError` returned from payment session. +func TestRequestRouteHandleNoRouteErr(t *testing.T) { + t.Parallel() + + p := createTestPaymentLifecycle() + + // Create a mock payment session. + paySession := &mockPaymentSession{} + + // Mount the mocked payment session. + p.paySession = paySession + + // Create a dummy payment state. + ps := &channeldb.MPPaymentState{ + NumAttemptsInFlight: 1, + RemainingAmt: 1, + FeesPaid: 100, + } + + // Mock remainingFees to be 1. + p.feeLimit = ps.FeesPaid + 1 + + // Mock the paySession's `RequestRoute` method to return an error. + paySession.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(nil, errNoTlvPayload) + + result, err := p.requestRoute(ps) + + // Expect no error is returned since it's not critical. + require.NoError(t, err, "expected no error") + require.Nil(t, result, "expected no route returned") + + // Assert that `RequestRoute` is called as expected. + paySession.AssertExpectations(t) +} + +// TestRequestRouteFailPaymentSucceed checks that `requestRoute` fails the +// payment when received an `noRouteError` returned from payment session while +// it has no inflight attempts. +func TestRequestRouteFailPaymentSucceed(t *testing.T) { + t.Parallel() + + p := createTestPaymentLifecycle() + + // Create a mock payment session. + paySession := &mockPaymentSession{} + + // Mock the control tower's `FailPayment` method. + ct := &mockControlTower{} + ct.On("FailPayment", + p.identifier, errNoTlvPayload.FailureReason(), + ).Return(nil) + + // Mount the mocked control tower and payment session. + p.router.cfg.Control = ct + p.paySession = paySession + + // Create a dummy payment state with zero inflight attempts. + ps := &channeldb.MPPaymentState{ + NumAttemptsInFlight: 0, + RemainingAmt: 1, + FeesPaid: 100, + } + + // Mock remainingFees to be 1. + p.feeLimit = ps.FeesPaid + 1 + + // Mock the paySession's `RequestRoute` method to return an error. + paySession.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(nil, errNoTlvPayload) + + result, err := p.requestRoute(ps) + + // Expect no error is returned since it's not critical. + require.NoError(t, err, "expected no error") + require.Nil(t, result, "expected no route returned") + + // Assert that `RequestRoute` is called as expected. + paySession.AssertExpectations(t) + + // Assert that `FailPayment` is called as expected. + ct.AssertExpectations(t) +} + +// TestRequestRouteFailPaymentError checks that `requestRoute` returns the +// error from calling `FailPayment`. +func TestRequestRouteFailPaymentError(t *testing.T) { + t.Parallel() + + p := createTestPaymentLifecycle() + + // Create a mock payment session. + paySession := &mockPaymentSession{} + + // Mock the control tower's `FailPayment` method. + ct := &mockControlTower{} + dummyErr := errors.New("dummy") + ct.On("FailPayment", + p.identifier, errNoTlvPayload.FailureReason(), + ).Return(dummyErr) + + // Mount the mocked control tower and payment session. + p.router.cfg.Control = ct + p.paySession = paySession + + // Create a dummy payment state with zero inflight attempts. + ps := &channeldb.MPPaymentState{ + NumAttemptsInFlight: 0, + RemainingAmt: 1, + FeesPaid: 100, + } + + // Mock remainingFees to be 1. + p.feeLimit = ps.FeesPaid + 1 + + // Mock the paySession's `RequestRoute` method to return an error. + paySession.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(nil, errNoTlvPayload) + + result, err := p.requestRoute(ps) + + // Expect an error is returned. + require.ErrorIs(t, err, dummyErr, "error not matched") + require.Nil(t, result, "expected no route returned") + + // Assert that `RequestRoute` is called as expected. + paySession.AssertExpectations(t) + + // Assert that `FailPayment` is called as expected. + ct.AssertExpectations(t) +} From 173900c8dcb7829348cdbb7f638ff6ba7d37c20b Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 13 Feb 2023 18:32:02 +0800 Subject: [PATCH 08/27] routing: only fail attempt inside `handleSwitchErr` This commit makes sure we only fail attempt inside `handleSwitchErr` to ensure the orders in failing payment and attempts. It refactors `collectResult` to return `attemptResult`, and expands `handleSwitchErr` to also handle the case where the attemptID is not found. --- routing/payment_lifecycle.go | 56 ++++++++++++------------------------ 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 6cbe5d69c..20497c62a 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -413,32 +413,15 @@ func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { defer handleResultErr() // Block until the result is available. - result, err := p.collectResult(attempt) + _, err := p.collectResult(attempt) if err != nil { - if err != ErrRouterShuttingDown && - err != htlcswitch.ErrSwitchExiting && - err != errShardHandlerExiting { + log.Errorf("Error collecting result for attempt %v "+ + "in payment %v: %v", attempt.AttemptID, + p.identifier, err) - log.Errorf("Error collecting result for "+ - "shard %v for payment %v: %v", - attempt.AttemptID, p.identifier, err) - } - - // Overwrite the param errToSend and return so that the - // defer function will use the param to proceed. errToSend = err return } - - // If a non-critical error was encountered handle it and mark - // the payment failed if the failure was terminal. - if result.err != nil { - // Overwrite the param errToSend and return so that the - // defer function will use the param to proceed. Notice - // that the errToSend could be nil here. - _, errToSend = p.handleSwitchErr(attempt, result.err) - return - } }() } @@ -477,24 +460,12 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( resultChan, err := p.router.cfg.Payer.GetAttemptResult( attempt.AttemptID, p.identifier, errorDecryptor, ) - switch { - // If this attempt ID is unknown to the Switch, it means it was never - // checkpointed and forwarded by the switch before a restart. In this - // case we can safely send a new payment attempt, and wait for its - // result to be available. - case err == htlcswitch.ErrPaymentIDNotFound: - log.Debugf("Attempt ID %v for payment %v not found in "+ - "the Switch, retrying.", attempt.AttemptID, - p.identifier) - - return p.failAttempt(attempt.AttemptID, err) - - // A critical, unexpected error was encountered. - case err != nil: + // Handle the switch error. + if err != nil { log.Errorf("Failed getting result for attemptID %d "+ "from switch: %v", attempt.AttemptID, err) - return nil, err + return p.handleSwitchErr(attempt, err) } // The switch knows about this payment, we'll wait for a result to be @@ -517,7 +488,7 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( // In case of a payment failure, fail the attempt with the control // tower and return. if result.Error != nil { - return p.failAttempt(attempt.AttemptID, result.Error) + return p.handleSwitchErr(attempt, result.Error) } // We successfully got a payment result back from the switch. @@ -754,6 +725,17 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, return p.failPaymentAndAttempt(attemptID, reason, sendErr) } + // If this attempt ID is unknown to the Switch, it means it was never + // checkpointed and forwarded by the switch before a restart. In this + // case we can safely send a new payment attempt, and wait for its + // result to be available. + if errors.Is(sendErr, htlcswitch.ErrPaymentIDNotFound) { + log.Debugf("Attempt ID %v for payment %v not found in the "+ + "Switch, retrying.", attempt.AttemptID, p.identifier) + + return p.failAttempt(attemptID, sendErr) + } + if sendErr == htlcswitch.ErrUnreadableFailureMessage { log.Warn("Unreadable failure when sending htlc: id=%v, hash=%v", attempt.AttemptID, attempt.Hash) From 9a0db291b5ba147b6eb4a86836377660ec343085 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Wed, 27 Sep 2023 09:36:55 -0400 Subject: [PATCH 09/27] routing: fix tests after main refactor Delete TestSendMPPaymentFailedWithShardsInFlight as it seems to be the same test as TestSendMPPaymentFailed. --- routing/payment_lifecycle_test.go | 2 +- routing/router_test.go | 264 +++++------------------------- 2 files changed, 46 insertions(+), 220 deletions(-) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 8e1389687..eca18305f 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -732,7 +732,7 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, select { case err := <-paymentResult: - require.Equal(t, test.paymentErr, err) + require.ErrorIs(t, err, test.paymentErr) case <-time.After(stepTimeout): fatal("got no payment result") diff --git a/routing/router_test.go b/routing/router_test.go index f18e80024..6adefe12a 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3874,18 +3874,43 @@ func TestSendMPPaymentFailed(t *testing.T) { controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) // Mock the InFlightHTLCs. - var htlcs []channeldb.HTLCAttempt + var ( + htlcs []channeldb.HTLCAttempt + numAttempts atomic.Uint32 + failAttemptCount atomic.Uint32 + failed atomic.Bool + numParts = uint32(4) + ) // Make a mock MPPayment. payment := &mockMPPayment{} - payment.On("InFlightHTLCs").Return(htlcs). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("GetStatus").Return(channeldb.StatusInFlight). - On("Terminated").Return(false). - On("NeedWaitAttempts").Return(false, nil) + payment.On("InFlightHTLCs").Return(htlcs).Once() + payment.On("GetStatus").Return(channeldb.StatusInFlight).Once() + payment.On("GetState").Return(&channeldb.MPPaymentState{}) + controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() - // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", identifier).Return(payment, nil) + // Mock the sequential FetchPayment to return the payment. + controlTower.On("FetchPayment", identifier).Return(payment, nil).Run( + func(_ mock.Arguments) { + // We want to at least send out all parts in order to + // wait for them later. + if numAttempts.Load() < numParts { + payment.On("Terminated").Return(false).Times(2). + On("NeedWaitAttempts").Return(false, nil).Once() + return + } + + // Wait if the payment wasn't failed yet. + if !failed.Load() { + payment.On("Terminated").Return(false).Times(2). + On("NeedWaitAttempts").Return(true, nil).Once() + + return + } + + payment.On("Terminated").Return(true). + On("GetHTLCs").Return(htlcs).Once() + }) // Create a route that can send 1/4 of the total amount. This value // will be returned by calling RequestRoute. @@ -3899,7 +3924,9 @@ func TestSendMPPaymentFailed(t *testing.T) { // HTLCs when calling RegisterAttempt. controlTower.On("RegisterAttempt", identifier, mock.Anything, - ).Return(nil) + ).Return(nil).Run(func(args mock.Arguments) { + numAttempts.Add(1) + }) // Create a buffered chan and it will be returned by GetAttemptResult. payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) @@ -3907,18 +3934,17 @@ func TestSendMPPaymentFailed(t *testing.T) { // We use the failAttemptCount to track how many attempts we want to // fail. Each time the following mock method is called, the count gets // updated. - failAttemptCount := 0 payer.On("GetAttemptResult", mock.Anything, identifier, mock.Anything, - ).Run(func(args mock.Arguments) { + ).Run(func(_ mock.Arguments) { // Before the mock method is returned, we send the result to // the read-only chan. // Update the counter. - failAttemptCount++ + failAttemptCount.Add(1) // We fail the first attempt with terminal error. - if failAttemptCount == 1 { + if failAttemptCount.Load() == 1 { payer.resultChan <- &htlcswitch.PaymentResult{ Error: htlcswitch.NewForwardingError( &lnwire.FailIncorrectDetails{}, @@ -3953,12 +3979,13 @@ func TestSendMPPaymentFailed(t *testing.T) { // Simple mocking the rest. controlTower.On("FailPayment", identifier, failureReason, - ).Return(nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mock the payment's - // Terminated() to be True. - payment.On("Terminated").Return(true) + ).Return(nil).Run(func(_ mock.Arguments) { + failed.Store(true) }) + // Mock the payment to return the failure reason. + payment.On("GetFailureReason").Return(&failureReason) + payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, ).Return(nil) @@ -3983,208 +4010,7 @@ func TestSendMPPaymentFailed(t *testing.T) { // methods are called as expected. require.Error(t, err, "expected send payment error") require.EqualValues(t, [32]byte{}, p, "preimage not match") - - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - sessionSource.AssertExpectations(t) - session.AssertExpectations(t) - missionControl.AssertExpectations(t) - payment.AssertExpectations(t) -} - -// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in -// terminal state, even if we have shards in flight, we still fail the payment -// and exit. This test mainly focuses on testing the logic of the method -// resumePayment is implemented as expected. -func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { - const startingBlockHeight = 101 - - // Create mockers to initialize the router. - controlTower := &mockControlTower{} - sessionSource := &mockPaymentSessionSource{} - missionControl := &mockMissionControl{} - payer := &mockPaymentAttemptDispatcher{} - chain := newMockChain(startingBlockHeight) - chainView := newMockChainView(chain) - testGraph := createDummyTestGraph(t) - - // Define the behavior of the mockers to the point where we can - // successfully start the router. - controlTower.On("FetchInFlightPayments").Return( - []*channeldb.MPPayment{}, nil, - ) - payer.On("CleanStore", mock.Anything).Return(nil) - - // Create and start the router. - router, err := New(Config{ - Control: controlTower, - SessionSource: sessionSource, - MissionControl: missionControl, - Payer: payer, - - // TODO(yy): create new mocks for the chain and chainview. - Chain: chain, - ChainView: chainView, - - // TODO(yy): mock the graph once it's changed into interface. - Graph: testGraph.graph, - - Clock: clock.NewTestClock(time.Unix(1, 0)), - GraphPruneInterval: time.Hour * 2, - NextPaymentID: func() (uint64, error) { - next := atomic.AddUint64(&uniquePaymentID, 1) - return next, nil - }, - - IsAlias: func(scid lnwire.ShortChannelID) bool { - return false - }, - }) - require.NoError(t, err, "failed to create router") - - // Make sure the router can start and stop without error. - require.NoError(t, router.Start(), "router failed to start") - t.Cleanup(func() { - require.NoError(t, router.Stop(), "router failed to stop") - }) - - // Once the router is started, check that the mocked methods are called - // as expected. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - - // Mock the methods to the point where we are inside the function - // resumePayment. - paymentAmt := lnwire.MilliSatoshi(10000) - req := createDummyLightningPayment( - t, testGraph.aliasMap["c"], paymentAmt, - ) - identifier := lntypes.Hash(req.Identifier()) - session := &mockPaymentSession{} - sessionSource.On("NewPaymentSession", req).Return(session, nil) - controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - - // Mock the InFlightHTLCs. - var htlcs []channeldb.HTLCAttempt - - // Make a mock MPPayment. - payment := &mockMPPayment{} - payment.On("InFlightHTLCs").Return(htlcs). - On("GetStatus").Return(channeldb.StatusInFlight). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("Terminated").Return(false). - On("NeedWaitAttempts").Return(false, nil) - - // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", identifier).Return(payment, nil) - - // Create a route that can send 1/4 of the total amount. This value - // will be returned by calling RequestRoute. - shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) - require.NoError(t, err, "failed to create route") - session.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(shard, nil) - - // Make a new htlc attempt with zero fee and append it to the payment's - // HTLCs when calling RegisterAttempt. - controlTower.On("RegisterAttempt", - identifier, mock.Anything, - ).Return(nil) - - // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) - - // We use the getPaymentResultCnt to track how many times we called - // GetAttemptResult. As shard launch is sequential, and we fail the - // first shard that calls GetAttemptResult, we may end up with different - // counts since the lifecycle itself is asynchronous. To avoid flakes - // due to this undeterminsitic behavior, we'll compare the final - // getPaymentResultCnt with other counters to create a final test - // expectation. - getPaymentResultCnt := 0 - payer.On("GetAttemptResult", - mock.Anything, identifier, mock.Anything, - ).Run(func(args mock.Arguments) { - // Before the mock method is returned, we send the result to - // the read-only chan. - - // Update the counter. - getPaymentResultCnt++ - - // We fail the first attempt with terminal error. - if getPaymentResultCnt == 1 { - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: htlcswitch.NewForwardingError( - &lnwire.FailIncorrectDetails{}, - 1, - ), - } - return - } - - // For the rest of the attempts we'll simulate that a network - // result update_fail_htlc has been received. This way the - // payment will fail cleanly. - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, - 1, - ), - } - }) - - // Mock the FailAttempt method to fail (at least once). - var failedAttempt channeldb.HTLCAttempt - controlTower.On("FailAttempt", - identifier, mock.Anything, mock.Anything, - ).Return(&failedAttempt, nil) - - // Setup ReportPaymentFail to return nil reason and error so the - // payment won't fail. - failureReason := channeldb.FailureReasonPaymentDetails - missionControl.On("ReportPaymentFail", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&failureReason, nil) - - // Simple mocking the rest. - cntFail := 0 - controlTower.On("FailPayment", - identifier, failureReason, - ).Return(nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mock the payment's - // Terminated() to be True. - payment.On("Terminated").Return(true) - }) - - payer.On("SendHTLC", - mock.Anything, mock.Anything, mock.Anything, - ).Return(nil).Run(func(args mock.Arguments) { - cntFail++ - }) - - // Call the actual method SendPayment on router. This is place inside a - // goroutine so we can set a timeout for the whole test, in case - // anything goes wrong and the test never finishes. - done := make(chan struct{}) - var p lntypes.Hash - go func() { - p, _, err = router.SendPayment(req) - close(done) - }() - - select { - case <-done: - case <-time.After(testTimeout): - t.Fatalf("SendPayment didn't exit") - } - - // Finally, validate the returned values and check that the mock - // methods are called as expected. - require.Error(t, err, "expected send payment error") - require.EqualValues(t, [32]byte{}, p, "preimage not match") - require.GreaterOrEqual(t, getPaymentResultCnt, 1) - require.Equal(t, getPaymentResultCnt, cntFail) + require.GreaterOrEqual(t, failAttemptCount.Load(), uint32(1)) controlTower.AssertExpectations(t) payer.AssertExpectations(t) From e8c0226e1cb2437befb551e1b80d271f8760486f Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 8 Mar 2023 00:48:22 +0800 Subject: [PATCH 10/27] routing: add `AllowMoreAttempts` to decide whether more attempts are allowed --- channeldb/mp_payment.go | 45 +++++++++ channeldb/mp_payment_test.go | 177 +++++++++++++++++++++++++++++++++++ routing/control_tower.go | 5 + routing/mock_test.go | 5 + 4 files changed, 232 insertions(+) diff --git a/channeldb/mp_payment.go b/channeldb/mp_payment.go index 30c595049..a4b323349 100644 --- a/channeldb/mp_payment.go +++ b/channeldb/mp_payment.go @@ -469,6 +469,51 @@ func (m *MPPayment) GetFailureReason() *FailureReason { return m.FailureReason } +// AllowMoreAttempts is used to decide whether we can safely attempt more HTLCs +// for a given payment state. Return an error if the payment is in an +// unexpected state. +func (m *MPPayment) AllowMoreAttempts() (bool, error) { + // Now check whether the remainingAmt is zero or not. If we don't have + // any remainingAmt, no more HTLCs should be made. + if m.State.RemainingAmt == 0 { + // If the payment is newly created, yet we don't have any + // remainingAmt, return an error. + if m.Status == StatusInitiated { + return false, fmt.Errorf("%w: initiated payment has "+ + "zero remainingAmt", ErrPaymentInternal) + } + + // Otherwise, exit early since all other statuses with zero + // remainingAmt indicate no more HTLCs can be made. + return false, nil + } + + // Otherwise, the remaining amount is not zero, we now decide whether + // to make more attempts based on the payment's current status. + // + // If at least one of the payment's attempts is settled, yet we haven't + // sent all the amount, it indicates something is wrong with the peer + // as the preimage is received. In this case, return an error state. + if m.Status == StatusSucceeded { + return false, fmt.Errorf("%w: payment already succeeded but "+ + "still have remaining amount %v", ErrPaymentInternal, + m.State.RemainingAmt) + } + + // Now check if we can register a new HTLC. + err := m.Registrable() + if err != nil { + log.Warnf("Payment(%v): cannot register HTLC attempt: %v, "+ + "current status: %s", m.Info.PaymentIdentifier, + err, m.Status) + + return false, nil + } + + // Now we know we can register new HTLCs. + return true, nil +} + // serializeHTLCSettleInfo serializes the details of a settled htlc. func serializeHTLCSettleInfo(w io.Writer, s *HTLCSettleInfo) error { if _, err := w.Write(s.Preimage[:]); err != nil { diff --git a/channeldb/mp_payment_test.go b/channeldb/mp_payment_test.go index ef00d90cc..51eda72bb 100644 --- a/channeldb/mp_payment_test.go +++ b/channeldb/mp_payment_test.go @@ -368,6 +368,183 @@ func TestNeedWaitAttempts(t *testing.T) { } } +// TestAllowMoreAttempts checks whether more attempts can be created against +// ALL possible payment statuses. +func TestAllowMoreAttempts(t *testing.T) { + t.Parallel() + + testCases := []struct { + status PaymentStatus + remainingAmt lnwire.MilliSatoshi + hasSettledHTLC bool + paymentFailed bool + allowMore bool + expectedErr error + }{ + { + // A newly created payment with zero remainingAmt + // indicates an error. + status: StatusInitiated, + remainingAmt: 0, + allowMore: false, + expectedErr: ErrPaymentInternal, + }, + { + // With zero remainingAmt we don't allow more HTLC + // attempts. + status: StatusInFlight, + remainingAmt: 0, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt we don't allow more HTLC + // attempts. + status: StatusSucceeded, + remainingAmt: 0, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt we don't allow more HTLC + // attempts. + status: StatusFailed, + remainingAmt: 0, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt and settled HTLCs we don't + // allow more HTLC attempts. + status: StatusInFlight, + remainingAmt: 0, + hasSettledHTLC: true, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt and failed payment we don't + // allow more HTLC attempts. + status: StatusInFlight, + remainingAmt: 0, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt and both settled HTLCs and + // failed payment, we don't allow more HTLC attempts. + status: StatusInFlight, + remainingAmt: 0, + hasSettledHTLC: true, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // A newly created payment can have more attempts. + status: StatusInitiated, + remainingAmt: 1000, + allowMore: true, + expectedErr: nil, + }, + { + // With HTLCs inflight we can have more attempts when + // the remainingAmt is not zero and we have neither + // failed payment or settled HTLCs. + status: StatusInFlight, + remainingAmt: 1000, + allowMore: true, + expectedErr: nil, + }, + { + // With HTLCs inflight we cannot have more attempts + // though the remainingAmt is not zero but we have + // settled HTLCs. + status: StatusInFlight, + remainingAmt: 1000, + hasSettledHTLC: true, + allowMore: false, + expectedErr: nil, + }, + { + // With HTLCs inflight we cannot have more attempts + // though the remainingAmt is not zero but we have + // failed payment. + status: StatusInFlight, + remainingAmt: 1000, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // With HTLCs inflight we cannot have more attempts + // though the remainingAmt is not zero but we have + // settled HTLCs and failed payment. + status: StatusInFlight, + remainingAmt: 1000, + hasSettledHTLC: true, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // With the payment settled, but the remainingAmt is + // not zero, we have an error state. + status: StatusSucceeded, + remainingAmt: 1000, + hasSettledHTLC: true, + allowMore: false, + expectedErr: ErrPaymentInternal, + }, + { + // With the payment failed with no inflight HTLCs, we + // don't allow more attempts to be made. + status: StatusFailed, + remainingAmt: 1000, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // With the payment in an unknown state, we don't allow + // more attempts to be made. + status: 0, + remainingAmt: 1000, + allowMore: false, + expectedErr: nil, + }, + } + + for i, tc := range testCases { + tc := tc + + p := &MPPayment{ + Info: &PaymentCreationInfo{ + PaymentIdentifier: [32]byte{1, 2, 3}, + }, + Status: tc.status, + State: &MPPaymentState{ + RemainingAmt: tc.remainingAmt, + HasSettledHTLC: tc.hasSettledHTLC, + PaymentFailed: tc.paymentFailed, + }, + } + + name := fmt.Sprintf("test_%d|status=%s|remainingAmt=%v", i, + tc.status, tc.remainingAmt) + + t.Run(name, func(t *testing.T) { + t.Parallel() + + result, err := p.AllowMoreAttempts() + require.ErrorIs(t, err, tc.expectedErr) + require.Equalf(t, tc.allowMore, result, "status=%v, "+ + "remainingAmt=%v", tc.status, tc.remainingAmt) + }) + } +} + func makeActiveAttempt(total, fee int) HTLCAttempt { return HTLCAttempt{ HTLCAttemptInfo: makeAttemptInfo(total, total-fee), diff --git a/routing/control_tower.go b/routing/control_tower.go index b23b7df5c..80ffbe5d9 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -33,6 +33,11 @@ type dbMPPayment interface { // GetFailureReason returns the reason the payment failed. GetFailureReason() *channeldb.FailureReason + + // AllowMoreAttempts is used to decide whether we can safely attempt + // more HTLCs for a given payment state. Return an error if the payment + // is in an unexpected state. + AllowMoreAttempts() (bool, error) } // ControlTower tracks all outgoing payments made, whose primary purpose is to diff --git a/routing/mock_test.go b/routing/mock_test.go index 6db22797e..5a2ee4206 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -839,6 +839,11 @@ func (m *mockMPPayment) GetFailureReason() *channeldb.FailureReason { return reason.(*channeldb.FailureReason) } +func (m *mockMPPayment) AllowMoreAttempts() (bool, error) { + args := m.Called() + return args.Bool(0), args.Error(1) +} + type mockLink struct { htlcswitch.ChannelLink bandwidth lnwire.MilliSatoshi From 3c5c37b6937321f9b7218b6611e9dc5c329a201d Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 8 Mar 2023 00:58:14 +0800 Subject: [PATCH 11/27] routing: introduce `stateStep` to manage payment lifecycle This commit adds a new struct, `stateStep`, to decide the workflow inside `resumePayment`. It also refactors `collectResultAsync` introducing a new channel `resultCollected`. This channel is used to signal the payment lifecycle that an HTLC attempt result is ready to be processed. --- routing/payment_lifecycle.go | 293 ++++++++++++++++-------------- routing/payment_lifecycle_test.go | 120 ++++++++++++ routing/router_test.go | 107 +++++++---- 3 files changed, 343 insertions(+), 177 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 20497c62a..d3e4289e6 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -3,7 +3,6 @@ package routing import ( "errors" "fmt" - "sync" "time" "github.com/btcsuite/btcd/btcec/v2" @@ -18,9 +17,6 @@ import ( "github.com/lightningnetwork/lnd/routing/shards" ) -// errShardHandlerExiting is returned from the shardHandler when it exits. -var errShardHandlerExiting = errors.New("shard handler exiting") - // paymentLifecycle holds all information about the current state of a payment // needed to resume if from any point. type paymentLifecycle struct { @@ -32,18 +28,15 @@ type paymentLifecycle struct { timeoutChan <-chan time.Time currentHeight int32 - // shardErrors is a channel where errors collected by calling - // collectResultAsync will be delivered. These results are meant to be - // inspected by calling waitForShard or checkShards, and the channel - // doesn't need to be initiated if the caller is using the sync - // collectResult directly. - // TODO(yy): delete. - shardErrors chan error - // quit is closed to signal the sub goroutines of the payment lifecycle // to stop. quit chan struct{} - wg sync.WaitGroup + + // resultCollected is used to signal that the result of an attempt has + // been collected. A nil error means the attempt is either successful + // or failed with temporary error. Otherwise, we should exit the + // lifecycle loop as a terminal error has occurred. + resultCollected chan error } // newPaymentLifecycle initiates a new payment lifecycle and returns it. @@ -53,14 +46,14 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, currentHeight int32) *paymentLifecycle { p := &paymentLifecycle{ - router: r, - feeLimit: feeLimit, - identifier: identifier, - paySession: paySession, - shardTracker: shardTracker, - currentHeight: currentHeight, - shardErrors: make(chan error), - quit: make(chan struct{}), + router: r, + feeLimit: feeLimit, + identifier: identifier, + paySession: paySession, + shardTracker: shardTracker, + currentHeight: currentHeight, + quit: make(chan struct{}), + resultCollected: make(chan error, 1), } // If a timeout is specified, create a timeout channel. If no timeout is @@ -92,6 +85,74 @@ func (p *paymentLifecycle) calcFeeBudget( return budget } +// stateStep defines an action to be taken in our payment lifecycle. We either +// quit, continue, or exit the lifecycle, see details below. +type stateStep uint8 + +const ( + // stepSkip is used when we need to skip the current lifecycle and jump + // to the next one. + stepSkip stateStep = iota + + // stepProceed is used when we can proceed the current lifecycle. + stepProceed + + // stepExit is used when we need to quit the current lifecycle. + stepExit +) + +// decideNextStep is used to determine the next step in the payment lifecycle. +func (p *paymentLifecycle) decideNextStep( + payment dbMPPayment) (stateStep, error) { + + // Check whether we could make new HTLC attempts. + allow, err := payment.AllowMoreAttempts() + if err != nil { + return stepExit, err + } + + if !allow { + // Check whether we need to wait for results. + wait, err := payment.NeedWaitAttempts() + if err != nil { + return stepExit, err + } + + // If we are not allowed to make new HTLC attempts and there's + // no need to wait, the lifecycle is done and we can exit. + if !wait { + return stepExit, nil + } + + log.Tracef("Waiting for attempt results for payment %v", + p.identifier) + + // Otherwise we wait for one HTLC attempt then continue + // the lifecycle. + // + // NOTE: we don't check `p.quit` since `decideNextStep` is + // running in the same goroutine as `resumePayment`. + select { + case err := <-p.resultCollected: + // If an error is returned, exit with it. + if err != nil { + return stepExit, err + } + + log.Tracef("Received attempt result for payment %v", + p.identifier) + + case <-p.router.quit: + return stepExit, ErrRouterShuttingDown + } + + return stepSkip, nil + } + + // Otherwise we need to make more attempts. + return stepProceed, nil +} + // resumePayment resumes the paymentLifecycle from the current state. func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { // When the payment lifecycle loop exits, we make sure to signal any @@ -127,20 +188,12 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { // critical error during path finding. lifecycle: for { - // Start by quickly checking if there are any outcomes already - // available to handle before we reevaluate our state. - if err := p.checkShards(); err != nil { - return exitWithErr(err) - } - // We update the payment state on every iteration. Since the // payment state is affected by multiple goroutines (ie, // collectResultAsync), it is NOT guaranteed that we always // have the latest state here. This is fine as long as the // state is consistent as a whole. - - // Fetch the latest payment from db. - payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + payment, err = p.router.cfg.Control.FetchPayment(p.identifier) if err != nil { return exitWithErr(err) } @@ -153,53 +206,14 @@ lifecycle: p.identifier, payment.Terminated(), ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees) - // TODO(yy): sanity check all the states to make sure - // everything is expected. - // We have a terminal condition and no active shards, we are - // ready to exit. - if payment.Terminated() { - // Find the first successful shard and return - // the preimage and route. - for _, a := range payment.GetHTLCs() { - if a.Settle == nil { - continue - } - - err := p.router.cfg.Control.DeleteFailedAttempts( - p.identifier, - ) - if err != nil { - log.Errorf("Error deleting failed "+ - "payment attempts for "+ - "payment %v: %v", p.identifier, - err) - } - - return a.Settle.Preimage, &a.Route, nil - } - - // Payment failed. - return exitWithErr(*payment.GetFailureReason()) - } - - // If we either reached a terminal error condition (but had - // active shards still) or there is no remaining value to send, - // we'll wait for a shard outcome. - wait, err := payment.NeedWaitAttempts() - if err != nil { - return exitWithErr(err) - } - - if wait { - // We still have outstanding shards, so wait for a new - // outcome to be available before re-evaluating our - // state. - if err := p.waitForShard(); err != nil { - return exitWithErr(err) - } - continue lifecycle - } - + // We now proceed our lifecycle with the following tasks in + // order, + // 1. check timeout. + // 2. request route. + // 3. create HTLC attempt. + // 4. send HTLC attempt. + // 5. collect HTLC attempt result. + // // Before we attempt any new shard, we'll check to see if // either we've gone past the payment attempt timeout, or the // router is exiting. In either case, we'll stop this payment @@ -209,6 +223,30 @@ lifecycle: return exitWithErr(err) } + // Now decide the next step of the current lifecycle. + step, err := p.decideNextStep(payment) + if err != nil { + return exitWithErr(err) + } + + switch step { + // Exit the for loop and return below. + case stepExit: + break lifecycle + + // Continue the for loop and skip the rest. + case stepSkip: + continue lifecycle + + // Continue the for loop and proceed the rest. + case stepProceed: + + // Unknown step received, exit with an error. + default: + err = fmt.Errorf("unknown step: %v", step) + return exitWithErr(err) + } + // Now request a route to be used to create our HTLC attempt. rt, err := p.requestRoute(ps) if err != nil { @@ -241,6 +279,27 @@ lifecycle: p.collectResultAsync(attempt) } } + + // Once we are out the lifecycle loop, it means we've reached a + // terminal condition. We either return the settled preimage or the + // payment's failure reason. + // + // Optionally delete the failed attempts from the database. + err = p.router.cfg.Control.DeleteFailedAttempts(p.identifier) + if err != nil { + log.Errorf("Error deleting failed htlc attempts for payment "+ + "%v: %v", p.identifier, err) + } + + // Find the first successful shard and return the preimage and route. + for _, a := range payment.GetHTLCs() { + if a.Settle != nil { + return a.Settle.Preimage, &a.Route, nil + } + } + + // Otherwise return the payment failure reason. + return [32]byte{}, nil, *payment.GetFailureReason() } // checkTimeout checks whether the payment has reached its timeout. @@ -332,46 +391,9 @@ func (p *paymentLifecycle) requestRoute( return nil, nil } -// stop signals any active shard goroutine to exit and waits for them to exit. +// stop signals any active shard goroutine to exit. func (p *paymentLifecycle) stop() { close(p.quit) - p.wg.Wait() -} - -// waitForShard blocks until any of the outstanding shards return. -func (p *paymentLifecycle) waitForShard() error { - select { - case err := <-p.shardErrors: - return err - - case <-p.quit: - return errShardHandlerExiting - - case <-p.router.quit: - return ErrRouterShuttingDown - } -} - -// checkShards is a non-blocking method that check if any shards has finished -// their execution. -func (p *paymentLifecycle) checkShards() error { - for { - select { - case err := <-p.shardErrors: - if err != nil { - return err - } - - case <-p.quit: - return errShardHandlerExiting - - case <-p.router.quit: - return ErrRouterShuttingDown - - default: - return nil - } - } } // attemptResult holds the HTLC attempt and a possible error returned from @@ -388,38 +410,33 @@ type attemptResult struct { } // collectResultAsync launches a goroutine that will wait for the result of the -// given HTLC attempt to be available then handle its result. It will fail the -// payment with the control tower if a terminal error is encountered. +// given HTLC attempt to be available then handle its result. Once received, it +// will send a nil error to channel `resultCollected` to indicate there's an +// result. func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { - // errToSend is the error to be sent to sh.shardErrors. - var errToSend error - - // handleResultErr is a function closure must be called using defer. It - // finishes collecting result by updating the payment state and send - // the error (or nil) to sh.shardErrors. - handleResultErr := func() { - // Send the error or quit. - select { - case p.shardErrors <- errToSend: - case <-p.router.quit: - case <-p.quit: - } - - p.wg.Done() - } - - p.wg.Add(1) go func() { - defer handleResultErr() - // Block until the result is available. _, err := p.collectResult(attempt) if err != nil { log.Errorf("Error collecting result for attempt %v "+ "in payment %v: %v", attempt.AttemptID, p.identifier, err) + } - errToSend = err + log.Debugf("Result collected for attempt %v in payment %v", + attempt.AttemptID, p.identifier) + + // Once the result is collected, we signal it by writing the + // error to `resultCollected`. + select { + // Send the signal or quit. + case p.resultCollected <- err: + + case <-p.quit: + log.Debugf("Lifecycle exiting while collecting "+ + "result for payment %v", p.identifier) + + case <-p.router.quit: return } }() diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index eca18305f..8fee4cfb5 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -21,6 +21,10 @@ import ( const stepTimeout = 5 * time.Second +var ( + dummyErr = errors.New("dummy") +) + // createTestRoute builds a route a->b->c paying the given amt to c. func createTestRoute(amt lnwire.MilliSatoshi, aliasMap map[string]route.Vertex) (*route.Route, error) { @@ -1112,3 +1116,119 @@ func TestRequestRouteFailPaymentError(t *testing.T) { // Assert that `FailPayment` is called as expected. ct.AssertExpectations(t) } + +// TestDecideNextStep checks the method `decideNextStep` behaves as expected. +func TestDecideNextStep(t *testing.T) { + t.Parallel() + + // mockReturn is used to hold the return values from AllowMoreAttempts + // or NeedWaitAttempts. + type mockReturn struct { + allowOrWait bool + err error + } + + testCases := []struct { + name string + allowMoreAttempts *mockReturn + needWaitAttempts *mockReturn + + // When the attemptResultChan has returned. + closeResultChan bool + + // Whether the router has quit. + routerQuit bool + + expectedStep stateStep + expectedErr error + }{ + { + name: "allow more attempts", + allowMoreAttempts: &mockReturn{true, nil}, + expectedStep: stepProceed, + expectedErr: nil, + }, + { + name: "error on allow more attempts", + allowMoreAttempts: &mockReturn{false, dummyErr}, + expectedStep: stepExit, + expectedErr: dummyErr, + }, + { + name: "no wait and exit", + allowMoreAttempts: &mockReturn{false, nil}, + needWaitAttempts: &mockReturn{false, nil}, + expectedStep: stepExit, + expectedErr: nil, + }, + { + name: "wait returns an error", + allowMoreAttempts: &mockReturn{false, nil}, + needWaitAttempts: &mockReturn{false, dummyErr}, + expectedStep: stepExit, + expectedErr: dummyErr, + }, + + { + name: "wait and exit on result chan", + allowMoreAttempts: &mockReturn{false, nil}, + needWaitAttempts: &mockReturn{true, nil}, + closeResultChan: true, + expectedStep: stepSkip, + expectedErr: nil, + }, + { + name: "wait and exit on router quit", + allowMoreAttempts: &mockReturn{false, nil}, + needWaitAttempts: &mockReturn{true, nil}, + routerQuit: true, + expectedStep: stepExit, + expectedErr: ErrRouterShuttingDown, + }, + } + + for _, tc := range testCases { + tc := tc + + // Create a test paymentLifecycle. + p := createTestPaymentLifecycle() + + // Make a mock payment. + payment := &mockMPPayment{} + + // Mock the method AllowMoreAttempts. + payment.On("AllowMoreAttempts").Return( + tc.allowMoreAttempts.allowOrWait, + tc.allowMoreAttempts.err, + ).Once() + + // Mock the method NeedWaitAttempts. + if tc.needWaitAttempts != nil { + payment.On("NeedWaitAttempts").Return( + tc.needWaitAttempts.allowOrWait, + tc.needWaitAttempts.err, + ).Once() + } + + // Send a nil error to the attemptResultChan if requested. + if tc.closeResultChan { + p.resultCollected = make(chan error, 1) + p.resultCollected <- nil + } + + // Quit the router if requested. + if tc.routerQuit { + close(p.router.quit) + } + + // Once the setup is finished, run the test cases. + t.Run(tc.name, func(t *testing.T) { + step, err := p.decideNextStep(payment) + require.Equal(t, tc.expectedStep, step) + require.ErrorIs(t, tc.expectedErr, err) + }) + + // Check the payment's methods are called as expected. + payment.AssertExpectations(t) + } +} diff --git a/routing/router_test.go b/routing/router_test.go index 6adefe12a..6073ee2be 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3470,34 +3470,44 @@ func TestSendMPPaymentSucceed(t *testing.T) { session := &mockPaymentSession{} sessionSource.On("NewPaymentSession", req).Return(session, nil) controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - // Mock the InFlightHTLCs. var ( htlcs []channeldb.HTLCAttempt numAttempts atomic.Uint32 + settled atomic.Bool + numParts = uint32(4) ) // Make a mock MPPayment. payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}) + On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). + On("Terminated").Return(false) + controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", - identifier, - ).Return(payment, nil).Run(func(args mock.Arguments) { - // When number of attempts made is less than 4, we will mock - // the payment's methods to allow the lifecycle to continue. - if numAttempts.Load() < 4 { - payment.On("Terminated").Return(false).Times(2). - On("NeedWaitAttempts").Return(false, nil).Once() - return - } + controlTower.On("FetchPayment", identifier).Return(payment, nil). + Run(func(args mock.Arguments) { + // When number of attempts made is less than 4, we will + // mock the payment's methods to allow the lifecycle to + // continue. + if numAttempts.Load() < numParts { + payment.On("AllowMoreAttempts").Return(true, nil).Once() + return + } - // Otherwise, terminate the lifecycle. - payment.On("Terminated").Return(true). - On("NeedWaitAttempts").Return(true, nil) - }) + if !settled.Load() { + fmt.Println("wait") + payment.On("AllowMoreAttempts").Return(false, nil).Once() + payment.On("NeedWaitAttempts").Return(true, nil).Once() + // We add another attempt to the counter to + // unblock next time. + return + } + + payment.On("AllowMoreAttempts").Return(false, nil). + On("NeedWaitAttempts").Return(false, nil) + }) // Mock SettleAttempt. preimage := lntypes.Preimage{1, 2, 3} @@ -3511,6 +3521,10 @@ func TestSendMPPaymentSucceed(t *testing.T) { payment.On("GetHTLCs").Return( []channeldb.HTLCAttempt{settledAttempt}, ) + // We want to at least wait for one settlement. + if numAttempts.Load() > 1 { + settled.Store(true) + } }) // Create a route that can send 1/4 of the total amount. This value @@ -3527,7 +3541,6 @@ func TestSendMPPaymentSucceed(t *testing.T) { controlTower.On("RegisterAttempt", identifier, mock.Anything, ).Return(nil).Run(func(args mock.Arguments) { - // Increase the counter whenever an attempt is made. numAttempts.Add(1) }) @@ -3663,29 +3676,40 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { htlcs []channeldb.HTLCAttempt numAttempts atomic.Uint32 failAttemptCount atomic.Uint32 + settled atomic.Bool ) // Make a mock MPPayment. payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}) + On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). + On("Terminated").Return(false) + controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", - identifier, - ).Return(payment, nil).Run(func(args mock.Arguments) { - // When number of attempts made is less than 6, we will mock - // the payment's methods to allow the lifecycle to continue. - if numAttempts.Load() < 6 { - payment.On("Terminated").Return(false).Times(2). - On("NeedWaitAttempts").Return(false, nil).Once() - return - } + controlTower.On("FetchPayment", identifier).Return(payment, nil). + Run(func(args mock.Arguments) { + // When number of attempts made is less than 4, we will + // mock the payment's methods to allow the lifecycle to + // continue. + attempts := numAttempts.Load() + if attempts < 6 { + payment.On("AllowMoreAttempts").Return(true, nil).Once() + return + } - // Otherwise, terminate the lifecycle. - payment.On("Terminated").Return(true). - On("NeedWaitAttempts").Return(true, nil) - }) + if !settled.Load() { + payment.On("AllowMoreAttempts").Return(false, nil).Once() + payment.On("NeedWaitAttempts").Return(true, nil).Once() + // We add another attempt to the counter to + // unblock next time. + numAttempts.Add(1) + return + } + + payment.On("AllowMoreAttempts").Return(false, nil). + On("NeedWaitAttempts").Return(false, nil) + }) // Create a route that can send 1/4 of the total amount. This value // will be returned by calling RequestRoute. @@ -3768,6 +3792,10 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { payment.On("GetHTLCs").Return( []channeldb.HTLCAttempt{settledAttempt}, ) + + if numAttempts.Load() > 1 { + settled.Store(true) + } }) controlTower.On("DeleteFailedAttempts", identifier).Return(nil) @@ -3885,8 +3913,8 @@ func TestSendMPPaymentFailed(t *testing.T) { // Make a mock MPPayment. payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs).Once() - payment.On("GetStatus").Return(channeldb.StatusInFlight).Once() payment.On("GetState").Return(&channeldb.MPPaymentState{}) + payment.On("Terminated").Return(false) controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock the sequential FetchPayment to return the payment. @@ -3895,21 +3923,20 @@ func TestSendMPPaymentFailed(t *testing.T) { // We want to at least send out all parts in order to // wait for them later. if numAttempts.Load() < numParts { - payment.On("Terminated").Return(false).Times(2). - On("NeedWaitAttempts").Return(false, nil).Once() + payment.On("AllowMoreAttempts").Return(true, nil).Once() return } // Wait if the payment wasn't failed yet. if !failed.Load() { - payment.On("Terminated").Return(false).Times(2). + payment.On("AllowMoreAttempts").Return(false, nil).Once(). On("NeedWaitAttempts").Return(true, nil).Once() - return } - payment.On("Terminated").Return(true). - On("GetHTLCs").Return(htlcs).Once() + payment.On("AllowMoreAttempts").Return(false, nil). + On("GetHTLCs").Return(htlcs).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() }) // Create a route that can send 1/4 of the total amount. This value @@ -3990,6 +4017,8 @@ func TestSendMPPaymentFailed(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, ).Return(nil) + controlTower.On("DeleteFailedAttempts", identifier).Return(nil) + // Call the actual method SendPayment on router. This is place inside a // goroutine so we can set a timeout for the whole test, in case // anything goes wrong and the test never finishes. From da8f1c084af39c19c622fa699a302d633764c925 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 13 Feb 2023 13:58:52 +0800 Subject: [PATCH 12/27] channeldb+routing: add new interface method `TerminalInfo` This commit adds a new interface method `TerminalInfo` and changes its implementation to return an `*HTLCAttempt` so it includes the route for a successful payment. Method `GetFailureReason` is now removed as its returned value can be found in the above method. --- channeldb/mp_payment.go | 9 ++------- routing/control_tower.go | 7 ++++--- routing/control_tower_test.go | 14 ++++++-------- routing/mock_test.go | 34 +++++++++++++++++++++++----------- routing/payment_lifecycle.go | 18 ++++++++---------- routing/router.go | 3 ++- routing/router_test.go | 34 ++++++++++++++-------------------- 7 files changed, 59 insertions(+), 60 deletions(-) diff --git a/channeldb/mp_payment.go b/channeldb/mp_payment.go index a4b323349..cf5669a50 100644 --- a/channeldb/mp_payment.go +++ b/channeldb/mp_payment.go @@ -219,10 +219,10 @@ func (m *MPPayment) Terminated() bool { // TerminalInfo returns any HTLC settle info recorded. If no settle info is // recorded, any payment level failure will be returned. If neither a settle // nor a failure is recorded, both return values will be nil. -func (m *MPPayment) TerminalInfo() (*HTLCSettleInfo, *FailureReason) { +func (m *MPPayment) TerminalInfo() (*HTLCAttempt, *FailureReason) { for _, h := range m.HTLCs { if h.Settle != nil { - return h.Settle, nil + return &h, nil } } @@ -464,11 +464,6 @@ func (m *MPPayment) GetHTLCs() []HTLCAttempt { return m.HTLCs } -// GetFailureReason returns the failure reason. -func (m *MPPayment) GetFailureReason() *FailureReason { - return m.FailureReason -} - // AllowMoreAttempts is used to decide whether we can safely attempt more HTLCs // for a given payment state. Return an error if the payment is in an // unexpected state. diff --git a/routing/control_tower.go b/routing/control_tower.go index 80ffbe5d9..c064a5b4f 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -31,13 +31,14 @@ type dbMPPayment interface { // InFlightHTLCs returns all HTLCs that are in flight. InFlightHTLCs() []channeldb.HTLCAttempt - // GetFailureReason returns the reason the payment failed. - GetFailureReason() *channeldb.FailureReason - // AllowMoreAttempts is used to decide whether we can safely attempt // more HTLCs for a given payment state. Return an error if the payment // is in an unexpected state. AllowMoreAttempts() (bool, error) + + // TerminalInfo returns the settled HTLC attempt or the payment's + // failure reason. + TerminalInfo() (*channeldb.HTLCAttempt, *channeldb.FailureReason) } // ControlTower tracks all outgoing payments made, whose primary purpose is to diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index f14c18b81..42303dc55 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -134,8 +134,8 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { "subscriber %v failed, want %s, got %s", i, channeldb.StatusSucceeded, result.GetStatus()) - settle, _ := result.TerminalInfo() - if settle.Preimage != preimg { + attempt, _ := result.TerminalInfo() + if attempt.Settle.Preimage != preimg { t.Fatal("unexpected preimage") } if len(result.HTLCs) != 1 { @@ -264,9 +264,8 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { ) settle1, _ := result1.TerminalInfo() - require.Equal( - t, preimg1, settle1.Preimage, "unexpected preimage payment 1", - ) + require.Equal(t, preimg1, settle1.Settle.Preimage, + "unexpected preimage payment 1") require.Len( t, result1.HTLCs, 1, "expect 1 htlc for payment 1, got %d", @@ -283,9 +282,8 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { ) settle2, _ := result2.TerminalInfo() - require.Equal( - t, preimg2, settle2.Preimage, "unexpected preimage payment 2", - ) + require.Equal(t, preimg2, settle2.Settle.Preimage, + "unexpected preimage payment 2") require.Len( t, result2.HTLCs, 1, "expect 1 htlc for payment 2, got %d", len(result2.HTLCs), diff --git a/routing/mock_test.go b/routing/mock_test.go index 5a2ee4206..2458f538b 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -828,22 +828,34 @@ func (m *mockMPPayment) InFlightHTLCs() []channeldb.HTLCAttempt { return args.Get(0).([]channeldb.HTLCAttempt) } -func (m *mockMPPayment) GetFailureReason() *channeldb.FailureReason { - args := m.Called() - - reason := args.Get(0) - if reason == nil { - return nil - } - - return reason.(*channeldb.FailureReason) -} - func (m *mockMPPayment) AllowMoreAttempts() (bool, error) { args := m.Called() return args.Bool(0), args.Error(1) } +func (m *mockMPPayment) TerminalInfo() (*channeldb.HTLCAttempt, + *channeldb.FailureReason) { + + args := m.Called() + + var ( + settleInfo *channeldb.HTLCAttempt + failureInfo *channeldb.FailureReason + ) + + settle := args.Get(0) + if settle != nil { + settleInfo = settle.(*channeldb.HTLCAttempt) + } + + reason := args.Get(1) + if reason != nil { + failureInfo = reason.(*channeldb.FailureReason) + } + + return settleInfo, failureInfo +} + type mockLink struct { htlcswitch.ChannelLink bandwidth lnwire.MilliSatoshi diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index d3e4289e6..511ef3596 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -201,10 +201,10 @@ lifecycle: ps := payment.GetState() remainingFees := p.calcFeeBudget(ps.FeesPaid) - log.Debugf("Payment %v in state terminate=%v, "+ - "active_shards=%v, rem_value=%v, fee_limit=%v", - p.identifier, payment.Terminated(), - ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees) + log.Debugf("Payment %v: status=%v, active_shards=%v, "+ + "rem_value=%v, fee_limit=%v", p.identifier, + payment.GetStatus(), ps.NumAttemptsInFlight, + ps.RemainingAmt, remainingFees) // We now proceed our lifecycle with the following tasks in // order, @@ -291,15 +291,13 @@ lifecycle: "%v: %v", p.identifier, err) } - // Find the first successful shard and return the preimage and route. - for _, a := range payment.GetHTLCs() { - if a.Settle != nil { - return a.Settle.Preimage, &a.Route, nil - } + htlc, failure := payment.TerminalInfo() + if htlc != nil { + return htlc.Settle.Preimage, &htlc.Route, nil } // Otherwise return the payment failure reason. - return [32]byte{}, nil, *payment.GetFailureReason() + return [32]byte{}, nil, *failure } // checkTimeout checks whether the payment has reached its timeout. diff --git a/routing/router.go b/routing/router.go index 0b442da15..f7b20a376 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2532,7 +2532,8 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // Exit if the above error has caused the payment to be failed, we also // return the error from sending attempt to mimic the old behavior of // this method. - if payment.GetFailureReason() != nil { + _, failedReason := payment.TerminalInfo() + if failedReason != nil { return result.attempt, result.err } diff --git a/routing/router_test.go b/routing/router_test.go index 6073ee2be..eda49dd9d 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3482,7 +3482,7 @@ func TestSendMPPaymentSucceed(t *testing.T) { payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs). On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("Terminated").Return(false) + On("GetStatus").Return(channeldb.StatusInFlight) controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock FetchPayment to return the payment. @@ -3518,9 +3518,6 @@ func TestSendMPPaymentSucceed(t *testing.T) { controlTower.On("SettleAttempt", identifier, mock.Anything, mock.Anything, ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { - payment.On("GetHTLCs").Return( - []channeldb.HTLCAttempt{settledAttempt}, - ) // We want to at least wait for one settlement. if numAttempts.Load() > 1 { settled.Store(true) @@ -3566,6 +3563,8 @@ func TestSendMPPaymentSucceed(t *testing.T) { controlTower.On("DeleteFailedAttempts", identifier).Return(nil) + payment.On("TerminalInfo").Return(&settledAttempt, nil) + // Call the actual method SendPayment on router. This is place inside a // goroutine so we can set a timeout for the whole test, in case // anything goes wrong and the test never finishes. @@ -3683,7 +3682,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs). On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("Terminated").Return(false) + On("GetStatus").Return(channeldb.StatusInFlight) controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock FetchPayment to return the payment. @@ -3787,12 +3786,6 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { controlTower.On("SettleAttempt", identifier, mock.Anything, mock.Anything, ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mock the payment's - // GetHTLCs() to return the settled htlc. - payment.On("GetHTLCs").Return( - []channeldb.HTLCAttempt{settledAttempt}, - ) - if numAttempts.Load() > 1 { settled.Store(true) } @@ -3800,6 +3793,8 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { controlTower.On("DeleteFailedAttempts", identifier).Return(nil) + payment.On("TerminalInfo").Return(&settledAttempt, nil) + // Call the actual method SendPayment on router. This is place inside a // goroutine so we can set a timeout for the whole test, in case // anything goes wrong and the test never finishes. @@ -3913,8 +3908,8 @@ func TestSendMPPaymentFailed(t *testing.T) { // Make a mock MPPayment. payment := &mockMPPayment{} payment.On("InFlightHTLCs").Return(htlcs).Once() - payment.On("GetState").Return(&channeldb.MPPaymentState{}) - payment.On("Terminated").Return(false) + payment.On("GetState").Return(&channeldb.MPPaymentState{}). + On("GetStatus").Return(channeldb.StatusInFlight) controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() // Mock the sequential FetchPayment to return the payment. @@ -3935,7 +3930,6 @@ func TestSendMPPaymentFailed(t *testing.T) { } payment.On("AllowMoreAttempts").Return(false, nil). - On("GetHTLCs").Return(htlcs).Once(). On("NeedWaitAttempts").Return(false, nil).Once() }) @@ -4011,12 +4005,12 @@ func TestSendMPPaymentFailed(t *testing.T) { }) // Mock the payment to return the failure reason. - payment.On("GetFailureReason").Return(&failureReason) - payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, ).Return(nil) + payment.On("TerminalInfo").Return(nil, &failureReason) + controlTower.On("DeleteFailedAttempts", identifier).Return(nil) // Call the actual method SendPayment on router. This is place inside a @@ -4194,7 +4188,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() // Mock the payment to return nil failrue reason. - payment.On("GetFailureReason").Return(nil) + payment.On("TerminalInfo").Return(nil, nil) // Expect a successful send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) @@ -4289,7 +4283,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { ).Return(nil, nil) // Mock the payment to return nil failrue reason. - payment.On("GetFailureReason").Return(nil) + payment.On("TerminalInfo").Return(nil, nil) // Expect a failed send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) @@ -4374,7 +4368,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() // Mock the payment to return a failrue reason. - payment.On("GetFailureReason").Return(&failureReason) + payment.On("TerminalInfo").Return(nil, &failureReason) // Expect a failed send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) @@ -4452,7 +4446,7 @@ func TestSendToRouteTempFailure(t *testing.T) { controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() // Mock the payment to return nil failrue reason. - payment.On("GetFailureReason").Return(nil) + payment.On("TerminalInfo").Return(nil, nil) // Return a nil reason to mock a temporary failure. missionControl.On("ReportPaymentFail", From eda24ec871e29cdf0604155bc2cce6fb7b8db018 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 7 Mar 2023 20:27:02 +0800 Subject: [PATCH 13/27] routing: catch lifecycle quit signal in `collectResult` --- routing/payment_lifecycle.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 511ef3596..d0f43d7b0 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -17,6 +17,10 @@ import ( "github.com/lightningnetwork/lnd/routing/shards" ) +// ErrPaymentLifecycleExiting is used when waiting for htlc attempt result, but +// the payment lifecycle is exiting . +var ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting") + // paymentLifecycle holds all information about the current state of a payment // needed to resume if from any point. type paymentLifecycle struct { @@ -496,6 +500,9 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( return nil, htlcswitch.ErrSwitchExiting } + case <-p.quit: + return nil, ErrPaymentLifecycleExiting + case <-p.router.quit: return nil, ErrRouterShuttingDown } From 09a5d235ec01a00d8d2c42ee05f6956a130ebe03 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 7 Mar 2023 18:58:48 +0800 Subject: [PATCH 14/27] routing: fail attempt when no shard is found or circuit generation fails --- htlcswitch/switch.go | 2 ++ routing/payment_lifecycle.go | 25 +++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 592d03a1c..dab25aaeb 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1000,6 +1000,8 @@ func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult, // We've received a fail update which means we can finalize the // user payment and return fail response. case *lnwire.UpdateFailHTLC: + // TODO(yy): construct deobfuscator here to avoid creating it + // in paymentLifecycle even for settled HTLCs. paymentErr := s.parseFailedPayment( deobfuscator, attemptID, paymentHash, n.unencrypted, n.isResolution, htlc, diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index d0f43d7b0..38489b141 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -456,15 +456,27 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( // below. hash, err := p.shardTracker.GetHash(attempt.AttemptID) if err != nil { - return nil, err + return p.failAttempt(attempt.AttemptID, err) } // Regenerate the circuit for this attempt. _, circuit, err := generateSphinxPacket( &attempt.Route, hash[:], attempt.SessionKey(), ) + // TODO(yy): We generate this circuit to create the error decryptor, + // which is then used in htlcswitch as the deobfuscator to decode the + // error from `UpdateFailHTLC`. However, suppose it's an + // `UpdateFulfillHTLC` message yet for some reason the sphinx packet is + // failed to be generated, we'd miss settling it. This means we should + // give it a second chance to try the settlement path in case + // `GetAttemptResult` gives us back the preimage. And move the circuit + // creation into htlcswitch so it's only constructed when there's a + // failure message we need to decode. if err != nil { - return nil, err + log.Debugf("Unable to generate circuit for attempt %v: %v", + attempt.AttemptID, err) + + return p.failAttempt(attempt.AttemptID, err) } // Using the created circuit, initialize the error decrypter so we can @@ -476,6 +488,12 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( // Now ask the switch to return the result of the payment when // available. + // + // TODO(yy): consider using htlcswitch to create the `errorDecryptor` + // since the htlc is already in db. This will also make the interface + // `PaymentAttemptDispatcher` deeper and easier to use. Moreover, we'd + // only create the decryptor when received a failure, further saving us + // a few CPU cycles. resultChan, err := p.router.cfg.Payer.GetAttemptResult( attempt.AttemptID, p.identifier, errorDecryptor, ) @@ -536,6 +554,9 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( ) if err != nil { log.Errorf("Unable to settle payment attempt: %v", err) + + // We won't mark the attempt as failed since we already have + // the preimage. return nil, err } From 01e3bd87ab1a1579aa217df35a7a1a0d13b8b078 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 8 Mar 2023 01:09:55 +0800 Subject: [PATCH 15/27] routing: delete old payment lifecycle related unit tests The old payment lifecycle is removed due to it's not "unit" - maintaining these tests probably takes as much work as the actual methods being tested, if not more so. Moreover, the usage of the old mockers in current payment lifecycle test is removed as it re-implements other interfaces and sometimes implements it uniquely just for the tests. This is bad as, not only we need to work on the actual interface implementations and test them , but also re-implement them again in the test without testing them! --- routing/payment_lifecycle_test.go | 779 ------------------------------ routing/router_test.go | 643 ------------------------ 2 files changed, 1422 deletions(-) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 8fee4cfb5..3f6dc1e71 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -1,17 +1,11 @@ package routing import ( - "crypto/rand" - "fmt" - "sync/atomic" "testing" "time" - "github.com/btcsuite/btcd/btcutil" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -19,783 +13,10 @@ import ( "github.com/stretchr/testify/require" ) -const stepTimeout = 5 * time.Second - var ( dummyErr = errors.New("dummy") ) -// createTestRoute builds a route a->b->c paying the given amt to c. -func createTestRoute(amt lnwire.MilliSatoshi, - aliasMap map[string]route.Vertex) (*route.Route, error) { - - hopFee := lnwire.NewMSatFromSatoshis(3) - hop1 := aliasMap["b"] - hop2 := aliasMap["c"] - hops := []*route.Hop{ - { - ChannelID: 1, - PubKeyBytes: hop1, - LegacyPayload: true, - AmtToForward: amt + hopFee, - }, - { - ChannelID: 2, - PubKeyBytes: hop2, - LegacyPayload: true, - AmtToForward: amt, - }, - } - - // We create a simple route that we will supply every time the router - // requests one. - return route.NewRouteFromHops( - amt+2*hopFee, 100, aliasMap["a"], hops, - ) -} - -// paymentLifecycleTestCase contains the steps that we expect for a payment -// lifecycle test, and the routes that pathfinding should deliver. -type paymentLifecycleTestCase struct { - name string - - // steps is a list of steps to perform during the testcase. - steps []string - - // routes is the sequence of routes we will provide to the - // router when it requests a new route. - routes []*route.Route - - // paymentErr is the error we expect our payment to fail with. This - // should be nil for tests with paymentSuccess steps and non-nil for - // payments with paymentError steps. - paymentErr error -} - -const ( - // routerInitPayment is a test step where we expect the router - // to call the InitPayment method on the control tower. - routerInitPayment = "Router:init-payment" - - // routerRegisterAttempt is a test step where we expect the - // router to call the RegisterAttempt method on the control - // tower. - routerRegisterAttempt = "Router:register-attempt" - - // routerSettleAttempt is a test step where we expect the - // router to call the SettleAttempt method on the control - // tower. - routerSettleAttempt = "Router:settle-attempt" - - // routerFailAttempt is a test step where we expect the router - // to call the FailAttempt method on the control tower. - routerFailAttempt = "Router:fail-attempt" - - // routerFailPayment is a test step where we expect the router - // to call the Fail method on the control tower. - routerFailPayment = "Router:fail-payment" - - // routeRelease is a test step where we unblock pathfinding and - // allow it to respond to our test with a route. - routeRelease = "PaymentSession:release" - - // sendToSwitchSuccess is a step where we expect the router to - // call send the payment attempt to the switch, and we will - // respond with a non-error, indicating that the payment - // attempt was successfully forwarded. - sendToSwitchSuccess = "SendToSwitch:success" - - // sendToSwitchResultFailure is a step where we expect the - // router to send the payment attempt to the switch, and we - // will respond with a forwarding error. This can happen when - // forwarding fail on our local links. - sendToSwitchResultFailure = "SendToSwitch:failure" - - // getPaymentResultSuccess is a test step where we expect the - // router to call the GetAttemptResult method, and we will - // respond with a successful payment result. - getPaymentResultSuccess = "GetAttemptResult:success" - - // getPaymentResultTempFailure is a test step where we expect the - // router to call the GetAttemptResult method, and we will - // respond with a forwarding error, expecting the router to retry. - getPaymentResultTempFailure = "GetAttemptResult:temp-failure" - - // getPaymentResultTerminalFailure is a test step where we - // expect the router to call the GetAttemptResult method, and - // we will respond with a terminal error, expecting the router - // to stop making payment attempts. - getPaymentResultTerminalFailure = "GetAttemptResult:terminal-failure" - - // resendPayment is a test step where we manually try to resend - // the same payment, making sure the router responds with an - // error indicating that it is already in flight. - resendPayment = "ResendPayment" - - // startRouter is a step where we manually start the router, - // used to test that it automatically will resume payments at - // startup. - startRouter = "StartRouter" - - // stopRouter is a test step where we manually make the router - // shut down. - stopRouter = "StopRouter" - - // paymentSuccess is a step where assert that we receive a - // successful result for the original payment made. - paymentSuccess = "PaymentSuccess" - - // paymentError is a step where assert that we receive an error - // for the original payment made. - paymentError = "PaymentError" - - // resentPaymentSuccess is a step where assert that we receive - // a successful result for a payment that was resent. - resentPaymentSuccess = "ResentPaymentSuccess" - - // resentPaymentError is a step where assert that we receive an - // error for a payment that was resent. - resentPaymentError = "ResentPaymentError" -) - -// TestRouterPaymentStateMachine tests that the router interacts as expected -// with the ControlTower during a payment lifecycle, such that it payment -// attempts are not sent twice to the switch, and results are handled after a -// restart. -func TestRouterPaymentStateMachine(t *testing.T) { - t.Parallel() - - const startingBlockHeight = 101 - - // Setup two simple channels such that we can mock sending along this - // route. - chanCapSat := btcutil.Amount(100000) - testChannels := []*testChannel{ - symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 1), - symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{ - Expiry: 144, - FeeRate: 400, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), - }, 2), - } - - testGraph, err := createTestGraphFromChannels(t, true, testChannels, "a") - require.NoError(t, err, "unable to create graph") - - paymentAmt := lnwire.NewMSatFromSatoshis(1000) - - // We create a simple route that we will supply every time the router - // requests one. - rt, err := createTestRoute(paymentAmt, testGraph.aliasMap) - require.NoError(t, err, "unable to create route") - - tests := []paymentLifecycleTestCase{ - { - // Tests a normal payment flow that succeeds. - name: "single shot success", - - steps: []string{ - routerInitPayment, - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - getPaymentResultSuccess, - routerSettleAttempt, - paymentSuccess, - }, - routes: []*route.Route{rt}, - }, - { - // A payment flow with a failure on the first attempt, - // but that succeeds on the second attempt. - name: "single shot retry", - - steps: []string{ - routerInitPayment, - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // Make the first sent attempt fail. - getPaymentResultTempFailure, - routerFailAttempt, - - // The router should retry. - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // Make the second sent attempt succeed. - getPaymentResultSuccess, - routerSettleAttempt, - paymentSuccess, - }, - routes: []*route.Route{rt, rt}, - }, - { - // A payment flow with a forwarding failure first time - // sending to the switch, but that succeeds on the - // second attempt. - name: "single shot switch failure", - - steps: []string{ - routerInitPayment, - routeRelease, - routerRegisterAttempt, - - // Make the first sent attempt fail. - sendToSwitchResultFailure, - routerFailAttempt, - - // The router should retry. - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // Make the second sent attempt succeed. - getPaymentResultSuccess, - routerSettleAttempt, - paymentSuccess, - }, - routes: []*route.Route{rt, rt}, - }, - { - // A payment that fails on the first attempt, and has - // only one route available to try. It will therefore - // fail permanently. - name: "single shot route fails", - - steps: []string{ - routerInitPayment, - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // Make the first sent attempt fail. - getPaymentResultTempFailure, - routerFailAttempt, - - routeRelease, - - // Since there are no more routes to try, the - // payment should fail. - routerFailPayment, - paymentError, - }, - routes: []*route.Route{rt}, - paymentErr: channeldb.FailureReasonNoRoute, - }, - { - // We expect the payment to fail immediately if we have - // no routes to try. - name: "single shot no route", - - steps: []string{ - routerInitPayment, - routeRelease, - routerFailPayment, - paymentError, - }, - routes: []*route.Route{}, - paymentErr: channeldb.FailureReasonNoRoute, - }, - { - // A normal payment flow, where we attempt to resend - // the same payment after each step. This ensures that - // the router don't attempt to resend a payment already - // in flight. - name: "single shot resend", - - steps: []string{ - routerInitPayment, - routeRelease, - routerRegisterAttempt, - - // Manually resend the payment, the router - // should attempt to init with the control - // tower, but fail since it is already in - // flight. - resendPayment, - routerInitPayment, - resentPaymentError, - - // The original payment should proceed as - // normal. - sendToSwitchSuccess, - - // Again resend the payment and assert it's not - // allowed. - resendPayment, - routerInitPayment, - resentPaymentError, - - // Notify about a success for the original - // payment. - getPaymentResultSuccess, - routerSettleAttempt, - - // Now that the original payment finished, - // resend it again to ensure this is not - // allowed. - resendPayment, - routerInitPayment, - resentPaymentError, - paymentSuccess, - }, - routes: []*route.Route{rt}, - }, - { - // Tests that the router is able to handle the - // received payment result after a restart. - name: "single shot restart", - - steps: []string{ - routerInitPayment, - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // Shut down the router. The original caller - // should get notified about this. - stopRouter, - paymentError, - - // Start the router again, and ensure the - // router registers the success with the - // control tower. - startRouter, - getPaymentResultSuccess, - routerSettleAttempt, - }, - routes: []*route.Route{rt}, - paymentErr: ErrRouterShuttingDown, - }, - { - // Tests that we are allowed to resend a payment after - // it has permanently failed. - name: "single shot resend fail", - - steps: []string{ - routerInitPayment, - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // Resending the payment at this stage should - // not be allowed. - resendPayment, - routerInitPayment, - resentPaymentError, - - // Make the first attempt fail. - getPaymentResultTempFailure, - routerFailAttempt, - - // Since we have no more routes to try, the - // original payment should fail. - routeRelease, - routerFailPayment, - paymentError, - - // Now resend the payment again. This should be - // allowed, since the payment has failed. - resendPayment, - routerInitPayment, - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - getPaymentResultSuccess, - routerSettleAttempt, - resentPaymentSuccess, - }, - routes: []*route.Route{rt}, - paymentErr: channeldb.FailureReasonNoRoute, - }, - } - - for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - testPaymentLifecycle( - t, test, paymentAmt, startingBlockHeight, - testGraph, - ) - }) - } -} - -func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, - paymentAmt lnwire.MilliSatoshi, startingBlockHeight uint32, - testGraph *testGraphInstance) { - - // Create a mock control tower with channels set up, that we use to - // synchronize and listen for events. - control := makeMockControlTower() - control.init = make(chan initArgs) - control.registerAttempt = make(chan registerAttemptArgs) - control.settleAttempt = make(chan settleAttemptArgs) - control.failAttempt = make(chan failAttemptArgs) - control.failPayment = make(chan failPaymentArgs) - control.fetchInFlight = make(chan struct{}) - - // setupRouter is a helper method that creates and starts the router in - // the desired configuration for this test. - setupRouter := func() (*ChannelRouter, chan error, - chan *htlcswitch.PaymentResult) { - - chain := newMockChain(startingBlockHeight) - chainView := newMockChainView(chain) - - // We set uo the use the following channels and a mock Payer to - // synchronize with the interaction to the Switch. - sendResult := make(chan error) - paymentResult := make(chan *htlcswitch.PaymentResult) - - payer := &mockPayerOld{ - sendResult: sendResult, - paymentResult: paymentResult, - } - - router, err := New(Config{ - Graph: testGraph.graph, - Chain: chain, - ChainView: chainView, - Control: control, - SessionSource: &mockPaymentSessionSourceOld{}, - MissionControl: &mockMissionControlOld{}, - Payer: payer, - ChannelPruneExpiry: time.Hour * 24, - GraphPruneInterval: time.Hour * 2, - NextPaymentID: func() (uint64, error) { - next := atomic.AddUint64(&uniquePaymentID, 1) - return next, nil - }, - Clock: clock.NewTestClock(time.Unix(1, 0)), - IsAlias: func(scid lnwire.ShortChannelID) bool { - return false - }, - }) - if err != nil { - t.Fatalf("unable to create router %v", err) - } - - // On startup, the router should fetch all pending payments - // from the ControlTower, so assert that here. - errCh := make(chan error) - go func() { - close(errCh) - select { - case <-control.fetchInFlight: - return - case <-time.After(1 * time.Second): - errCh <- errors.New("router did not fetch in flight " + - "payments") - } - }() - - if err := router.Start(); err != nil { - t.Fatalf("unable to start router: %v", err) - } - - select { - case err := <-errCh: - if err != nil { - t.Fatalf("error in anonymous goroutine: %s", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("did not fetch in flight payments at startup") - } - - return router, sendResult, paymentResult - } - - router, sendResult, getPaymentResult := setupRouter() - t.Cleanup(func() { - require.NoError(t, router.Stop()) - }) - - // Craft a LightningPayment struct. - var preImage lntypes.Preimage - if _, err := rand.Read(preImage[:]); err != nil { - t.Fatalf("unable to generate preimage") - } - - payHash := preImage.Hash() - - payment := LightningPayment{ - Target: testGraph.aliasMap["c"], - Amount: paymentAmt, - FeeLimit: noFeeLimit, - paymentHash: &payHash, - } - - // Setup our payment session source to block on release of - // routes. - routeChan := make(chan struct{}) - router.cfg.SessionSource = &mockPaymentSessionSourceOld{ - routes: test.routes, - routeRelease: routeChan, - } - - router.cfg.MissionControl = &mockMissionControlOld{} - - // Send the payment. Since this is new payment hash, the - // information should be registered with the ControlTower. - paymentResult := make(chan error) - done := make(chan struct{}) - go func() { - _, _, err := router.SendPayment(&payment) - paymentResult <- err - close(done) - }() - - var resendResult chan error - for i, step := range test.steps { - i, step := i, step - - // fatal is a helper closure that wraps the step info. - fatal := func(err string, args ...interface{}) { - if args != nil { - err = fmt.Sprintf(err, args) - } - t.Fatalf( - "test case: %s failed on step [%v:%s], err: %s", - test.name, i, step, err, - ) - } - - switch step { - case routerInitPayment: - var args initArgs - select { - case args = <-control.init: - case <-time.After(stepTimeout): - fatal("no init payment with control") - } - - if args.c == nil { - fatal("expected non-nil CreationInfo") - } - - case routeRelease: - select { - case <-routeChan: - case <-time.After(stepTimeout): - fatal("no route requested") - } - - // In this step we expect the router to make a call to - // register a new attempt with the ControlTower. - case routerRegisterAttempt: - var args registerAttemptArgs - select { - case args = <-control.registerAttempt: - case <-time.After(stepTimeout): - fatal("attempt not registered with control") - } - - if args.a == nil { - fatal("expected non-nil AttemptInfo") - } - - // In this step we expect the router to call the - // ControlTower's SettleAttempt method with the preimage. - case routerSettleAttempt: - select { - case <-control.settleAttempt: - case <-time.After(stepTimeout): - fatal("attempt settle not " + - "registered with control") - } - - // In this step we expect the router to call the - // ControlTower's FailAttempt method with a HTLC fail - // info. - case routerFailAttempt: - select { - case <-control.failAttempt: - case <-time.After(stepTimeout): - fatal("attempt fail not " + - "registered with control") - } - - // In this step we expect the router to call the - // ControlTower's Fail method, to indicate that the - // payment failed. - case routerFailPayment: - select { - case <-control.failPayment: - case <-time.After(stepTimeout): - fatal("payment fail not " + - "registered with control") - } - - // In this step we expect the SendToSwitch method to be - // called, and we respond with a nil-error. - case sendToSwitchSuccess: - select { - case sendResult <- nil: - case <-time.After(stepTimeout): - fatal("unable to send result") - } - - // In this step we expect the SendToSwitch method to be - // called, and we respond with a forwarding error - case sendToSwitchResultFailure: - select { - case sendResult <- htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, - 1, - ): - case <-time.After(stepTimeout): - fatal("unable to send result") - } - - // In this step we expect the GetAttemptResult method - // to be called, and we respond with the preimage to - // complete the payment. - case getPaymentResultSuccess: - select { - case getPaymentResult <- &htlcswitch.PaymentResult{ - Preimage: preImage, - }: - case <-time.After(stepTimeout): - fatal("unable to send result") - } - - // In this state we expect the GetAttemptResult method - // to be called, and we respond with a forwarding - // error, indicating that the router should retry. - case getPaymentResultTempFailure: - failure := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, - 1, - ) - - select { - case getPaymentResult <- &htlcswitch.PaymentResult{ - Error: failure, - }: - case <-time.After(stepTimeout): - fatal("unable to get result") - } - - // In this state we expect the router to call the - // GetAttemptResult method, and we will respond with a - // terminal error, indicating the router should stop - // making payment attempts. - case getPaymentResultTerminalFailure: - failure := htlcswitch.NewForwardingError( - &lnwire.FailIncorrectDetails{}, - 1, - ) - - select { - case getPaymentResult <- &htlcswitch.PaymentResult{ - Error: failure, - }: - case <-time.After(stepTimeout): - fatal("unable to get result") - } - - // In this step we manually try to resend the same - // payment, making sure the router responds with an - // error indicating that it is already in flight. - case resendPayment: - resendResult = make(chan error) - go func() { - _, _, err := router.SendPayment(&payment) - resendResult <- err - }() - - // In this step we manually stop the router. - case stopRouter: - // On shutdown, the switch closes our result channel. - // Mimic this behavior in our mock. - close(getPaymentResult) - - if err := router.Stop(); err != nil { - fatal("unable to restart: %v", err) - } - - // In this step we manually start the router. - case startRouter: - router, sendResult, getPaymentResult = setupRouter() - - // In this state we expect to receive an error for the - // original payment made. - case paymentError: - require.Error(t, test.paymentErr, - "paymentError not set") - - select { - case err := <-paymentResult: - require.ErrorIs(t, err, test.paymentErr) - - case <-time.After(stepTimeout): - fatal("got no payment result") - } - - // In this state we expect the original payment to - // succeed. - case paymentSuccess: - require.Nil(t, test.paymentErr) - - select { - case err := <-paymentResult: - if err != nil { - t.Fatalf("did not expect "+ - "error %v", err) - } - - case <-time.After(stepTimeout): - fatal("got no payment result") - } - - // In this state we expect to receive an error for the - // resent payment made. - case resentPaymentError: - select { - case err := <-resendResult: - if err == nil { - t.Fatalf("expected error") - } - - case <-time.After(stepTimeout): - fatal("got no payment result") - } - - // In this state we expect the resent payment to - // succeed. - case resentPaymentSuccess: - select { - case err := <-resendResult: - if err != nil { - t.Fatalf("did not expect error %v", err) - } - - case <-time.After(stepTimeout): - fatal("got no payment result") - } - - default: - fatal("unknown step %v", step) - } - } - - select { - case <-done: - case <-time.After(testTimeout): - t.Fatalf("SendPayment didn't exit") - } -} - func makeSettledAttempt(total, fee int, preimage lntypes.Preimage) channeldb.HTLCAttempt { diff --git a/routing/router_test.go b/routing/router_test.go index eda49dd9d..79f0b026c 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3400,649 +3400,6 @@ func createDummyLightningPayment(t *testing.T, } } -// TestSendMPPaymentSucceed tests that we can successfully send a MPPayment via -// router.SendPayment. This test mainly focuses on testing the logic of the -// method resumePayment is implemented as expected. -func TestSendMPPaymentSucceed(t *testing.T) { - const startingBlockHeight = 101 - - // Create mockers to initialize the router. - controlTower := &mockControlTower{} - sessionSource := &mockPaymentSessionSource{} - missionControl := &mockMissionControl{} - payer := &mockPaymentAttemptDispatcher{} - chain := newMockChain(startingBlockHeight) - chainView := newMockChainView(chain) - testGraph := createDummyTestGraph(t) - - // Define the behavior of the mockers to the point where we can - // successfully start the router. - controlTower.On("FetchInFlightPayments").Return( - []*channeldb.MPPayment{}, nil, - ) - payer.On("CleanStore", mock.Anything).Return(nil) - - // Create and start the router. - router, err := New(Config{ - Control: controlTower, - SessionSource: sessionSource, - MissionControl: missionControl, - Payer: payer, - - // TODO(yy): create new mocks for the chain and chainview. - Chain: chain, - ChainView: chainView, - - // TODO(yy): mock the graph once it's changed into interface. - Graph: testGraph.graph, - - Clock: clock.NewTestClock(time.Unix(1, 0)), - GraphPruneInterval: time.Hour * 2, - NextPaymentID: func() (uint64, error) { - next := atomic.AddUint64(&uniquePaymentID, 1) - return next, nil - }, - - IsAlias: func(scid lnwire.ShortChannelID) bool { - return false - }, - }) - require.NoError(t, err, "failed to create router") - - // Make sure the router can start and stop without error. - require.NoError(t, router.Start(), "router failed to start") - t.Cleanup(func() { - require.NoError(t, router.Stop(), "router failed to stop") - }) - - // Once the router is started, check that the mocked methods are called - // as expected. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - - // Mock the methods to the point where we are inside the function - // resumePayment. - paymentAmt := lnwire.MilliSatoshi(10000) - req := createDummyLightningPayment( - t, testGraph.aliasMap["c"], paymentAmt, - ) - identifier := lntypes.Hash(req.Identifier()) - session := &mockPaymentSession{} - sessionSource.On("NewPaymentSession", req).Return(session, nil) - controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - // Mock the InFlightHTLCs. - var ( - htlcs []channeldb.HTLCAttempt - numAttempts atomic.Uint32 - settled atomic.Bool - numParts = uint32(4) - ) - - // Make a mock MPPayment. - payment := &mockMPPayment{} - payment.On("InFlightHTLCs").Return(htlcs). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("GetStatus").Return(channeldb.StatusInFlight) - controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() - - // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", identifier).Return(payment, nil). - Run(func(args mock.Arguments) { - // When number of attempts made is less than 4, we will - // mock the payment's methods to allow the lifecycle to - // continue. - if numAttempts.Load() < numParts { - payment.On("AllowMoreAttempts").Return(true, nil).Once() - return - } - - if !settled.Load() { - fmt.Println("wait") - payment.On("AllowMoreAttempts").Return(false, nil).Once() - payment.On("NeedWaitAttempts").Return(true, nil).Once() - // We add another attempt to the counter to - // unblock next time. - return - } - - payment.On("AllowMoreAttempts").Return(false, nil). - On("NeedWaitAttempts").Return(false, nil) - }) - - // Mock SettleAttempt. - preimage := lntypes.Preimage{1, 2, 3} - settledAttempt := makeSettledAttempt( - int(paymentAmt/4), 0, preimage, - ) - - controlTower.On("SettleAttempt", - identifier, mock.Anything, mock.Anything, - ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { - // We want to at least wait for one settlement. - if numAttempts.Load() > 1 { - settled.Store(true) - } - }) - - // Create a route that can send 1/4 of the total amount. This value - // will be returned by calling RequestRoute. - shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) - require.NoError(t, err, "failed to create route") - - session.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(shard, nil) - - // Make a new htlc attempt with zero fee and append it to the payment's - // HTLCs when calling RegisterAttempt. - controlTower.On("RegisterAttempt", - identifier, mock.Anything, - ).Return(nil).Run(func(args mock.Arguments) { - numAttempts.Add(1) - }) - - // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) - payer.On("GetAttemptResult", - mock.Anything, identifier, mock.Anything, - ).Run(func(args mock.Arguments) { - // Before the mock method is returned, we send the result to - // the read-only chan. - payer.resultChan <- &htlcswitch.PaymentResult{} - }) - - // Simple mocking the rest. - payer.On("SendHTLC", - mock.Anything, mock.Anything, mock.Anything, - ).Return(nil) - - missionControl.On("ReportPaymentSuccess", - mock.Anything, mock.Anything, - ).Return(nil).Run(func(args mock.Arguments) { - }) - - controlTower.On("DeleteFailedAttempts", identifier).Return(nil) - - payment.On("TerminalInfo").Return(&settledAttempt, nil) - - // Call the actual method SendPayment on router. This is place inside a - // goroutine so we can set a timeout for the whole test, in case - // anything goes wrong and the test never finishes. - done := make(chan struct{}) - var p lntypes.Hash - go func() { - p, _, err = router.SendPayment(req) - close(done) - }() - - select { - case <-done: - case <-time.After(testTimeout): - t.Fatalf("SendPayment didn't exit") - } - - // Finally, validate the returned values and check that the mock - // methods are called as expected. - require.NoError(t, err, "send payment failed") - require.EqualValues(t, preimage, p, "preimage not match") - - // Note that we also implicitly check the methods such as FailAttempt, - // ReportPaymentFail, etc, are not called because we never mocked them - // in this test. If any of the unexpected methods was called, the test - // would fail. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - sessionSource.AssertExpectations(t) - session.AssertExpectations(t) - missionControl.AssertExpectations(t) - payment.AssertExpectations(t) -} - -// TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if -// there are failed ones,so that a payment is successfully sent. This test -// mainly focuses on testing the logic of the method resumePayment is -// implemented as expected. -func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { - const startingBlockHeight = 101 - - // Create mockers to initialize the router. - controlTower := &mockControlTower{} - sessionSource := &mockPaymentSessionSource{} - missionControl := &mockMissionControl{} - payer := &mockPaymentAttemptDispatcher{} - chain := newMockChain(startingBlockHeight) - chainView := newMockChainView(chain) - testGraph := createDummyTestGraph(t) - - // Define the behavior of the mockers to the point where we can - // successfully start the router. - controlTower.On("FetchInFlightPayments").Return( - []*channeldb.MPPayment{}, nil, - ) - payer.On("CleanStore", mock.Anything).Return(nil) - - // Create and start the router. - router, err := New(Config{ - Control: controlTower, - SessionSource: sessionSource, - MissionControl: missionControl, - Payer: payer, - - // TODO(yy): create new mocks for the chain and chainview. - Chain: chain, - ChainView: chainView, - - // TODO(yy): mock the graph once it's changed into interface. - Graph: testGraph.graph, - - Clock: clock.NewTestClock(time.Unix(1, 0)), - GraphPruneInterval: time.Hour * 2, - NextPaymentID: func() (uint64, error) { - next := atomic.AddUint64(&uniquePaymentID, 1) - return next, nil - }, - - IsAlias: func(scid lnwire.ShortChannelID) bool { - return false - }, - }) - require.NoError(t, err, "failed to create router") - - // Make sure the router can start and stop without error. - require.NoError(t, router.Start(), "router failed to start") - t.Cleanup(func() { - require.NoError(t, router.Stop(), "router failed to stop") - }) - - // Once the router is started, check that the mocked methods are called - // as expected. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - - // Mock the methods to the point where we are inside the function - // resumePayment. - paymentAmt := lnwire.MilliSatoshi(20000) - req := createDummyLightningPayment( - t, testGraph.aliasMap["c"], paymentAmt, - ) - identifier := lntypes.Hash(req.Identifier()) - session := &mockPaymentSession{} - sessionSource.On("NewPaymentSession", req).Return(session, nil) - controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - - // Mock the InFlightHTLCs. - var ( - htlcs []channeldb.HTLCAttempt - numAttempts atomic.Uint32 - failAttemptCount atomic.Uint32 - settled atomic.Bool - ) - - // Make a mock MPPayment. - payment := &mockMPPayment{} - payment.On("InFlightHTLCs").Return(htlcs). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("GetStatus").Return(channeldb.StatusInFlight) - controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() - - // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", identifier).Return(payment, nil). - Run(func(args mock.Arguments) { - // When number of attempts made is less than 4, we will - // mock the payment's methods to allow the lifecycle to - // continue. - attempts := numAttempts.Load() - if attempts < 6 { - payment.On("AllowMoreAttempts").Return(true, nil).Once() - return - } - - if !settled.Load() { - payment.On("AllowMoreAttempts").Return(false, nil).Once() - payment.On("NeedWaitAttempts").Return(true, nil).Once() - // We add another attempt to the counter to - // unblock next time. - numAttempts.Add(1) - return - } - - payment.On("AllowMoreAttempts").Return(false, nil). - On("NeedWaitAttempts").Return(false, nil) - }) - - // Create a route that can send 1/4 of the total amount. This value - // will be returned by calling RequestRoute. - shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) - require.NoError(t, err, "failed to create route") - session.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(shard, nil) - - // Make a new htlc attempt with zero fee and append it to the payment's - // HTLCs when calling RegisterAttempt. - controlTower.On("RegisterAttempt", - identifier, mock.Anything, - ).Return(nil).Run(func(args mock.Arguments) { - // Increase the counter whenever an attempt is made. - numAttempts.Add(1) - }) - - // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) - - // We use the failAttemptCount to track how many attempts we want to - // fail. Each time the following mock method is called, the count gets - // updated. - payer.On("GetAttemptResult", - mock.Anything, identifier, mock.Anything, - ).Run(func(args mock.Arguments) { - // Before the mock method is returned, we send the result to - // the read-only chan. - - // Update the counter. - failAttemptCount.Add(1) - - // We will make the first two attempts failed with temporary - // error. - if failAttemptCount.Load() <= 2 { - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, - 1, - ), - } - return - } - - // Otherwise we will mark the attempt succeeded. - payer.resultChan <- &htlcswitch.PaymentResult{} - }) - - // Mock the FailAttempt method to fail one of the attempts. - var failedAttempt channeldb.HTLCAttempt - controlTower.On("FailAttempt", - identifier, mock.Anything, mock.Anything, - ).Return(&failedAttempt, nil) - - // Setup ReportPaymentFail to return nil reason and error so the - // payment won't fail. - missionControl.On("ReportPaymentFail", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(nil, nil) - - // Simple mocking the rest. - payer.On("SendHTLC", - mock.Anything, mock.Anything, mock.Anything, - ).Return(nil) - missionControl.On("ReportPaymentSuccess", - mock.Anything, mock.Anything, - ).Return(nil) - - // Mock SettleAttempt by changing one of the HTLCs to be settled. - preimage := lntypes.Preimage{1, 2, 3} - settledAttempt := makeSettledAttempt( - int(paymentAmt/4), 0, preimage, - ) - controlTower.On("SettleAttempt", - identifier, mock.Anything, mock.Anything, - ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { - if numAttempts.Load() > 1 { - settled.Store(true) - } - }) - - controlTower.On("DeleteFailedAttempts", identifier).Return(nil) - - payment.On("TerminalInfo").Return(&settledAttempt, nil) - - // Call the actual method SendPayment on router. This is place inside a - // goroutine so we can set a timeout for the whole test, in case - // anything goes wrong and the test never finishes. - done := make(chan struct{}) - var p lntypes.Hash - go func() { - p, _, err = router.SendPayment(req) - close(done) - }() - - select { - case <-done: - case <-time.After(testTimeout): - t.Fatalf("SendPayment didn't exit") - } - - // Finally, validate the returned values and check that the mock - // methods are called as expected. - require.NoError(t, err, "send payment failed") - require.EqualValues(t, preimage, p, "preimage not match") - - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - sessionSource.AssertExpectations(t) - session.AssertExpectations(t) - missionControl.AssertExpectations(t) - payment.AssertExpectations(t) -} - -// TestSendMPPaymentFailed tests that when one of the shard fails with a -// terminal error, the router will stop attempting and the payment will fail. -// This test mainly focuses on testing the logic of the method resumePayment -// is implemented as expected. -func TestSendMPPaymentFailed(t *testing.T) { - const startingBlockHeight = 101 - - // Create mockers to initialize the router. - controlTower := &mockControlTower{} - sessionSource := &mockPaymentSessionSource{} - missionControl := &mockMissionControl{} - payer := &mockPaymentAttemptDispatcher{} - chain := newMockChain(startingBlockHeight) - chainView := newMockChainView(chain) - testGraph := createDummyTestGraph(t) - - // Define the behavior of the mockers to the point where we can - // successfully start the router. - controlTower.On("FetchInFlightPayments").Return( - []*channeldb.MPPayment{}, nil, - ) - payer.On("CleanStore", mock.Anything).Return(nil) - - // Create and start the router. - router, err := New(Config{ - Control: controlTower, - SessionSource: sessionSource, - MissionControl: missionControl, - Payer: payer, - - // TODO(yy): create new mocks for the chain and chainview. - Chain: chain, - ChainView: chainView, - - // TODO(yy): mock the graph once it's changed into interface. - Graph: testGraph.graph, - - Clock: clock.NewTestClock(time.Unix(1, 0)), - GraphPruneInterval: time.Hour * 2, - NextPaymentID: func() (uint64, error) { - next := atomic.AddUint64(&uniquePaymentID, 1) - return next, nil - }, - - IsAlias: func(scid lnwire.ShortChannelID) bool { - return false - }, - }) - require.NoError(t, err, "failed to create router") - - // Make sure the router can start and stop without error. - require.NoError(t, router.Start(), "router failed to start") - t.Cleanup(func() { - require.NoError(t, router.Stop(), "router failed to stop") - }) - - // Once the router is started, check that the mocked methods are called - // as expected. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - - // Mock the methods to the point where we are inside the function - // resumePayment. - paymentAmt := lnwire.MilliSatoshi(10000) - req := createDummyLightningPayment( - t, testGraph.aliasMap["c"], paymentAmt, - ) - identifier := lntypes.Hash(req.Identifier()) - session := &mockPaymentSession{} - sessionSource.On("NewPaymentSession", req).Return(session, nil) - controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - - // Mock the InFlightHTLCs. - var ( - htlcs []channeldb.HTLCAttempt - numAttempts atomic.Uint32 - failAttemptCount atomic.Uint32 - failed atomic.Bool - numParts = uint32(4) - ) - - // Make a mock MPPayment. - payment := &mockMPPayment{} - payment.On("InFlightHTLCs").Return(htlcs).Once() - payment.On("GetState").Return(&channeldb.MPPaymentState{}). - On("GetStatus").Return(channeldb.StatusInFlight) - controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() - - // Mock the sequential FetchPayment to return the payment. - controlTower.On("FetchPayment", identifier).Return(payment, nil).Run( - func(_ mock.Arguments) { - // We want to at least send out all parts in order to - // wait for them later. - if numAttempts.Load() < numParts { - payment.On("AllowMoreAttempts").Return(true, nil).Once() - return - } - - // Wait if the payment wasn't failed yet. - if !failed.Load() { - payment.On("AllowMoreAttempts").Return(false, nil).Once(). - On("NeedWaitAttempts").Return(true, nil).Once() - return - } - - payment.On("AllowMoreAttempts").Return(false, nil). - On("NeedWaitAttempts").Return(false, nil).Once() - }) - - // Create a route that can send 1/4 of the total amount. This value - // will be returned by calling RequestRoute. - shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) - require.NoError(t, err, "failed to create route") - session.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(shard, nil) - - // Make a new htlc attempt with zero fee and append it to the payment's - // HTLCs when calling RegisterAttempt. - controlTower.On("RegisterAttempt", - identifier, mock.Anything, - ).Return(nil).Run(func(args mock.Arguments) { - numAttempts.Add(1) - }) - - // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) - - // We use the failAttemptCount to track how many attempts we want to - // fail. Each time the following mock method is called, the count gets - // updated. - payer.On("GetAttemptResult", - mock.Anything, identifier, mock.Anything, - ).Run(func(_ mock.Arguments) { - // Before the mock method is returned, we send the result to - // the read-only chan. - - // Update the counter. - failAttemptCount.Add(1) - - // We fail the first attempt with terminal error. - if failAttemptCount.Load() == 1 { - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: htlcswitch.NewForwardingError( - &lnwire.FailIncorrectDetails{}, - 1, - ), - } - return - } - - // We will make the rest attempts failed with temporary error. - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, - 1, - ), - } - }) - - // Mock the FailAttempt method to fail one of the attempts. - var failedAttempt channeldb.HTLCAttempt - controlTower.On("FailAttempt", - identifier, mock.Anything, mock.Anything, - ).Return(&failedAttempt, nil) - - // Setup ReportPaymentFail to return nil reason and error so the - // payment won't fail. - failureReason := channeldb.FailureReasonPaymentDetails - missionControl.On("ReportPaymentFail", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&failureReason, nil) - - // Simple mocking the rest. - controlTower.On("FailPayment", - identifier, failureReason, - ).Return(nil).Run(func(_ mock.Arguments) { - failed.Store(true) - }) - - // Mock the payment to return the failure reason. - payer.On("SendHTLC", - mock.Anything, mock.Anything, mock.Anything, - ).Return(nil) - - payment.On("TerminalInfo").Return(nil, &failureReason) - - controlTower.On("DeleteFailedAttempts", identifier).Return(nil) - - // Call the actual method SendPayment on router. This is place inside a - // goroutine so we can set a timeout for the whole test, in case - // anything goes wrong and the test never finishes. - done := make(chan struct{}) - var p lntypes.Hash - go func() { - p, _, err = router.SendPayment(req) - close(done) - }() - - select { - case <-done: - case <-time.After(testTimeout): - t.Fatalf("SendPayment didn't exit") - } - - // Finally, validate the returned values and check that the mock - // methods are called as expected. - require.Error(t, err, "expected send payment error") - require.EqualValues(t, [32]byte{}, p, "preimage not match") - require.GreaterOrEqual(t, failAttemptCount.Load(), uint32(1)) - - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - sessionSource.AssertExpectations(t) - session.AssertExpectations(t) - missionControl.AssertExpectations(t) - payment.AssertExpectations(t) -} - // TestBlockDifferenceFix tests if when the router is behind on blocks, the // router catches up to the best block head. func TestBlockDifferenceFix(t *testing.T) { From ddad6ad4c4d67e9ef813f8bcead101ab05e03716 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 13 Feb 2023 20:57:18 +0800 Subject: [PATCH 16/27] routing: update mockers in unit test This commit adds more mockers to be used in coming unit tests and simplified the mockers to be more straightforward. --- routing/mock_test.go | 109 +++++++++++++++++++++++++++++++++-------- routing/router_test.go | 32 ++++-------- 2 files changed, 98 insertions(+), 43 deletions(-) diff --git a/routing/mock_test.go b/routing/mock_test.go index 2458f538b..f712c420d 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -12,7 +12,9 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/routing/shards" "github.com/stretchr/testify/mock" ) @@ -572,8 +574,6 @@ func (m *mockControlTowerOld) SubscribeAllPayments() ( type mockPaymentAttemptDispatcher struct { mock.Mock - - resultChan chan *htlcswitch.PaymentResult } var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) @@ -589,11 +589,14 @@ func (m *mockPaymentAttemptDispatcher) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) ( <-chan *htlcswitch.PaymentResult, error) { - m.Called(attemptID, paymentHash, deobfuscator) + args := m.Called(attemptID, paymentHash, deobfuscator) - // Instead of returning the mocked returned values, we need to return - // the chan resultChan so it can be converted into a read-only chan. - return m.resultChan, nil + resultChan := args.Get(0) + if resultChan == nil { + return nil, args.Error(1) + } + + return args.Get(0).(chan *htlcswitch.PaymentResult), args.Error(1) } func (m *mockPaymentAttemptDispatcher) CleanStore( @@ -698,7 +701,6 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, type mockControlTower struct { mock.Mock - sync.Mutex } var _ ControlTower = (*mockControlTower)(nil) @@ -718,9 +720,6 @@ func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error { func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, a *channeldb.HTLCAttemptInfo) error { - m.Lock() - defer m.Unlock() - args := m.Called(phash, a) return args.Error(0) } @@ -729,29 +728,32 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( *channeldb.HTLCAttempt, error) { - m.Lock() - defer m.Unlock() - args := m.Called(phash, pid, settleInfo) - return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) + + attempt := args.Get(0) + if attempt == nil { + return nil, args.Error(1) + } + + return attempt.(*channeldb.HTLCAttempt), args.Error(1) } func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { - m.Lock() - defer m.Unlock() - args := m.Called(phash, pid, failInfo) + + attempt := args.Get(0) + if attempt == nil { + return nil, args.Error(1) + } + return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) } func (m *mockControlTower) FailPayment(phash lntypes.Hash, reason channeldb.FailureReason) error { - m.Lock() - defer m.Unlock() - args := m.Called(phash, reason) return args.Error(0) } @@ -877,3 +879,70 @@ func (m *mockLink) EligibleToForward() bool { func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error { return m.mayAddOutgoingErr } + +type mockShardTracker struct { + mock.Mock +} + +var _ shards.ShardTracker = (*mockShardTracker)(nil) + +func (m *mockShardTracker) NewShard(attemptID uint64, + lastShard bool) (shards.PaymentShard, error) { + + args := m.Called(attemptID, lastShard) + + shard := args.Get(0) + if shard == nil { + return nil, args.Error(1) + } + + return shard.(shards.PaymentShard), args.Error(1) +} + +func (m *mockShardTracker) GetHash(attemptID uint64) (lntypes.Hash, error) { + args := m.Called(attemptID) + return args.Get(0).(lntypes.Hash), args.Error(1) +} + +func (m *mockShardTracker) CancelShard(attemptID uint64) error { + args := m.Called(attemptID) + return args.Error(0) +} + +type mockShard struct { + mock.Mock +} + +var _ shards.PaymentShard = (*mockShard)(nil) + +// Hash returns the hash used for the HTLC representing this shard. +func (m *mockShard) Hash() lntypes.Hash { + args := m.Called() + return args.Get(0).(lntypes.Hash) +} + +// MPP returns any extra MPP records that should be set for the final +// hop on the route used by this shard. +func (m *mockShard) MPP() *record.MPP { + args := m.Called() + + r := args.Get(0) + if r == nil { + return nil + } + + return r.(*record.MPP) +} + +// AMP returns any extra AMP records that should be set for the final +// hop on the route used by this shard. +func (m *mockShard) AMP() *record.AMP { + args := m.Called() + + r := args.Get(0) + if r == nil { + return nil + } + + return r.(*record.AMP) +} diff --git a/routing/router_test.go b/routing/router_test.go index 79f0b026c..3f1f5dcb1 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3528,12 +3528,12 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { ).Return(nil) // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 1) + resultChan := make(chan *htlcswitch.PaymentResult, 1) payer.On("GetAttemptResult", mock.Anything, mock.Anything, mock.Anything, - ).Run(func(_ mock.Arguments) { + ).Return(resultChan, nil).Run(func(_ mock.Arguments) { // Send a successful payment result. - payer.resultChan <- &htlcswitch.PaymentResult{} + resultChan <- &htlcswitch.PaymentResult{} }) missionControl.On("ReportPaymentSuccess", @@ -3599,6 +3599,11 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { }, }} + // Create the error to be returned. + tempErr := htlcswitch.NewForwardingError( + &lnwire.FailTemporaryChannelFailure{}, 1, + ) + // Register mockers with the expected method calls. controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) @@ -3608,26 +3613,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, - ).Return(nil) - - // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 1) - - // Create the error to be returned. - tempErr := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, - 1, - ) - - // Mock GetAttemptResult to return a failure. - payer.On("GetAttemptResult", - mock.Anything, mock.Anything, mock.Anything, - ).Run(func(_ mock.Arguments) { - // Send an attempt failure. - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: tempErr, - } - }) + ).Return(tempErr) // Mock the control tower to return the mocked payment. payment := &mockMPPayment{} From e46c689bf1501c32841f8e837051a644b653b349 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 8 Mar 2023 01:14:32 +0800 Subject: [PATCH 17/27] routing: refactor attempt makers to return pointers Thus adding following unit tests can be a bit easier. --- routing/payment_lifecycle_test.go | 62 +++++++++++++++++++++++++------ routing/router_test.go | 16 ++++---- 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 3f6dc1e71..e90f162fc 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcec/v2" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lntypes" @@ -17,31 +18,68 @@ var ( dummyErr = errors.New("dummy") ) -func makeSettledAttempt(total, fee int, - preimage lntypes.Preimage) channeldb.HTLCAttempt { +// createDummyRoute builds a route a->b->c paying the given amt to c. +func createDummyRoute(t *testing.T, amt lnwire.MilliSatoshi) *route.Route { + t.Helper() - return channeldb.HTLCAttempt{ - HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + priv, err := btcec.NewPrivateKey() + require.NoError(t, err, "failed to create private key") + hop1 := route.NewVertex(priv.PubKey()) + + priv, err = btcec.NewPrivateKey() + require.NoError(t, err, "failed to create private key") + hop2 := route.NewVertex(priv.PubKey()) + + hopFee := lnwire.NewMSatFromSatoshis(3) + hops := []*route.Hop{ + { + ChannelID: 1, + PubKeyBytes: hop1, + LegacyPayload: true, + AmtToForward: amt + hopFee, + }, + { + ChannelID: 2, + PubKeyBytes: hop2, + LegacyPayload: true, + AmtToForward: amt, + }, + } + + priv, err = btcec.NewPrivateKey() + require.NoError(t, err, "failed to create private key") + source := route.NewVertex(priv.PubKey()) + + // We create a simple route that we will supply every time the router + // requests one. + rt, err := route.NewRouteFromHops(amt+2*hopFee, 100, source, hops) + require.NoError(t, err, "failed to create route") + + return rt +} + +func makeSettledAttempt(t *testing.T, total int, + preimage lntypes.Preimage) *channeldb.HTLCAttempt { + + return &channeldb.HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(t, total), Settle: &channeldb.HTLCSettleInfo{Preimage: preimage}, } } -func makeFailedAttempt(total, fee int) *channeldb.HTLCAttempt { +func makeFailedAttempt(t *testing.T, total int) *channeldb.HTLCAttempt { return &channeldb.HTLCAttempt{ - HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + HTLCAttemptInfo: makeAttemptInfo(t, total), Failure: &channeldb.HTLCFailInfo{ Reason: channeldb.HTLCFailInternal, }, } } -func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo { - hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)} +func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo { + rt := createDummyRoute(t, lnwire.MilliSatoshi(amt)) return channeldb.HTLCAttemptInfo{ - Route: route.Route{ - TotalAmount: lnwire.MilliSatoshi(total), - Hops: []*route.Hop{hop}, - }, + Route: *rt, } } diff --git a/routing/router_test.go b/routing/router_test.go index 3f1f5dcb1..0625d430d 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3483,7 +3483,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { ) preimage := lntypes.Preimage{1} - testAttempt := makeSettledAttempt(int(payAmt), 0, preimage) + testAttempt := makeSettledAttempt(t, int(payAmt), preimage) node, err := createTestNode() require.NoError(t, err) @@ -3521,7 +3521,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) controlTower.On("SettleAttempt", payHash, mock.Anything, mock.Anything, - ).Return(&testAttempt, nil) + ).Return(testAttempt, nil) payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, @@ -3550,7 +3550,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { // Expect a successful send to route. attempt, err := router.SendToRouteSkipTempErr(payHash, rt) require.NoError(t, err) - require.Equal(t, &testAttempt, attempt) + require.Equal(t, testAttempt, attempt) // Assert the above methods are called as expected. controlTower.AssertExpectations(t) @@ -3563,11 +3563,11 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { // cause the payment to be failed. func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { var ( - payHash lntypes.Hash - payAmt = lnwire.MilliSatoshi(10000) - testAttempt = &channeldb.HTLCAttempt{} + payHash lntypes.Hash + payAmt = lnwire.MilliSatoshi(10000) ) + testAttempt := makeFailedAttempt(t, int(payAmt)) node, err := createTestNode() require.NoError(t, err) @@ -3648,7 +3648,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { payAmt = lnwire.MilliSatoshi(10000) ) - testAttempt := makeFailedAttempt(int(payAmt), 0) + testAttempt := makeFailedAttempt(t, int(payAmt)) node, err := createTestNode() require.NoError(t, err) @@ -3733,7 +3733,7 @@ func TestSendToRouteTempFailure(t *testing.T) { payAmt = lnwire.MilliSatoshi(10000) ) - testAttempt := makeFailedAttempt(int(payAmt), 0) + testAttempt := makeFailedAttempt(t, int(payAmt)) node, err := createTestNode() require.NoError(t, err) From 10052ff4f5564b8ce2382316ad99a9064a9d43c2 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 8 Mar 2023 02:50:33 +0800 Subject: [PATCH 18/27] routing: patch unit tests for payment lifecycle This commit adds unit tests for `resumePayment`. In addition, the `resumePayment` has been split into two parts so it's easier to be tested, 1) sending the htlc, and 2) collecting results. As seen in the new tests, this split largely reduces the complexity involved and makes the unit test flow sequential. This commit also makes full use of `mock.Mock` in the unit tests to provide a more clear testing flow. --- routing/payment_lifecycle.go | 16 +- routing/payment_lifecycle_test.go | 1164 ++++++++++++++++++++++++++++- 2 files changed, 1156 insertions(+), 24 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 38489b141..20c9c9cde 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -41,6 +41,12 @@ type paymentLifecycle struct { // or failed with temporary error. Otherwise, we should exit the // lifecycle loop as a terminal error has occurred. resultCollected chan error + + // resultCollector is a function that is used to collect the result of + // an HTLC attempt, which is always mounted to `p.collectResultAsync` + // except in unit test, where we use a much simpler resultCollector to + // decouple the test flow for the payment lifecycle. + resultCollector func(attempt *channeldb.HTLCAttempt) } // newPaymentLifecycle initiates a new payment lifecycle and returns it. @@ -60,6 +66,9 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, resultCollected: make(chan error, 1), } + // Mount the result collector. + p.resultCollector = p.collectResultAsync + // If a timeout is specified, create a timeout channel. If no timeout is // specified, the channel is left nil and will never abort the payment // loop. @@ -178,7 +187,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { log.Infof("Resuming payment shard %v for payment %v", a.AttemptID, p.identifier) - p.collectResultAsync(&a) + p.resultCollector(&a) } // exitWithErr is a helper closure that logs and returns an error. @@ -280,7 +289,7 @@ lifecycle: // Now that the shard was successfully sent, launch a go // routine that will handle its result when its back. if result.err == nil { - p.collectResultAsync(attempt) + p.resultCollector(attempt) } } @@ -416,6 +425,9 @@ type attemptResult struct { // will send a nil error to channel `resultCollected` to indicate there's an // result. func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { + log.Debugf("Collecting result for attempt %v in payment %v", + attempt.AttemptID, p.identifier) + go func() { // Block until the result is available. _, err := p.collectResult(attempt) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index e90f162fc..f4b637ee4 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -1,22 +1,211 @@ package routing import ( + "sync/atomic" "testing" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lnmock" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) -var ( - dummyErr = errors.New("dummy") -) +var errDummy = errors.New("dummy") + +// createTestPaymentLifecycle creates a `paymentLifecycle` without mocks. +func createTestPaymentLifecycle() *paymentLifecycle { + paymentHash := lntypes.Hash{1, 2, 3} + quitChan := make(chan struct{}) + rt := &ChannelRouter{ + cfg: &Config{}, + quit: quitChan, + } + + return &paymentLifecycle{ + router: rt, + identifier: paymentHash, + } +} + +// mockers wraps a list of mocked interfaces used inside payment lifecycle. +type mockers struct { + shard *mockShard + shardTracker *mockShardTracker + control *mockControlTower + paySession *mockPaymentSession + payer *mockPaymentAttemptDispatcher + clock *lnmock.MockClock + missionControl *mockMissionControl + + // collectResultsCount is the number of times the collectResultAsync + // has been called. + collectResultsCount int + + // payment is the mocked `dbMPPayment` used in the test. + payment *mockMPPayment +} + +// newTestPaymentLifecycle creates a `paymentLifecycle` using the mocks. It +// also asserts the mockers are called as expected when the test is finished. +func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { + paymentHash := lntypes.Hash{1, 2, 3} + quitChan := make(chan struct{}) + + // Create a mock shard to be return from `NewShard`. + mockShard := &mockShard{} + + // Create a list of mocks and add it to the router config. + mockControlTower := &mockControlTower{} + mockPayer := &mockPaymentAttemptDispatcher{} + mockClock := &lnmock.MockClock{} + mockMissionControl := &mockMissionControl{} + + // Make a channel router. + rt := &ChannelRouter{ + cfg: &Config{ + Control: mockControlTower, + Payer: mockPayer, + Clock: mockClock, + MissionControl: mockMissionControl, + }, + quit: quitChan, + } + + // Create mockers to init a payment lifecycle. + mockPaymentSession := &mockPaymentSession{} + mockShardTracker := &mockShardTracker{} + + // Create a test payment lifecycle with no fee limit and no timeout. + p := newPaymentLifecycle( + rt, noFeeLimit, paymentHash, mockPaymentSession, + mockShardTracker, 0, 0, + ) + + // Create a mock payment which is returned from mockControlTower. + mockPayment := &mockMPPayment{} + + mockers := &mockers{ + shard: mockShard, + shardTracker: mockShardTracker, + control: mockControlTower, + paySession: mockPaymentSession, + payer: mockPayer, + clock: mockClock, + missionControl: mockMissionControl, + payment: mockPayment, + } + + // Overwrite the collectResultAsync to focus on testing the payment + // lifecycle within the goroutine. + resultCollector := func(attempt *channeldb.HTLCAttempt) { + mockers.collectResultsCount++ + } + p.resultCollector = resultCollector + + // Validate the mockers are called as expected before exiting the test. + t.Cleanup(func() { + mockShard.AssertExpectations(t) + mockShardTracker.AssertExpectations(t) + mockControlTower.AssertExpectations(t) + mockPaymentSession.AssertExpectations(t) + mockPayer.AssertExpectations(t) + mockClock.AssertExpectations(t) + mockMissionControl.AssertExpectations(t) + mockPayment.AssertExpectations(t) + }) + + return p, mockers +} + +// setupTestPaymentLifecycle creates a new `paymentLifecycle` and mocks the +// initial steps of the payment lifecycle so we can enter into the loop +// directly. +func setupTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { + p, m := newTestPaymentLifecycle(t) + + // Mock the first two calls. + m.control.On("FetchPayment", p.identifier).Return( + m.payment, nil, + ).Once() + + htlcs := []channeldb.HTLCAttempt{} + m.payment.On("InFlightHTLCs").Return(htlcs).Once() + + return p, m +} + +// resumePaymentResult is used to hold the returned values from +// `resumePayment`. +type resumePaymentResult struct { + preimage lntypes.Hash + err error +} + +// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error +// is returned. +func sendPaymentAndAssertFailed(t *testing.T, + p *paymentLifecycle, errExpected error) { + + resultChan := make(chan *resumePaymentResult, 1) + + // We now make a call to `resumePayment` and expect it to return the + // error. + go func() { + preimage, _, err := p.resumePayment() + resultChan <- &resumePaymentResult{ + preimage: preimage, + err: err, + } + }() + + // Validate the returned values or timeout. + select { + case r := <-resultChan: + require.ErrorIs(t, r.err, errExpected, "expected error") + require.Empty(t, r.preimage, "preimage should be empty") + + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for result") + } +} + +// sendPaymentAndAssertSucceeded calls `resumePayment` and asserts that the +// returned preimage is correct. +func sendPaymentAndAssertSucceeded(t *testing.T, + p *paymentLifecycle, expected lntypes.Preimage) { + + resultChan := make(chan *resumePaymentResult, 1) + + // We now make a call to `resumePayment` and expect it to return the + // preimage. + go func() { + preimage, _, err := p.resumePayment() + resultChan <- &resumePaymentResult{ + preimage: preimage, + err: err, + } + }() + + // Validate the returned values or timeout. + select { + case r := <-resultChan: + require.NoError(t, r.err, "unexpected error") + require.EqualValues(t, expected, r.preimage, + "preimage not match") + + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for result") + } +} // createDummyRoute builds a route a->b->c paying the given amt to c. func createDummyRoute(t *testing.T, amt lnwire.MilliSatoshi) *route.Route { @@ -149,21 +338,6 @@ func TestCheckTimeoutOnRouterQuit(t *testing.T) { require.ErrorIs(t, err, ErrRouterShuttingDown) } -// createTestPaymentLifecycle creates a `paymentLifecycle` using the mocks. -func createTestPaymentLifecycle() *paymentLifecycle { - paymentHash := lntypes.Hash{1, 2, 3} - quitChan := make(chan struct{}) - rt := &ChannelRouter{ - cfg: &Config{}, - quit: quitChan, - } - - return &paymentLifecycle{ - router: rt, - identifier: paymentHash, - } -} - // TestRequestRouteSucceed checks that `requestRoute` can successfully request // a route. func TestRequestRouteSucceed(t *testing.T) { @@ -409,9 +583,9 @@ func TestDecideNextStep(t *testing.T) { }, { name: "error on allow more attempts", - allowMoreAttempts: &mockReturn{false, dummyErr}, + allowMoreAttempts: &mockReturn{false, errDummy}, expectedStep: stepExit, - expectedErr: dummyErr, + expectedErr: errDummy, }, { name: "no wait and exit", @@ -423,9 +597,9 @@ func TestDecideNextStep(t *testing.T) { { name: "wait returns an error", allowMoreAttempts: &mockReturn{false, nil}, - needWaitAttempts: &mockReturn{false, dummyErr}, + needWaitAttempts: &mockReturn{false, errDummy}, expectedStep: stepExit, - expectedErr: dummyErr, + expectedErr: errDummy, }, { @@ -491,3 +665,949 @@ func TestDecideNextStep(t *testing.T) { payment.AssertExpectations(t) } } + +// TestResumePaymentFailOnFetchPayment checks when we fail to fetch the +// payment, the error is returned. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnFetchPayment(t *testing.T) { + // Create a test paymentLifecycle. + p, m := newTestPaymentLifecycle(t) + + // Mock an error returned. + m.control.On("FetchPayment", p.identifier).Return(nil, errDummy) + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnTimeout checks that when timeout is reached, the +// payment is failed. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnTimeout(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. make the timeout happens instantly and sleep one millisecond to + // make sure it timed out. + p.timeoutChan = time.After(1 * time.Nanosecond) + time.Sleep(1 * time.Millisecond) + + // 4. the payment should be failed with reason timeout. + m.control.On("FailPayment", + p.identifier, channeldb.FailureReasonTimeout, + ).Return(nil).Once() + + // 5. decideNextStep now returns stepExit. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // 6. control tower deletes failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // 7. the payment returns the failed reason. + reason := channeldb.FailureReasonTimeout + m.payment.On("TerminalInfo").Return(nil, &reason) + + // Send the payment and assert it failed with the timeout reason. + sendPaymentAndAssertFailed(t, p, reason) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnTimeoutErr checks that the lifecycle fails when an +// error is returned from `checkTimeout`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnTimeoutErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. quit the router to return an error. + close(p.router.quit) + + // Send the payment and assert it failed when router is shutting down. + sendPaymentAndAssertFailed(t, p, ErrRouterShuttingDown) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnStepErr checks that the lifecycle fails when an +// error is returned from `decideNextStep`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnStepErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns an error. + m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnRequestRouteErr checks that the lifecycle fails when +// an error is returned from `requestRoute`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnRequestRouteErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 4. mock requestRoute to return an error. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(nil, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnRegisterAttemptErr checks that the lifecycle fails +// when an error is returned from `registerAttempt`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 5. mock shardTracker used in `createNewPaymentAttempt` to return an + // error. + // + // Mock NextPaymentID to always return the attemptID. + attemptID := uint64(1) + p.router.cfg.NextPaymentID = func() (uint64, error) { + return attemptID, nil + } + + // Return an error to end the lifecycle. + m.shardTracker.On("NewShard", + attemptID, true, + ).Return(nil, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnSendAttemptErr checks that the lifecycle fails when +// an error is returned from `sendAttempt`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 5. mock `registerAttempt` to return an attempt. + // + // Mock NextPaymentID to always return the attemptID. + attemptID := uint64(1) + p.router.cfg.NextPaymentID = func() (uint64, error) { + return attemptID, nil + } + + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID, true, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // Mock the time and expect it to be call twice. + m.clock.On("Now").Return(time.Now()).Twice() + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Once() + + // 6. mock `sendAttempt` to return an error. + m.payer.On("SendHTLC", + mock.Anything, attemptID, mock.Anything, + ).Return(errDummy).Once() + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd fail the payment, cancel the shard and fail the attempt. + // + // `FailPayment` should be called with an internal reason. + reason := channeldb.FailureReasonError + m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() + + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a dummy error to exit the loop. + m.control.On("FailAttempt", + p.identifier, attemptID, mock.Anything, + ).Return(nil, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentSuccess checks that a normal payment flow that is +// succeeded. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentSuccess(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt) + + // We now enter the payment lifecycle loop. + // + // 1.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 1.2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 1.3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 1.4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 1.5. mock `registerAttempt` to return an attempt. + // + // Mock NextPaymentID to always return the attemptID. + attemptID := uint64(1) + p.router.cfg.NextPaymentID = func() (uint64, error) { + return attemptID, nil + } + + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID, true, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // Mock the time and expect it to be called. + m.clock.On("Now").Return(time.Now()) + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Once() + + // 1.6. mock `sendAttempt` to succeed, which brings us into the next + // iteration of the lifecycle. + m.payer.On("SendHTLC", + mock.Anything, attemptID, mock.Anything, + ).Return(nil).Once() + + // We now enter the second iteration of the lifecycle loop. + // + // 2.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2.2. calls `GetState` and return the state. + m.payment.On("GetState").Return(ps).Run(func(args mock.Arguments) { + ps.RemainingAmt = 0 + }).Once() + + // 2.3. decideNextStep now returns stepExit and exits the loop. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // We should perform an optional deletion over failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // Finally, mock the `TerminalInfo` to return the settled attempt. + // Create a SettleAttempt. + testPreimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt(t, int(paymentAmt), testPreimage) + m.payment.On("TerminalInfo").Return(settledAttempt, nil).Once() + + // Send the payment and assert the preimage is matched. + sendPaymentAndAssertSucceeded(t, p, testPreimage) + + // Expected collectResultAsync to called. + require.Equal(t, 1, m.collectResultsCount) +} + +// TestResumePaymentSuccessWithTwoAttempts checks a successful payment flow +// with two HTLC attempts. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt/2) + + // We now enter the payment lifecycle loop. + // + // 1.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 1.2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 1.3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 1.4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // Create two attempt IDs here. + attemptID1 := uint64(1) + attemptID2 := uint64(2) + + // 1.5. mock `registerAttempt` to return an attempt. + // + // Mock NextPaymentID to return the first attemptID on the first call + // and the second attemptID on the second call. + var numAttempts atomic.Uint64 + p.router.cfg.NextPaymentID = func() (uint64, error) { + numAttempts.Add(1) + if numAttempts.Load() == 1 { + return attemptID1, nil + } + + return attemptID2, nil + } + + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID1, false, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // Mock the time and expect it to be called. + m.clock.On("Now").Return(time.Now()) + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + ps.RemainingAmt = paymentAmt / 2 + }).Once() + + // 1.6. mock `sendAttempt` to succeed, which brings us into the next + // iteration of the lifecycle where we mock a temp failure. + m.payer.On("SendHTLC", + mock.Anything, attemptID1, mock.Anything, + ).Return(nil).Once() + + // We now enter the second iteration of the lifecycle loop. + // + // 2.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2.2. calls `GetState` and return the state. + m.payment.On("GetState").Return(ps).Once() + + // 2.3. decideNextStep now returns stepProceed so we can send the + // second attempt. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 2.4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt/2, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 2.5. mock `registerAttempt` to return an attempt. + // + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID2, true, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Once() + + // 2.6. mock `sendAttempt` to succeed, which brings us into the next + // iteration of the lifecycle. + m.payer.On("SendHTLC", + mock.Anything, attemptID2, mock.Anything, + ).Return(nil).Once() + + // We now enter the third iteration of the lifecycle loop. + // + // 3.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 3.2. calls `GetState` and return the state. + m.payment.On("GetState").Return(ps).Once() + + // 3.3. decideNextStep now returns stepExit to exit the loop. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // We should perform an optional deletion over failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // Finally, mock the `TerminalInfo` to return the settled attempt. + // Create a SettleAttempt. + testPreimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt(t, int(paymentAmt), testPreimage) + m.payment.On("TerminalInfo").Return(settledAttempt, nil).Once() + + // Send the payment and assert the preimage is matched. + sendPaymentAndAssertSucceeded(t, p, testPreimage) + + // Expected collectResultAsync to called. + require.Equal(t, 2, m.collectResultsCount) +} + +// TestCollectResultExitOnErr checks that when there's an error returned from +// htlcswitch via `GetAttemptResult`, it's handled and returned. +func TestCollectResultExitOnErr(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a dummy error. + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(nil, errDummy).Once() + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd fail the payment, cancel the shard and fail the attempt. + // + // `FailPayment` should be called with an internal reason. + reason := channeldb.FailureReasonError + m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() + + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a switch error. + switchErr := errors.New("switch err") + m.control.On("FailAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, switchErr).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, switchErr, "expected switch error") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnResultErr checks that when there's an error returned +// from htlcswitch via the result channel, it's handled and returned. +func TestCollectResultExitOnResultErr(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send an error to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Error: errDummy, + } + }) + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd fail the payment, cancel the shard and fail the attempt. + // + // `FailPayment` should be called with an internal reason. + reason := channeldb.FailureReasonError + m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() + + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a switch error. + switchErr := errors.New("switch err") + m.control.On("FailAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, switchErr).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, switchErr, "expected switch error") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnSwitcQuit checks that when the htlcswitch is shutting +// down an error is returned. +func TestCollectResultExitOnSwitchQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Close the result chan to simulate a htlcswitch quit. + close(resultChan) + }) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, htlcswitch.ErrSwitchExiting, + "expected switch exit") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnRouterQuit checks that when the channel router is +// shutting down an error is returned. +func TestCollectResultExitOnRouterQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Close the channel router. + close(p.router.quit) + }) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, ErrRouterShuttingDown, "expected router exit") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnLifecycleQuit checks that when the payment lifecycle +// is shutting down an error is returned. +func TestCollectResultExitOnLifecycleQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Stop the lifecycle. + p.stop() + }) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, ErrPaymentLifecycleExiting, + "expected lifecycle exit") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnSettleErr checks that when settling the attempt +// fails an error is returned. +func TestCollectResultExitOnSettleErr(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Preimage: preimage, + } + }) + + // Once the result is received, `ReportPaymentSuccess` should be + // called. + m.missionControl.On("ReportPaymentSuccess", + attempt.AttemptID, &attempt.Route, + ).Return(nil).Once() + + // Now mock an error being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, errDummy).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, errDummy, "expected settle error") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultSuccess checks a successful htlc settlement. +func TestCollectResultSuccess(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Preimage: preimage, + } + }) + + // Once the result is received, `ReportPaymentSuccess` should be + // called. + m.missionControl.On("ReportPaymentSuccess", + attempt.AttemptID, &attempt.Route, + ).Return(nil).Once() + + // Now the settled htlc being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(attempt, nil).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.NoError(t, err, "expected no error") + require.Equal(t, preimage, result.attempt.Settle.Preimage, + "preimage mismatch") +} + +// TestCollectResultAsyncSuccess checks a successful htlc settlement. +func TestCollectResultAsyncSuccess(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Preimage: preimage, + } + }) + + // Once the result is received, `ReportPaymentSuccess` should be + // called. + m.missionControl.On("ReportPaymentSuccess", + attempt.AttemptID, &attempt.Route, + ).Return(nil).Once() + + // Now the settled htlc being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(attempt, nil).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + p.collectResultAsync(attempt) + + // Assert the result is returned within 5 seconds. + var err error + waitErr := wait.NoError(func() error { + err = <-p.resultCollected + return nil + }, testTimeout) + require.NoError(t, waitErr, "timeout waiting for result") + + // Assert that a nil error is received. + require.NoError(t, err, "expected no error") +} From 6e93764bc125e6ff5c529c049868e0b261d0610f Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 9 Mar 2023 02:21:39 +0800 Subject: [PATCH 19/27] routing: make sure payment hash is random in unit tests This commit makes sure a testing payment is created via `createDummyLightningPayment` to ensure the payment hash is unique to avoid collision of the same payment hash being used in uint tests. Since the tests are running in parallel and accessing db, if two difference tests are using the same payment hash, no clean test state can be guaranteed. --- routing/router_test.go | 191 ++++++++++++++++++----------------------- 1 file changed, 84 insertions(+), 107 deletions(-) diff --git a/routing/router_test.go b/routing/router_test.go index 0625d430d..1eb39777f 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -309,14 +309,10 @@ func TestSendPaymentRouteFailureFallback(t *testing.T) { // Craft a LightningPayment struct that'll send a payment from roasbeef // to luo ji for 1000 satoshis, with a maximum of 1000 satoshis in fees. - var payHash lntypes.Hash paymentAmt := lnwire.NewMSatFromSatoshis(1000) - payment := LightningPayment{ - Target: ctx.aliases["sophon"], - Amount: paymentAmt, - FeeLimit: noFeeLimit, - paymentHash: &payHash, - } + payment := createDummyLightningPayment( + t, ctx.aliases["sophon"], paymentAmt, + ) var preImage [32]byte copy(preImage[:], bytes.Repeat([]byte{9}, 32)) @@ -347,8 +343,9 @@ func TestSendPaymentRouteFailureFallback(t *testing.T) { // Send off the payment request to the router, route through pham nuwen // should've been selected as a fall back and succeeded correctly. - paymentPreImage, route, err := ctx.router.SendPayment(&payment) - require.NoError(t, err, "unable to send payment") + paymentPreImage, route, err := ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) // The route selected should have two hops require.Equal(t, 2, len(route.Hops), "incorrect route length") @@ -386,22 +383,18 @@ func TestSendPaymentRouteInfiniteLoopWithBadHopHint(t *testing.T) { // Craft a LightningPayment struct that'll send a payment from roasbeef // to songoku for 1000 satoshis. - var payHash lntypes.Hash paymentAmt := lnwire.NewMSatFromSatoshis(1000) - payment := LightningPayment{ - Target: ctx.aliases["songoku"], - Amount: paymentAmt, - FeeLimit: noFeeLimit, - paymentHash: &payHash, - RouteHints: [][]zpay32.HopHint{{ - zpay32.HopHint{ - NodeID: sourceNodeID, - ChannelID: badChannelID, - FeeBaseMSat: uint32(50), - CLTVExpiryDelta: uint16(200), - }, - }}, - } + payment := createDummyLightningPayment( + t, ctx.aliases["songoku"], paymentAmt, + ) + payment.RouteHints = [][]zpay32.HopHint{{ + zpay32.HopHint{ + NodeID: sourceNodeID, + ChannelID: badChannelID, + FeeBaseMSat: uint32(50), + CLTVExpiryDelta: uint16(200), + }, + }} var preImage [32]byte copy(preImage[:], bytes.Repeat([]byte{9}, 32)) @@ -429,8 +422,9 @@ func TestSendPaymentRouteInfiniteLoopWithBadHopHint(t *testing.T) { // Send off the payment request to the router, should succeed // ignoring the bad channel id hint. - paymentPreImage, route, paymentErr := ctx.router.SendPayment(&payment) - require.NoError(t, paymentErr, "payment returned an error") + paymentPreImage, route, paymentErr := ctx.router.SendPayment(payment) + require.NoErrorf(t, paymentErr, "unable to send payment: %v", + payment.paymentHash) // The preimage should match up with the one created above. require.Equal(t, preImage[:], paymentPreImage[:], "incorrect preimage") @@ -593,14 +587,10 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { // Craft a LightningPayment struct that'll send a payment from roasbeef // to sophon for 1000 satoshis. - var payHash lntypes.Hash amt := lnwire.NewMSatFromSatoshis(1000) - payment := LightningPayment{ - Target: ctx.aliases["sophon"], - Amount: amt, - FeeLimit: noFeeLimit, - paymentHash: &payHash, - } + payment := createDummyLightningPayment( + t, ctx.aliases["sophon"], amt, + ) var preImage [32]byte copy(preImage[:], bytes.Repeat([]byte{9}, 32)) @@ -655,8 +645,9 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { // Send off the payment request to the router, route through phamnuwen // should've been selected as a fall back and succeeded correctly. - paymentPreImage, route, err := ctx.router.SendPayment(&payment) - require.NoError(t, err, "unable to send payment") + paymentPreImage, route, err := ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) // The route selected should have two hops require.Equal(t, 2, len(route.Hops), "incorrect route length") @@ -696,7 +687,6 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { ) var ( - payHash lntypes.Hash preImage [32]byte amt = lnwire.NewMSatFromSatoshis(1000) privateChannelID = uint64(55555) @@ -713,21 +703,18 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { // 1000 satoshis. This route has lowest fees compared with the rest. // This also holds when the private channel fee is updated to a higher // value. - payment := LightningPayment{ - Target: ctx.aliases["elst"], - Amount: amt, - FeeLimit: noFeeLimit, - paymentHash: &payHash, - RouteHints: [][]zpay32.HopHint{{ - // Add a private channel between songoku and elst. - zpay32.HopHint{ - NodeID: sgNodeID, - ChannelID: privateChannelID, - FeeBaseMSat: feeBaseMSat, - CLTVExpiryDelta: expiryDelta, - }, - }}, - } + payment := createDummyLightningPayment( + t, ctx.aliases["elst"], amt, + ) + payment.RouteHints = [][]zpay32.HopHint{{ + // Add a private channel between songoku and elst. + zpay32.HopHint{ + NodeID: sgNodeID, + ChannelID: privateChannelID, + FeeBaseMSat: feeBaseMSat, + CLTVExpiryDelta: expiryDelta, + }, + }} // Prepare an error update for the private channel, with twice the // original fee. @@ -765,8 +752,9 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { // Send off the payment request to the router, route through son // goku and then across the private channel to elst. - paymentPreImage, route, err := ctx.router.SendPayment(&payment) - require.NoError(t, err, "unable to send payment") + paymentPreImage, route, err := ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) require.True(t, errorReturned, "failed to simulate error in the first payment attempt", @@ -826,7 +814,6 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { ) var ( - payHash lntypes.Hash preImage [32]byte amt = lnwire.NewMSatFromSatoshis(1000) privateChannelID = uint64(55555) @@ -842,21 +829,18 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { // Craft a LightningPayment struct that'll send a payment from roasbeef // to elst, through a private channel between songoku and elst for // 1000 satoshis. This route has lowest fees compared with the rest. - payment := LightningPayment{ - Target: ctx.aliases["elst"], - Amount: amt, - FeeLimit: feeLimit, - paymentHash: &payHash, - RouteHints: [][]zpay32.HopHint{{ - // Add a private channel between songoku and elst. - zpay32.HopHint{ - NodeID: sgNodeID, - ChannelID: privateChannelID, - FeeBaseMSat: feeBaseMSat, - CLTVExpiryDelta: expiryDelta, - }, - }}, - } + payment := createDummyLightningPayment( + t, ctx.aliases["elst"], amt, + ) + payment.RouteHints = [][]zpay32.HopHint{{ + // Add a private channel between songoku and elst. + zpay32.HopHint{ + NodeID: sgNodeID, + ChannelID: privateChannelID, + FeeBaseMSat: feeBaseMSat, + CLTVExpiryDelta: expiryDelta, + }, + }} // Prepare an error update for the private channel. The updated fee // will exceeds the feeLimit. @@ -894,8 +878,9 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { // Send off the payment request to the router, route through son // goku and then across the private channel to elst. - paymentPreImage, route, err := ctx.router.SendPayment(&payment) - require.NoError(t, err, "unable to send payment") + paymentPreImage, route, err := ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) require.True(t, errorReturned, "failed to simulate error in the first payment attempt", @@ -943,14 +928,10 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { // Craft a LightningPayment struct that'll send a payment from roasbeef // to sophon for 1k satoshis. - var payHash lntypes.Hash amt := lnwire.NewMSatFromSatoshis(1000) - payment := LightningPayment{ - Target: ctx.aliases["sophon"], - Amount: amt, - FeeLimit: noFeeLimit, - paymentHash: &payHash, - } + payment := createDummyLightningPayment( + t, ctx.aliases["sophon"], amt, + ) var preImage [32]byte copy(preImage[:], bytes.Repeat([]byte{9}, 32)) @@ -1020,8 +1001,9 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { // Send off the payment request to the router, this payment should // succeed as we should actually go through Pham Nuwen in order to get // to Sophon, even though he has higher fees. - paymentPreImage, rt, err := ctx.router.SendPayment(&payment) - require.NoError(t, err, "unable to send payment") + paymentPreImage, rt, err := ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) assertExpectedPath(paymentPreImage, rt) @@ -1045,8 +1027,9 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { // w.r.t to the block height, and instead go through Pham Nuwen. We // flip a bit in the payment hash to allow resending this payment. payment.paymentHash[1] ^= 1 - paymentPreImage, rt, err = ctx.router.SendPayment(&payment) - require.NoError(t, err, "unable to send payment") + paymentPreImage, rt, err = ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) assertExpectedPath(paymentPreImage, rt) } @@ -1062,14 +1045,10 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // Craft a LightningPayment struct that'll send a payment from roasbeef // to luo ji for 1000 satoshis, with a maximum of 1000 satoshis in fees. - var payHash lntypes.Hash paymentAmt := lnwire.NewMSatFromSatoshis(1000) - payment := LightningPayment{ - Target: ctx.aliases["sophon"], - Amount: paymentAmt, - FeeLimit: noFeeLimit, - paymentHash: &payHash, - } + payment := createDummyLightningPayment( + t, ctx.aliases["sophon"], paymentAmt, + ) var preImage [32]byte copy(preImage[:], bytes.Repeat([]byte{9}, 32)) @@ -1113,7 +1092,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // When we try to dispatch that payment, we should receive an error as // both attempts should fail and cause both routes to be pruned. - _, _, err := ctx.router.SendPayment(&payment) + _, _, err := ctx.router.SendPayment(payment) require.Error(t, err, "payment didn't return error") // The final error returned should also indicate that the peer wasn't @@ -1121,7 +1100,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { require.Equal(t, channeldb.FailureReasonNoRoute, err) // Inspect the two attempts that were made before the payment failed. - p, err := ctx.router.cfg.Control.FetchPayment(payHash) + p, err := ctx.router.cfg.Control.FetchPayment(*payment.paymentHash) require.NoError(t, err) htlcs := p.GetHTLCs() @@ -1158,8 +1137,9 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // This shouldn't return an error, as we'll make a payment attempt via // the pham nuwen channel based on the assumption that there might be an // intermittent issue with the songoku <-> sophon channel. - paymentPreImage, rt, err := ctx.router.SendPayment(&payment) - require.NoError(t, err, "unable send payment") + paymentPreImage, rt, err := ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) // This path should go: roasbeef -> pham nuwen -> sophon require.Equal(t, 2, len(rt.Hops), "incorrect route length") @@ -1193,8 +1173,9 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // We flip a bit in the payment hash to allow resending this payment. payment.paymentHash[1] ^= 1 - paymentPreImage, rt, err = ctx.router.SendPayment(&payment) - require.NoError(t, err, "unable send payment") + paymentPreImage, rt, err = ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) // This should succeed finally. The route selected should have two // hops. @@ -2754,13 +2735,9 @@ func TestUnknownErrorSource(t *testing.T) { ) // Create a payment to node c. - var payHash lntypes.Hash - payment := LightningPayment{ - Target: ctx.aliases["c"], - Amount: lnwire.NewMSatFromSatoshis(1000), - FeeLimit: noFeeLimit, - paymentHash: &payHash, - } + payment := createDummyLightningPayment( + t, ctx.aliases["c"], lnwire.NewMSatFromSatoshis(1000), + ) // We'll modify the SendToSwitch method so that it simulates hop b as a // node that returns an unparsable failure if approached via the a->b @@ -2784,8 +2761,9 @@ func TestUnknownErrorSource(t *testing.T) { // the route a->b->c is tried first. An unreadable faiure is returned // which should pruning the channel a->b. We expect the payment to // succeed via a->d. - _, _, err = ctx.router.SendPayment(&payment) - require.NoError(t, err, "expected payment to succeed, but got") + _, _, err = ctx.router.SendPayment(payment) + require.NoErrorf(t, err, "unable to send payment: %v", + payment.paymentHash) // Next we modify payment result to return an unknown failure. ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( @@ -2804,9 +2782,8 @@ func TestUnknownErrorSource(t *testing.T) { // Send off the payment request to the router. We expect the payment to // fail because both routes have been pruned. - payHash = lntypes.Hash{1} - payment.paymentHash = &payHash - _, _, err = ctx.router.SendPayment(&payment) + payment.paymentHash[1] ^= 1 + _, _, err = ctx.router.SendPayment(payment) if err == nil { t.Fatalf("expected payment to fail") } From 27ee917a2024aab577034d30de950fd5ced234b2 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 9 Mar 2023 19:25:41 +0800 Subject: [PATCH 20/27] docs: update release note for payment lifecycle --- docs/release-notes/release-notes-0.18.0.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 7a288a471..8d88cd1e2 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -37,6 +37,9 @@ might panic due to empty witness data found in a transaction. More details can be found [here](https://github.com/bitcoin/bitcoin/issues/28730). +* [Fixed a case](https://github.com/lightningnetwork/lnd/pull/7503) where it's + possible a failed payment might be stuck in pending. + # New Features ## Functional Enhancements @@ -90,6 +93,11 @@ `lnrpc.GetInfoResponse` message along with the `chain` field in the `lnrpc.Chain` message have also been deprecated for the same reason. +* The payment lifecycle code has been refactored to improve its maintainablity. + In particular, the complexity involved in the lifecycle loop has been + decoupled into logical steps, with each step having its own responsibility, + making it easier to reason about the payment flow. + ## Breaking Changes ## Performance Improvements From e3dadd528b6267093c83facc5a5bf4a641b7e5c9 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Sun, 10 Sep 2023 09:22:30 +0800 Subject: [PATCH 21/27] routing: mark payment as failed when no route can be found --- routing/payment_lifecycle.go | 51 ++++++++++------------ routing/payment_lifecycle_test.go | 71 +++++-------------------------- 2 files changed, 32 insertions(+), 90 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 20c9c9cde..fdaa68042 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -266,9 +266,14 @@ lifecycle: return exitWithErr(err) } - // NOTE: might cause an infinite loop, see notes in - // `requestRoute` for details. + // We may not be able to find a route for current attempt. In + // that case, we continue the loop and move straight to the + // next iteration in case there are results for inflight HTLCs + // that still need to be collected. if rt == nil { + log.Errorf("No route found for payment %v", + p.identifier) + continue lifecycle } @@ -363,42 +368,30 @@ func (p *paymentLifecycle) requestRoute( log.Warnf("Failed to find route for payment %v: %v", p.identifier, err) // If the error belongs to `noRouteError` set, it means a non-critical - // error has happened during path finding and we might be able to find - // another route during next HTLC attempt. Otherwise, we'll return the - // critical error found. + // error has happened during path finding and we will mark the payment + // failed with this reason. Otherwise, we'll return the critical error + // found to abort the lifecycle. var routeErr noRouteError if !errors.As(err, &routeErr) { return nil, fmt.Errorf("requestRoute got: %w", err) } - // There is no route to try, and we have no active shards. This means - // that there is no way for us to send the payment, so mark it failed - // with no route. - // - // NOTE: if we have zero `numShardsInFlight`, it means all the HTLC - // attempts have failed. Otherwise, if there are still inflight - // attempts, we might enter an infinite loop in our lifecycle if - // there's still remaining amount since we will keep adding new HTLC - // attempts and they all fail with `noRouteError`. - // - // TODO(yy): further check the error returned here. It's the - // `paymentSession`'s responsibility to find a route for us with best - // effort. When it cannot find a path, we need to treat it as a - // terminal condition and fail the payment no matter it has inflight + // It's the `paymentSession`'s responsibility to find a route for us + // with best effort. When it cannot find a path, we need to treat it as + // a terminal condition and fail the payment no matter it has inflight // HTLCs or not. - if ps.NumAttemptsInFlight == 0 { - failureCode := routeErr.FailureReason() - log.Debugf("Marking payment %v permanently failed with no "+ - "route: %v", p.identifier, failureCode) + failureCode := routeErr.FailureReason() + log.Warnf("Marking payment %v permanently failed with no route: %v", + p.identifier, failureCode) - err := p.router.cfg.Control.FailPayment( - p.identifier, failureCode, - ) - if err != nil { - return nil, fmt.Errorf("FailPayment got: %w", err) - } + err = p.router.cfg.Control.FailPayment(p.identifier, failureCode) + if err != nil { + return nil, fmt.Errorf("FailPayment got: %w", err) } + // NOTE: we decide to not return the non-critical noRouteError here to + // avoid terminating the payment lifecycle as there might be other + // inflight HTLCs which we must wait for their results. return nil, nil } diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index f4b637ee4..f0a18165e 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -419,13 +419,8 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) { func TestRequestRouteHandleNoRouteErr(t *testing.T) { t.Parallel() - p := createTestPaymentLifecycle() - - // Create a mock payment session. - paySession := &mockPaymentSession{} - - // Mount the mocked payment session. - p.paySession = paySession + // Create a paymentLifecycle with mockers. + p, m := newTestPaymentLifecycle(t) // Create a dummy payment state. ps := &channeldb.MPPaymentState{ @@ -437,68 +432,22 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { // Mock remainingFees to be 1. p.feeLimit = ps.FeesPaid + 1 - // Mock the paySession's `RequestRoute` method to return an error. - paySession.On("RequestRoute", + // Mock the paySession's `RequestRoute` method to return a NoRouteErr + // type. + m.paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, ).Return(nil, errNoTlvPayload) + // The payment should be failed with reason no route. + m.control.On("FailPayment", + p.identifier, channeldb.FailureReasonNoRoute, + ).Return(nil).Once() + result, err := p.requestRoute(ps) // Expect no error is returned since it's not critical. require.NoError(t, err, "expected no error") require.Nil(t, result, "expected no route returned") - - // Assert that `RequestRoute` is called as expected. - paySession.AssertExpectations(t) -} - -// TestRequestRouteFailPaymentSucceed checks that `requestRoute` fails the -// payment when received an `noRouteError` returned from payment session while -// it has no inflight attempts. -func TestRequestRouteFailPaymentSucceed(t *testing.T) { - t.Parallel() - - p := createTestPaymentLifecycle() - - // Create a mock payment session. - paySession := &mockPaymentSession{} - - // Mock the control tower's `FailPayment` method. - ct := &mockControlTower{} - ct.On("FailPayment", - p.identifier, errNoTlvPayload.FailureReason(), - ).Return(nil) - - // Mount the mocked control tower and payment session. - p.router.cfg.Control = ct - p.paySession = paySession - - // Create a dummy payment state with zero inflight attempts. - ps := &channeldb.MPPaymentState{ - NumAttemptsInFlight: 0, - RemainingAmt: 1, - FeesPaid: 100, - } - - // Mock remainingFees to be 1. - p.feeLimit = ps.FeesPaid + 1 - - // Mock the paySession's `RequestRoute` method to return an error. - paySession.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(nil, errNoTlvPayload) - - result, err := p.requestRoute(ps) - - // Expect no error is returned since it's not critical. - require.NoError(t, err, "expected no error") - require.Nil(t, result, "expected no route returned") - - // Assert that `RequestRoute` is called as expected. - paySession.AssertExpectations(t) - - // Assert that `FailPayment` is called as expected. - ct.AssertExpectations(t) } // TestRequestRouteFailPaymentError checks that `requestRoute` returns the From 8f5c6e8367c353c89b0181330c10d3df2f3877e3 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 20 Oct 2023 06:30:54 +0800 Subject: [PATCH 22/27] trivial: fix typos --- routing/router_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/routing/router_test.go b/routing/router_test.go index 1eb39777f..67a8964bf 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3521,7 +3521,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { payment := &mockMPPayment{} controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - // Mock the payment to return nil failrue reason. + // Mock the payment to return nil failure reason. payment.On("TerminalInfo").Return(nil, nil) // Expect a successful send to route. @@ -3602,7 +3602,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { mock.Anything, rt, mock.Anything, mock.Anything, ).Return(nil, nil) - // Mock the payment to return nil failrue reason. + // Mock the payment to return nil failure reason. payment.On("TerminalInfo").Return(nil, nil) // Expect a failed send to route. @@ -3687,7 +3687,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { payment := &mockMPPayment{} controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - // Mock the payment to return a failrue reason. + // Mock the payment to return a failure reason. payment.On("TerminalInfo").Return(nil, &failureReason) // Expect a failed send to route. @@ -3765,7 +3765,7 @@ func TestSendToRouteTempFailure(t *testing.T) { payment := &mockMPPayment{} controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - // Mock the payment to return nil failrue reason. + // Mock the payment to return nil failure reason. payment.On("TerminalInfo").Return(nil, nil) // Return a nil reason to mock a temporary failure. From 7ccb77269d0bc5e7387c96548ac12fb0a06acf29 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 20 Oct 2023 07:13:22 +0800 Subject: [PATCH 23/27] routing: log preimage when it's failed to be saved to db --- routing/payment_lifecycle.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index fdaa68042..9853f8fd9 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -558,7 +558,9 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( }, ) if err != nil { - log.Errorf("Unable to settle payment attempt: %v", err) + log.Errorf("Error settling attempt %v for payment %v with "+ + "preimage %v: %v", attempt.AttemptID, p.identifier, + result.Preimage, err) // We won't mark the attempt as failed since we already have // the preimage. From 98378d9408eb8fa5bf666cc4c6c739deaecae165 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 20 Oct 2023 07:45:52 +0800 Subject: [PATCH 24/27] routing: unify all dummy errors to be `errDummy` --- routing/payment_lifecycle_test.go | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index f0a18165e..51862a184 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/require" ) +// errDummy is used by the mockers to return a dummy error. var errDummy = errors.New("dummy") // createTestPaymentLifecycle creates a `paymentLifecycle` without mocks. @@ -304,10 +305,9 @@ func TestCheckTimeoutTimedOut(t *testing.T) { // by the function too. // // Mock `FailPayment` to return a dummy error. - dummyErr := errors.New("dummy") ct = &mockControlTower{} ct.On("FailPayment", - p.identifier, channeldb.FailureReasonTimeout).Return(dummyErr) + p.identifier, channeldb.FailureReasonTimeout).Return(errDummy) // Mount the mocked control tower. p.router.cfg.Control = ct @@ -320,7 +320,7 @@ func TestCheckTimeoutTimedOut(t *testing.T) { // Call the function and expect an error. err = p.checkTimeout() - require.ErrorIs(t, err, dummyErr) + require.ErrorIs(t, err, errDummy) // Assert that `FailPayment` is called as expected. ct.AssertExpectations(t) @@ -399,15 +399,14 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) { p.feeLimit = ps.FeesPaid + 1 // Mock the paySession's `RequestRoute` method to return an error. - dummyErr := errors.New("dummy") paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(nil, dummyErr) + ).Return(nil, errDummy) result, err := p.requestRoute(ps) // Expect an error is returned since it's critical. - require.ErrorIs(t, err, dummyErr, "error not matched") + require.ErrorIs(t, err, errDummy, "error not matched") require.Nil(t, result, "expected no route returned") // Assert that `RequestRoute` is called as expected. @@ -462,10 +461,9 @@ func TestRequestRouteFailPaymentError(t *testing.T) { // Mock the control tower's `FailPayment` method. ct := &mockControlTower{} - dummyErr := errors.New("dummy") ct.On("FailPayment", p.identifier, errNoTlvPayload.FailureReason(), - ).Return(dummyErr) + ).Return(errDummy) // Mount the mocked control tower and payment session. p.router.cfg.Control = ct @@ -489,7 +487,7 @@ func TestRequestRouteFailPaymentError(t *testing.T) { result, err := p.requestRoute(ps) // Expect an error is returned. - require.ErrorIs(t, err, dummyErr, "error not matched") + require.ErrorIs(t, err, errDummy, "error not matched") require.Nil(t, result, "expected no route returned") // Assert that `RequestRoute` is called as expected. @@ -1247,18 +1245,17 @@ func TestCollectResultExitOnErr(t *testing.T) { // `CancelShard` should be called with the attemptID. m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() - // Mock `FailAttempt` to return a switch error. - switchErr := errors.New("switch err") + // Mock `FailAttempt` to return a dummy error. m.control.On("FailAttempt", p.identifier, attempt.AttemptID, mock.Anything, - ).Return(nil, switchErr).Once() + ).Return(nil, errDummy).Once() // Mock the clock to return a current time. m.clock.On("Now").Return(time.Now()) // Now call the method under test. result, err := p.collectResult(attempt) - require.ErrorIs(t, err, switchErr, "expected switch error") + require.ErrorIs(t, err, errDummy, "expected dummy error") require.Nil(t, result, "expected nil attempt") } @@ -1299,18 +1296,17 @@ func TestCollectResultExitOnResultErr(t *testing.T) { // `CancelShard` should be called with the attemptID. m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() - // Mock `FailAttempt` to return a switch error. - switchErr := errors.New("switch err") + // Mock `FailAttempt` to return a dummy error. m.control.On("FailAttempt", p.identifier, attempt.AttemptID, mock.Anything, - ).Return(nil, switchErr).Once() + ).Return(nil, errDummy).Once() // Mock the clock to return a current time. m.clock.On("Now").Return(time.Now()) // Now call the method under test. result, err := p.collectResult(attempt) - require.ErrorIs(t, err, switchErr, "expected switch error") + require.ErrorIs(t, err, errDummy, "expected dummy error") require.Nil(t, result, "expected nil attempt") } From 168cfd7cd58e05b1f2120888446cfa0de6463571 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 20 Oct 2023 09:18:39 +0800 Subject: [PATCH 25/27] docs: emphasize the new payment status `StatusInitiated` --- docs/release-notes/release-notes-0.18.0.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 8d88cd1e2..fa3ac9a3c 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -61,6 +61,13 @@ [http-header-timeout](https://github.com/lightningnetwork/lnd/pull/7715), is added so users can specify the amount of time the http server will wait for a request to complete before closing the connection. The default value is 5 seconds. ## RPC Additions + +* [Deprecated](https://github.com/lightningnetwork/lnd/pull/7175) + `StatusUnknown` from the payment's rpc response in its status and added a new + status, `StatusInitiated`, to explicitly report its current state. Before + running this new version, please make sure to upgrade your client application + to include this new status so it can understand the RPC response properly. + ## lncli Additions # Improvements From 678f416008fefba439a43e007f73f156ccd7e172 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 24 Oct 2023 14:53:39 +0800 Subject: [PATCH 26/27] routing+docs: make sure non-MPP cannot use skipTempErr --- docs/release-notes/release-notes-0.18.0.md | 6 +++ routing/router.go | 11 ++++ routing/router_test.go | 58 +++++++++++++++++++++- 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index fa3ac9a3c..b269d8339 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -91,6 +91,12 @@ hash](https://github.com/lightningnetwork/lnd/pull/8106) to the signer.SignMessage/signer.VerifyMessage RPCs. +* `sendtoroute` will return an error when it's called using the flag + `--skip_temp_err` on a payment that's not a MPP. This is needed as a temp + error is defined as a routing error found in one of a MPP's HTLC attempts. + If, however, there's only one HTLC attempt, when it's failed, this payment is + considered failed, thus there's no such thing as temp error for a non-MPP. + ## lncli Updates ## Code Health diff --git a/routing/router.go b/routing/router.go index f7b20a376..c70a824fa 100644 --- a/routing/router.go +++ b/routing/router.go @@ -119,6 +119,10 @@ var ( // provided by either a blinded route or a cleartext pubkey. ErrNoTarget = errors.New("destination not set in target or blinded " + "path") + + // ErrSkipTempErr is returned when a non-MPP is made yet the + // skipTempErr flag is set. + ErrSkipTempErr = errors.New("cannot skip temp error for non-MPP") ) // ChannelGraphSource represents the source of information about the topology @@ -2450,6 +2454,13 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, amt = mpp.TotalMsat() } + // For non-MPP, there's no such thing as temp error as there's only one + // HTLC attempt being made. When this HTLC is failed, the payment is + // failed hence cannot be retried. + if skipTempErr && mpp == nil { + return nil, ErrSkipTempErr + } + // For non-AMP payments the overall payment identifier will be the same // hash as used for this HTLC. paymentIdentifier := htlcHash diff --git a/routing/router_test.go b/routing/router_test.go index 67a8964bf..d6abfe7a9 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3384,7 +3384,8 @@ func TestBlockDifferenceFix(t *testing.T) { initialBlockHeight := uint32(0) - // Starting height here is set to 0, which is behind where we want to be. + // Starting height here is set to 0, which is behind where we want to + // be. ctx := createTestCtxSingleNode(t, initialBlockHeight) // Add initial block to our mini blockchain. @@ -3454,6 +3455,8 @@ func TestBlockDifferenceFix(t *testing.T) { // TestSendToRouteSkipTempErrSuccess validates a successful payment send. func TestSendToRouteSkipTempErrSuccess(t *testing.T) { + t.Parallel() + var ( payHash lntypes.Hash payAmt = lnwire.MilliSatoshi(10000) @@ -3536,9 +3539,62 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { payment.AssertExpectations(t) } +// TestSendToRouteSkipTempErrNonMPP checks that an error is return when +// skipping temp error for non-MPP. +func TestSendToRouteSkipTempErrNonMPP(t *testing.T) { + t.Parallel() + + var ( + payHash lntypes.Hash + payAmt = lnwire.MilliSatoshi(10000) + ) + + node, err := createTestNode() + require.NoError(t, err) + + // Create a simple 1-hop route without the MPP field. + hops := []*route.Hop{ + { + ChannelID: 1, + PubKeyBytes: node.PubKeyBytes, + AmtToForward: payAmt, + }, + } + rt, err := route.NewRouteFromHops(payAmt, 100, node.PubKeyBytes, hops) + require.NoError(t, err) + + // Create mockers. + controlTower := &mockControlTower{} + payer := &mockPaymentAttemptDispatcher{} + missionControl := &mockMissionControl{} + + // Create the router. + router := &ChannelRouter{cfg: &Config{ + Control: controlTower, + Payer: payer, + MissionControl: missionControl, + Clock: clock.NewTestClock(time.Unix(1, 0)), + NextPaymentID: func() (uint64, error) { + return 0, nil + }, + }} + + // Expect an error to be returned. + attempt, err := router.SendToRouteSkipTempErr(payHash, rt) + require.ErrorIs(t, ErrSkipTempErr, err) + require.Nil(t, attempt) + + // Assert the above methods are not called. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + missionControl.AssertExpectations(t) +} + // TestSendToRouteSkipTempErrTempFailure validates a temporary failure won't // cause the payment to be failed. func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { + t.Parallel() + var ( payHash lntypes.Hash payAmt = lnwire.MilliSatoshi(10000) From 5168af55a9d54be470ab03507d55debfe813baa0 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 27 Oct 2023 16:30:58 +0800 Subject: [PATCH 27/27] itest: fix flake in `max_htlc_pathfind` ``` lnd_max_htlcs_test.go:149: Error Trace: /home/runner/work/lnd/lnd/itest/lnd_max_htlcs_test.go:149 /home/runner/work/lnd/lnd/itest/lnd_max_htlcs_test.go:40 /home/runner/work/lnd/lnd/lntest/harness.go:286 /home/runner/work/lnd/lnd/itest/lnd_test.go:136 Error: Not equal: expected: 3 actual : 0 Test: TestLightningNetworkDaemon/tranche01/60-of-134/btcd/max_htlc_pathfind Messages: expected accepted ``` --- itest/lnd_max_htlcs_test.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/itest/lnd_max_htlcs_test.go b/itest/lnd_max_htlcs_test.go index 36a99e1f1..a9bee5ae4 100644 --- a/itest/lnd_max_htlcs_test.go +++ b/itest/lnd_max_htlcs_test.go @@ -122,8 +122,7 @@ func acceptHoldInvoice(ht *lntest.HarnessTest, idx int, sender, invoice := receiver.RPC.AddHoldInvoice(req) invStream := receiver.RPC.SubscribeSingleInvoice(hash[:]) - inv := ht.ReceiveSingleInvoice(invStream) - require.Equal(ht, lnrpc.Invoice_OPEN, inv.State, "expect open") + ht.AssertInvoiceState(invStream, lnrpc.Invoice_OPEN) sendReq := &routerrpc.SendPaymentRequest{ PaymentRequest: invoice.PaymentRequest, @@ -145,9 +144,7 @@ func acceptHoldInvoice(ht *lntest.HarnessTest, idx int, sender, ) require.Len(ht, payment.Htlcs, 1) - inv = ht.ReceiveSingleInvoice(invStream) - require.Equal(ht, lnrpc.Invoice_ACCEPTED, inv.State, - "expected accepted") + ht.AssertInvoiceState(invStream, lnrpc.Invoice_ACCEPTED) return &holdSubscription{ recipient: receiver,