mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-11-29 15:36:34 +01:00
routing+channeldb: make MPPayment into an interface
This commit turns `MPPayment` into an interface inside `routing`. Having this interface gives us the benefit to write more granular unit tests inside payment lifecycle. As seen from the modified unit tests, several hacky ways of testing the `SendPayment` method is now replaced by a mock over `MPPayment`.
This commit is contained in:
committed by
Olaoluwa Osuntokun
parent
c412ab5ccb
commit
34d0e5d4c5
@@ -1120,15 +1120,16 @@ func TestSendPaymentErrorPathPruning(t *testing.T) {
|
||||
p, err := ctx.router.cfg.Control.FetchPayment(payHash)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 2, len(p.HTLCs), "expected two attempts")
|
||||
htlcs := p.GetHTLCs()
|
||||
require.Equal(t, 2, len(htlcs), "expected two attempts")
|
||||
|
||||
// We expect the first attempt to have failed with a
|
||||
// TemporaryChannelFailure, the second with UnknownNextPeer.
|
||||
msg := p.HTLCs[0].Failure.Message
|
||||
msg := htlcs[0].Failure.Message
|
||||
_, ok := msg.(*lnwire.FailTemporaryChannelFailure)
|
||||
require.True(t, ok, "unexpected fail message")
|
||||
|
||||
msg = p.HTLCs[1].Failure.Message
|
||||
msg = htlcs[1].Failure.Message
|
||||
_, ok = msg.(*lnwire.FailUnknownNextPeer)
|
||||
require.True(t, ok, "unexpected fail message")
|
||||
|
||||
@@ -3470,29 +3471,64 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||
}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
// Mock the InFlightHTLCs.
|
||||
var (
|
||||
htlcs []channeldb.HTLCAttempt
|
||||
numAttempts atomic.Uint32
|
||||
)
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment",
|
||||
identifier,
|
||||
).Return(payment, nil).Run(func(args mock.Arguments) {
|
||||
// When number of attempts made is less than 4, we will mock
|
||||
// the payment's methods to allow the lifecycle to continue.
|
||||
if numAttempts.Load() < 4 {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, terminate the lifecycle.
|
||||
payment.On("Terminated").Return(true).
|
||||
On("NeedWaitAttempts").Return(true, nil)
|
||||
})
|
||||
|
||||
// Mock SettleAttempt.
|
||||
preimage := lntypes.Preimage{1, 2, 3}
|
||||
settledAttempt := makeSettledAttempt(
|
||||
int(paymentAmt/4), 0, preimage,
|
||||
)
|
||||
|
||||
controlTower.On("SettleAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
|
||||
payment.On("GetHTLCs").Return(
|
||||
[]channeldb.HTLCAttempt{settledAttempt},
|
||||
)
|
||||
})
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
|
||||
require.NoError(t, err, "failed to create route")
|
||||
|
||||
session.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(shard, nil)
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
// Increase the counter whenever an attempt is made.
|
||||
numAttempts.Add(1)
|
||||
})
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
@@ -3509,30 +3545,12 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
missionControl.On("ReportPaymentSuccess",
|
||||
mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
// Mock SettleAttempt by changing one of the HTLCs to be settled.
|
||||
preimage := lntypes.Preimage{1, 2, 3}
|
||||
settledAttempt := makeSettledAttempt(
|
||||
int(paymentAmt/4), 0, preimage,
|
||||
)
|
||||
controlTower.On("SettleAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt settled and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle == nil {
|
||||
attempt.Settle = &channeldb.HTLCSettleInfo{
|
||||
Preimage: preimage,
|
||||
}
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
}
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
})
|
||||
|
||||
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
@@ -3565,6 +3583,7 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
||||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if
|
||||
@@ -3639,13 +3658,34 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||
}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
// Mock the InFlightHTLCs.
|
||||
var (
|
||||
htlcs []channeldb.HTLCAttempt
|
||||
numAttempts atomic.Uint32
|
||||
failAttemptCount atomic.Uint32
|
||||
)
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment",
|
||||
identifier,
|
||||
).Return(payment, nil).Run(func(args mock.Arguments) {
|
||||
// When number of attempts made is less than 6, we will mock
|
||||
// the payment's methods to allow the lifecycle to continue.
|
||||
if numAttempts.Load() < 6 {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, terminate the lifecycle.
|
||||
payment.On("Terminated").Return(true).
|
||||
On("NeedWaitAttempts").Return(true, nil)
|
||||
})
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
@@ -3657,11 +3697,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
// Increase the counter whenever an attempt is made.
|
||||
numAttempts.Add(1)
|
||||
})
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
@@ -3670,7 +3710,6 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
// We use the failAttemptCount to track how many attempts we want to
|
||||
// fail. Each time the following mock method is called, the count gets
|
||||
// updated.
|
||||
failAttemptCount := 0
|
||||
payer.On("GetAttemptResult",
|
||||
mock.Anything, identifier, mock.Anything,
|
||||
).Run(func(args mock.Arguments) {
|
||||
@@ -3678,11 +3717,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
// the read-only chan.
|
||||
|
||||
// Update the counter.
|
||||
failAttemptCount++
|
||||
failAttemptCount.Add(1)
|
||||
|
||||
// We will make the first two attempts failed with temporary
|
||||
// error.
|
||||
if failAttemptCount <= 2 {
|
||||
if failAttemptCount.Load() <= 2 {
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: htlcswitch.NewForwardingError(
|
||||
&lnwire.FailTemporaryChannelFailure{},
|
||||
@@ -3700,20 +3739,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt as failed and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle != nil || attempt.Failure != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attempt.Failure = &channeldb.HTLCFailInfo{}
|
||||
failedAttempt = attempt
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
})
|
||||
).Return(&failedAttempt, nil)
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
@@ -3737,20 +3763,13 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
controlTower.On("SettleAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt settled and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle != nil || attempt.Failure != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attempt.Settle = &channeldb.HTLCSettleInfo{
|
||||
Preimage: preimage,
|
||||
}
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
// Whenever this method is invoked, we will mock the payment's
|
||||
// GetHTLCs() to return the settled htlc.
|
||||
payment.On("GetHTLCs").Return(
|
||||
[]channeldb.HTLCAttempt{settledAttempt},
|
||||
)
|
||||
})
|
||||
|
||||
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
@@ -3779,6 +3798,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentFailed tests that when one of the shard fails with a
|
||||
@@ -3853,12 +3873,18 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||
}
|
||||
// Mock the InFlightHTLCs.
|
||||
var htlcs []channeldb.HTLCAttempt
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
|
||||
On("GetStatus").Return(channeldb.StatusInFlight).
|
||||
On("Terminated").Return(false).
|
||||
On("NeedWaitAttempts").Return(false, nil)
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
@@ -3871,12 +3897,9 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
})
|
||||
).Return(nil)
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
@@ -3918,43 +3941,24 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
||||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt as failed and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle != nil || attempt.Failure != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attempt.Failure = &channeldb.HTLCFailInfo{}
|
||||
failedAttempt = attempt
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
})
|
||||
).Return(&failedAttempt, nil)
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
var called bool
|
||||
failureReason := channeldb.FailureReasonPaymentDetails
|
||||
missionControl.On("ReportPaymentFail",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(&failureReason, nil).Run(func(args mock.Arguments) {
|
||||
// We only return the terminal error once, thus when the method
|
||||
// is called, we will return it with a nil error.
|
||||
if called {
|
||||
args[0] = nil
|
||||
return
|
||||
}
|
||||
|
||||
// If it's the first time calling this method, we will return a
|
||||
// terminal error.
|
||||
payment.FailureReason = &failureReason
|
||||
called = true
|
||||
})
|
||||
).Return(&failureReason, nil)
|
||||
|
||||
// Simple mocking the rest.
|
||||
controlTower.On("FailPayment", identifier, failureReason).Return(nil)
|
||||
controlTower.On("FailPayment",
|
||||
identifier, failureReason,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mock the payment's
|
||||
// Terminated() to be True.
|
||||
payment.On("Terminated").Return(true)
|
||||
})
|
||||
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
@@ -3985,6 +3989,7 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
||||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in
|
||||
@@ -4059,12 +4064,18 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||
}
|
||||
// Mock the InFlightHTLCs.
|
||||
var htlcs []channeldb.HTLCAttempt
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetStatus").Return(channeldb.StatusInFlight).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
|
||||
On("Terminated").Return(false).
|
||||
On("NeedWaitAttempts").Return(false, nil)
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
@@ -4077,12 +4088,9 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
})
|
||||
).Return(nil)
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
@@ -4130,28 +4138,25 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
||||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt as failed and exit.
|
||||
failedAttempt = payment.HTLCs[0]
|
||||
failedAttempt.Failure = &channeldb.HTLCFailInfo{}
|
||||
payment.HTLCs[0] = failedAttempt
|
||||
})
|
||||
).Return(&failedAttempt, nil)
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
failureReason := channeldb.FailureReasonPaymentDetails
|
||||
cntReportPaymentFail := 0
|
||||
missionControl.On("ReportPaymentFail",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(&failureReason, nil).Run(func(args mock.Arguments) {
|
||||
payment.FailureReason = &failureReason
|
||||
cntReportPaymentFail++
|
||||
})
|
||||
).Return(&failureReason, nil)
|
||||
|
||||
// Simple mocking the rest.
|
||||
cntFail := 0
|
||||
controlTower.On("FailPayment", identifier, failureReason).Return(nil)
|
||||
controlTower.On("FailPayment",
|
||||
identifier, failureReason,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mock the payment's
|
||||
// Terminated() to be True.
|
||||
payment.On("Terminated").Return(true)
|
||||
})
|
||||
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
@@ -4179,7 +4184,6 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
||||
require.Error(t, err, "expected send payment error")
|
||||
require.EqualValues(t, [32]byte{}, p, "preimage not match")
|
||||
require.GreaterOrEqual(t, getPaymentResultCnt, 1)
|
||||
require.Equal(t, getPaymentResultCnt, cntReportPaymentFail)
|
||||
require.Equal(t, getPaymentResultCnt, cntFail)
|
||||
|
||||
controlTower.AssertExpectations(t)
|
||||
@@ -4187,6 +4191,7 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
||||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestBlockDifferenceFix tests if when the router is behind on blocks, the
|
||||
|
||||
Reference in New Issue
Block a user