mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-30 07:35:07 +02:00
routing: update mockers in unit test
This commit adds more mockers to be used in coming unit tests and simplified the mockers to be more straightforward.
This commit is contained in:
@@ -12,7 +12,9 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/htlcswitch"
|
"github.com/lightningnetwork/lnd/htlcswitch"
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/record"
|
||||||
"github.com/lightningnetwork/lnd/routing/route"
|
"github.com/lightningnetwork/lnd/routing/route"
|
||||||
|
"github.com/lightningnetwork/lnd/routing/shards"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -572,8 +574,6 @@ func (m *mockControlTowerOld) SubscribeAllPayments() (
|
|||||||
|
|
||||||
type mockPaymentAttemptDispatcher struct {
|
type mockPaymentAttemptDispatcher struct {
|
||||||
mock.Mock
|
mock.Mock
|
||||||
|
|
||||||
resultChan chan *htlcswitch.PaymentResult
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
|
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
|
||||||
@@ -589,11 +589,14 @@ func (m *mockPaymentAttemptDispatcher) GetAttemptResult(attemptID uint64,
|
|||||||
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
|
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
|
||||||
<-chan *htlcswitch.PaymentResult, error) {
|
<-chan *htlcswitch.PaymentResult, error) {
|
||||||
|
|
||||||
m.Called(attemptID, paymentHash, deobfuscator)
|
args := m.Called(attemptID, paymentHash, deobfuscator)
|
||||||
|
|
||||||
// Instead of returning the mocked returned values, we need to return
|
resultChan := args.Get(0)
|
||||||
// the chan resultChan so it can be converted into a read-only chan.
|
if resultChan == nil {
|
||||||
return m.resultChan, nil
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return args.Get(0).(chan *htlcswitch.PaymentResult), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockPaymentAttemptDispatcher) CleanStore(
|
func (m *mockPaymentAttemptDispatcher) CleanStore(
|
||||||
@@ -698,7 +701,6 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
|
|||||||
|
|
||||||
type mockControlTower struct {
|
type mockControlTower struct {
|
||||||
mock.Mock
|
mock.Mock
|
||||||
sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ControlTower = (*mockControlTower)(nil)
|
var _ ControlTower = (*mockControlTower)(nil)
|
||||||
@@ -718,9 +720,6 @@ func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error {
|
|||||||
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
|
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
|
||||||
a *channeldb.HTLCAttemptInfo) error {
|
a *channeldb.HTLCAttemptInfo) error {
|
||||||
|
|
||||||
m.Lock()
|
|
||||||
defer m.Unlock()
|
|
||||||
|
|
||||||
args := m.Called(phash, a)
|
args := m.Called(phash, a)
|
||||||
return args.Error(0)
|
return args.Error(0)
|
||||||
}
|
}
|
||||||
@@ -729,29 +728,32 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
|
|||||||
pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
|
pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
|
||||||
*channeldb.HTLCAttempt, error) {
|
*channeldb.HTLCAttempt, error) {
|
||||||
|
|
||||||
m.Lock()
|
|
||||||
defer m.Unlock()
|
|
||||||
|
|
||||||
args := m.Called(phash, pid, settleInfo)
|
args := m.Called(phash, pid, settleInfo)
|
||||||
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
|
|
||||||
|
attempt := args.Get(0)
|
||||||
|
if attempt == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return attempt.(*channeldb.HTLCAttempt), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
|
func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
|
||||||
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
|
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
|
||||||
|
|
||||||
m.Lock()
|
|
||||||
defer m.Unlock()
|
|
||||||
|
|
||||||
args := m.Called(phash, pid, failInfo)
|
args := m.Called(phash, pid, failInfo)
|
||||||
|
|
||||||
|
attempt := args.Get(0)
|
||||||
|
if attempt == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
|
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockControlTower) FailPayment(phash lntypes.Hash,
|
func (m *mockControlTower) FailPayment(phash lntypes.Hash,
|
||||||
reason channeldb.FailureReason) error {
|
reason channeldb.FailureReason) error {
|
||||||
|
|
||||||
m.Lock()
|
|
||||||
defer m.Unlock()
|
|
||||||
|
|
||||||
args := m.Called(phash, reason)
|
args := m.Called(phash, reason)
|
||||||
return args.Error(0)
|
return args.Error(0)
|
||||||
}
|
}
|
||||||
@@ -877,3 +879,70 @@ func (m *mockLink) EligibleToForward() bool {
|
|||||||
func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error {
|
func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error {
|
||||||
return m.mayAddOutgoingErr
|
return m.mayAddOutgoingErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockShardTracker struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ shards.ShardTracker = (*mockShardTracker)(nil)
|
||||||
|
|
||||||
|
func (m *mockShardTracker) NewShard(attemptID uint64,
|
||||||
|
lastShard bool) (shards.PaymentShard, error) {
|
||||||
|
|
||||||
|
args := m.Called(attemptID, lastShard)
|
||||||
|
|
||||||
|
shard := args.Get(0)
|
||||||
|
if shard == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return shard.(shards.PaymentShard), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockShardTracker) GetHash(attemptID uint64) (lntypes.Hash, error) {
|
||||||
|
args := m.Called(attemptID)
|
||||||
|
return args.Get(0).(lntypes.Hash), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockShardTracker) CancelShard(attemptID uint64) error {
|
||||||
|
args := m.Called(attemptID)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockShard struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ shards.PaymentShard = (*mockShard)(nil)
|
||||||
|
|
||||||
|
// Hash returns the hash used for the HTLC representing this shard.
|
||||||
|
func (m *mockShard) Hash() lntypes.Hash {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).(lntypes.Hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MPP returns any extra MPP records that should be set for the final
|
||||||
|
// hop on the route used by this shard.
|
||||||
|
func (m *mockShard) MPP() *record.MPP {
|
||||||
|
args := m.Called()
|
||||||
|
|
||||||
|
r := args.Get(0)
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.(*record.MPP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AMP returns any extra AMP records that should be set for the final
|
||||||
|
// hop on the route used by this shard.
|
||||||
|
func (m *mockShard) AMP() *record.AMP {
|
||||||
|
args := m.Called()
|
||||||
|
|
||||||
|
r := args.Get(0)
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.(*record.AMP)
|
||||||
|
}
|
||||||
|
@@ -3528,12 +3528,12 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
|
|||||||
).Return(nil)
|
).Return(nil)
|
||||||
|
|
||||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 1)
|
resultChan := make(chan *htlcswitch.PaymentResult, 1)
|
||||||
payer.On("GetAttemptResult",
|
payer.On("GetAttemptResult",
|
||||||
mock.Anything, mock.Anything, mock.Anything,
|
mock.Anything, mock.Anything, mock.Anything,
|
||||||
).Run(func(_ mock.Arguments) {
|
).Return(resultChan, nil).Run(func(_ mock.Arguments) {
|
||||||
// Send a successful payment result.
|
// Send a successful payment result.
|
||||||
payer.resultChan <- &htlcswitch.PaymentResult{}
|
resultChan <- &htlcswitch.PaymentResult{}
|
||||||
})
|
})
|
||||||
|
|
||||||
missionControl.On("ReportPaymentSuccess",
|
missionControl.On("ReportPaymentSuccess",
|
||||||
@@ -3599,6 +3599,11 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
// Create the error to be returned.
|
||||||
|
tempErr := htlcswitch.NewForwardingError(
|
||||||
|
&lnwire.FailTemporaryChannelFailure{}, 1,
|
||||||
|
)
|
||||||
|
|
||||||
// Register mockers with the expected method calls.
|
// Register mockers with the expected method calls.
|
||||||
controlTower.On("InitPayment", payHash, mock.Anything).Return(nil)
|
controlTower.On("InitPayment", payHash, mock.Anything).Return(nil)
|
||||||
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)
|
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)
|
||||||
@@ -3608,26 +3613,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
|
|||||||
|
|
||||||
payer.On("SendHTLC",
|
payer.On("SendHTLC",
|
||||||
mock.Anything, mock.Anything, mock.Anything,
|
mock.Anything, mock.Anything, mock.Anything,
|
||||||
).Return(nil)
|
).Return(tempErr)
|
||||||
|
|
||||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
|
||||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 1)
|
|
||||||
|
|
||||||
// Create the error to be returned.
|
|
||||||
tempErr := htlcswitch.NewForwardingError(
|
|
||||||
&lnwire.FailTemporaryChannelFailure{},
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Mock GetAttemptResult to return a failure.
|
|
||||||
payer.On("GetAttemptResult",
|
|
||||||
mock.Anything, mock.Anything, mock.Anything,
|
|
||||||
).Run(func(_ mock.Arguments) {
|
|
||||||
// Send an attempt failure.
|
|
||||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
|
||||||
Error: tempErr,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Mock the control tower to return the mocked payment.
|
// Mock the control tower to return the mocked payment.
|
||||||
payment := &mockMPPayment{}
|
payment := &mockMPPayment{}
|
||||||
|
Reference in New Issue
Block a user