From c03509397f58b568aaa13a5cd8f447982dbe073f Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 31 Jan 2024 03:25:58 +0800 Subject: [PATCH] sweep: add mocks and patch unit test for `sweepPendingInputs` --- sweep/mock_test.go | 114 ++++++++++++++++++++++++++++++++++++++++++ sweep/sweeper_test.go | 64 ++++++++++++++++++++++++ 2 files changed, 178 insertions(+) diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 5ef6b78a2..fc7ff9c34 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/mock" @@ -332,3 +333,116 @@ func (m *mockUtxoAggregator) ClusterInputs(inputs pendingInputs) []InputSet { return args.Get(0).([]InputSet) } + +// MockWallet is a mock implementation of the Wallet interface. +type MockWallet struct { + mock.Mock +} + +// Compile-time constraint to ensure MockWallet implements Wallet. +var _ Wallet = (*MockWallet)(nil) + +// PublishTransaction performs cursory validation (dust checks, etc) and +// broadcasts the passed transaction to the Bitcoin network. +func (m *MockWallet) PublishTransaction(tx *wire.MsgTx, label string) error { + args := m.Called(tx, label) + + return args.Error(0) +} + +// ListUnspentWitnessFromDefaultAccount returns all unspent outputs which are +// version 0 witness programs from the default wallet account. The 'minConfs' +// and 'maxConfs' parameters indicate the minimum and maximum number of +// confirmations an output needs in order to be returned by this method. +func (m *MockWallet) ListUnspentWitnessFromDefaultAccount( + minConfs, maxConfs int32) ([]*lnwallet.Utxo, error) { + + args := m.Called(minConfs, maxConfs) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).([]*lnwallet.Utxo), args.Error(1) +} + +// WithCoinSelectLock will execute the passed function closure in a +// synchronized manner preventing any coin selection operations from proceeding +// while the closure is executing. This can be seen as the ability to execute a +// function closure under an exclusive coin selection lock. +func (m *MockWallet) WithCoinSelectLock(f func() error) error { + m.Called(f) + + return f() +} + +// RemoveDescendants removes any wallet transactions that spends +// outputs created by the specified transaction. +func (m *MockWallet) RemoveDescendants(tx *wire.MsgTx) error { + args := m.Called(tx) + + return args.Error(0) +} + +// FetchTx returns the transaction that corresponds to the transaction +// hash passed in. If the transaction can't be found then a nil +// transaction pointer is returned. +func (m *MockWallet) FetchTx(txid chainhash.Hash) (*wire.MsgTx, error) { + args := m.Called(txid) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*wire.MsgTx), args.Error(1) +} + +// CancelRebroadcast is used to inform the rebroadcaster sub-system +// that it no longer needs to try to rebroadcast a transaction. This is +// used to ensure that invalid transactions (inputs spent) aren't +// retried in the background. +func (m *MockWallet) CancelRebroadcast(tx chainhash.Hash) { + m.Called(tx) +} + +// MockInputSet is a mock implementation of the InputSet interface. +type MockInputSet struct { + mock.Mock +} + +// Compile-time constraint to ensure MockInputSet implements InputSet. +var _ InputSet = (*MockInputSet)(nil) + +// Inputs returns the set of inputs that should be used to create a tx. +func (m *MockInputSet) Inputs() []input.Input { + args := m.Called() + + if args.Get(0) == nil { + return nil + } + + return args.Get(0).([]input.Input) +} + +// FeeRate returns the fee rate that should be used for the tx. +func (m *MockInputSet) FeeRate() chainfee.SatPerKWeight { + args := m.Called() + + return args.Get(0).(chainfee.SatPerKWeight) +} + +// AddWalletInputs adds wallet inputs to the set until a non-dust +// change output can be made. Return an error if there are not enough +// wallet inputs. +func (m *MockInputSet) AddWalletInputs(wallet Wallet) error { + args := m.Called(wallet) + + return args.Error(0) +} + +// NeedWalletInput returns true if the input set needs more wallet +// inputs. +func (m *MockInputSet) NeedWalletInput() bool { + args := m.Called() + + return args.Bool(0) +} diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index e7cd27827..7396407b7 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -21,6 +21,7 @@ import ( lnmock "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -2413,3 +2414,66 @@ func TestMarkInputFailed(t *testing.T) { // Assert the state is updated. require.Equal(t, StateFailed, pi.state) } + +// TestSweepPendingInputs checks that `sweepPendingInputs` correctly executes +// its workflow based on the returned values from the interfaces. +func TestSweepPendingInputs(t *testing.T) { + t.Parallel() + + // Create a mock wallet and aggregator. + wallet := &MockWallet{} + aggregator := &mockUtxoAggregator{} + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Wallet: wallet, + Aggregator: aggregator, + }) + + // Create an input set that needs wallet inputs. + setNeedWallet := &MockInputSet{} + + // Mock this set to ask for wallet input. + setNeedWallet.On("NeedWalletInput").Return(true).Once() + setNeedWallet.On("AddWalletInputs", wallet).Return(nil).Once() + + // Mock the wallet to require the lock once. + wallet.On("WithCoinSelectLock", mock.Anything).Return(nil).Once() + + // Create an input set that doesn't need wallet inputs. + normalSet := &MockInputSet{} + normalSet.On("NeedWalletInput").Return(false).Once() + + // Mock the methods used in `sweep`. This is not important for this + // unit test. + feeRate := chainfee.SatPerKWeight(1000) + setNeedWallet.On("Inputs").Return(nil).Once() + setNeedWallet.On("FeeRate").Return(feeRate).Once() + normalSet.On("Inputs").Return(nil).Once() + normalSet.On("FeeRate").Return(feeRate).Once() + + // Make pending inputs for testing. We don't need real values here as + // the returned clusters are mocked. + pis := make(pendingInputs) + + // Mock the aggregator to return the mocked input sets. + aggregator.On("ClusterInputs", pis).Return([]InputSet{ + setNeedWallet, normalSet, + }) + + // Set change output script to an invalid value. This should cause the + // `createSweepTx` inside `sweep` to fail. This is done so we can + // terminate the method early as we are only interested in testing the + // workflow in `sweepPendingInputs`. We don't need to test `sweep` here + // as it should be tested in its own unit test. + s.currentOutputScript = []byte{1} + + // Call the method under test. + s.sweepPendingInputs(pis) + + // Assert mocked methods are called as expected. + wallet.AssertExpectations(t) + aggregator.AssertExpectations(t) + setNeedWallet.AssertExpectations(t) + normalSet.AssertExpectations(t) +}