From f13a3a80538e4d658a95d4356d5a02eb323851f4 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 11 Jan 2024 04:18:40 +0800 Subject: [PATCH] sweep: use `testify/mock` for `MockSweeperStore` --- sweep/store_mock.go | 40 ++++++++++++++++++++++------------------ sweep/store_test.go | 40 ++++++++-------------------------------- sweep/sweeper_test.go | 42 +++++++++++++++++++++++------------------- 3 files changed, 53 insertions(+), 69 deletions(-) diff --git a/sweep/store_mock.go b/sweep/store_mock.go index 16f7714a9..73b797963 100644 --- a/sweep/store_mock.go +++ b/sweep/store_mock.go @@ -2,54 +2,58 @@ package sweep import ( "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/stretchr/testify/mock" ) // MockSweeperStore is a mock implementation of sweeper store. This type is // exported, because it is currently used in nursery tests too. type MockSweeperStore struct { - ourTxes map[chainhash.Hash]struct{} + mock.Mock } // NewMockSweeperStore returns a new instance. func NewMockSweeperStore() *MockSweeperStore { - return &MockSweeperStore{ - ourTxes: make(map[chainhash.Hash]struct{}), - } + return &MockSweeperStore{} } -// IsOurTx determines whether a tx is published by us, based on its -// hash. +// IsOurTx determines whether a tx is published by us, based on its hash. func (s *MockSweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) { - _, ok := s.ourTxes[hash] - return ok, nil + args := s.Called(hash) + + return args.Bool(0), args.Error(1) } // StoreTx stores a tx we are about to publish. func (s *MockSweeperStore) StoreTx(tr *TxRecord) error { - s.ourTxes[tr.Txid] = struct{}{} - - return nil + args := s.Called(tr) + return args.Error(0) } // ListSweeps lists all the sweeps we have successfully published. func (s *MockSweeperStore) ListSweeps() ([]chainhash.Hash, error) { - var txns []chainhash.Hash - for tx := range s.ourTxes { - txns = append(txns, tx) - } + args := s.Called() - return txns, nil + return args.Get(0).([]chainhash.Hash), args.Error(1) } // GetTx queries the database to find the tx that matches the given txid. // Returns ErrTxNotFound if it cannot be found. func (s *MockSweeperStore) GetTx(hash chainhash.Hash) (*TxRecord, error) { - return nil, ErrTxNotFound + args := s.Called(hash) + + tr := args.Get(0) + if tr != nil { + return args.Get(0).(*TxRecord), args.Error(1) + } + + return nil, args.Error(1) } // DeleteTx removes the given tx from db. func (s *MockSweeperStore) DeleteTx(txid chainhash.Hash) error { - return nil + args := s.Called(txid) + + return args.Error(0) } // Compile-time constraint to ensure MockSweeperStore implements SweeperStore. diff --git a/sweep/store_test.go b/sweep/store_test.go index 7cfc649c9..ea65b0177 100644 --- a/sweep/store_test.go +++ b/sweep/store_test.go @@ -14,35 +14,13 @@ import ( // TestStore asserts that the store persists the presented data to disk and is // able to retrieve it again. func TestStore(t *testing.T) { - t.Run("bolt", func(t *testing.T) { + // Create new store. + cdb, err := channeldb.MakeTestDB(t) + require.NoError(t, err) - // Create new store. - cdb, err := channeldb.MakeTestDB(t) - if err != nil { - t.Fatalf("unable to open channel db: %v", err) - } - - testStore(t, func() (SweeperStore, error) { - var chain chainhash.Hash - return NewSweeperStore(cdb, &chain) - }) - }) - t.Run("mock", func(t *testing.T) { - store := NewMockSweeperStore() - - testStore(t, func() (SweeperStore, error) { - // Return same store, because the mock has no real - // persistence. - return store, nil - }) - }) -} - -func testStore(t *testing.T, createStore func() (SweeperStore, error)) { - store, err := createStore() - if err != nil { - t.Fatal(err) - } + var chain chainhash.Hash + store, err := NewSweeperStore(cdb, &chain) + require.NoError(t, err) // Notify publication of tx1 tx1 := wire.MsgTx{} @@ -75,10 +53,8 @@ func testStore(t *testing.T, createStore func() (SweeperStore, error)) { require.NoError(t, err) // Recreate the sweeper store - store, err = createStore() - if err != nil { - t.Fatal(err) - } + store, err = NewSweeperStore(cdb, &chain) + require.NoError(t, err) // Assert that both txes are recognized as our own. ours, err := store.IsOurTx(tx1.TxHash()) diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 3054c9f0c..2003254cd 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/mock" @@ -41,7 +42,7 @@ type sweeperTestContext struct { notifier *MockNotifier estimator *mockFeeEstimator backend *mockBackend - store *MockSweeperStore + store SweeperStore publishChan chan wire.MsgTx } @@ -102,7 +103,13 @@ func init() { func createSweeperTestContext(t *testing.T) *sweeperTestContext { notifier := NewMockNotifier(t) - store := NewMockSweeperStore() + // Create new store. + cdb, err := channeldb.MakeTestDB(t) + require.NoError(t, err) + + var chain chainhash.Hash + store, err := NewSweeperStore(cdb, &chain) + require.NoError(t, err) backend := newMockBackend(t, notifier) backend.walletUtxos = []*lnwallet.Utxo{ @@ -682,7 +689,6 @@ func TestIdempotency(t *testing.T) { // Timer is still running, but spend notification was delivered before // it expired. - ctx.finish(1) } @@ -701,9 +707,8 @@ func TestRestart(t *testing.T) { // Sweep input and expect sweep tx. input1 := spendableInputs[0] - if _, err := ctx.sweeper.SweepInput(input1, defaultFeePref); err != nil { - t.Fatal(err) - } + _, err := ctx.sweeper.SweepInput(input1, defaultFeePref) + require.NoError(t, err) ctx.receiveTx() @@ -758,23 +763,20 @@ func TestRestart(t *testing.T) { ctx.finish(1) } -// TestRestartRemoteSpend asserts that the sweeper picks up sweeping properly after -// a restart with remote spend. +// TestRestartRemoteSpend asserts that the sweeper picks up sweeping properly +// after a restart with remote spend. func TestRestartRemoteSpend(t *testing.T) { - ctx := createSweeperTestContext(t) // Sweep input. input1 := spendableInputs[0] - if _, err := ctx.sweeper.SweepInput(input1, defaultFeePref); err != nil { - t.Fatal(err) - } + _, err := ctx.sweeper.SweepInput(input1, defaultFeePref) + require.NoError(t, err) // Sweep another input. input2 := spendableInputs[1] - if _, err := ctx.sweeper.SweepInput(input2, defaultFeePref); err != nil { - t.Fatal(err) - } + _, err = ctx.sweeper.SweepInput(input2, defaultFeePref) + require.NoError(t, err) sweepTx := ctx.receiveTx() @@ -798,7 +800,8 @@ func TestRestartRemoteSpend(t *testing.T) { // Mine remote spending tx. ctx.backend.mine() - // Simulate other subsystem (e.g. contract resolver) re-offering input 0. + // Simulate other subsystem (e.g. contract resolver) re-offering input + // 0. spendChan, err := ctx.sweeper.SweepInput(input1, defaultFeePref) if err != nil { t.Fatal(err) @@ -815,8 +818,8 @@ func TestRestartRemoteSpend(t *testing.T) { ctx.finish(1) } -// TestRestartConfirmed asserts that the sweeper picks up sweeping properly after -// a restart with a confirm of our own sweep tx. +// TestRestartConfirmed asserts that the sweeper picks up sweeping properly +// after a restart with a confirm of our own sweep tx. func TestRestartConfirmed(t *testing.T) { ctx := createSweeperTestContext(t) @@ -834,7 +837,8 @@ func TestRestartConfirmed(t *testing.T) { // Mine the sweep tx. ctx.backend.mine() - // Simulate other subsystem (e.g. contract resolver) re-offering input 0. + // Simulate other subsystem (e.g. contract resolver) re-offering input + // 0. spendChan, err := ctx.sweeper.SweepInput(input, defaultFeePref) if err != nil { t.Fatal(err)