routing: stop tracking totalAmount in paymentLifecycle

This commit removes the field `totalAmount` from `paymentLifecycle` and
only reads it from the channeldb payment.
This commit is contained in:
yyforyongyu
2022-06-10 01:34:22 +08:00
parent e3bc4f4cc9
commit 8d49dfb07e
5 changed files with 33 additions and 21 deletions

View File

@ -758,6 +758,7 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
// Make a copy of the payment here to avoid data race.
p := args.Get(0).(*channeldb.MPPayment)
payment := &channeldb.MPPayment{
Info: p.Info,
FailureReason: p.FailureReason,
}
payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs))

View File

@ -23,7 +23,6 @@ var errShardHandlerExiting = fmt.Errorf("shard handler exiting")
// needed to resume if from any point.
type paymentLifecycle struct {
router *ChannelRouter
totalAmount lnwire.MilliSatoshi
feeLimit lnwire.MilliSatoshi
identifier lntypes.Hash
paySession PaymentSession
@ -83,9 +82,10 @@ func (p *paymentLifecycle) fetchPaymentState() (*channeldb.MPPayment,
sentAmt, fees := payment.SentAmt()
// Sanity check we haven't sent a value larger than the payment amount.
if sentAmt > p.totalAmount {
totalAmt := payment.Info.Value
if sentAmt > totalAmt {
return nil, nil, fmt.Errorf("amount sent %v exceeds "+
"total amount %v", sentAmt, p.totalAmount)
"total amount %v", sentAmt, totalAmt)
}
// We'll subtract the used fee from our fee budget, but allow the fees
@ -109,7 +109,7 @@ func (p *paymentLifecycle) fetchPaymentState() (*channeldb.MPPayment,
// Update the payment state.
state := &paymentState{
numShardsInFlight: len(payment.InFlightHTLCs()),
remainingAmt: p.totalAmount - sentAmt,
remainingAmt: totalAmt - sentAmt,
remainingFees: feeBudget,
terminate: terminate,
}

View File

@ -1052,10 +1052,9 @@ func TestUpdatePaymentState(t *testing.T) {
ct := &mockControlTower{}
rt := &ChannelRouter{cfg: &Config{Control: ct}}
pl := &paymentLifecycle{
router: rt,
identifier: paymentHash,
totalAmount: lnwire.MilliSatoshi(tc.totalAmt),
feeLimit: lnwire.MilliSatoshi(tc.feeLimit),
router: rt,
identifier: paymentHash,
feeLimit: lnwire.MilliSatoshi(tc.feeLimit),
}
if tc.payment == nil {
@ -1066,6 +1065,12 @@ func TestUpdatePaymentState(t *testing.T) {
nil, dummyErr,
)
} else {
// Attach the payment info.
info := &channeldb.PaymentCreationInfo{
Value: lnwire.MilliSatoshi(tc.totalAmt),
}
tc.payment.Info = info
// Otherwise we will return the payment.
ct.On("FetchPayment", paymentHash).Return(
tc.payment, nil,

View File

@ -667,9 +667,8 @@ func (r *ChannelRouter) Start() error {
// also set a zero fee limit, as no more routes should
// be tried.
_, _, err := r.sendPayment(
payment.Info.Value, 0,
payment.Info.PaymentIdentifier, 0, paySession,
shardTracker,
0, payment.Info.PaymentIdentifier, 0,
paySession, shardTracker,
)
if err != nil {
log.Errorf("Resuming payment %v failed: %v.",
@ -2048,7 +2047,7 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
return r.sendPayment(
payment.Amount, payment.FeeLimit, payment.Identifier(),
payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker,
)
}
@ -2071,7 +2070,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
spewPayment(payment))
_, _, err := r.sendPayment(
payment.Amount, payment.FeeLimit, payment.Identifier(),
payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker,
)
if err != nil {
@ -2335,9 +2334,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// carry out its execution. After restarts it is safe, and assumed, that the
// router will call this method for every payment still in-flight according to
// the ControlTower.
func (r *ChannelRouter) sendPayment(
totalAmt, feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
timeout time.Duration, paySession PaymentSession,
func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, timeout time.Duration,
paySession PaymentSession,
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
// We'll also fetch the current block height so we can properly
@ -2351,7 +2350,6 @@ func (r *ChannelRouter) sendPayment(
// can resume the payment from the current state.
p := &paymentLifecycle{
router: r,
totalAmount: totalAmt,
feeLimit: feeLimit,
identifier: identifier,
paySession: paySession,

View File

@ -3421,7 +3421,9 @@ func TestSendMPPaymentSucceed(t *testing.T) {
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{}
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
@ -3588,7 +3590,9 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{}
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
@ -3800,7 +3804,9 @@ func TestSendMPPaymentFailed(t *testing.T) {
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{}
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
@ -4004,7 +4010,9 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{}
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value