diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 0c2b0c2c5..7273a7c73 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -337,11 +337,10 @@ func TestCancelInvoice(t *testing.T) { func TestSettleHoldInvoice(t *testing.T) { defer timeout()() - cdb, cleanup, err := newTestChannelDB(clock.NewTestClock(time.Time{})) + cdb, err := newTestChannelDB(t, clock.NewTestClock(time.Time{})) if err != nil { t.Fatal(err) } - defer cleanup() // Instantiate and start the invoice ctx.registry. cfg := RegistryConfig{ @@ -510,11 +509,10 @@ func TestSettleHoldInvoice(t *testing.T) { func TestCancelHoldInvoice(t *testing.T) { defer timeout()() - cdb, cleanup, err := newTestChannelDB(clock.NewTestClock(time.Time{})) + cdb, err := newTestChannelDB(t, clock.NewTestClock(time.Time{})) if err != nil { t.Fatal(err) } - defer cleanup() // Instantiate and start the invoice ctx.registry. cfg := RegistryConfig{ @@ -935,9 +933,7 @@ func TestMppPayment(t *testing.T) { func TestInvoiceExpiryWithRegistry(t *testing.T) { t.Parallel() - cdb, cleanup, err := newTestChannelDB(clock.NewTestClock(time.Time{})) - defer cleanup() - + cdb, err := newTestChannelDB(t, clock.NewTestClock(time.Time{})) if err != nil { t.Fatal(err) } @@ -1043,9 +1039,7 @@ func TestOldInvoiceRemovalOnStart(t *testing.T) { t.Parallel() testClock := clock.NewTestClock(testTime) - cdb, cleanup, err := newTestChannelDB(testClock) - defer cleanup() - + cdb, err := newTestChannelDB(t, testClock) require.NoError(t, err) cfg := RegistryConfig{ diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index b1bdafeff..d158d2e9e 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "encoding/hex" "fmt" - "io/ioutil" "os" "runtime/pprof" "sync" @@ -162,29 +161,20 @@ var ( } ) -func newTestChannelDB(clock clock.Clock) (*channeldb.DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. +func newTestChannelDB(t *testing.T, clock clock.Clock) (*channeldb.DB, error) { + // Create channeldb for the first time. cdb, err := channeldb.Open( - tempDirName, channeldb.OptionClock(clock), + t.TempDir(), channeldb.OptionClock(clock), ) if err != nil { - os.RemoveAll(tempDirName) - return nil, nil, err + return nil, err } - cleanUp := func() { + t.Cleanup(func() { cdb.Close() - os.RemoveAll(tempDirName) - } + }) - return cdb, cleanUp, nil + return cdb, nil } type testContext struct { @@ -200,7 +190,7 @@ type testContext struct { func newTestContext(t *testing.T) *testContext { clock := clock.NewTestClock(testTime) - cdb, cleanup, err := newTestChannelDB(clock) + cdb, err := newTestChannelDB(t, clock) if err != nil { t.Fatal(err) } @@ -221,7 +211,6 @@ func newTestContext(t *testing.T) *testContext { err = registry.Start() if err != nil { - cleanup() t.Fatal(err) } @@ -235,7 +224,6 @@ func newTestContext(t *testing.T) *testContext { if err = registry.Stop(); err != nil { t.Fatalf("failed to stop invoice registry: %v", err) } - cleanup() }, }