sweep: add mocks and patch unit test for sweepPendingInputs

This commit is contained in:
yyforyongyu 2024-01-31 03:25:58 +08:00
parent 210b7838c7
commit c03509397f
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
2 changed files with 178 additions and 0 deletions

View File

@ -7,6 +7,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/stretchr/testify/mock"
@ -332,3 +333,116 @@ func (m *mockUtxoAggregator) ClusterInputs(inputs pendingInputs) []InputSet {
return args.Get(0).([]InputSet)
}
// MockWallet is a mock implementation of the Wallet interface.
type MockWallet struct {
mock.Mock
}
// Compile-time constraint to ensure MockWallet implements Wallet.
var _ Wallet = (*MockWallet)(nil)
// PublishTransaction performs cursory validation (dust checks, etc) and
// broadcasts the passed transaction to the Bitcoin network.
func (m *MockWallet) PublishTransaction(tx *wire.MsgTx, label string) error {
args := m.Called(tx, label)
return args.Error(0)
}
// ListUnspentWitnessFromDefaultAccount returns all unspent outputs which are
// version 0 witness programs from the default wallet account. The 'minConfs'
// and 'maxConfs' parameters indicate the minimum and maximum number of
// confirmations an output needs in order to be returned by this method.
func (m *MockWallet) ListUnspentWitnessFromDefaultAccount(
minConfs, maxConfs int32) ([]*lnwallet.Utxo, error) {
args := m.Called(minConfs, maxConfs)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]*lnwallet.Utxo), args.Error(1)
}
// WithCoinSelectLock will execute the passed function closure in a
// synchronized manner preventing any coin selection operations from proceeding
// while the closure is executing. This can be seen as the ability to execute a
// function closure under an exclusive coin selection lock.
func (m *MockWallet) WithCoinSelectLock(f func() error) error {
m.Called(f)
return f()
}
// RemoveDescendants removes any wallet transactions that spends
// outputs created by the specified transaction.
func (m *MockWallet) RemoveDescendants(tx *wire.MsgTx) error {
args := m.Called(tx)
return args.Error(0)
}
// FetchTx returns the transaction that corresponds to the transaction
// hash passed in. If the transaction can't be found then a nil
// transaction pointer is returned.
func (m *MockWallet) FetchTx(txid chainhash.Hash) (*wire.MsgTx, error) {
args := m.Called(txid)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*wire.MsgTx), args.Error(1)
}
// CancelRebroadcast is used to inform the rebroadcaster sub-system
// that it no longer needs to try to rebroadcast a transaction. This is
// used to ensure that invalid transactions (inputs spent) aren't
// retried in the background.
func (m *MockWallet) CancelRebroadcast(tx chainhash.Hash) {
m.Called(tx)
}
// MockInputSet is a mock implementation of the InputSet interface.
type MockInputSet struct {
mock.Mock
}
// Compile-time constraint to ensure MockInputSet implements InputSet.
var _ InputSet = (*MockInputSet)(nil)
// Inputs returns the set of inputs that should be used to create a tx.
func (m *MockInputSet) Inputs() []input.Input {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).([]input.Input)
}
// FeeRate returns the fee rate that should be used for the tx.
func (m *MockInputSet) FeeRate() chainfee.SatPerKWeight {
args := m.Called()
return args.Get(0).(chainfee.SatPerKWeight)
}
// AddWalletInputs adds wallet inputs to the set until a non-dust
// change output can be made. Return an error if there are not enough
// wallet inputs.
func (m *MockInputSet) AddWalletInputs(wallet Wallet) error {
args := m.Called(wallet)
return args.Error(0)
}
// NeedWalletInput returns true if the input set needs more wallet
// inputs.
func (m *MockInputSet) NeedWalletInput() bool {
args := m.Called()
return args.Bool(0)
}

View File

@ -21,6 +21,7 @@ import (
lnmock "github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
@ -2413,3 +2414,66 @@ func TestMarkInputFailed(t *testing.T) {
// Assert the state is updated.
require.Equal(t, StateFailed, pi.state)
}
// TestSweepPendingInputs checks that `sweepPendingInputs` correctly executes
// its workflow based on the returned values from the interfaces.
func TestSweepPendingInputs(t *testing.T) {
t.Parallel()
// Create a mock wallet and aggregator.
wallet := &MockWallet{}
aggregator := &mockUtxoAggregator{}
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Wallet: wallet,
Aggregator: aggregator,
})
// Create an input set that needs wallet inputs.
setNeedWallet := &MockInputSet{}
// Mock this set to ask for wallet input.
setNeedWallet.On("NeedWalletInput").Return(true).Once()
setNeedWallet.On("AddWalletInputs", wallet).Return(nil).Once()
// Mock the wallet to require the lock once.
wallet.On("WithCoinSelectLock", mock.Anything).Return(nil).Once()
// Create an input set that doesn't need wallet inputs.
normalSet := &MockInputSet{}
normalSet.On("NeedWalletInput").Return(false).Once()
// Mock the methods used in `sweep`. This is not important for this
// unit test.
feeRate := chainfee.SatPerKWeight(1000)
setNeedWallet.On("Inputs").Return(nil).Once()
setNeedWallet.On("FeeRate").Return(feeRate).Once()
normalSet.On("Inputs").Return(nil).Once()
normalSet.On("FeeRate").Return(feeRate).Once()
// Make pending inputs for testing. We don't need real values here as
// the returned clusters are mocked.
pis := make(pendingInputs)
// Mock the aggregator to return the mocked input sets.
aggregator.On("ClusterInputs", pis).Return([]InputSet{
setNeedWallet, normalSet,
})
// Set change output script to an invalid value. This should cause the
// `createSweepTx` inside `sweep` to fail. This is done so we can
// terminate the method early as we are only interested in testing the
// workflow in `sweepPendingInputs`. We don't need to test `sweep` here
// as it should be tested in its own unit test.
s.currentOutputScript = []byte{1}
// Call the method under test.
s.sweepPendingInputs(pis)
// Assert mocked methods are called as expected.
wallet.AssertExpectations(t)
aggregator.AssertExpectations(t)
setNeedWallet.AssertExpectations(t)
normalSet.AssertExpectations(t)
}