routing: update mockers in unit test

This commit adds more mockers to be used in coming unit tests and
simplified the mockers to be more straightforward.
This commit is contained in:
yyforyongyu
2023-02-13 20:57:18 +08:00
parent 01e3bd87ab
commit ddad6ad4c4
2 changed files with 98 additions and 43 deletions

View File

@@ -12,7 +12,9 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
) )
@@ -572,8 +574,6 @@ func (m *mockControlTowerOld) SubscribeAllPayments() (
type mockPaymentAttemptDispatcher struct { type mockPaymentAttemptDispatcher struct {
mock.Mock mock.Mock
resultChan chan *htlcswitch.PaymentResult
} }
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
@@ -589,11 +589,14 @@ func (m *mockPaymentAttemptDispatcher) GetAttemptResult(attemptID uint64,
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) ( paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error) { <-chan *htlcswitch.PaymentResult, error) {
m.Called(attemptID, paymentHash, deobfuscator) args := m.Called(attemptID, paymentHash, deobfuscator)
// Instead of returning the mocked returned values, we need to return resultChan := args.Get(0)
// the chan resultChan so it can be converted into a read-only chan. if resultChan == nil {
return m.resultChan, nil return nil, args.Error(1)
}
return args.Get(0).(chan *htlcswitch.PaymentResult), args.Error(1)
} }
func (m *mockPaymentAttemptDispatcher) CleanStore( func (m *mockPaymentAttemptDispatcher) CleanStore(
@@ -698,7 +701,6 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
type mockControlTower struct { type mockControlTower struct {
mock.Mock mock.Mock
sync.Mutex
} }
var _ ControlTower = (*mockControlTower)(nil) var _ ControlTower = (*mockControlTower)(nil)
@@ -718,9 +720,6 @@ func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error {
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
a *channeldb.HTLCAttemptInfo) error { a *channeldb.HTLCAttemptInfo) error {
m.Lock()
defer m.Unlock()
args := m.Called(phash, a) args := m.Called(phash, a)
return args.Error(0) return args.Error(0)
} }
@@ -729,29 +728,32 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
*channeldb.HTLCAttempt, error) { *channeldb.HTLCAttempt, error) {
m.Lock()
defer m.Unlock()
args := m.Called(phash, pid, settleInfo) args := m.Called(phash, pid, settleInfo)
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
attempt := args.Get(0)
if attempt == nil {
return nil, args.Error(1)
}
return attempt.(*channeldb.HTLCAttempt), args.Error(1)
} }
func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
m.Lock()
defer m.Unlock()
args := m.Called(phash, pid, failInfo) args := m.Called(phash, pid, failInfo)
attempt := args.Get(0)
if attempt == nil {
return nil, args.Error(1)
}
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
} }
func (m *mockControlTower) FailPayment(phash lntypes.Hash, func (m *mockControlTower) FailPayment(phash lntypes.Hash,
reason channeldb.FailureReason) error { reason channeldb.FailureReason) error {
m.Lock()
defer m.Unlock()
args := m.Called(phash, reason) args := m.Called(phash, reason)
return args.Error(0) return args.Error(0)
} }
@@ -877,3 +879,70 @@ func (m *mockLink) EligibleToForward() bool {
func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error { func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error {
return m.mayAddOutgoingErr return m.mayAddOutgoingErr
} }
type mockShardTracker struct {
mock.Mock
}
var _ shards.ShardTracker = (*mockShardTracker)(nil)
func (m *mockShardTracker) NewShard(attemptID uint64,
lastShard bool) (shards.PaymentShard, error) {
args := m.Called(attemptID, lastShard)
shard := args.Get(0)
if shard == nil {
return nil, args.Error(1)
}
return shard.(shards.PaymentShard), args.Error(1)
}
func (m *mockShardTracker) GetHash(attemptID uint64) (lntypes.Hash, error) {
args := m.Called(attemptID)
return args.Get(0).(lntypes.Hash), args.Error(1)
}
func (m *mockShardTracker) CancelShard(attemptID uint64) error {
args := m.Called(attemptID)
return args.Error(0)
}
type mockShard struct {
mock.Mock
}
var _ shards.PaymentShard = (*mockShard)(nil)
// Hash returns the hash used for the HTLC representing this shard.
func (m *mockShard) Hash() lntypes.Hash {
args := m.Called()
return args.Get(0).(lntypes.Hash)
}
// MPP returns any extra MPP records that should be set for the final
// hop on the route used by this shard.
func (m *mockShard) MPP() *record.MPP {
args := m.Called()
r := args.Get(0)
if r == nil {
return nil
}
return r.(*record.MPP)
}
// AMP returns any extra AMP records that should be set for the final
// hop on the route used by this shard.
func (m *mockShard) AMP() *record.AMP {
args := m.Called()
r := args.Get(0)
if r == nil {
return nil
}
return r.(*record.AMP)
}

View File

@@ -3528,12 +3528,12 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
).Return(nil) ).Return(nil)
// Create a buffered chan and it will be returned by GetAttemptResult. // Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 1) resultChan := make(chan *htlcswitch.PaymentResult, 1)
payer.On("GetAttemptResult", payer.On("GetAttemptResult",
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Run(func(_ mock.Arguments) { ).Return(resultChan, nil).Run(func(_ mock.Arguments) {
// Send a successful payment result. // Send a successful payment result.
payer.resultChan <- &htlcswitch.PaymentResult{} resultChan <- &htlcswitch.PaymentResult{}
}) })
missionControl.On("ReportPaymentSuccess", missionControl.On("ReportPaymentSuccess",
@@ -3599,6 +3599,11 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
}, },
}} }}
// Create the error to be returned.
tempErr := htlcswitch.NewForwardingError(
&lnwire.FailTemporaryChannelFailure{}, 1,
)
// Register mockers with the expected method calls. // Register mockers with the expected method calls.
controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) controlTower.On("InitPayment", payHash, mock.Anything).Return(nil)
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)
@@ -3608,26 +3613,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
payer.On("SendHTLC", payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(nil) ).Return(tempErr)
// Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 1)
// Create the error to be returned.
tempErr := htlcswitch.NewForwardingError(
&lnwire.FailTemporaryChannelFailure{},
1,
)
// Mock GetAttemptResult to return a failure.
payer.On("GetAttemptResult",
mock.Anything, mock.Anything, mock.Anything,
).Run(func(_ mock.Arguments) {
// Send an attempt failure.
payer.resultChan <- &htlcswitch.PaymentResult{
Error: tempErr,
}
})
// Mock the control tower to return the mocked payment. // Mock the control tower to return the mocked payment.
payment := &mockMPPayment{} payment := &mockMPPayment{}