diff --git a/routing/mock_test.go b/routing/mock_test.go index 2458f538b..f712c420d 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -12,7 +12,9 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/routing/shards" "github.com/stretchr/testify/mock" ) @@ -572,8 +574,6 @@ func (m *mockControlTowerOld) SubscribeAllPayments() ( type mockPaymentAttemptDispatcher struct { mock.Mock - - resultChan chan *htlcswitch.PaymentResult } var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) @@ -589,11 +589,14 @@ func (m *mockPaymentAttemptDispatcher) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) ( <-chan *htlcswitch.PaymentResult, error) { - m.Called(attemptID, paymentHash, deobfuscator) + args := m.Called(attemptID, paymentHash, deobfuscator) - // Instead of returning the mocked returned values, we need to return - // the chan resultChan so it can be converted into a read-only chan. - return m.resultChan, nil + resultChan := args.Get(0) + if resultChan == nil { + return nil, args.Error(1) + } + + return args.Get(0).(chan *htlcswitch.PaymentResult), args.Error(1) } func (m *mockPaymentAttemptDispatcher) CleanStore( @@ -698,7 +701,6 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, type mockControlTower struct { mock.Mock - sync.Mutex } var _ ControlTower = (*mockControlTower)(nil) @@ -718,9 +720,6 @@ func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error { func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, a *channeldb.HTLCAttemptInfo) error { - m.Lock() - defer m.Unlock() - args := m.Called(phash, a) return args.Error(0) } @@ -729,29 +728,32 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( *channeldb.HTLCAttempt, error) { - m.Lock() - defer m.Unlock() - args := m.Called(phash, pid, settleInfo) - return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) + + attempt := args.Get(0) + if attempt == nil { + return nil, args.Error(1) + } + + return attempt.(*channeldb.HTLCAttempt), args.Error(1) } func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { - m.Lock() - defer m.Unlock() - args := m.Called(phash, pid, failInfo) + + attempt := args.Get(0) + if attempt == nil { + return nil, args.Error(1) + } + return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) } func (m *mockControlTower) FailPayment(phash lntypes.Hash, reason channeldb.FailureReason) error { - m.Lock() - defer m.Unlock() - args := m.Called(phash, reason) return args.Error(0) } @@ -877,3 +879,70 @@ func (m *mockLink) EligibleToForward() bool { func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error { return m.mayAddOutgoingErr } + +type mockShardTracker struct { + mock.Mock +} + +var _ shards.ShardTracker = (*mockShardTracker)(nil) + +func (m *mockShardTracker) NewShard(attemptID uint64, + lastShard bool) (shards.PaymentShard, error) { + + args := m.Called(attemptID, lastShard) + + shard := args.Get(0) + if shard == nil { + return nil, args.Error(1) + } + + return shard.(shards.PaymentShard), args.Error(1) +} + +func (m *mockShardTracker) GetHash(attemptID uint64) (lntypes.Hash, error) { + args := m.Called(attemptID) + return args.Get(0).(lntypes.Hash), args.Error(1) +} + +func (m *mockShardTracker) CancelShard(attemptID uint64) error { + args := m.Called(attemptID) + return args.Error(0) +} + +type mockShard struct { + mock.Mock +} + +var _ shards.PaymentShard = (*mockShard)(nil) + +// Hash returns the hash used for the HTLC representing this shard. +func (m *mockShard) Hash() lntypes.Hash { + args := m.Called() + return args.Get(0).(lntypes.Hash) +} + +// MPP returns any extra MPP records that should be set for the final +// hop on the route used by this shard. +func (m *mockShard) MPP() *record.MPP { + args := m.Called() + + r := args.Get(0) + if r == nil { + return nil + } + + return r.(*record.MPP) +} + +// AMP returns any extra AMP records that should be set for the final +// hop on the route used by this shard. +func (m *mockShard) AMP() *record.AMP { + args := m.Called() + + r := args.Get(0) + if r == nil { + return nil + } + + return r.(*record.AMP) +} diff --git a/routing/router_test.go b/routing/router_test.go index 79f0b026c..3f1f5dcb1 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3528,12 +3528,12 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { ).Return(nil) // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 1) + resultChan := make(chan *htlcswitch.PaymentResult, 1) payer.On("GetAttemptResult", mock.Anything, mock.Anything, mock.Anything, - ).Run(func(_ mock.Arguments) { + ).Return(resultChan, nil).Run(func(_ mock.Arguments) { // Send a successful payment result. - payer.resultChan <- &htlcswitch.PaymentResult{} + resultChan <- &htlcswitch.PaymentResult{} }) missionControl.On("ReportPaymentSuccess", @@ -3599,6 +3599,11 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { }, }} + // Create the error to be returned. + tempErr := htlcswitch.NewForwardingError( + &lnwire.FailTemporaryChannelFailure{}, 1, + ) + // Register mockers with the expected method calls. controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) @@ -3608,26 +3613,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, - ).Return(nil) - - // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 1) - - // Create the error to be returned. - tempErr := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, - 1, - ) - - // Mock GetAttemptResult to return a failure. - payer.On("GetAttemptResult", - mock.Anything, mock.Anything, mock.Anything, - ).Run(func(_ mock.Arguments) { - // Send an attempt failure. - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: tempErr, - } - }) + ).Return(tempErr) // Mock the control tower to return the mocked payment. payment := &mockMPPayment{}