From 8d49dfb07ea00352effc9ed99ce995db270b78e6 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 10 Jun 2022 01:34:22 +0800 Subject: [PATCH] routing: stop tracking `totalAmount` in `paymentLifecycle` This commit removes the field `totalAmount` from `paymentLifecycle` and only reads it from the channeldb payment. --- routing/mock_test.go | 1 + routing/payment_lifecycle.go | 8 ++++---- routing/payment_lifecycle_test.go | 13 +++++++++---- routing/router.go | 16 +++++++--------- routing/router_test.go | 16 ++++++++++++---- 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/routing/mock_test.go b/routing/mock_test.go index 6d6dbf9b8..afd0126dd 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -758,6 +758,7 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( // Make a copy of the payment here to avoid data race. p := args.Get(0).(*channeldb.MPPayment) payment := &channeldb.MPPayment{ + Info: p.Info, FailureReason: p.FailureReason, } payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs)) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 6a7198ba2..32f0d811d 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -23,7 +23,6 @@ var errShardHandlerExiting = fmt.Errorf("shard handler exiting") // needed to resume if from any point. type paymentLifecycle struct { router *ChannelRouter - totalAmount lnwire.MilliSatoshi feeLimit lnwire.MilliSatoshi identifier lntypes.Hash paySession PaymentSession @@ -83,9 +82,10 @@ func (p *paymentLifecycle) fetchPaymentState() (*channeldb.MPPayment, sentAmt, fees := payment.SentAmt() // Sanity check we haven't sent a value larger than the payment amount. - if sentAmt > p.totalAmount { + totalAmt := payment.Info.Value + if sentAmt > totalAmt { return nil, nil, fmt.Errorf("amount sent %v exceeds "+ - "total amount %v", sentAmt, p.totalAmount) + "total amount %v", sentAmt, totalAmt) } // We'll subtract the used fee from our fee budget, but allow the fees @@ -109,7 +109,7 @@ func (p *paymentLifecycle) fetchPaymentState() (*channeldb.MPPayment, // Update the payment state. state := &paymentState{ numShardsInFlight: len(payment.InFlightHTLCs()), - remainingAmt: p.totalAmount - sentAmt, + remainingAmt: totalAmt - sentAmt, remainingFees: feeBudget, terminate: terminate, } diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index a220df8d0..102aa4c01 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -1052,10 +1052,9 @@ func TestUpdatePaymentState(t *testing.T) { ct := &mockControlTower{} rt := &ChannelRouter{cfg: &Config{Control: ct}} pl := &paymentLifecycle{ - router: rt, - identifier: paymentHash, - totalAmount: lnwire.MilliSatoshi(tc.totalAmt), - feeLimit: lnwire.MilliSatoshi(tc.feeLimit), + router: rt, + identifier: paymentHash, + feeLimit: lnwire.MilliSatoshi(tc.feeLimit), } if tc.payment == nil { @@ -1066,6 +1065,12 @@ func TestUpdatePaymentState(t *testing.T) { nil, dummyErr, ) } else { + // Attach the payment info. + info := &channeldb.PaymentCreationInfo{ + Value: lnwire.MilliSatoshi(tc.totalAmt), + } + tc.payment.Info = info + // Otherwise we will return the payment. ct.On("FetchPayment", paymentHash).Return( tc.payment, nil, diff --git a/routing/router.go b/routing/router.go index affad930b..01d4b5022 100644 --- a/routing/router.go +++ b/routing/router.go @@ -667,9 +667,8 @@ func (r *ChannelRouter) Start() error { // also set a zero fee limit, as no more routes should // be tried. _, _, err := r.sendPayment( - payment.Info.Value, 0, - payment.Info.PaymentIdentifier, 0, paySession, - shardTracker, + 0, payment.Info.PaymentIdentifier, 0, + paySession, shardTracker, ) if err != nil { log.Errorf("Resuming payment %v failed: %v.", @@ -2048,7 +2047,7 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte, // Since this is the first time this payment is being made, we pass nil // for the existing attempt. return r.sendPayment( - payment.Amount, payment.FeeLimit, payment.Identifier(), + payment.FeeLimit, payment.Identifier(), payment.PayAttemptTimeout, paySession, shardTracker, ) } @@ -2071,7 +2070,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error { spewPayment(payment)) _, _, err := r.sendPayment( - payment.Amount, payment.FeeLimit, payment.Identifier(), + payment.FeeLimit, payment.Identifier(), payment.PayAttemptTimeout, paySession, shardTracker, ) if err != nil { @@ -2335,9 +2334,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // carry out its execution. After restarts it is safe, and assumed, that the // router will call this method for every payment still in-flight according to // the ControlTower. -func (r *ChannelRouter) sendPayment( - totalAmt, feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash, - timeout time.Duration, paySession PaymentSession, +func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi, + identifier lntypes.Hash, timeout time.Duration, + paySession PaymentSession, shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) { // We'll also fetch the current block height so we can properly @@ -2351,7 +2350,6 @@ func (r *ChannelRouter) sendPayment( // can resume the payment from the current state. p := &paymentLifecycle{ router: r, - totalAmount: totalAmt, feeLimit: feeLimit, identifier: identifier, paySession: paySession, diff --git a/routing/router_test.go b/routing/router_test.go index 8387c284d..ea76140a2 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3421,7 +3421,9 @@ func TestSendMPPaymentSucceed(t *testing.T) { // The following mocked methods are called inside resumePayment. Note // that the payment object below will determine the state of the // paymentLifecycle. - payment := &channeldb.MPPayment{} + payment := &channeldb.MPPayment{ + Info: &channeldb.PaymentCreationInfo{Value: paymentAmt}, + } controlTower.On("FetchPayment", identifier).Return(payment, nil) // Create a route that can send 1/4 of the total amount. This value @@ -3588,7 +3590,9 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { // The following mocked methods are called inside resumePayment. Note // that the payment object below will determine the state of the // paymentLifecycle. - payment := &channeldb.MPPayment{} + payment := &channeldb.MPPayment{ + Info: &channeldb.PaymentCreationInfo{Value: paymentAmt}, + } controlTower.On("FetchPayment", identifier).Return(payment, nil) // Create a route that can send 1/4 of the total amount. This value @@ -3800,7 +3804,9 @@ func TestSendMPPaymentFailed(t *testing.T) { // The following mocked methods are called inside resumePayment. Note // that the payment object below will determine the state of the // paymentLifecycle. - payment := &channeldb.MPPayment{} + payment := &channeldb.MPPayment{ + Info: &channeldb.PaymentCreationInfo{Value: paymentAmt}, + } controlTower.On("FetchPayment", identifier).Return(payment, nil) // Create a route that can send 1/4 of the total amount. This value @@ -4004,7 +4010,9 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { // The following mocked methods are called inside resumePayment. Note // that the payment object below will determine the state of the // paymentLifecycle. - payment := &channeldb.MPPayment{} + payment := &channeldb.MPPayment{ + Info: &channeldb.PaymentCreationInfo{Value: paymentAmt}, + } controlTower.On("FetchPayment", identifier).Return(payment, nil) // Create a route that can send 1/4 of the total amount. This value