diff --git a/chainntnfs/bitcoindnotify/bitcoind_test.go b/chainntnfs/bitcoindnotify/bitcoind_test.go index be336ef6c..fa51efa5e 100644 --- a/chainntnfs/bitcoindnotify/bitcoind_test.go +++ b/chainntnfs/bitcoindnotify/bitcoind_test.go @@ -37,11 +37,7 @@ var ( func initHintCache(t *testing.T) *channeldb.HeightHintCache { t.Helper() - db, err := channeldb.Open(t.TempDir()) - require.NoError(t, err, "unable to create db") - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) + db := channeldb.OpenForTesting(t, t.TempDir()) testCfg := channeldb.CacheConfig{ QueryDisable: false, diff --git a/chainntnfs/btcdnotify/btcd_test.go b/chainntnfs/btcdnotify/btcd_test.go index 1cfbff731..6a1b97854 100644 --- a/chainntnfs/btcdnotify/btcd_test.go +++ b/chainntnfs/btcdnotify/btcd_test.go @@ -33,11 +33,7 @@ var ( func initHintCache(t *testing.T) *channeldb.HeightHintCache { t.Helper() - db, err := channeldb.Open(t.TempDir()) - require.NoError(t, err, "unable to create db") - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) + db := channeldb.OpenForTesting(t, t.TempDir()) testCfg := channeldb.CacheConfig{ QueryDisable: false, diff --git a/chainntnfs/test/test_interface.go b/chainntnfs/test/test_interface.go index 35e63a45e..99daf54f1 100644 --- a/chainntnfs/test/test_interface.go +++ b/chainntnfs/test/test_interface.go @@ -1906,10 +1906,8 @@ func TestInterfaces(t *testing.T, targetBackEnd string) { // Initialize a height hint cache for each notifier. tempDir := t.TempDir() - db, err := channeldb.Open(tempDir) - if err != nil { - t.Fatalf("unable to create db: %v", err) - } + db := channeldb.OpenForTesting(t, tempDir) + testCfg := channeldb.CacheConfig{ QueryDisable: false, } diff --git a/channeldb/db.go b/channeldb/db.go index 70de0aaf7..bf7909ba5 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -34,6 +34,7 @@ import ( "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) const ( @@ -345,10 +346,11 @@ type DB struct { noRevLogAmtData bool } -// Open opens or creates channeldb. Any necessary schemas migrations due -// to updates will take place as necessary. -// TODO(bhandras): deprecate this function. -func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) { +// OpenForTesting opens or creates a channeldb to be used for tests. Any +// necessary schemas migrations due to updates will take place as necessary. +func OpenForTesting(t testing.TB, dbPath string, + modifiers ...OptionModifier) *DB { + backend, err := kvdb.GetBoltBackend(&kvdb.BoltBackendConfig{ DBPath: dbPath, DBFileName: dbName, @@ -357,16 +359,18 @@ func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) { AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge, DBTimeout: kvdb.DefaultDBTimeout, }) - if err != nil { - return nil, err - } + require.NoError(t, err) db, err := CreateWithBackend(backend, modifiers...) - if err == nil { - db.dbPath = dbPath - } + require.NoError(t, err) - return db, err + db.dbPath = dbPath + + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + + return db } // CreateWithBackend creates channeldb instance using the passed kvdb.Backend. diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 1ae170f5c..9fed9934b 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -65,11 +65,7 @@ func TestOpenWithCreate(t *testing.T) { // Now, reopen the same db in dry run migration mode. Since we have not // applied any migrations, this should ignore the flag and not fail. - cdb, err = Open(dbPath, OptionDryRunMigration(true)) - require.NoError(t, err, "unable to create channeldb") - if err := cdb.Close(); err != nil { - t.Fatalf("unable to close channeldb: %v", err) - } + OpenForTesting(t, dbPath, OptionDryRunMigration(true)) } // TestWipe tests that the database wipe operation completes successfully diff --git a/channeldb/height_hint_test.go b/channeldb/height_hint_test.go index 3d98707e5..1549ee5f4 100644 --- a/channeldb/height_hint_test.go +++ b/channeldb/height_hint_test.go @@ -23,15 +23,11 @@ func initHintCache(t *testing.T) *HeightHintCache { func initHintCacheWithConfig(t *testing.T, cfg CacheConfig) *HeightHintCache { t.Helper() - db, err := Open(t.TempDir()) - require.NoError(t, err, "unable to create db") + db := OpenForTesting(t, t.TempDir()) + hintCache, err := NewHeightHintCache(cfg, db.Backend) require.NoError(t, err, "unable to create hint cache") - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) - return hintCache } diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index bd4ad8568..576009eda 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -635,15 +635,6 @@ func TestMockRetributionStore(t *testing.T) { } } -func makeTestChannelDB(t *testing.T) (*channeldb.DB, error) { - db, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, err - } - - return db, nil -} - // TestChannelDBRetributionStore instantiates a retributionStore backed by a // channeldb.DB, and tests its behavior using the general RetributionStore test // suite. @@ -654,25 +645,19 @@ func TestChannelDBRetributionStore(t *testing.T) { t.Run( "channeldbDBRetributionStore."+test.name, func(tt *testing.T) { - db, err := makeTestChannelDB(t) - if err != nil { - t.Fatalf("unable to open channeldb: %v", err) - } - defer db.Close() + db := channeldb.OpenForTesting(t, t.TempDir()) restartDb := func() RetributionStorer { // Close and reopen channeldb - if err = db.Close(); err != nil { + if err := db.Close(); err != nil { t.Fatalf("unable to close "+ "channeldb during "+ "restart: %v", err) } - db, err = channeldb.Open(db.Path()) - if err != nil { - t.Fatalf("unable to open "+ - "channeldb: %v", err) - } + db = channeldb.OpenForTesting( + t, db.Path(), + ) return NewRetributionStore(db) } @@ -2279,21 +2264,8 @@ func createInitChannels(t *testing.T) ( return nil, nil, err } - dbAlice, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbAlice.Close()) - }) - - dbBob, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbBob.Close()) - }) + dbAlice := channeldb.OpenForTesting(t, t.TempDir()) + dbBob := channeldb.OpenForTesting(t, t.TempDir()) estimator := chainfee.NewStaticEstimator(12500, 0) feePerKw, err := estimator.EstimateFeePerKW(1) diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index 6ede1baf1..fe2603ca5 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -22,13 +22,7 @@ import ( func TestChainArbitratorRepublishCloses(t *testing.T) { t.Parallel() - db, err := channeldb.Open(t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) + db := channeldb.OpenForTesting(t, t.TempDir()) // Create 10 test channels and sync them to the database. const numChans = 10 @@ -139,11 +133,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { func TestResolveContract(t *testing.T) { t.Parallel() - db, err := channeldb.Open(t.TempDir()) - require.NoError(t, err, "unable to open db") - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) + db := channeldb.OpenForTesting(t, t.TempDir()) // With the DB created, we'll make a new channel, and mark it as // pending open within the database. diff --git a/contractcourt/utils_test.go b/contractcourt/utils_test.go index 9a3c5308e..994bc57a8 100644 --- a/contractcourt/utils_test.go +++ b/contractcourt/utils_test.go @@ -65,12 +65,9 @@ func copyChannelState(t *testing.T, state *channeldb.OpenChannel) ( return nil, err } - newDb, err := channeldb.Open(tempDbPath) - if err != nil { - return nil, err - } + newDB := channeldb.OpenForTesting(t, tempDbPath) - chans, err := newDb.ChannelStateDB().FetchAllChannels() + chans, err := newDB.ChannelStateDB().FetchAllChannels() if err != nil { return nil, err } diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 6d2ebbfe9..056069940 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -73,21 +73,6 @@ var ( rebroadcastInterval = time.Hour * 1000000 ) -// makeTestDB creates a new instance of the ChannelDB for testing purposes. -func makeTestDB(t *testing.T) (*channeldb.DB, error) { - // Create channeldb for the first time. - cdb, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, err - } - - t.Cleanup(func() { - cdb.Close() - }) - - return cdb, nil -} - type mockGraphSource struct { bestHeight uint32 @@ -734,10 +719,7 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( notifier := newMockNotifier() router := newMockRouter(startHeight) - db, err := makeTestDB(t) - if err != nil { - return nil, err - } + db := channeldb.OpenForTesting(t, t.TempDir()) waitingProofStore, err := channeldb.NewWaitingProofStore(db) if err != nil { diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go index 36c082e36..88c40c144 100644 --- a/discovery/message_store_test.go +++ b/discovery/message_store_test.go @@ -17,19 +17,10 @@ import ( func createTestMessageStore(t *testing.T) *MessageStore { t.Helper() - db, err := channeldb.Open(t.TempDir()) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } - - t.Cleanup(func() { - db.Close() - }) + db := channeldb.OpenForTesting(t, t.TempDir()) store, err := NewMessageStore(db) - if err != nil { - t.Fatalf("unable to initialize message store: %v", err) - } + require.NoError(t, err) return store } diff --git a/funding/manager_test.go b/funding/manager_test.go index c73f19a39..525f69f9a 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -427,10 +427,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, } dbDir := filepath.Join(tempTestDir, "cdb") - fullDB, err := channeldb.Open(dbDir) - if err != nil { - return nil, err - } + fullDB := channeldb.OpenForTesting(t, dbDir) cdb := fullDB.ChannelStateDB() diff --git a/htlcswitch/circuit_test.go b/htlcswitch/circuit_test.go index 108692719..ddad11aca 100644 --- a/htlcswitch/circuit_test.go +++ b/htlcswitch/circuit_test.go @@ -625,9 +625,7 @@ func makeCircuitDB(t *testing.T, path string) *channeldb.DB { path = t.TempDir() } - db, err := channeldb.Open(path) - require.NoError(t, err, "unable to open channel db") - t.Cleanup(func() { db.Close() }) + db := channeldb.OpenForTesting(t, path) return db } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 46c47961f..4ee538a58 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -2169,7 +2169,7 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt, BaseFee: lnwire.NewMSatFromSatoshis(1), TimeLockDelta: 6, } - invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta) + invoiceRegistry = newMockRegistry(t) ) pCache := newMockPreimageCache() @@ -2267,7 +2267,6 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt, t.Cleanup(func() { close(alicePeer.quit) - invoiceRegistry.cleanup() }) harness := singleLinkTestHarness{ diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 17626e3fc..ce791bef3 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net" - "os" "path/filepath" "sync" "sync/atomic" @@ -216,11 +215,7 @@ func initSwitchWithTempDB(t testing.TB, startingHeight uint32) (*Switch, error) { tempPath := filepath.Join(t.TempDir(), "switchdb") - db, err := channeldb.Open(tempPath) - if err != nil { - return nil, err - } - t.Cleanup(func() { db.Close() }) + db := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(startingHeight, db) if err != nil { @@ -254,9 +249,7 @@ func newMockServer(t testing.TB, name string, startingHeight uint32, t.Cleanup(func() { _ = htlcSwitch.Stop() }) - registry := newMockRegistry(defaultDelta) - - t.Cleanup(func() { registry.cleanup() }) + registry := newMockRegistry(t) return &mockServer{ t: t, @@ -977,37 +970,12 @@ func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { var _ ChannelLink = (*mockChannelLink)(nil) -func newDB() (*channeldb.DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := os.MkdirTemp("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) - if err != nil { - os.RemoveAll(tempDirName) - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) - } - - return cdb, cleanUp, nil -} - const testInvoiceCltvExpiry = 6 type mockInvoiceRegistry struct { settleChan chan lntypes.Hash registry *invoices.InvoiceRegistry - - cleanup func() } type mockChainNotifier struct { @@ -1024,11 +992,8 @@ func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( }, nil } -func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { - cdb, cleanup, err := newDB() - if err != nil { - panic(err) - } +func newMockRegistry(t testing.TB) *mockInvoiceRegistry { + cdb := channeldb.OpenForTesting(t, t.TempDir()) modifierMock := &invoices.MockHtlcModifier{} registry := invoices.NewRegistry( @@ -1046,7 +1011,6 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { return &mockInvoiceRegistry{ registry: registry, - cleanup: cleanup, } } diff --git a/htlcswitch/payment_result_test.go b/htlcswitch/payment_result_test.go index 664197f76..f6def1465 100644 --- a/htlcswitch/payment_result_test.go +++ b/htlcswitch/payment_result_test.go @@ -101,11 +101,7 @@ func TestNetworkResultStore(t *testing.T) { const numResults = 4 - db, err := channeldb.Open(t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { db.Close() }) + db := channeldb.OpenForTesting(t, t.TempDir()) store := newNetworkResultStore(db) diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index e68e7d1b1..abfb8e4d5 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1002,9 +1002,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1096,9 +1094,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1192,9 +1188,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1286,9 +1280,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1385,9 +1377,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1471,9 +1461,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1541,9 +1529,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1622,9 +1608,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1698,9 +1682,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1778,9 +1760,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1870,9 +1850,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf(err.Error()) } - cdb3, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb3.Close() }) + cdb3 := channeldb.OpenForTesting(t, tempPath) s3, err := initSwitchWithDB(testStartingHeight, cdb3) require.NoError(t, err, "unable reinit switch") @@ -3827,9 +3805,7 @@ func newInterceptableSwitchTestContext( tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -4914,9 +4890,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err) - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err) @@ -4990,9 +4964,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { err = cdb.Close() require.NoError(t, err) - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err) - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err) @@ -5130,9 +5102,7 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err) - defer cdb.Close() + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err) @@ -5471,9 +5441,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err) - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err) diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 02993d3bf..0f4b28fb8 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -251,21 +251,8 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, return nil, nil, err } - dbAlice, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbAlice.Close()) - }) - - dbBob, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbBob.Close()) - }) + dbAlice := channeldb.OpenForTesting(t, t.TempDir()) + dbBob := channeldb.OpenForTesting(t, t.TempDir()) estimator := chainfee.NewStaticEstimator(6000, 0) feePerKw, err := estimator.EstimateFeePerKW(1) @@ -403,11 +390,7 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, switch err { case nil: case kvdb.ErrDatabaseNotOpen: - dbAlice, err = channeldb.Open(dbAlice.Path()) - if err != nil { - return nil, errors.Errorf("unable to reopen alice "+ - "db: %v", err) - } + dbAlice = channeldb.OpenForTesting(t, dbAlice.Path()) aliceStoredChannels, err = dbAlice.ChannelStateDB(). FetchOpenChannels(aliceKeyPub) @@ -451,7 +434,7 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, switch err { case nil: case kvdb.ErrDatabaseNotOpen: - dbBob, err = channeldb.Open(dbBob.Path()) + dbBob = channeldb.OpenForTesting(t, dbBob.Path()) if err != nil { return nil, errors.Errorf("unable to reopen bob "+ "db: %v", err) diff --git a/lnwallet/test/test_interface.go b/lnwallet/test/test_interface.go index c85aa2066..c006aa2e5 100644 --- a/lnwallet/test/test_interface.go +++ b/lnwallet/test/test_interface.go @@ -311,16 +311,14 @@ func loadTestCredits(miner *rpctest.Harness, w *lnwallet.LightningWallet, // createTestWallet creates a test LightningWallet will a total of 20BTC // available for funding channels. -func createTestWallet(tempTestDir string, miningNode *rpctest.Harness, - netParams *chaincfg.Params, notifier chainntnfs.ChainNotifier, - wc lnwallet.WalletController, keyRing keychain.SecretKeyRing, - signer input.Signer, bio lnwallet.BlockChainIO) (*lnwallet.LightningWallet, error) { +func createTestWallet(t *testing.T, tempTestDir string, + miningNode *rpctest.Harness, netParams *chaincfg.Params, + notifier chainntnfs.ChainNotifier, wc lnwallet.WalletController, + keyRing keychain.SecretKeyRing, signer input.Signer, + bio lnwallet.BlockChainIO) *lnwallet.LightningWallet { dbDir := filepath.Join(tempTestDir, "cdb") - fullDB, err := channeldb.Open(dbDir) - if err != nil { - return nil, err - } + fullDB := channeldb.OpenForTesting(t, dbDir) cfg := lnwallet.Config{ Database: fullDB.ChannelStateDB(), @@ -335,20 +333,18 @@ func createTestWallet(tempTestDir string, miningNode *rpctest.Harness, } wallet, err := lnwallet.NewLightningWallet(cfg) - if err != nil { - return nil, err - } + require.NoError(t, err) - if err := wallet.Startup(); err != nil { - return nil, err - } + require.NoError(t, wallet.Startup()) + + t.Cleanup(func() { + require.NoError(t, wallet.Shutdown()) + }) // Load our test wallet with 20 outputs each holding 4BTC. - if err := loadTestCredits(miningNode, wallet, 20, 4); err != nil { - return nil, err - } + require.NoError(t, loadTestCredits(miningNode, wallet, 20, 4)) - return wallet, nil + return wallet } func testGetRecoveryInfo(miner *rpctest.Harness, @@ -3206,9 +3202,7 @@ func TestLightningWallet(t *testing.T, targetBackEnd string) { rpcConfig := miningNode.RPCConfig() - tempDir := t.TempDir() - db, err := channeldb.Open(tempDir) - require.NoError(t, err, "unable to create db") + db := channeldb.OpenForTesting(t, t.TempDir()) testCfg := channeldb.CacheConfig{ QueryDisable: false, } @@ -3450,20 +3444,16 @@ func runTests(t *testing.T, walletDriver *lnwallet.WalletDriver, } // Funding via 20 outputs with 4BTC each. - alice, err := createTestWallet( - tempTestDirAlice, miningNode, netParams, + alice := createTestWallet( + t, tempTestDirAlice, miningNode, netParams, chainNotifier, aliceWalletController, aliceKeyRing, aliceSigner, bio, ) - require.NoError(t, err, "unable to create test ln wallet") - defer alice.Shutdown() - bob, err := createTestWallet( - tempTestDirBob, miningNode, netParams, + bob := createTestWallet( + t, tempTestDirBob, miningNode, netParams, chainNotifier, bobWalletController, bobKeyRing, bobSigner, bio, ) - require.NoError(t, err, "unable to create test ln wallet") - defer bob.Shutdown() // Both wallets should now have 80BTC available for // spending. diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index 225323296..ff9adfbd7 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -243,21 +243,8 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, return nil, nil, err } - dbAlice, err := channeldb.Open(t.TempDir(), dbModifiers...) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbAlice.Close()) - }) - - dbBob, err := channeldb.Open(t.TempDir(), dbModifiers...) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbBob.Close()) - }) + dbAlice := channeldb.OpenForTesting(t, t.TempDir(), dbModifiers...) + dbBob := channeldb.OpenForTesting(t, t.TempDir(), dbModifiers...) estimator := chainfee.NewStaticEstimator(6000, 0) feePerKw, err := estimator.EstimateFeePerKW(1) diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index 791277296..135d1866b 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -914,11 +914,8 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp ) // Create temporary databases. - dbRemote, err := channeldb.Open(t.TempDir()) - require.NoError(t, err) - - dbLocal, err := channeldb.Open(t.TempDir()) - require.NoError(t, err) + dbRemote := channeldb.OpenForTesting(t, t.TempDir()) + dbLocal := channeldb.OpenForTesting(t, t.TempDir()) // Create the initial commitment transactions for the channel. feePerKw := chainfee.SatPerKWeight(feeRate) diff --git a/peer/test_utils.go b/peer/test_utils.go index 8f60d08b9..eb510a53b 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -203,13 +203,7 @@ func createTestPeerWithChannel(t *testing.T, updateChan func(a, return nil, err } - dbBob, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, err - } - t.Cleanup(func() { - require.NoError(t, dbBob.Close()) - }) + dbBob := channeldb.OpenForTesting(t, t.TempDir()) feePerKw, err := estimator.EstimateFeePerKW(1) if err != nil { @@ -624,11 +618,7 @@ func createTestPeer(t *testing.T) *peerTestCtx { dbAliceGraph, err := graphdb.NewChannelGraph(graphBackend) require.NoError(t, err) - dbAliceChannel, err := channeldb.Open(dbPath) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, dbAliceChannel.Close()) - }) + dbAliceChannel := channeldb.OpenForTesting(t, dbPath) nodeSignerAlice := netann.NewNodeSigner(aliceKeySigner) diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 2baad92f1..5fb271afa 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -47,16 +47,13 @@ var ( func TestControlTowerSubscribeUnknown(t *testing.T) { t.Parallel() - db, err := initDB(t, false) - require.NoError(t, err, "unable to init db") + db := initDB(t, false) pControl := NewControlTower(channeldb.NewPaymentControl(db)) // Subscription should fail when the payment is not known. - _, err = pControl.SubscribePayment(lntypes.Hash{1}) - if err != channeldb.ErrPaymentNotInitiated { - t.Fatal("expected subscribe to fail for unknown payment") - } + _, err := pControl.SubscribePayment(lntypes.Hash{1}) + require.ErrorIs(t, err, channeldb.ErrPaymentNotInitiated) } // TestControlTowerSubscribeSuccess tests that payment updates for a @@ -64,8 +61,7 @@ func TestControlTowerSubscribeUnknown(t *testing.T) { func TestControlTowerSubscribeSuccess(t *testing.T) { t.Parallel() - db, err := initDB(t, false) - require.NoError(t, err, "unable to init db") + db := initDB(t, false) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -184,8 +180,7 @@ func TestPaymentControlSubscribeFail(t *testing.T) { func TestPaymentControlSubscribeAllSuccess(t *testing.T) { t.Parallel() - db, err := initDB(t, true) - require.NoError(t, err, "unable to init db: %v") + db := initDB(t, true) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -298,8 +293,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { func TestPaymentControlSubscribeAllImmediate(t *testing.T) { t.Parallel() - db, err := initDB(t, true) - require.NoError(t, err, "unable to init db: %v") + db := initDB(t, true) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -336,8 +330,7 @@ func TestPaymentControlSubscribeAllImmediate(t *testing.T) { func TestPaymentControlUnsubscribeSuccess(t *testing.T) { t.Parallel() - db, err := initDB(t, true) - require.NoError(t, err, "unable to init db: %v") + db := initDB(t, true) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -406,8 +399,7 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) { func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, keepFailedPaymentAttempts bool) { - db, err := initDB(t, keepFailedPaymentAttempts) - require.NoError(t, err, "unable to init db") + db := initDB(t, keepFailedPaymentAttempts) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -525,17 +517,12 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, } } -func initDB(t *testing.T, keepFailedPaymentAttempts bool) (*channeldb.DB, error) { - db, err := channeldb.Open( - t.TempDir(), channeldb.OptionKeepFailedPaymentAttempts( +func initDB(t *testing.T, keepFailedPaymentAttempts bool) *channeldb.DB { + return channeldb.OpenForTesting( + t, t.TempDir(), channeldb.OptionKeepFailedPaymentAttempts( keepFailedPaymentAttempts, ), ) - if err != nil { - return nil, err - } - - return db, err } func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo, diff --git a/rpcserver_test.go b/rpcserver_test.go index 53ec6d0ac..b4b66e719 100644 --- a/rpcserver_test.go +++ b/rpcserver_test.go @@ -41,11 +41,7 @@ func (m *mockDataParser) InlineParseCustomData(msg proto.Message) error { func TestAuxDataParser(t *testing.T) { // We create an empty channeldb, so we can fetch some channels. - cdb, err := channeldb.Open(t.TempDir()) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, cdb.Close()) - }) + cdb := channeldb.OpenForTesting(t, t.TempDir()) r := &rpcServer{ server: &server{