mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-28 14:40:51 +02:00
routing: update mockers in unit test
This commit adds more mockers to be used in coming unit tests and simplified the mockers to be more straightforward.
This commit is contained in:
@@ -12,7 +12,9 @@ import (
|
||||
"github.com/lightningnetwork/lnd/htlcswitch"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/routing/shards"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
@@ -572,8 +574,6 @@ func (m *mockControlTowerOld) SubscribeAllPayments() (
|
||||
|
||||
type mockPaymentAttemptDispatcher struct {
|
||||
mock.Mock
|
||||
|
||||
resultChan chan *htlcswitch.PaymentResult
|
||||
}
|
||||
|
||||
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
|
||||
@@ -589,11 +589,14 @@ func (m *mockPaymentAttemptDispatcher) GetAttemptResult(attemptID uint64,
|
||||
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
|
||||
<-chan *htlcswitch.PaymentResult, error) {
|
||||
|
||||
m.Called(attemptID, paymentHash, deobfuscator)
|
||||
args := m.Called(attemptID, paymentHash, deobfuscator)
|
||||
|
||||
// Instead of returning the mocked returned values, we need to return
|
||||
// the chan resultChan so it can be converted into a read-only chan.
|
||||
return m.resultChan, nil
|
||||
resultChan := args.Get(0)
|
||||
if resultChan == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(chan *htlcswitch.PaymentResult), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockPaymentAttemptDispatcher) CleanStore(
|
||||
@@ -698,7 +701,6 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
|
||||
|
||||
type mockControlTower struct {
|
||||
mock.Mock
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
var _ ControlTower = (*mockControlTower)(nil)
|
||||
@@ -718,9 +720,6 @@ func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error {
|
||||
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
|
||||
a *channeldb.HTLCAttemptInfo) error {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
args := m.Called(phash, a)
|
||||
return args.Error(0)
|
||||
}
|
||||
@@ -729,29 +728,32 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
|
||||
pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
|
||||
*channeldb.HTLCAttempt, error) {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
args := m.Called(phash, pid, settleInfo)
|
||||
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
|
||||
|
||||
attempt := args.Get(0)
|
||||
if attempt == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return attempt.(*channeldb.HTLCAttempt), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
|
||||
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
args := m.Called(phash, pid, failInfo)
|
||||
|
||||
attempt := args.Get(0)
|
||||
if attempt == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) FailPayment(phash lntypes.Hash,
|
||||
reason channeldb.FailureReason) error {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
args := m.Called(phash, reason)
|
||||
return args.Error(0)
|
||||
}
|
||||
@@ -877,3 +879,70 @@ func (m *mockLink) EligibleToForward() bool {
|
||||
func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error {
|
||||
return m.mayAddOutgoingErr
|
||||
}
|
||||
|
||||
type mockShardTracker struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ shards.ShardTracker = (*mockShardTracker)(nil)
|
||||
|
||||
func (m *mockShardTracker) NewShard(attemptID uint64,
|
||||
lastShard bool) (shards.PaymentShard, error) {
|
||||
|
||||
args := m.Called(attemptID, lastShard)
|
||||
|
||||
shard := args.Get(0)
|
||||
if shard == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return shard.(shards.PaymentShard), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockShardTracker) GetHash(attemptID uint64) (lntypes.Hash, error) {
|
||||
args := m.Called(attemptID)
|
||||
return args.Get(0).(lntypes.Hash), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockShardTracker) CancelShard(attemptID uint64) error {
|
||||
args := m.Called(attemptID)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type mockShard struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ shards.PaymentShard = (*mockShard)(nil)
|
||||
|
||||
// Hash returns the hash used for the HTLC representing this shard.
|
||||
func (m *mockShard) Hash() lntypes.Hash {
|
||||
args := m.Called()
|
||||
return args.Get(0).(lntypes.Hash)
|
||||
}
|
||||
|
||||
// MPP returns any extra MPP records that should be set for the final
|
||||
// hop on the route used by this shard.
|
||||
func (m *mockShard) MPP() *record.MPP {
|
||||
args := m.Called()
|
||||
|
||||
r := args.Get(0)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.(*record.MPP)
|
||||
}
|
||||
|
||||
// AMP returns any extra AMP records that should be set for the final
|
||||
// hop on the route used by this shard.
|
||||
func (m *mockShard) AMP() *record.AMP {
|
||||
args := m.Called()
|
||||
|
||||
r := args.Get(0)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.(*record.AMP)
|
||||
}
|
||||
|
Reference in New Issue
Block a user