mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-14 19:01:10 +02:00
routing+routerrpc: cancelable context in SendPaymentV2
In this commit we set up the payment loop context according to user-provided parameters. The `cancelable` parameter indicates whether the user is able to interrupt the payment loop by cancelling the server stream context. We'll additionally wrap the context in a deadline if the user provided a payment timeout. We remove the timeout channel of the payment_lifecycle.go and in favor of the deadline context.
This commit is contained in:
parent
e729084149
commit
bba01cf634
@ -360,13 +360,25 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The payment context is influenced by two user-provided parameters,
|
||||||
|
// the cancelable flag and the payment attempt timeout.
|
||||||
|
// If the payment is cancelable, we will use the stream context as the
|
||||||
|
// payment context. That way, if the user ends the stream, the payment
|
||||||
|
// loop will be canceled.
|
||||||
|
// The second context parameter is the timeout. If the user provides a
|
||||||
|
// timeout, we will additionally wrap the context in a deadline. If the
|
||||||
|
// user provided 'cancelable' and ends the stream before the timeout is
|
||||||
|
// reached the payment will be canceled.
|
||||||
|
ctx := context.Background()
|
||||||
|
if req.Cancelable {
|
||||||
|
ctx = stream.Context()
|
||||||
|
}
|
||||||
|
|
||||||
// Send the payment asynchronously.
|
// Send the payment asynchronously.
|
||||||
s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker)
|
s.cfg.Router.SendPaymentAsync(ctx, payment, paySession, shardTracker)
|
||||||
|
|
||||||
// Track the payment and return.
|
// Track the payment and return.
|
||||||
return s.trackPayment(
|
return s.trackPayment(sub, payHash, stream, req.NoInflightUpdates)
|
||||||
sub, payHash, stream, req.NoInflightUpdates,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// EstimateRouteFee allows callers to obtain an expected value w.r.t how much it
|
// EstimateRouteFee allows callers to obtain an expected value w.r.t how much it
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
@ -29,7 +30,6 @@ type paymentLifecycle struct {
|
|||||||
identifier lntypes.Hash
|
identifier lntypes.Hash
|
||||||
paySession PaymentSession
|
paySession PaymentSession
|
||||||
shardTracker shards.ShardTracker
|
shardTracker shards.ShardTracker
|
||||||
timeoutChan <-chan time.Time
|
|
||||||
currentHeight int32
|
currentHeight int32
|
||||||
|
|
||||||
// quit is closed to signal the sub goroutines of the payment lifecycle
|
// quit is closed to signal the sub goroutines of the payment lifecycle
|
||||||
@ -52,7 +52,7 @@ type paymentLifecycle struct {
|
|||||||
// newPaymentLifecycle initiates a new payment lifecycle and returns it.
|
// newPaymentLifecycle initiates a new payment lifecycle and returns it.
|
||||||
func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
|
func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
|
||||||
identifier lntypes.Hash, paySession PaymentSession,
|
identifier lntypes.Hash, paySession PaymentSession,
|
||||||
shardTracker shards.ShardTracker, timeout time.Duration,
|
shardTracker shards.ShardTracker,
|
||||||
currentHeight int32) *paymentLifecycle {
|
currentHeight int32) *paymentLifecycle {
|
||||||
|
|
||||||
p := &paymentLifecycle{
|
p := &paymentLifecycle{
|
||||||
@ -69,13 +69,6 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
|
|||||||
// Mount the result collector.
|
// Mount the result collector.
|
||||||
p.resultCollector = p.collectResultAsync
|
p.resultCollector = p.collectResultAsync
|
||||||
|
|
||||||
// If a timeout is specified, create a timeout channel. If no timeout is
|
|
||||||
// specified, the channel is left nil and will never abort the payment
|
|
||||||
// loop.
|
|
||||||
if timeout != 0 {
|
|
||||||
p.timeoutChan = time.After(timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,7 +160,9 @@ func (p *paymentLifecycle) decideNextStep(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// resumePayment resumes the paymentLifecycle from the current state.
|
// resumePayment resumes the paymentLifecycle from the current state.
|
||||||
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
|
func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte,
|
||||||
|
*route.Route, error) {
|
||||||
|
|
||||||
// When the payment lifecycle loop exits, we make sure to signal any
|
// When the payment lifecycle loop exits, we make sure to signal any
|
||||||
// sub goroutine of the HTLC attempt to exit, then wait for them to
|
// sub goroutine of the HTLC attempt to exit, then wait for them to
|
||||||
// return.
|
// return.
|
||||||
@ -221,18 +216,17 @@ lifecycle:
|
|||||||
|
|
||||||
// We now proceed our lifecycle with the following tasks in
|
// We now proceed our lifecycle with the following tasks in
|
||||||
// order,
|
// order,
|
||||||
// 1. check timeout.
|
// 1. check context.
|
||||||
// 2. request route.
|
// 2. request route.
|
||||||
// 3. create HTLC attempt.
|
// 3. create HTLC attempt.
|
||||||
// 4. send HTLC attempt.
|
// 4. send HTLC attempt.
|
||||||
// 5. collect HTLC attempt result.
|
// 5. collect HTLC attempt result.
|
||||||
//
|
//
|
||||||
// Before we attempt any new shard, we'll check to see if
|
// Before we attempt any new shard, we'll check to see if we've
|
||||||
// either we've gone past the payment attempt timeout, or the
|
// gone past the payment attempt timeout, or if the context was
|
||||||
// router is exiting. In either case, we'll stop this payment
|
// cancelled, or the router is exiting. In any of these cases,
|
||||||
// attempt short. If a timeout is not applicable, timeoutChan
|
// we'll stop this payment attempt short.
|
||||||
// will be nil.
|
if err := p.checkContext(ctx); err != nil {
|
||||||
if err := p.checkTimeout(); err != nil {
|
|
||||||
return exitWithErr(err)
|
return exitWithErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -318,19 +312,30 @@ lifecycle:
|
|||||||
return [32]byte{}, nil, *failure
|
return [32]byte{}, nil, *failure
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkTimeout checks whether the payment has reached its timeout.
|
// checkContext checks whether the payment context has been canceled.
|
||||||
func (p *paymentLifecycle) checkTimeout() error {
|
// Cancellation occurs manually or if the context times out.
|
||||||
|
func (p *paymentLifecycle) checkContext(ctx context.Context) error {
|
||||||
select {
|
select {
|
||||||
case <-p.timeoutChan:
|
case <-ctx.Done():
|
||||||
log.Warnf("payment attempt not completed before timeout")
|
// If the context was canceled, we'll mark the payment as
|
||||||
|
// failed. There are two cases to distinguish here: Either a
|
||||||
|
// user-provided timeout was reached, or the context was
|
||||||
|
// canceled, either to a manual cancellation or due to an
|
||||||
|
// unknown error.
|
||||||
|
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||||
|
log.Warnf("Payment attempt not completed before "+
|
||||||
|
"timeout, id=%s", p.identifier.String())
|
||||||
|
} else {
|
||||||
|
log.Warnf("Payment attempt context canceled, id=%s",
|
||||||
|
p.identifier.String())
|
||||||
|
}
|
||||||
|
|
||||||
// By marking the payment failed, depending on whether it has
|
// By marking the payment failed, depending on whether it has
|
||||||
// inflight HTLCs or not, its status will now either be
|
// inflight HTLCs or not, its status will now either be
|
||||||
// `StatusInflight` or `StatusFailed`. In either case, no more
|
// `StatusInflight` or `StatusFailed`. In either case, no more
|
||||||
// HTLCs will be attempted.
|
// HTLCs will be attempted.
|
||||||
err := p.router.cfg.Control.FailPayment(
|
reason := channeldb.FailureReasonTimeout
|
||||||
p.identifier, channeldb.FailureReasonTimeout,
|
err := p.router.cfg.Control.FailPayment(p.identifier, reason)
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("FailPayment got %w", err)
|
return fmt.Errorf("FailPayment got %w", err)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -88,7 +89,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) {
|
|||||||
// Create a test payment lifecycle with no fee limit and no timeout.
|
// Create a test payment lifecycle with no fee limit and no timeout.
|
||||||
p := newPaymentLifecycle(
|
p := newPaymentLifecycle(
|
||||||
rt, noFeeLimit, paymentHash, mockPaymentSession,
|
rt, noFeeLimit, paymentHash, mockPaymentSession,
|
||||||
mockShardTracker, 0, 0,
|
mockShardTracker, 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create a mock payment which is returned from mockControlTower.
|
// Create a mock payment which is returned from mockControlTower.
|
||||||
@ -151,9 +152,9 @@ type resumePaymentResult struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error
|
// sendPaymentAndAssertError calls `resumePayment` and asserts that an error is
|
||||||
// is returned.
|
// returned.
|
||||||
func sendPaymentAndAssertFailed(t *testing.T,
|
func sendPaymentAndAssertError(t *testing.T, ctx context.Context,
|
||||||
p *paymentLifecycle, errExpected error) {
|
p *paymentLifecycle, errExpected error) {
|
||||||
|
|
||||||
resultChan := make(chan *resumePaymentResult, 1)
|
resultChan := make(chan *resumePaymentResult, 1)
|
||||||
@ -161,7 +162,7 @@ func sendPaymentAndAssertFailed(t *testing.T,
|
|||||||
// We now make a call to `resumePayment` and expect it to return the
|
// We now make a call to `resumePayment` and expect it to return the
|
||||||
// error.
|
// error.
|
||||||
go func() {
|
go func() {
|
||||||
preimage, _, err := p.resumePayment()
|
preimage, _, err := p.resumePayment(ctx)
|
||||||
resultChan <- &resumePaymentResult{
|
resultChan <- &resumePaymentResult{
|
||||||
preimage: preimage,
|
preimage: preimage,
|
||||||
err: err,
|
err: err,
|
||||||
@ -189,7 +190,7 @@ func sendPaymentAndAssertSucceeded(t *testing.T,
|
|||||||
// We now make a call to `resumePayment` and expect it to return the
|
// We now make a call to `resumePayment` and expect it to return the
|
||||||
// preimage.
|
// preimage.
|
||||||
go func() {
|
go func() {
|
||||||
preimage, _, err := p.resumePayment()
|
preimage, _, err := p.resumePayment(context.Background())
|
||||||
resultChan <- &resumePaymentResult{
|
resultChan <- &resumePaymentResult{
|
||||||
preimage: preimage,
|
preimage: preimage,
|
||||||
err: err,
|
err: err,
|
||||||
@ -278,6 +279,10 @@ func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo {
|
|||||||
func TestCheckTimeoutTimedOut(t *testing.T) {
|
func TestCheckTimeoutTimedOut(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
deadline := time.Now().Add(time.Nanosecond)
|
||||||
|
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
p := createTestPaymentLifecycle()
|
p := createTestPaymentLifecycle()
|
||||||
|
|
||||||
// Mock the control tower's `FailPayment` method.
|
// Mock the control tower's `FailPayment` method.
|
||||||
@ -288,14 +293,11 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
|
|||||||
// Mount the mocked control tower.
|
// Mount the mocked control tower.
|
||||||
p.router.cfg.Control = ct
|
p.router.cfg.Control = ct
|
||||||
|
|
||||||
// Make the timeout happens instantly.
|
|
||||||
p.timeoutChan = time.After(1 * time.Nanosecond)
|
|
||||||
|
|
||||||
// Sleep one millisecond to make sure it timed out.
|
// Sleep one millisecond to make sure it timed out.
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
|
||||||
// Call the function and expect no error.
|
// Call the function and expect no error.
|
||||||
err := p.checkTimeout()
|
err := p.checkContext(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Assert that `FailPayment` is called as expected.
|
// Assert that `FailPayment` is called as expected.
|
||||||
@ -313,13 +315,15 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
|
|||||||
p.router.cfg.Control = ct
|
p.router.cfg.Control = ct
|
||||||
|
|
||||||
// Make the timeout happens instantly.
|
// Make the timeout happens instantly.
|
||||||
p.timeoutChan = time.After(1 * time.Nanosecond)
|
deadline = time.Now().Add(time.Nanosecond)
|
||||||
|
ctx, cancel = context.WithDeadline(context.Background(), deadline)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// Sleep one millisecond to make sure it timed out.
|
// Sleep one millisecond to make sure it timed out.
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
|
||||||
// Call the function and expect an error.
|
// Call the function and expect an error.
|
||||||
err = p.checkTimeout()
|
err = p.checkContext(ctx)
|
||||||
require.ErrorIs(t, err, errDummy)
|
require.ErrorIs(t, err, errDummy)
|
||||||
|
|
||||||
// Assert that `FailPayment` is called as expected.
|
// Assert that `FailPayment` is called as expected.
|
||||||
@ -331,10 +335,13 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
|
|||||||
func TestCheckTimeoutOnRouterQuit(t *testing.T) {
|
func TestCheckTimeoutOnRouterQuit(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
p := createTestPaymentLifecycle()
|
p := createTestPaymentLifecycle()
|
||||||
|
|
||||||
close(p.router.quit)
|
close(p.router.quit)
|
||||||
err := p.checkTimeout()
|
err := p.checkContext(ctx)
|
||||||
require.ErrorIs(t, err, ErrRouterShuttingDown)
|
require.ErrorIs(t, err, ErrRouterShuttingDown)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -627,7 +634,7 @@ func TestResumePaymentFailOnFetchPayment(t *testing.T) {
|
|||||||
m.control.On("FetchPayment", p.identifier).Return(nil, errDummy)
|
m.control.On("FetchPayment", p.identifier).Return(nil, errDummy)
|
||||||
|
|
||||||
// Send the payment and assert it failed.
|
// Send the payment and assert it failed.
|
||||||
sendPaymentAndAssertFailed(t, p, errDummy)
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
||||||
|
|
||||||
// Expected collectResultAsync to not be called.
|
// Expected collectResultAsync to not be called.
|
||||||
require.Zero(t, m.collectResultsCount)
|
require.Zero(t, m.collectResultsCount)
|
||||||
@ -656,14 +663,15 @@ func TestResumePaymentFailOnTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
m.payment.On("GetState").Return(ps).Once()
|
m.payment.On("GetState").Return(ps).Once()
|
||||||
|
|
||||||
// NOTE: GetStatus is only used to populate the logs which is
|
// NOTE: GetStatus is only used to populate the logs which is not
|
||||||
// not critical so we loosen the checks on how many times it's
|
// critical, so we loosen the checks on how many times it's been called.
|
||||||
// been called.
|
|
||||||
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
||||||
|
|
||||||
// 3. make the timeout happens instantly and sleep one millisecond to
|
// 3. make the timeout happens instantly and sleep one millisecond to
|
||||||
// make sure it timed out.
|
// make sure it timed out.
|
||||||
p.timeoutChan = time.After(1 * time.Nanosecond)
|
deadline := time.Now().Add(time.Nanosecond)
|
||||||
|
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||||
|
defer cancel()
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
|
||||||
// 4. the payment should be failed with reason timeout.
|
// 4. the payment should be failed with reason timeout.
|
||||||
@ -683,7 +691,7 @@ func TestResumePaymentFailOnTimeout(t *testing.T) {
|
|||||||
m.payment.On("TerminalInfo").Return(nil, &reason)
|
m.payment.On("TerminalInfo").Return(nil, &reason)
|
||||||
|
|
||||||
// Send the payment and assert it failed with the timeout reason.
|
// Send the payment and assert it failed with the timeout reason.
|
||||||
sendPaymentAndAssertFailed(t, p, reason)
|
sendPaymentAndAssertError(t, ctx, p, reason)
|
||||||
|
|
||||||
// Expected collectResultAsync to not be called.
|
// Expected collectResultAsync to not be called.
|
||||||
require.Zero(t, m.collectResultsCount)
|
require.Zero(t, m.collectResultsCount)
|
||||||
@ -721,7 +729,65 @@ func TestResumePaymentFailOnTimeoutErr(t *testing.T) {
|
|||||||
close(p.router.quit)
|
close(p.router.quit)
|
||||||
|
|
||||||
// Send the payment and assert it failed when router is shutting down.
|
// Send the payment and assert it failed when router is shutting down.
|
||||||
sendPaymentAndAssertFailed(t, p, ErrRouterShuttingDown)
|
sendPaymentAndAssertError(
|
||||||
|
t, context.Background(), p, ErrRouterShuttingDown,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Expected collectResultAsync to not be called.
|
||||||
|
require.Zero(t, m.collectResultsCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestResumePaymentFailContextCancel checks that the lifecycle fails when the
|
||||||
|
// context is canceled.
|
||||||
|
//
|
||||||
|
// NOTE: No parallel test because it overwrites global variables.
|
||||||
|
//
|
||||||
|
//nolint:paralleltest
|
||||||
|
func TestResumePaymentFailContextCancel(t *testing.T) {
|
||||||
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
||||||
|
p, m := setupTestPaymentLifecycle(t)
|
||||||
|
|
||||||
|
// Create the cancelable payment context.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
||||||
|
|
||||||
|
// We now enter the payment lifecycle loop.
|
||||||
|
//
|
||||||
|
// 1. calls `FetchPayment` and return the payment.
|
||||||
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
||||||
|
|
||||||
|
// 2. calls `GetState` and return the state.
|
||||||
|
ps := &channeldb.MPPaymentState{
|
||||||
|
RemainingAmt: paymentAmt,
|
||||||
|
}
|
||||||
|
m.payment.On("GetState").Return(ps).Once()
|
||||||
|
|
||||||
|
// NOTE: GetStatus is only used to populate the logs which is not
|
||||||
|
// critical, so we loosen the checks on how many times it's been called.
|
||||||
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
||||||
|
|
||||||
|
// 3. Cancel the context and skip the FailPayment error to trigger the
|
||||||
|
// context cancellation of the payment.
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
m.control.On(
|
||||||
|
"FailPayment", p.identifier, channeldb.FailureReasonTimeout,
|
||||||
|
).Return(nil).Once()
|
||||||
|
|
||||||
|
// 5. decideNextStep now returns stepExit.
|
||||||
|
m.payment.On("AllowMoreAttempts").Return(false, nil).Once().
|
||||||
|
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||||
|
|
||||||
|
// 6. Control tower deletes failed attempts.
|
||||||
|
m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once()
|
||||||
|
|
||||||
|
// 7. We will observe FailureReasonError if the context was cancelled.
|
||||||
|
reason := channeldb.FailureReasonError
|
||||||
|
m.payment.On("TerminalInfo").Return(nil, &reason)
|
||||||
|
|
||||||
|
// Send the payment and assert it failed with the timeout reason.
|
||||||
|
sendPaymentAndAssertError(t, ctx, p, reason)
|
||||||
|
|
||||||
// Expected collectResultAsync to not be called.
|
// Expected collectResultAsync to not be called.
|
||||||
require.Zero(t, m.collectResultsCount)
|
require.Zero(t, m.collectResultsCount)
|
||||||
@ -759,7 +825,7 @@ func TestResumePaymentFailOnStepErr(t *testing.T) {
|
|||||||
m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once()
|
m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once()
|
||||||
|
|
||||||
// Send the payment and assert it failed.
|
// Send the payment and assert it failed.
|
||||||
sendPaymentAndAssertFailed(t, p, errDummy)
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
||||||
|
|
||||||
// Expected collectResultAsync to not be called.
|
// Expected collectResultAsync to not be called.
|
||||||
require.Zero(t, m.collectResultsCount)
|
require.Zero(t, m.collectResultsCount)
|
||||||
@ -803,7 +869,7 @@ func TestResumePaymentFailOnRequestRouteErr(t *testing.T) {
|
|||||||
).Return(nil, errDummy).Once()
|
).Return(nil, errDummy).Once()
|
||||||
|
|
||||||
// Send the payment and assert it failed.
|
// Send the payment and assert it failed.
|
||||||
sendPaymentAndAssertFailed(t, p, errDummy)
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
||||||
|
|
||||||
// Expected collectResultAsync to not be called.
|
// Expected collectResultAsync to not be called.
|
||||||
require.Zero(t, m.collectResultsCount)
|
require.Zero(t, m.collectResultsCount)
|
||||||
@ -863,7 +929,7 @@ func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) {
|
|||||||
).Return(nil, errDummy).Once()
|
).Return(nil, errDummy).Once()
|
||||||
|
|
||||||
// Send the payment and assert it failed.
|
// Send the payment and assert it failed.
|
||||||
sendPaymentAndAssertFailed(t, p, errDummy)
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
||||||
|
|
||||||
// Expected collectResultAsync to not be called.
|
// Expected collectResultAsync to not be called.
|
||||||
require.Zero(t, m.collectResultsCount)
|
require.Zero(t, m.collectResultsCount)
|
||||||
@ -955,7 +1021,7 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) {
|
|||||||
).Return(nil, errDummy).Once()
|
).Return(nil, errDummy).Once()
|
||||||
|
|
||||||
// Send the payment and assert it failed.
|
// Send the payment and assert it failed.
|
||||||
sendPaymentAndAssertFailed(t, p, errDummy)
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
||||||
|
|
||||||
// Expected collectResultAsync to not be called.
|
// Expected collectResultAsync to not be called.
|
||||||
require.Zero(t, m.collectResultsCount)
|
require.Zero(t, m.collectResultsCount)
|
||||||
|
@ -2,6 +2,7 @@ package routing
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -715,13 +716,15 @@ func (r *ChannelRouter) Start() error {
|
|||||||
// result for the in-flight attempt is received.
|
// result for the in-flight attempt is received.
|
||||||
paySession := r.cfg.SessionSource.NewPaymentSessionEmpty()
|
paySession := r.cfg.SessionSource.NewPaymentSessionEmpty()
|
||||||
|
|
||||||
// We pass in a zero timeout value, to indicate we
|
// We pass in a non-timeout context, to indicate we
|
||||||
// don't need it to timeout. It will stop immediately
|
// don't need it to timeout. It will stop immediately
|
||||||
// after the existing attempt has finished anyway. We
|
// after the existing attempt has finished anyway. We
|
||||||
// also set a zero fee limit, as no more routes should
|
// also set a zero fee limit, as no more routes should
|
||||||
// be tried.
|
// be tried.
|
||||||
|
noTimeout := time.Duration(0)
|
||||||
_, _, err := r.sendPayment(
|
_, _, err := r.sendPayment(
|
||||||
0, payment.Info.PaymentIdentifier, 0,
|
context.Background(), 0,
|
||||||
|
payment.Info.PaymentIdentifier, noTimeout,
|
||||||
paySession, shardTracker,
|
paySession, shardTracker,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -2406,18 +2409,16 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
|
|||||||
log.Tracef("Dispatching SendPayment for lightning payment: %v",
|
log.Tracef("Dispatching SendPayment for lightning payment: %v",
|
||||||
spewPayment(payment))
|
spewPayment(payment))
|
||||||
|
|
||||||
// Since this is the first time this payment is being made, we pass nil
|
|
||||||
// for the existing attempt.
|
|
||||||
return r.sendPayment(
|
return r.sendPayment(
|
||||||
payment.FeeLimit, payment.Identifier(),
|
context.Background(), payment.FeeLimit, payment.Identifier(),
|
||||||
payment.PayAttemptTimeout, paySession, shardTracker,
|
payment.PayAttemptTimeout, paySession, shardTracker,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
|
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
|
||||||
// result needs to be retrieved via the control tower.
|
// result needs to be retrieved via the control tower.
|
||||||
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
|
func (r *ChannelRouter) SendPaymentAsync(ctx context.Context,
|
||||||
ps PaymentSession, st shards.ShardTracker) {
|
payment *LightningPayment, ps PaymentSession, st shards.ShardTracker) {
|
||||||
|
|
||||||
// Since this is the first time this payment is being made, we pass nil
|
// Since this is the first time this payment is being made, we pass nil
|
||||||
// for the existing attempt.
|
// for the existing attempt.
|
||||||
@ -2429,7 +2430,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
|
|||||||
spewPayment(payment))
|
spewPayment(payment))
|
||||||
|
|
||||||
_, _, err := r.sendPayment(
|
_, _, err := r.sendPayment(
|
||||||
payment.FeeLimit, payment.Identifier(),
|
ctx, payment.FeeLimit, payment.Identifier(),
|
||||||
payment.PayAttemptTimeout, ps, st,
|
payment.PayAttemptTimeout, ps, st,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -2604,9 +2605,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
|
|||||||
// - nil payment session (since we already have a route).
|
// - nil payment session (since we already have a route).
|
||||||
// - no payment timeout.
|
// - no payment timeout.
|
||||||
// - no current block height.
|
// - no current block height.
|
||||||
p := newPaymentLifecycle(
|
p := newPaymentLifecycle(r, 0, paymentIdentifier, nil, shardTracker, 0)
|
||||||
r, 0, paymentIdentifier, nil, shardTracker, 0, 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
// We found a route to try, create a new HTLC attempt to try.
|
// We found a route to try, create a new HTLC attempt to try.
|
||||||
//
|
//
|
||||||
@ -2699,11 +2698,23 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
|
|||||||
// carry out its execution. After restarts, it is safe, and assumed, that the
|
// carry out its execution. After restarts, it is safe, and assumed, that the
|
||||||
// router will call this method for every payment still in-flight according to
|
// router will call this method for every payment still in-flight according to
|
||||||
// the ControlTower.
|
// the ControlTower.
|
||||||
func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
|
func (r *ChannelRouter) sendPayment(ctx context.Context,
|
||||||
identifier lntypes.Hash, timeout time.Duration,
|
feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
|
||||||
paySession PaymentSession,
|
paymentAttemptTimeout time.Duration, paySession PaymentSession,
|
||||||
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
|
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
|
||||||
|
|
||||||
|
// If the user provides a timeout, we will additionally wrap the context
|
||||||
|
// in a deadline.
|
||||||
|
cancel := func() {}
|
||||||
|
if paymentAttemptTimeout > 0 {
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, paymentAttemptTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since resumePayment is a blocking call, we'll cancel this
|
||||||
|
// context if the payment completes before the optional
|
||||||
|
// deadline.
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// We'll also fetch the current block height, so we can properly
|
// We'll also fetch the current block height, so we can properly
|
||||||
// calculate the required HTLC time locks within the route.
|
// calculate the required HTLC time locks within the route.
|
||||||
_, currentHeight, err := r.cfg.Chain.GetBestBlock()
|
_, currentHeight, err := r.cfg.Chain.GetBestBlock()
|
||||||
@ -2714,11 +2725,11 @@ func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
|
|||||||
// Now set up a paymentLifecycle struct with these params, such that we
|
// Now set up a paymentLifecycle struct with these params, such that we
|
||||||
// can resume the payment from the current state.
|
// can resume the payment from the current state.
|
||||||
p := newPaymentLifecycle(
|
p := newPaymentLifecycle(
|
||||||
r, feeLimit, identifier, paySession,
|
r, feeLimit, identifier, paySession, shardTracker,
|
||||||
shardTracker, timeout, currentHeight,
|
currentHeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
return p.resumePayment()
|
return p.resumePayment(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractChannelUpdate examines the error and extracts the channel update.
|
// extractChannelUpdate examines the error and extracts the channel update.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user