mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-10-10 23:52:35 +02:00
routing: refactor update payment state tests
This commit refactors the resumePayment to extract some logics back to paymentState so that the code is more testable. It also adds unit tests for paymentState, and breaks the original MPPayment tests into independent tests so that it's easier to maintain and debug. All the new tests are built using mock so that the control flow is eaiser to setup and change.
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
@@ -1069,7 +1070,8 @@ func TestSendPaymentErrorPathPruning(t *testing.T) {
|
||||
_, ok = msg.(*lnwire.FailUnknownNextPeer)
|
||||
require.True(t, ok, "unexpected fail message")
|
||||
|
||||
ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory()
|
||||
err = ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory()
|
||||
require.NoError(t, err, "reset history failed")
|
||||
|
||||
// Next, we'll modify the SendToSwitch method to indicate that the
|
||||
// connection between songoku and isn't up.
|
||||
@@ -3436,3 +3438,796 @@ func TestChannelOnChainRejectionZombie(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
assertChanChainRejection(t, ctx, edge, ErrInvalidFundingOutput)
|
||||
}
|
||||
|
||||
func createDummyTestGraph(t *testing.T) *testGraphInstance {
|
||||
// Setup two simple channels such that we can mock sending along this
|
||||
// route.
|
||||
chanCapSat := btcutil.Amount(100000)
|
||||
testChannels := []*testChannel{
|
||||
symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{
|
||||
Expiry: 144,
|
||||
FeeRate: 400,
|
||||
MinHTLC: 1,
|
||||
MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat),
|
||||
}, 1),
|
||||
symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{
|
||||
Expiry: 144,
|
||||
FeeRate: 400,
|
||||
MinHTLC: 1,
|
||||
MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat),
|
||||
}, 2),
|
||||
}
|
||||
|
||||
testGraph, err := createTestGraphFromChannels(testChannels, "a")
|
||||
require.NoError(t, err, "failed to create graph")
|
||||
return testGraph
|
||||
}
|
||||
|
||||
func createDummyLightningPayment(t *testing.T,
|
||||
target route.Vertex, amt lnwire.MilliSatoshi) *LightningPayment {
|
||||
|
||||
var preImage lntypes.Preimage
|
||||
_, err := rand.Read(preImage[:])
|
||||
require.NoError(t, err, "unable to generate preimage")
|
||||
|
||||
payHash := preImage.Hash()
|
||||
|
||||
return &LightningPayment{
|
||||
Target: target,
|
||||
Amount: amt,
|
||||
FeeLimit: noFeeLimit,
|
||||
paymentHash: &payHash,
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendMPPaymentSucceed tests that we can successfully send a MPPayment via
|
||||
// router.SendPayment. This test mainly focuses on testing the logic of the
|
||||
// method resumePayment is implemented as expected.
|
||||
func TestSendMPPaymentSucceed(t *testing.T) {
|
||||
const startingBlockHeight = 101
|
||||
|
||||
// Create mockers to initialize the router.
|
||||
controlTower := &mockControlTower{}
|
||||
sessionSource := &mockPaymentSessionSource{}
|
||||
missionControl := &mockMissionControl{}
|
||||
payer := &mockPaymentAttemptDispatcher{}
|
||||
chain := newMockChain(startingBlockHeight)
|
||||
chainView := newMockChainView(chain)
|
||||
testGraph := createDummyTestGraph(t)
|
||||
|
||||
// Define the behavior of the mockers to the point where we can
|
||||
// successfully start the router.
|
||||
controlTower.On("FetchInFlightPayments").Return(
|
||||
[]*channeldb.MPPayment{}, nil,
|
||||
)
|
||||
payer.On("CleanStore", mock.Anything).Return(nil)
|
||||
|
||||
// Create and start the router.
|
||||
router, err := New(Config{
|
||||
Control: controlTower,
|
||||
SessionSource: sessionSource,
|
||||
MissionControl: missionControl,
|
||||
Payer: payer,
|
||||
|
||||
// TODO(yy): create new mocks for the chain and chainview.
|
||||
Chain: chain,
|
||||
ChainView: chainView,
|
||||
|
||||
// TODO(yy): mock the graph once it's changed into interface.
|
||||
Graph: testGraph.graph,
|
||||
|
||||
Clock: clock.NewTestClock(time.Unix(1, 0)),
|
||||
GraphPruneInterval: time.Hour * 2,
|
||||
NextPaymentID: func() (uint64, error) {
|
||||
next := atomic.AddUint64(&uniquePaymentID, 1)
|
||||
return next, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
// Make sure the router can start and stop without error.
|
||||
require.NoError(t, router.Start(), "router failed to start")
|
||||
defer func() {
|
||||
require.NoError(t, router.Stop(), "router failed to stop")
|
||||
}()
|
||||
|
||||
// Once the router is started, check that the mocked methods are called
|
||||
// as expected.
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
|
||||
// Mock the methods to the point where we are inside the function
|
||||
// resumePayment.
|
||||
paymentAmt := lnwire.MilliSatoshi(10000)
|
||||
req := createDummyLightningPayment(
|
||||
t, testGraph.aliasMap["c"], paymentAmt,
|
||||
)
|
||||
identifier := lntypes.Hash(req.Identifier())
|
||||
session := &mockPaymentSession{}
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
|
||||
require.NoError(t, err, "failed to create route")
|
||||
session.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(shard, nil)
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
})
|
||||
|
||||
// Create a buffered chan and it will be returned by GetPaymentResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
payer.On("GetPaymentResult",
|
||||
mock.Anything, identifier, mock.Anything,
|
||||
).Run(func(args mock.Arguments) {
|
||||
// Before the mock method is returned, we send the result to
|
||||
// the read-only chan.
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{}
|
||||
})
|
||||
|
||||
// Simple mocking the rest.
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
missionControl.On("ReportPaymentSuccess",
|
||||
mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
// Mock SettleAttempt by changing one of the HTLCs to be settled.
|
||||
preimage := lntypes.Preimage{1, 2, 3}
|
||||
settledAttempt := makeSettledAttempt(
|
||||
int(paymentAmt/4), 0, preimage,
|
||||
)
|
||||
controlTower.On("SettleAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt settled and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle == nil {
|
||||
attempt.Settle = &channeldb.HTLCSettleInfo{
|
||||
Preimage: preimage,
|
||||
}
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
// goroutine so we can set a timeout for the whole test, in case
|
||||
// anything goes wrong and the test never finishes.
|
||||
done := make(chan struct{})
|
||||
var p lntypes.Hash
|
||||
go func() {
|
||||
p, _, err = router.SendPayment(req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatalf("SendPayment didn't exit")
|
||||
}
|
||||
|
||||
// Finally, validate the returned values and check that the mock
|
||||
// methods are called as expected.
|
||||
require.NoError(t, err, "send payment failed")
|
||||
require.EqualValues(t, preimage, p, "preimage not match")
|
||||
|
||||
// Note that we also implicitly check the methods such as FailAttempt,
|
||||
// ReportPaymentFail, etc, are not called because we never mocked them
|
||||
// in this test. If any of the unexpected methods was called, the test
|
||||
// would fail.
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if
|
||||
// there are failed ones,so that a payment is successfully sent. This test
|
||||
// mainly focuses on testing the logic of the method resumePayment is
|
||||
// implemented as expected.
|
||||
func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
const startingBlockHeight = 101
|
||||
|
||||
// Create mockers to initialize the router.
|
||||
controlTower := &mockControlTower{}
|
||||
sessionSource := &mockPaymentSessionSource{}
|
||||
missionControl := &mockMissionControl{}
|
||||
payer := &mockPaymentAttemptDispatcher{}
|
||||
chain := newMockChain(startingBlockHeight)
|
||||
chainView := newMockChainView(chain)
|
||||
testGraph := createDummyTestGraph(t)
|
||||
|
||||
// Define the behavior of the mockers to the point where we can
|
||||
// successfully start the router.
|
||||
controlTower.On("FetchInFlightPayments").Return(
|
||||
[]*channeldb.MPPayment{}, nil,
|
||||
)
|
||||
payer.On("CleanStore", mock.Anything).Return(nil)
|
||||
|
||||
// Create and start the router.
|
||||
router, err := New(Config{
|
||||
Control: controlTower,
|
||||
SessionSource: sessionSource,
|
||||
MissionControl: missionControl,
|
||||
Payer: payer,
|
||||
|
||||
// TODO(yy): create new mocks for the chain and chainview.
|
||||
Chain: chain,
|
||||
ChainView: chainView,
|
||||
|
||||
// TODO(yy): mock the graph once it's changed into interface.
|
||||
Graph: testGraph.graph,
|
||||
|
||||
Clock: clock.NewTestClock(time.Unix(1, 0)),
|
||||
GraphPruneInterval: time.Hour * 2,
|
||||
NextPaymentID: func() (uint64, error) {
|
||||
next := atomic.AddUint64(&uniquePaymentID, 1)
|
||||
return next, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
// Make sure the router can start and stop without error.
|
||||
require.NoError(t, router.Start(), "router failed to start")
|
||||
defer func() {
|
||||
require.NoError(t, router.Stop(), "router failed to stop")
|
||||
}()
|
||||
|
||||
// Once the router is started, check that the mocked methods are called
|
||||
// as expected.
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
|
||||
// Mock the methods to the point where we are inside the function
|
||||
// resumePayment.
|
||||
paymentAmt := lnwire.MilliSatoshi(20000)
|
||||
req := createDummyLightningPayment(
|
||||
t, testGraph.aliasMap["c"], paymentAmt,
|
||||
)
|
||||
identifier := lntypes.Hash(req.Identifier())
|
||||
session := &mockPaymentSession{}
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
|
||||
require.NoError(t, err, "failed to create route")
|
||||
session.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(shard, nil)
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
})
|
||||
|
||||
// Create a buffered chan and it will be returned by GetPaymentResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
|
||||
// We use the failAttemptCount to track how many attempts we want to
|
||||
// fail. Each time the following mock method is called, the count gets
|
||||
// updated.
|
||||
failAttemptCount := 0
|
||||
payer.On("GetPaymentResult",
|
||||
mock.Anything, identifier, mock.Anything,
|
||||
).Run(func(args mock.Arguments) {
|
||||
// Before the mock method is returned, we send the result to
|
||||
// the read-only chan.
|
||||
|
||||
// Update the counter.
|
||||
failAttemptCount++
|
||||
|
||||
// We will make the first two attempts failed with temporary
|
||||
// error.
|
||||
if failAttemptCount <= 2 {
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: htlcswitch.NewForwardingError(
|
||||
&lnwire.FailTemporaryChannelFailure{},
|
||||
1,
|
||||
),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise we will mark the attempt succeeded.
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{}
|
||||
})
|
||||
|
||||
// Mock the FailAttempt method to fail one of the attempts.
|
||||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt as failed and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle != nil || attempt.Failure != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attempt.Failure = &channeldb.HTLCFailInfo{}
|
||||
failedAttempt = attempt
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
missionControl.On("ReportPaymentFail",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil, nil)
|
||||
|
||||
// Simple mocking the rest.
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
missionControl.On("ReportPaymentSuccess",
|
||||
mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
// Mock SettleAttempt by changing one of the HTLCs to be settled.
|
||||
preimage := lntypes.Preimage{1, 2, 3}
|
||||
settledAttempt := makeSettledAttempt(
|
||||
int(paymentAmt/4), 0, preimage,
|
||||
)
|
||||
controlTower.On("SettleAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt settled and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle != nil || attempt.Failure != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attempt.Settle = &channeldb.HTLCSettleInfo{
|
||||
Preimage: preimage,
|
||||
}
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
// goroutine so we can set a timeout for the whole test, in case
|
||||
// anything goes wrong and the test never finishes.
|
||||
done := make(chan struct{})
|
||||
var p lntypes.Hash
|
||||
go func() {
|
||||
p, _, err = router.SendPayment(req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatalf("SendPayment didn't exit")
|
||||
}
|
||||
|
||||
// Finally, validate the returned values and check that the mock
|
||||
// methods are called as expected.
|
||||
require.NoError(t, err, "send payment failed")
|
||||
require.EqualValues(t, preimage, p, "preimage not match")
|
||||
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentFailed tests that when one of the shard fails with a
|
||||
// terminal error, the router will stop attempting and the payment will fail.
|
||||
// This test mainly focuses on testing the logic of the method resumePayment
|
||||
// is implemented as expected.
|
||||
func TestSendMPPaymentFailed(t *testing.T) {
|
||||
const startingBlockHeight = 101
|
||||
|
||||
// Create mockers to initialize the router.
|
||||
controlTower := &mockControlTower{}
|
||||
sessionSource := &mockPaymentSessionSource{}
|
||||
missionControl := &mockMissionControl{}
|
||||
payer := &mockPaymentAttemptDispatcher{}
|
||||
chain := newMockChain(startingBlockHeight)
|
||||
chainView := newMockChainView(chain)
|
||||
testGraph := createDummyTestGraph(t)
|
||||
|
||||
// Define the behavior of the mockers to the point where we can
|
||||
// successfully start the router.
|
||||
controlTower.On("FetchInFlightPayments").Return(
|
||||
[]*channeldb.MPPayment{}, nil,
|
||||
)
|
||||
payer.On("CleanStore", mock.Anything).Return(nil)
|
||||
|
||||
// Create and start the router.
|
||||
router, err := New(Config{
|
||||
Control: controlTower,
|
||||
SessionSource: sessionSource,
|
||||
MissionControl: missionControl,
|
||||
Payer: payer,
|
||||
|
||||
// TODO(yy): create new mocks for the chain and chainview.
|
||||
Chain: chain,
|
||||
ChainView: chainView,
|
||||
|
||||
// TODO(yy): mock the graph once it's changed into interface.
|
||||
Graph: testGraph.graph,
|
||||
|
||||
Clock: clock.NewTestClock(time.Unix(1, 0)),
|
||||
GraphPruneInterval: time.Hour * 2,
|
||||
NextPaymentID: func() (uint64, error) {
|
||||
next := atomic.AddUint64(&uniquePaymentID, 1)
|
||||
return next, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
// Make sure the router can start and stop without error.
|
||||
require.NoError(t, router.Start(), "router failed to start")
|
||||
defer func() {
|
||||
require.NoError(t, router.Stop(), "router failed to stop")
|
||||
}()
|
||||
|
||||
// Once the router is started, check that the mocked methods are called
|
||||
// as expected.
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
|
||||
// Mock the methods to the point where we are inside the function
|
||||
// resumePayment.
|
||||
paymentAmt := lnwire.MilliSatoshi(10000)
|
||||
req := createDummyLightningPayment(
|
||||
t, testGraph.aliasMap["c"], paymentAmt,
|
||||
)
|
||||
identifier := lntypes.Hash(req.Identifier())
|
||||
session := &mockPaymentSession{}
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
|
||||
require.NoError(t, err, "failed to create route")
|
||||
session.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(shard, nil)
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
})
|
||||
|
||||
// Create a buffered chan and it will be returned by GetPaymentResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
|
||||
// We use the failAttemptCount to track how many attempts we want to
|
||||
// fail. Each time the following mock method is called, the count gets
|
||||
// updated.
|
||||
failAttemptCount := 0
|
||||
payer.On("GetPaymentResult",
|
||||
mock.Anything, identifier, mock.Anything,
|
||||
).Run(func(args mock.Arguments) {
|
||||
// Before the mock method is returned, we send the result to
|
||||
// the read-only chan.
|
||||
|
||||
// Update the counter.
|
||||
failAttemptCount++
|
||||
|
||||
// We fail the first attempt with terminal error.
|
||||
if failAttemptCount == 1 {
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: htlcswitch.NewForwardingError(
|
||||
&lnwire.FailIncorrectDetails{},
|
||||
1,
|
||||
),
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// We will make the rest attempts failed with temporary error.
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: htlcswitch.NewForwardingError(
|
||||
&lnwire.FailTemporaryChannelFailure{},
|
||||
1,
|
||||
),
|
||||
}
|
||||
})
|
||||
|
||||
// Mock the FailAttempt method to fail one of the attempts.
|
||||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt as failed and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle != nil || attempt.Failure != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attempt.Failure = &channeldb.HTLCFailInfo{}
|
||||
failedAttempt = attempt
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
var called bool
|
||||
failureReason := channeldb.FailureReasonPaymentDetails
|
||||
missionControl.On("ReportPaymentFail",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil, nil).Run(func(args mock.Arguments) {
|
||||
// We only return the terminal error once, thus when the method
|
||||
// is called, we will return it with a nil error.
|
||||
if called {
|
||||
missionControl.failReason = nil
|
||||
return
|
||||
}
|
||||
|
||||
// If it's the first time calling this method, we will return a
|
||||
// terminal error.
|
||||
missionControl.failReason = &failureReason
|
||||
payment.FailureReason = &failureReason
|
||||
called = true
|
||||
})
|
||||
|
||||
// Simple mocking the rest.
|
||||
controlTower.On("Fail", identifier, failureReason).Return(nil)
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
// goroutine so we can set a timeout for the whole test, in case
|
||||
// anything goes wrong and the test never finishes.
|
||||
done := make(chan struct{})
|
||||
var p lntypes.Hash
|
||||
go func() {
|
||||
p, _, err = router.SendPayment(req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatalf("SendPayment didn't exit")
|
||||
}
|
||||
|
||||
// Finally, validate the returned values and check that the mock
|
||||
// methods are called as expected.
|
||||
require.Error(t, err, "expected send payment error")
|
||||
require.EqualValues(t, [32]byte{}, p, "preimage not match")
|
||||
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in
|
||||
// terminal state, even if we have shards in flight, we still fail the payment
|
||||
// and exit. This test mainly focuses on testing the logic of the method
|
||||
// resumePayment is implemented as expected.
|
||||
func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
||||
const startingBlockHeight = 101
|
||||
|
||||
// Create mockers to initialize the router.
|
||||
controlTower := &mockControlTower{}
|
||||
sessionSource := &mockPaymentSessionSource{}
|
||||
missionControl := &mockMissionControl{}
|
||||
payer := &mockPaymentAttemptDispatcher{}
|
||||
chain := newMockChain(startingBlockHeight)
|
||||
chainView := newMockChainView(chain)
|
||||
testGraph := createDummyTestGraph(t)
|
||||
|
||||
// Define the behavior of the mockers to the point where we can
|
||||
// successfully start the router.
|
||||
controlTower.On("FetchInFlightPayments").Return(
|
||||
[]*channeldb.MPPayment{}, nil,
|
||||
)
|
||||
payer.On("CleanStore", mock.Anything).Return(nil)
|
||||
|
||||
// Create and start the router.
|
||||
router, err := New(Config{
|
||||
Control: controlTower,
|
||||
SessionSource: sessionSource,
|
||||
MissionControl: missionControl,
|
||||
Payer: payer,
|
||||
|
||||
// TODO(yy): create new mocks for the chain and chainview.
|
||||
Chain: chain,
|
||||
ChainView: chainView,
|
||||
|
||||
// TODO(yy): mock the graph once it's changed into interface.
|
||||
Graph: testGraph.graph,
|
||||
|
||||
Clock: clock.NewTestClock(time.Unix(1, 0)),
|
||||
GraphPruneInterval: time.Hour * 2,
|
||||
NextPaymentID: func() (uint64, error) {
|
||||
next := atomic.AddUint64(&uniquePaymentID, 1)
|
||||
return next, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
// Make sure the router can start and stop without error.
|
||||
require.NoError(t, router.Start(), "router failed to start")
|
||||
defer func() {
|
||||
require.NoError(t, router.Stop(), "router failed to stop")
|
||||
}()
|
||||
|
||||
// Once the router is started, check that the mocked methods are called
|
||||
// as expected.
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
|
||||
// Mock the methods to the point where we are inside the function
|
||||
// resumePayment.
|
||||
paymentAmt := lnwire.MilliSatoshi(10000)
|
||||
req := createDummyLightningPayment(
|
||||
t, testGraph.aliasMap["c"], paymentAmt,
|
||||
)
|
||||
identifier := lntypes.Hash(req.Identifier())
|
||||
session := &mockPaymentSession{}
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
|
||||
require.NoError(t, err, "failed to create route")
|
||||
session.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(shard, nil)
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
})
|
||||
|
||||
// Create a buffered chan and it will be returned by GetPaymentResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
|
||||
// We use the failAttemptCount to track how many attempts we want to
|
||||
// fail. Each time the following mock method is called, the count gets
|
||||
// updated.
|
||||
failAttemptCount := 0
|
||||
payer.On("GetPaymentResult",
|
||||
mock.Anything, identifier, mock.Anything,
|
||||
).Run(func(args mock.Arguments) {
|
||||
// Before the mock method is returned, we send the result to
|
||||
// the read-only chan.
|
||||
|
||||
// Update the counter.
|
||||
failAttemptCount++
|
||||
|
||||
// We fail the first attempt with terminal error.
|
||||
if failAttemptCount == 1 {
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: htlcswitch.NewForwardingError(
|
||||
&lnwire.FailIncorrectDetails{},
|
||||
1,
|
||||
),
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// For the rest attempts we will NOT send anything to the
|
||||
// resultChan, thus making all the shards in active state,
|
||||
// neither settled or failed.
|
||||
})
|
||||
|
||||
// Mock the FailAttempt method to fail EXACTLY once.
|
||||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt as failed and exit.
|
||||
failedAttempt = payment.HTLCs[0]
|
||||
failedAttempt.Failure = &channeldb.HTLCFailInfo{}
|
||||
payment.HTLCs[0] = failedAttempt
|
||||
}).Once()
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
failureReason := channeldb.FailureReasonPaymentDetails
|
||||
missionControl.On("ReportPaymentFail",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(failureReason, nil).Run(func(args mock.Arguments) {
|
||||
missionControl.failReason = &failureReason
|
||||
payment.FailureReason = &failureReason
|
||||
}).Once()
|
||||
|
||||
// Simple mocking the rest.
|
||||
controlTower.On("Fail", identifier, failureReason).Return(nil).Once()
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
// goroutine so we can set a timeout for the whole test, in case
|
||||
// anything goes wrong and the test never finishes.
|
||||
done := make(chan struct{})
|
||||
var p lntypes.Hash
|
||||
go func() {
|
||||
p, _, err = router.SendPayment(req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatalf("SendPayment didn't exit")
|
||||
}
|
||||
|
||||
// Finally, validate the returned values and check that the mock
|
||||
// methods are called as expected.
|
||||
require.Error(t, err, "expected send payment error")
|
||||
require.EqualValues(t, [32]byte{}, p, "preimage not match")
|
||||
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
}
|
||||
|
Reference in New Issue
Block a user