From f85661d94a045255226d9b845d8f6a72fd624f55 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 29 Feb 2024 03:07:22 +0800 Subject: [PATCH] lnwallet+sweep: add new method `CheckMempoolAcceptance` --- lnmock/chain.go | 159 +++++++++++++++++++++++++++ lntest/mock/walletcontroller.go | 4 + lnwallet/btcwallet/btcwallet.go | 31 ++++++ lnwallet/btcwallet/btcwallet_test.go | 90 +++++++++++++++ lnwallet/interface.go | 5 + lnwallet/mock.go | 4 + sweep/interface.go | 5 + sweep/mock_test.go | 12 ++ 8 files changed, 310 insertions(+) create mode 100644 lnmock/chain.go diff --git a/lnmock/chain.go b/lnmock/chain.go new file mode 100644 index 000000000..dd208c33e --- /dev/null +++ b/lnmock/chain.go @@ -0,0 +1,159 @@ +package lnmock + +import ( + "github.com/btcsuite/btcd/btcjson" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/chain" + "github.com/btcsuite/btcwallet/waddrmgr" + "github.com/stretchr/testify/mock" +) + +// MockChain is a mock implementation of the Chain interface. +type MockChain struct { + mock.Mock +} + +// Compile-time constraint to ensure MockChain implements the Chain interface. +var _ chain.Interface = (*MockChain)(nil) + +func (m *MockChain) Start() error { + args := m.Called() + + return args.Error(0) +} + +func (m *MockChain) Stop() { + m.Called() +} + +func (m *MockChain) WaitForShutdown() { + m.Called() +} + +func (m *MockChain) GetBestBlock() (*chainhash.Hash, int32, error) { + args := m.Called() + + if args.Get(0) == nil { + return nil, args.Get(1).(int32), args.Error(2) + } + + return args.Get(0).(*chainhash.Hash), args.Get(1).(int32), args.Error(2) +} + +func (m *MockChain) GetBlock(hash *chainhash.Hash) (*wire.MsgBlock, error) { + args := m.Called(hash) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*wire.MsgBlock), args.Error(1) +} + +func (m *MockChain) GetBlockHash(height int64) (*chainhash.Hash, error) { + args := m.Called(height) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*chainhash.Hash), args.Error(1) +} + +func (m *MockChain) GetBlockHeader(hash *chainhash.Hash) ( + *wire.BlockHeader, error) { + + args := m.Called(hash) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*wire.BlockHeader), args.Error(1) +} + +func (m *MockChain) IsCurrent() bool { + args := m.Called() + + return args.Bool(0) +} + +func (m *MockChain) FilterBlocks(req *chain.FilterBlocksRequest) ( + *chain.FilterBlocksResponse, error) { + + args := m.Called(req) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*chain.FilterBlocksResponse), args.Error(1) +} + +func (m *MockChain) BlockStamp() (*waddrmgr.BlockStamp, error) { + args := m.Called() + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*waddrmgr.BlockStamp), args.Error(1) +} + +func (m *MockChain) SendRawTransaction(tx *wire.MsgTx, allowHighFees bool) ( + *chainhash.Hash, error) { + + args := m.Called(tx, allowHighFees) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*chainhash.Hash), args.Error(1) +} + +func (m *MockChain) Rescan(startHash *chainhash.Hash, addrs []btcutil.Address, + outPoints map[wire.OutPoint]btcutil.Address) error { + + args := m.Called(startHash, addrs, outPoints) + + return args.Error(0) +} + +func (m *MockChain) NotifyReceived(addrs []btcutil.Address) error { + args := m.Called(addrs) + + return args.Error(0) +} + +func (m *MockChain) NotifyBlocks() error { + args := m.Called() + + return args.Error(0) +} + +func (m *MockChain) Notifications() <-chan interface{} { + args := m.Called() + + return args.Get(0).(<-chan interface{}) +} + +func (m *MockChain) BackEnd() string { + args := m.Called() + + return args.String(0) +} + +func (m *MockChain) TestMempoolAccept(txns []*wire.MsgTx, maxFeeRate float64) ( + []*btcjson.TestMempoolAcceptResult, error) { + + args := m.Called(txns, maxFeeRate) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).([]*btcjson.TestMempoolAcceptResult), args.Error(1) +} diff --git a/lntest/mock/walletcontroller.go b/lntest/mock/walletcontroller.go index 6d09acd54..21d78add3 100644 --- a/lntest/mock/walletcontroller.go +++ b/lntest/mock/walletcontroller.go @@ -282,3 +282,7 @@ func (w *WalletController) FetchTx(chainhash.Hash) (*wire.MsgTx, error) { func (w *WalletController) RemoveDescendants(*wire.MsgTx) error { return nil } + +func (w *WalletController) CheckMempoolAcceptance(tx *wire.MsgTx) error { + return nil +} diff --git a/lnwallet/btcwallet/btcwallet.go b/lnwallet/btcwallet/btcwallet.go index ec4bc5d9b..ebca031c5 100644 --- a/lnwallet/btcwallet/btcwallet.go +++ b/lnwallet/btcwallet/btcwallet.go @@ -1898,3 +1898,34 @@ func (b *BtcWallet) RemoveDescendants(tx *wire.MsgTx) error { return b.wallet.TxStore.RemoveUnminedTx(wtxmgrNs, txRecord) }) } + +// CheckMempoolAcceptance is a wrapper around `TestMempoolAccept` which checks +// the mempool acceptance of a transaction. +func (b *BtcWallet) CheckMempoolAcceptance(tx *wire.MsgTx) error { + // Use a max feerate of 0 means the default value will be used when + // testing mempool acceptance. The default max feerate is 0.10 BTC/kvb, + // or 10,000 sat/vb. + results, err := b.chain.TestMempoolAccept([]*wire.MsgTx{tx}, 0) + if err != nil { + return err + } + + // Sanity check that the expected single result is returned. + if len(results) != 1 { + return fmt.Errorf("expected 1 result from TestMempoolAccept, "+ + "instead got %v", len(results)) + } + + result := results[0] + log.Debugf("TestMempoolAccept result: %s", spew.Sdump(result)) + + // Mempool check failed, we now map the reject reason to a proper RPC + // error and return it. + if !result.Allowed { + err := rpcclient.MapRPCErr(errors.New(result.RejectReason)) + + return fmt.Errorf("mempool rejection: %w", err) + } + + return nil +} diff --git a/lnwallet/btcwallet/btcwallet_test.go b/lnwallet/btcwallet/btcwallet_test.go index 28b783acc..892ec25fd 100644 --- a/lnwallet/btcwallet/btcwallet_test.go +++ b/lnwallet/btcwallet/btcwallet_test.go @@ -3,8 +3,12 @@ package btcwallet import ( "testing" + "github.com/btcsuite/btcd/btcjson" + "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/chain" "github.com/btcsuite/btcwallet/wallet" + "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" ) @@ -132,3 +136,89 @@ func TestPreviousOutpoints(t *testing.T) { }) } } + +// TestCheckMempoolAcceptance asserts the CheckMempoolAcceptance behaves as +// expected. +func TestCheckMempoolAcceptance(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock chain.Interface. + mockChain := &lnmock.MockChain{} + defer mockChain.AssertExpectations(t) + + // Create a test tx and a test max feerate. + tx := wire.NewMsgTx(2) + maxFeeRate := float64(0) + + // Create a test wallet. + wallet := &BtcWallet{ + chain: mockChain, + } + + // Assert that when the chain backend doesn't support + // `TestMempoolAccept`, an error is returned. + // + // Mock the chain backend to not support `TestMempoolAccept`. + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + nil, rpcclient.ErrBackendVersion).Once() + + err := wallet.CheckMempoolAcceptance(tx) + rt.ErrorIs(err, rpcclient.ErrBackendVersion) + + // Assert that when the chain backend doesn't implement + // `TestMempoolAccept`, an error is returned. + // + // Mock the chain backend to not support `TestMempoolAccept`. + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + nil, chain.ErrUnimplemented).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.ErrorIs(err, chain.ErrUnimplemented) + + // Assert that when the returned results are not as expected, an error + // is returned. + // + // Mock the chain backend to return more than one result. + results := []*btcjson.TestMempoolAcceptResult{ + {Txid: "txid1", Allowed: true}, + {Txid: "txid2", Allowed: false}, + } + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + results, nil).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.ErrorContains(err, "expected 1 result from TestMempoolAccept") + + // Assert that when the tx is rejected, the reason is converted to an + // RPC error and returned. + // + // Mock the chain backend to return one result. + results = []*btcjson.TestMempoolAcceptResult{{ + Txid: tx.TxHash().String(), + Allowed: false, + RejectReason: "insufficient fee", + }} + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + results, nil).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.ErrorIs(err, rpcclient.ErrInsufficientFee) + + // Assert that when the tx is accepted, no error is returned. + // + // Mock the chain backend to return one result. + results = []*btcjson.TestMempoolAcceptResult{ + {Txid: tx.TxHash().String(), Allowed: true}, + } + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + results, nil).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.NoError(err) +} diff --git a/lnwallet/interface.go b/lnwallet/interface.go index e26f4f291..59e6f5aab 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -536,6 +536,11 @@ type WalletController interface { // which could be e.g. btcd, bitcoind, neutrino, or another consensus // service. BackEnd() string + + // CheckMempoolAcceptance checks whether a transaction follows mempool + // policies and returns an error if it cannot be accepted into the + // mempool. + CheckMempoolAcceptance(tx *wire.MsgTx) error } // BlockChainIO is a dedicated source which will be used to obtain queries diff --git a/lnwallet/mock.go b/lnwallet/mock.go index f0f257ef0..0146df57e 100644 --- a/lnwallet/mock.go +++ b/lnwallet/mock.go @@ -294,6 +294,10 @@ func (w *mockWalletController) RemoveDescendants(*wire.MsgTx) error { return nil } +func (w *mockWalletController) CheckMempoolAcceptance(tx *wire.MsgTx) error { + return nil +} + // mockChainNotifier is a mock implementation of the ChainNotifier interface. type mockChainNotifier struct { SpendChan chan *chainntnfs.SpendDetail diff --git a/sweep/interface.go b/sweep/interface.go index a9de8bc57..e58cc8507 100644 --- a/sweep/interface.go +++ b/sweep/interface.go @@ -41,4 +41,9 @@ type Wallet interface { // used to ensure that invalid transactions (inputs spent) aren't // retried in the background. CancelRebroadcast(tx chainhash.Hash) + + // CheckMempoolAcceptance checks whether a transaction follows mempool + // policies and returns an error if it cannot be accepted into the + // mempool. + CheckMempoolAcceptance(tx *wire.MsgTx) error } diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 270c3844e..3688db72c 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -46,6 +46,10 @@ func newMockBackend(t *testing.T, notifier *MockNotifier) *mockBackend { } } +func (b *mockBackend) CheckMempoolAcceptance(tx *wire.MsgTx) error { + return nil +} + func (b *mockBackend) publishTransaction(tx *wire.MsgTx) error { b.lock.Lock() defer b.lock.Unlock() @@ -344,6 +348,14 @@ type MockWallet struct { // Compile-time constraint to ensure MockWallet implements Wallet. var _ Wallet = (*MockWallet)(nil) +// CheckMempoolAcceptance checks if the transaction can be accepted to the +// mempool. +func (m *MockWallet) CheckMempoolAcceptance(tx *wire.MsgTx) error { + args := m.Called(tx) + + return args.Error(0) +} + // PublishTransaction performs cursory validation (dust checks, etc) and // broadcasts the passed transaction to the Bitcoin network. func (m *MockWallet) PublishTransaction(tx *wire.MsgTx, label string) error {