diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 75aa48bcb..85b6b2b2c 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -360,13 +360,25 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest, return err } + // The payment context is influenced by two user-provided parameters, + // the cancelable flag and the payment attempt timeout. + // If the payment is cancelable, we will use the stream context as the + // payment context. That way, if the user ends the stream, the payment + // loop will be canceled. + // The second context parameter is the timeout. If the user provides a + // timeout, we will additionally wrap the context in a deadline. If the + // user provided 'cancelable' and ends the stream before the timeout is + // reached the payment will be canceled. + ctx := context.Background() + if req.Cancelable { + ctx = stream.Context() + } + // Send the payment asynchronously. - s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker) + s.cfg.Router.SendPaymentAsync(ctx, payment, paySession, shardTracker) // Track the payment and return. - return s.trackPayment( - sub, payHash, stream, req.NoInflightUpdates, - ) + return s.trackPayment(sub, payHash, stream, req.NoInflightUpdates) } // EstimateRouteFee allows callers to obtain an expected value w.r.t how much it diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 8d719ae38..8769ca5d3 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -1,6 +1,7 @@ package routing import ( + "context" "errors" "fmt" "time" @@ -29,7 +30,6 @@ type paymentLifecycle struct { identifier lntypes.Hash paySession PaymentSession shardTracker shards.ShardTracker - timeoutChan <-chan time.Time currentHeight int32 // quit is closed to signal the sub goroutines of the payment lifecycle @@ -52,7 +52,7 @@ type paymentLifecycle struct { // newPaymentLifecycle initiates a new payment lifecycle and returns it. func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash, paySession PaymentSession, - shardTracker shards.ShardTracker, timeout time.Duration, + shardTracker shards.ShardTracker, currentHeight int32) *paymentLifecycle { p := &paymentLifecycle{ @@ -69,13 +69,6 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, // Mount the result collector. p.resultCollector = p.collectResultAsync - // If a timeout is specified, create a timeout channel. If no timeout is - // specified, the channel is left nil and will never abort the payment - // loop. - if timeout != 0 { - p.timeoutChan = time.After(timeout) - } - return p } @@ -167,7 +160,9 @@ func (p *paymentLifecycle) decideNextStep( } // resumePayment resumes the paymentLifecycle from the current state. -func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { +func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte, + *route.Route, error) { + // When the payment lifecycle loop exits, we make sure to signal any // sub goroutine of the HTLC attempt to exit, then wait for them to // return. @@ -221,18 +216,17 @@ lifecycle: // We now proceed our lifecycle with the following tasks in // order, - // 1. check timeout. + // 1. check context. // 2. request route. // 3. create HTLC attempt. // 4. send HTLC attempt. // 5. collect HTLC attempt result. // - // Before we attempt any new shard, we'll check to see if - // either we've gone past the payment attempt timeout, or the - // router is exiting. In either case, we'll stop this payment - // attempt short. If a timeout is not applicable, timeoutChan - // will be nil. - if err := p.checkTimeout(); err != nil { + // Before we attempt any new shard, we'll check to see if we've + // gone past the payment attempt timeout, or if the context was + // cancelled, or the router is exiting. In any of these cases, + // we'll stop this payment attempt short. + if err := p.checkContext(ctx); err != nil { return exitWithErr(err) } @@ -318,19 +312,30 @@ lifecycle: return [32]byte{}, nil, *failure } -// checkTimeout checks whether the payment has reached its timeout. -func (p *paymentLifecycle) checkTimeout() error { +// checkContext checks whether the payment context has been canceled. +// Cancellation occurs manually or if the context times out. +func (p *paymentLifecycle) checkContext(ctx context.Context) error { select { - case <-p.timeoutChan: - log.Warnf("payment attempt not completed before timeout") + case <-ctx.Done(): + // If the context was canceled, we'll mark the payment as + // failed. There are two cases to distinguish here: Either a + // user-provided timeout was reached, or the context was + // canceled, either to a manual cancellation or due to an + // unknown error. + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + log.Warnf("Payment attempt not completed before "+ + "timeout, id=%s", p.identifier.String()) + } else { + log.Warnf("Payment attempt context canceled, id=%s", + p.identifier.String()) + } // By marking the payment failed, depending on whether it has // inflight HTLCs or not, its status will now either be // `StatusInflight` or `StatusFailed`. In either case, no more // HTLCs will be attempted. - err := p.router.cfg.Control.FailPayment( - p.identifier, channeldb.FailureReasonTimeout, - ) + reason := channeldb.FailureReasonTimeout + err := p.router.cfg.Control.FailPayment(p.identifier, reason) if err != nil { return fmt.Errorf("FailPayment got %w", err) } diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 51862a184..3b19812c6 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "sync/atomic" "testing" "time" @@ -88,7 +89,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { // Create a test payment lifecycle with no fee limit and no timeout. p := newPaymentLifecycle( rt, noFeeLimit, paymentHash, mockPaymentSession, - mockShardTracker, 0, 0, + mockShardTracker, 0, ) // Create a mock payment which is returned from mockControlTower. @@ -151,9 +152,9 @@ type resumePaymentResult struct { err error } -// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error -// is returned. -func sendPaymentAndAssertFailed(t *testing.T, +// sendPaymentAndAssertError calls `resumePayment` and asserts that an error is +// returned. +func sendPaymentAndAssertError(t *testing.T, ctx context.Context, p *paymentLifecycle, errExpected error) { resultChan := make(chan *resumePaymentResult, 1) @@ -161,7 +162,7 @@ func sendPaymentAndAssertFailed(t *testing.T, // We now make a call to `resumePayment` and expect it to return the // error. go func() { - preimage, _, err := p.resumePayment() + preimage, _, err := p.resumePayment(ctx) resultChan <- &resumePaymentResult{ preimage: preimage, err: err, @@ -189,7 +190,7 @@ func sendPaymentAndAssertSucceeded(t *testing.T, // We now make a call to `resumePayment` and expect it to return the // preimage. go func() { - preimage, _, err := p.resumePayment() + preimage, _, err := p.resumePayment(context.Background()) resultChan <- &resumePaymentResult{ preimage: preimage, err: err, @@ -278,6 +279,10 @@ func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo { func TestCheckTimeoutTimedOut(t *testing.T) { t.Parallel() + deadline := time.Now().Add(time.Nanosecond) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + p := createTestPaymentLifecycle() // Mock the control tower's `FailPayment` method. @@ -288,14 +293,11 @@ func TestCheckTimeoutTimedOut(t *testing.T) { // Mount the mocked control tower. p.router.cfg.Control = ct - // Make the timeout happens instantly. - p.timeoutChan = time.After(1 * time.Nanosecond) - // Sleep one millisecond to make sure it timed out. time.Sleep(1 * time.Millisecond) // Call the function and expect no error. - err := p.checkTimeout() + err := p.checkContext(ctx) require.NoError(t, err) // Assert that `FailPayment` is called as expected. @@ -313,13 +315,15 @@ func TestCheckTimeoutTimedOut(t *testing.T) { p.router.cfg.Control = ct // Make the timeout happens instantly. - p.timeoutChan = time.After(1 * time.Nanosecond) + deadline = time.Now().Add(time.Nanosecond) + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() // Sleep one millisecond to make sure it timed out. time.Sleep(1 * time.Millisecond) // Call the function and expect an error. - err = p.checkTimeout() + err = p.checkContext(ctx) require.ErrorIs(t, err, errDummy) // Assert that `FailPayment` is called as expected. @@ -331,10 +335,13 @@ func TestCheckTimeoutTimedOut(t *testing.T) { func TestCheckTimeoutOnRouterQuit(t *testing.T) { t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + p := createTestPaymentLifecycle() close(p.router.quit) - err := p.checkTimeout() + err := p.checkContext(ctx) require.ErrorIs(t, err, ErrRouterShuttingDown) } @@ -627,7 +634,7 @@ func TestResumePaymentFailOnFetchPayment(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(nil, errDummy) // Send the payment and assert it failed. - sendPaymentAndAssertFailed(t, p, errDummy) + sendPaymentAndAssertError(t, context.Background(), p, errDummy) // Expected collectResultAsync to not be called. require.Zero(t, m.collectResultsCount) @@ -656,14 +663,15 @@ func TestResumePaymentFailOnTimeout(t *testing.T) { } m.payment.On("GetState").Return(ps).Once() - // NOTE: GetStatus is only used to populate the logs which is - // not critical so we loosen the checks on how many times it's - // been called. + // NOTE: GetStatus is only used to populate the logs which is not + // critical, so we loosen the checks on how many times it's been called. m.payment.On("GetStatus").Return(channeldb.StatusInFlight) // 3. make the timeout happens instantly and sleep one millisecond to // make sure it timed out. - p.timeoutChan = time.After(1 * time.Nanosecond) + deadline := time.Now().Add(time.Nanosecond) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() time.Sleep(1 * time.Millisecond) // 4. the payment should be failed with reason timeout. @@ -683,7 +691,7 @@ func TestResumePaymentFailOnTimeout(t *testing.T) { m.payment.On("TerminalInfo").Return(nil, &reason) // Send the payment and assert it failed with the timeout reason. - sendPaymentAndAssertFailed(t, p, reason) + sendPaymentAndAssertError(t, ctx, p, reason) // Expected collectResultAsync to not be called. require.Zero(t, m.collectResultsCount) @@ -721,7 +729,65 @@ func TestResumePaymentFailOnTimeoutErr(t *testing.T) { close(p.router.quit) // Send the payment and assert it failed when router is shutting down. - sendPaymentAndAssertFailed(t, p, ErrRouterShuttingDown) + sendPaymentAndAssertError( + t, context.Background(), p, ErrRouterShuttingDown, + ) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailContextCancel checks that the lifecycle fails when the +// context is canceled. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailContextCancel(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create the cancelable payment context. + ctx, cancel := context.WithCancel(context.Background()) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is not + // critical, so we loosen the checks on how many times it's been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. Cancel the context and skip the FailPayment error to trigger the + // context cancellation of the payment. + cancel() + + m.control.On( + "FailPayment", p.identifier, channeldb.FailureReasonTimeout, + ).Return(nil).Once() + + // 5. decideNextStep now returns stepExit. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // 6. Control tower deletes failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // 7. We will observe FailureReasonError if the context was cancelled. + reason := channeldb.FailureReasonError + m.payment.On("TerminalInfo").Return(nil, &reason) + + // Send the payment and assert it failed with the timeout reason. + sendPaymentAndAssertError(t, ctx, p, reason) // Expected collectResultAsync to not be called. require.Zero(t, m.collectResultsCount) @@ -759,7 +825,7 @@ func TestResumePaymentFailOnStepErr(t *testing.T) { m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once() // Send the payment and assert it failed. - sendPaymentAndAssertFailed(t, p, errDummy) + sendPaymentAndAssertError(t, context.Background(), p, errDummy) // Expected collectResultAsync to not be called. require.Zero(t, m.collectResultsCount) @@ -803,7 +869,7 @@ func TestResumePaymentFailOnRequestRouteErr(t *testing.T) { ).Return(nil, errDummy).Once() // Send the payment and assert it failed. - sendPaymentAndAssertFailed(t, p, errDummy) + sendPaymentAndAssertError(t, context.Background(), p, errDummy) // Expected collectResultAsync to not be called. require.Zero(t, m.collectResultsCount) @@ -863,7 +929,7 @@ func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) { ).Return(nil, errDummy).Once() // Send the payment and assert it failed. - sendPaymentAndAssertFailed(t, p, errDummy) + sendPaymentAndAssertError(t, context.Background(), p, errDummy) // Expected collectResultAsync to not be called. require.Zero(t, m.collectResultsCount) @@ -955,7 +1021,7 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { ).Return(nil, errDummy).Once() // Send the payment and assert it failed. - sendPaymentAndAssertFailed(t, p, errDummy) + sendPaymentAndAssertError(t, context.Background(), p, errDummy) // Expected collectResultAsync to not be called. require.Zero(t, m.collectResultsCount) diff --git a/routing/router.go b/routing/router.go index 088f08357..851db4af0 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "fmt" "math" "runtime" @@ -715,13 +716,15 @@ func (r *ChannelRouter) Start() error { // result for the in-flight attempt is received. paySession := r.cfg.SessionSource.NewPaymentSessionEmpty() - // We pass in a zero timeout value, to indicate we + // We pass in a non-timeout context, to indicate we // don't need it to timeout. It will stop immediately // after the existing attempt has finished anyway. We // also set a zero fee limit, as no more routes should // be tried. + noTimeout := time.Duration(0) _, _, err := r.sendPayment( - 0, payment.Info.PaymentIdentifier, 0, + context.Background(), 0, + payment.Info.PaymentIdentifier, noTimeout, paySession, shardTracker, ) if err != nil { @@ -2406,18 +2409,16 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte, log.Tracef("Dispatching SendPayment for lightning payment: %v", spewPayment(payment)) - // Since this is the first time this payment is being made, we pass nil - // for the existing attempt. return r.sendPayment( - payment.FeeLimit, payment.Identifier(), + context.Background(), payment.FeeLimit, payment.Identifier(), payment.PayAttemptTimeout, paySession, shardTracker, ) } // SendPaymentAsync is the non-blocking version of SendPayment. The payment // result needs to be retrieved via the control tower. -func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment, - ps PaymentSession, st shards.ShardTracker) { +func (r *ChannelRouter) SendPaymentAsync(ctx context.Context, + payment *LightningPayment, ps PaymentSession, st shards.ShardTracker) { // Since this is the first time this payment is being made, we pass nil // for the existing attempt. @@ -2429,7 +2430,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment, spewPayment(payment)) _, _, err := r.sendPayment( - payment.FeeLimit, payment.Identifier(), + ctx, payment.FeeLimit, payment.Identifier(), payment.PayAttemptTimeout, ps, st, ) if err != nil { @@ -2604,9 +2605,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // - nil payment session (since we already have a route). // - no payment timeout. // - no current block height. - p := newPaymentLifecycle( - r, 0, paymentIdentifier, nil, shardTracker, 0, 0, - ) + p := newPaymentLifecycle(r, 0, paymentIdentifier, nil, shardTracker, 0) // We found a route to try, create a new HTLC attempt to try. // @@ -2699,11 +2698,23 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // carry out its execution. After restarts, it is safe, and assumed, that the // router will call this method for every payment still in-flight according to // the ControlTower. -func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi, - identifier lntypes.Hash, timeout time.Duration, - paySession PaymentSession, +func (r *ChannelRouter) sendPayment(ctx context.Context, + feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash, + paymentAttemptTimeout time.Duration, paySession PaymentSession, shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) { + // If the user provides a timeout, we will additionally wrap the context + // in a deadline. + cancel := func() {} + if paymentAttemptTimeout > 0 { + ctx, cancel = context.WithTimeout(ctx, paymentAttemptTimeout) + } + + // Since resumePayment is a blocking call, we'll cancel this + // context if the payment completes before the optional + // deadline. + defer cancel() + // We'll also fetch the current block height, so we can properly // calculate the required HTLC time locks within the route. _, currentHeight, err := r.cfg.Chain.GetBestBlock() @@ -2714,11 +2725,11 @@ func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi, // Now set up a paymentLifecycle struct with these params, such that we // can resume the payment from the current state. p := newPaymentLifecycle( - r, feeLimit, identifier, paySession, - shardTracker, timeout, currentHeight, + r, feeLimit, identifier, paySession, shardTracker, + currentHeight, ) - return p.resumePayment() + return p.resumePayment(ctx) } // extractChannelUpdate examines the error and extracts the channel update.