routing+channeldb: make MPPayment into an interface

This commit turns `MPPayment` into an interface inside `routing`. Having
this interface gives us the benefit to write more granular unit tests
inside payment lifecycle. As seen from the modified unit tests, several
hacky ways of testing the `SendPayment` method is now replaced by a mock
over `MPPayment`.
This commit is contained in:
yyforyongyu
2023-02-09 12:51:43 +08:00
committed by Olaoluwa Osuntokun
parent c412ab5ccb
commit 34d0e5d4c5
8 changed files with 257 additions and 167 deletions

View File

@@ -1120,15 +1120,16 @@ func TestSendPaymentErrorPathPruning(t *testing.T) {
p, err := ctx.router.cfg.Control.FetchPayment(payHash)
require.NoError(t, err)
require.Equal(t, 2, len(p.HTLCs), "expected two attempts")
htlcs := p.GetHTLCs()
require.Equal(t, 2, len(htlcs), "expected two attempts")
// We expect the first attempt to have failed with a
// TemporaryChannelFailure, the second with UnknownNextPeer.
msg := p.HTLCs[0].Failure.Message
msg := htlcs[0].Failure.Message
_, ok := msg.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok, "unexpected fail message")
msg = p.HTLCs[1].Failure.Message
msg = htlcs[1].Failure.Message
_, ok = msg.(*lnwire.FailUnknownNextPeer)
require.True(t, ok, "unexpected fail message")
@@ -3470,29 +3471,64 @@ func TestSendMPPaymentSucceed(t *testing.T) {
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Mock the InFlightHTLCs.
var (
htlcs []channeldb.HTLCAttempt
numAttempts atomic.Uint32
)
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment",
identifier,
).Return(payment, nil).Run(func(args mock.Arguments) {
// When number of attempts made is less than 4, we will mock
// the payment's methods to allow the lifecycle to continue.
if numAttempts.Load() < 4 {
payment.On("Terminated").Return(false).Times(2).
On("NeedWaitAttempts").Return(false, nil).Once()
return
}
// Otherwise, terminate the lifecycle.
payment.On("Terminated").Return(true).
On("NeedWaitAttempts").Return(true, nil)
})
// Mock SettleAttempt.
preimage := lntypes.Preimage{1, 2, 3}
settledAttempt := makeSettledAttempt(
int(paymentAmt/4), 0, preimage,
)
controlTower.On("SettleAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
payment.On("GetHTLCs").Return(
[]channeldb.HTLCAttempt{settledAttempt},
)
})
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
require.NoError(t, err, "failed to create route")
session.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(shard, nil)
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
// Increase the counter whenever an attempt is made.
numAttempts.Add(1)
})
// Create a buffered chan and it will be returned by GetAttemptResult.
@@ -3509,30 +3545,12 @@ func TestSendMPPaymentSucceed(t *testing.T) {
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
missionControl.On("ReportPaymentSuccess",
mock.Anything, mock.Anything,
).Return(nil)
// Mock SettleAttempt by changing one of the HTLCs to be settled.
preimage := lntypes.Preimage{1, 2, 3}
settledAttempt := makeSettledAttempt(
int(paymentAmt/4), 0, preimage,
)
controlTower.On("SettleAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt settled and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle == nil {
attempt.Settle = &channeldb.HTLCSettleInfo{
Preimage: preimage,
}
payment.HTLCs[i] = attempt
return
}
}
).Return(nil).Run(func(args mock.Arguments) {
})
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
// Call the actual method SendPayment on router. This is place inside a
@@ -3565,6 +3583,7 @@ func TestSendMPPaymentSucceed(t *testing.T) {
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}
// TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if
@@ -3639,13 +3658,34 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Mock the InFlightHTLCs.
var (
htlcs []channeldb.HTLCAttempt
numAttempts atomic.Uint32
failAttemptCount atomic.Uint32
)
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment",
identifier,
).Return(payment, nil).Run(func(args mock.Arguments) {
// When number of attempts made is less than 6, we will mock
// the payment's methods to allow the lifecycle to continue.
if numAttempts.Load() < 6 {
payment.On("Terminated").Return(false).Times(2).
On("NeedWaitAttempts").Return(false, nil).Once()
return
}
// Otherwise, terminate the lifecycle.
payment.On("Terminated").Return(true).
On("NeedWaitAttempts").Return(true, nil)
})
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
@@ -3657,11 +3697,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
// Increase the counter whenever an attempt is made.
numAttempts.Add(1)
})
// Create a buffered chan and it will be returned by GetAttemptResult.
@@ -3670,7 +3710,6 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
// We use the failAttemptCount to track how many attempts we want to
// fail. Each time the following mock method is called, the count gets
// updated.
failAttemptCount := 0
payer.On("GetAttemptResult",
mock.Anything, identifier, mock.Anything,
).Run(func(args mock.Arguments) {
@@ -3678,11 +3717,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
// the read-only chan.
// Update the counter.
failAttemptCount++
failAttemptCount.Add(1)
// We will make the first two attempts failed with temporary
// error.
if failAttemptCount <= 2 {
if failAttemptCount.Load() <= 2 {
payer.resultChan <- &htlcswitch.PaymentResult{
Error: htlcswitch.NewForwardingError(
&lnwire.FailTemporaryChannelFailure{},
@@ -3700,20 +3739,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt as failed and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle != nil || attempt.Failure != nil {
continue
}
attempt.Failure = &channeldb.HTLCFailInfo{}
failedAttempt = attempt
payment.HTLCs[i] = attempt
return
}
})
).Return(&failedAttempt, nil)
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
@@ -3737,20 +3763,13 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
controlTower.On("SettleAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt settled and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle != nil || attempt.Failure != nil {
continue
}
attempt.Settle = &channeldb.HTLCSettleInfo{
Preimage: preimage,
}
payment.HTLCs[i] = attempt
return
}
// Whenever this method is invoked, we will mock the payment's
// GetHTLCs() to return the settled htlc.
payment.On("GetHTLCs").Return(
[]channeldb.HTLCAttempt{settledAttempt},
)
})
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
// Call the actual method SendPayment on router. This is place inside a
@@ -3779,6 +3798,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}
// TestSendMPPaymentFailed tests that when one of the shard fails with a
@@ -3853,12 +3873,18 @@ func TestSendMPPaymentFailed(t *testing.T) {
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
// Mock the InFlightHTLCs.
var htlcs []channeldb.HTLCAttempt
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
On("GetStatus").Return(channeldb.StatusInFlight).
On("Terminated").Return(false).
On("NeedWaitAttempts").Return(false, nil)
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
@@ -3871,12 +3897,9 @@ func TestSendMPPaymentFailed(t *testing.T) {
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
})
).Return(nil)
// Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
@@ -3918,43 +3941,24 @@ func TestSendMPPaymentFailed(t *testing.T) {
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt as failed and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle != nil || attempt.Failure != nil {
continue
}
attempt.Failure = &channeldb.HTLCFailInfo{}
failedAttempt = attempt
payment.HTLCs[i] = attempt
return
}
})
).Return(&failedAttempt, nil)
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
var called bool
failureReason := channeldb.FailureReasonPaymentDetails
missionControl.On("ReportPaymentFail",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(&failureReason, nil).Run(func(args mock.Arguments) {
// We only return the terminal error once, thus when the method
// is called, we will return it with a nil error.
if called {
args[0] = nil
return
}
// If it's the first time calling this method, we will return a
// terminal error.
payment.FailureReason = &failureReason
called = true
})
).Return(&failureReason, nil)
// Simple mocking the rest.
controlTower.On("FailPayment", identifier, failureReason).Return(nil)
controlTower.On("FailPayment",
identifier, failureReason,
).Return(nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mock the payment's
// Terminated() to be True.
payment.On("Terminated").Return(true)
})
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
@@ -3985,6 +3989,7 @@ func TestSendMPPaymentFailed(t *testing.T) {
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}
// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in
@@ -4059,12 +4064,18 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
// Mock the InFlightHTLCs.
var htlcs []channeldb.HTLCAttempt
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetStatus").Return(channeldb.StatusInFlight).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
On("Terminated").Return(false).
On("NeedWaitAttempts").Return(false, nil)
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
@@ -4077,12 +4088,9 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
})
).Return(nil)
// Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
@@ -4130,28 +4138,25 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt as failed and exit.
failedAttempt = payment.HTLCs[0]
failedAttempt.Failure = &channeldb.HTLCFailInfo{}
payment.HTLCs[0] = failedAttempt
})
).Return(&failedAttempt, nil)
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
failureReason := channeldb.FailureReasonPaymentDetails
cntReportPaymentFail := 0
missionControl.On("ReportPaymentFail",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(&failureReason, nil).Run(func(args mock.Arguments) {
payment.FailureReason = &failureReason
cntReportPaymentFail++
})
).Return(&failureReason, nil)
// Simple mocking the rest.
cntFail := 0
controlTower.On("FailPayment", identifier, failureReason).Return(nil)
controlTower.On("FailPayment",
identifier, failureReason,
).Return(nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mock the payment's
// Terminated() to be True.
payment.On("Terminated").Return(true)
})
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
@@ -4179,7 +4184,6 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
require.Error(t, err, "expected send payment error")
require.EqualValues(t, [32]byte{}, p, "preimage not match")
require.GreaterOrEqual(t, getPaymentResultCnt, 1)
require.Equal(t, getPaymentResultCnt, cntReportPaymentFail)
require.Equal(t, getPaymentResultCnt, cntFail)
controlTower.AssertExpectations(t)
@@ -4187,6 +4191,7 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}
// TestBlockDifferenceFix tests if when the router is behind on blocks, the