mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-11-24 13:06:43 +01:00
routing: stop tracking totalAmount in paymentLifecycle
This commit removes the field `totalAmount` from `paymentLifecycle` and only reads it from the channeldb payment.
This commit is contained in:
@@ -758,6 +758,7 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
|
|||||||
// Make a copy of the payment here to avoid data race.
|
// Make a copy of the payment here to avoid data race.
|
||||||
p := args.Get(0).(*channeldb.MPPayment)
|
p := args.Get(0).(*channeldb.MPPayment)
|
||||||
payment := &channeldb.MPPayment{
|
payment := &channeldb.MPPayment{
|
||||||
|
Info: p.Info,
|
||||||
FailureReason: p.FailureReason,
|
FailureReason: p.FailureReason,
|
||||||
}
|
}
|
||||||
payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs))
|
payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs))
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ var errShardHandlerExiting = fmt.Errorf("shard handler exiting")
|
|||||||
// needed to resume if from any point.
|
// needed to resume if from any point.
|
||||||
type paymentLifecycle struct {
|
type paymentLifecycle struct {
|
||||||
router *ChannelRouter
|
router *ChannelRouter
|
||||||
totalAmount lnwire.MilliSatoshi
|
|
||||||
feeLimit lnwire.MilliSatoshi
|
feeLimit lnwire.MilliSatoshi
|
||||||
identifier lntypes.Hash
|
identifier lntypes.Hash
|
||||||
paySession PaymentSession
|
paySession PaymentSession
|
||||||
@@ -83,9 +82,10 @@ func (p *paymentLifecycle) fetchPaymentState() (*channeldb.MPPayment,
|
|||||||
sentAmt, fees := payment.SentAmt()
|
sentAmt, fees := payment.SentAmt()
|
||||||
|
|
||||||
// Sanity check we haven't sent a value larger than the payment amount.
|
// Sanity check we haven't sent a value larger than the payment amount.
|
||||||
if sentAmt > p.totalAmount {
|
totalAmt := payment.Info.Value
|
||||||
|
if sentAmt > totalAmt {
|
||||||
return nil, nil, fmt.Errorf("amount sent %v exceeds "+
|
return nil, nil, fmt.Errorf("amount sent %v exceeds "+
|
||||||
"total amount %v", sentAmt, p.totalAmount)
|
"total amount %v", sentAmt, totalAmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We'll subtract the used fee from our fee budget, but allow the fees
|
// We'll subtract the used fee from our fee budget, but allow the fees
|
||||||
@@ -109,7 +109,7 @@ func (p *paymentLifecycle) fetchPaymentState() (*channeldb.MPPayment,
|
|||||||
// Update the payment state.
|
// Update the payment state.
|
||||||
state := &paymentState{
|
state := &paymentState{
|
||||||
numShardsInFlight: len(payment.InFlightHTLCs()),
|
numShardsInFlight: len(payment.InFlightHTLCs()),
|
||||||
remainingAmt: p.totalAmount - sentAmt,
|
remainingAmt: totalAmt - sentAmt,
|
||||||
remainingFees: feeBudget,
|
remainingFees: feeBudget,
|
||||||
terminate: terminate,
|
terminate: terminate,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1052,10 +1052,9 @@ func TestUpdatePaymentState(t *testing.T) {
|
|||||||
ct := &mockControlTower{}
|
ct := &mockControlTower{}
|
||||||
rt := &ChannelRouter{cfg: &Config{Control: ct}}
|
rt := &ChannelRouter{cfg: &Config{Control: ct}}
|
||||||
pl := &paymentLifecycle{
|
pl := &paymentLifecycle{
|
||||||
router: rt,
|
router: rt,
|
||||||
identifier: paymentHash,
|
identifier: paymentHash,
|
||||||
totalAmount: lnwire.MilliSatoshi(tc.totalAmt),
|
feeLimit: lnwire.MilliSatoshi(tc.feeLimit),
|
||||||
feeLimit: lnwire.MilliSatoshi(tc.feeLimit),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if tc.payment == nil {
|
if tc.payment == nil {
|
||||||
@@ -1066,6 +1065,12 @@ func TestUpdatePaymentState(t *testing.T) {
|
|||||||
nil, dummyErr,
|
nil, dummyErr,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
// Attach the payment info.
|
||||||
|
info := &channeldb.PaymentCreationInfo{
|
||||||
|
Value: lnwire.MilliSatoshi(tc.totalAmt),
|
||||||
|
}
|
||||||
|
tc.payment.Info = info
|
||||||
|
|
||||||
// Otherwise we will return the payment.
|
// Otherwise we will return the payment.
|
||||||
ct.On("FetchPayment", paymentHash).Return(
|
ct.On("FetchPayment", paymentHash).Return(
|
||||||
tc.payment, nil,
|
tc.payment, nil,
|
||||||
|
|||||||
@@ -667,9 +667,8 @@ func (r *ChannelRouter) Start() error {
|
|||||||
// also set a zero fee limit, as no more routes should
|
// also set a zero fee limit, as no more routes should
|
||||||
// be tried.
|
// be tried.
|
||||||
_, _, err := r.sendPayment(
|
_, _, err := r.sendPayment(
|
||||||
payment.Info.Value, 0,
|
0, payment.Info.PaymentIdentifier, 0,
|
||||||
payment.Info.PaymentIdentifier, 0, paySession,
|
paySession, shardTracker,
|
||||||
shardTracker,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Resuming payment %v failed: %v.",
|
log.Errorf("Resuming payment %v failed: %v.",
|
||||||
@@ -2048,7 +2047,7 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
|
|||||||
// Since this is the first time this payment is being made, we pass nil
|
// Since this is the first time this payment is being made, we pass nil
|
||||||
// for the existing attempt.
|
// for the existing attempt.
|
||||||
return r.sendPayment(
|
return r.sendPayment(
|
||||||
payment.Amount, payment.FeeLimit, payment.Identifier(),
|
payment.FeeLimit, payment.Identifier(),
|
||||||
payment.PayAttemptTimeout, paySession, shardTracker,
|
payment.PayAttemptTimeout, paySession, shardTracker,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -2071,7 +2070,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
|
|||||||
spewPayment(payment))
|
spewPayment(payment))
|
||||||
|
|
||||||
_, _, err := r.sendPayment(
|
_, _, err := r.sendPayment(
|
||||||
payment.Amount, payment.FeeLimit, payment.Identifier(),
|
payment.FeeLimit, payment.Identifier(),
|
||||||
payment.PayAttemptTimeout, paySession, shardTracker,
|
payment.PayAttemptTimeout, paySession, shardTracker,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -2335,9 +2334,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
|
|||||||
// carry out its execution. After restarts it is safe, and assumed, that the
|
// carry out its execution. After restarts it is safe, and assumed, that the
|
||||||
// router will call this method for every payment still in-flight according to
|
// router will call this method for every payment still in-flight according to
|
||||||
// the ControlTower.
|
// the ControlTower.
|
||||||
func (r *ChannelRouter) sendPayment(
|
func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
|
||||||
totalAmt, feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
|
identifier lntypes.Hash, timeout time.Duration,
|
||||||
timeout time.Duration, paySession PaymentSession,
|
paySession PaymentSession,
|
||||||
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
|
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
|
||||||
|
|
||||||
// We'll also fetch the current block height so we can properly
|
// We'll also fetch the current block height so we can properly
|
||||||
@@ -2351,7 +2350,6 @@ func (r *ChannelRouter) sendPayment(
|
|||||||
// can resume the payment from the current state.
|
// can resume the payment from the current state.
|
||||||
p := &paymentLifecycle{
|
p := &paymentLifecycle{
|
||||||
router: r,
|
router: r,
|
||||||
totalAmount: totalAmt,
|
|
||||||
feeLimit: feeLimit,
|
feeLimit: feeLimit,
|
||||||
identifier: identifier,
|
identifier: identifier,
|
||||||
paySession: paySession,
|
paySession: paySession,
|
||||||
|
|||||||
@@ -3421,7 +3421,9 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
|||||||
// The following mocked methods are called inside resumePayment. Note
|
// The following mocked methods are called inside resumePayment. Note
|
||||||
// that the payment object below will determine the state of the
|
// that the payment object below will determine the state of the
|
||||||
// paymentLifecycle.
|
// paymentLifecycle.
|
||||||
payment := &channeldb.MPPayment{}
|
payment := &channeldb.MPPayment{
|
||||||
|
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||||
|
}
|
||||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||||
|
|
||||||
// Create a route that can send 1/4 of the total amount. This value
|
// Create a route that can send 1/4 of the total amount. This value
|
||||||
@@ -3588,7 +3590,9 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
|||||||
// The following mocked methods are called inside resumePayment. Note
|
// The following mocked methods are called inside resumePayment. Note
|
||||||
// that the payment object below will determine the state of the
|
// that the payment object below will determine the state of the
|
||||||
// paymentLifecycle.
|
// paymentLifecycle.
|
||||||
payment := &channeldb.MPPayment{}
|
payment := &channeldb.MPPayment{
|
||||||
|
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||||
|
}
|
||||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||||
|
|
||||||
// Create a route that can send 1/4 of the total amount. This value
|
// Create a route that can send 1/4 of the total amount. This value
|
||||||
@@ -3800,7 +3804,9 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||||||
// The following mocked methods are called inside resumePayment. Note
|
// The following mocked methods are called inside resumePayment. Note
|
||||||
// that the payment object below will determine the state of the
|
// that the payment object below will determine the state of the
|
||||||
// paymentLifecycle.
|
// paymentLifecycle.
|
||||||
payment := &channeldb.MPPayment{}
|
payment := &channeldb.MPPayment{
|
||||||
|
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||||
|
}
|
||||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||||
|
|
||||||
// Create a route that can send 1/4 of the total amount. This value
|
// Create a route that can send 1/4 of the total amount. This value
|
||||||
@@ -4004,7 +4010,9 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
|||||||
// The following mocked methods are called inside resumePayment. Note
|
// The following mocked methods are called inside resumePayment. Note
|
||||||
// that the payment object below will determine the state of the
|
// that the payment object below will determine the state of the
|
||||||
// paymentLifecycle.
|
// paymentLifecycle.
|
||||||
payment := &channeldb.MPPayment{}
|
payment := &channeldb.MPPayment{
|
||||||
|
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||||
|
}
|
||||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||||
|
|
||||||
// Create a route that can send 1/4 of the total amount. This value
|
// Create a route that can send 1/4 of the total amount. This value
|
||||||
|
|||||||
Reference in New Issue
Block a user