contractcourt: use T.TempDir to create temporary test directory

Signed-off-by: Eng Zer Jun <engzerjun@gmail.com>
This commit is contained in:
Eng Zer Jun 2022-08-15 21:05:40 +08:00
parent 5d5ad9ce65
commit d1541d6628
No known key found for this signature in database
GPG Key ID: DAEBBD2E34C111E6
6 changed files with 60 additions and 152 deletions

View File

@ -10,10 +10,8 @@ import (
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"os"
"reflect"
"sync"
"testing"
@ -639,25 +637,13 @@ func TestMockRetributionStore(t *testing.T) {
}
}
func makeTestChannelDB() (*channeldb.DB, func(), error) {
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName, err := ioutil.TempDir("", "channeldb")
func makeTestChannelDB(t *testing.T) (*channeldb.DB, error) {
db, err := channeldb.Open(t.TempDir())
if err != nil {
return nil, nil, err
return nil, err
}
cleanUp := func() {
os.RemoveAll(tempDirName)
}
db, err := channeldb.Open(tempDirName)
if err != nil {
cleanUp()
return nil, nil, err
}
return db, cleanUp, nil
return db, nil
}
// TestChannelDBRetributionStore instantiates a retributionStore backed by a
@ -670,12 +656,11 @@ func TestChannelDBRetributionStore(t *testing.T) {
t.Run(
"channeldbDBRetributionStore."+test.name,
func(tt *testing.T) {
db, cleanUp, err := makeTestChannelDB()
db, err := makeTestChannelDB(t)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
defer db.Close()
defer cleanUp()
restartDb := func() RetributionStorer {
// Close and reopen channeldb
@ -976,7 +961,7 @@ func initBreachedState(t *testing.T) (*BreachArbiter,
// Create a pair of channels using a notifier that allows us to signal
// a spend of the funding transaction. Alice's channel will be the on
// observing a breach.
alice, bob, cleanUpChans, err := createInitChannels(1)
alice, bob, cleanUpChans, err := createInitChannels(t, 1)
require.NoError(t, err, "unable to create test channels")
// Instantiate a breach arbiter to handle the breach of alice's channel.
@ -2177,8 +2162,7 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent,
// createInitChannels creates two initialized test channels funded with 10 BTC,
// with 5 BTC allocated to each side. Within the channel, Alice is the
// initiator.
func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(), error) {
func createInitChannels(t *testing.T, revocationWindow int) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(), error) {
aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(
channels.AlicesPrivKey,
)
@ -2285,22 +2269,12 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
return nil, nil, nil, err
}
alicePath, err := ioutil.TempDir("", "alicedb")
dbAlice, err := channeldb.Open(t.TempDir())
if err != nil {
return nil, nil, nil, err
}
dbAlice, err := channeldb.Open(alicePath)
if err != nil {
return nil, nil, nil, err
}
bobPath, err := ioutil.TempDir("", "bobdb")
if err != nil {
return nil, nil, nil, err
}
dbBob, err := channeldb.Open(bobPath)
dbBob, err := channeldb.Open(t.TempDir())
if err != nil {
return nil, nil, nil, err
}
@ -2418,8 +2392,6 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
cleanUpFunc := func() {
dbBob.Close()
dbAlice.Close()
os.RemoveAll(bobPath)
os.RemoveAll(alicePath)
}
// Now that the channel are open, simulate the start of a session by

View File

@ -2,9 +2,7 @@ package contractcourt
import (
"crypto/rand"
"io/ioutil"
prand "math/rand"
"os"
"reflect"
"testing"
"time"
@ -146,36 +144,28 @@ var (
}
)
func makeTestDB() (kvdb.Backend, func(), error) {
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName, err := ioutil.TempDir("", "arblog")
if err != nil {
return nil, nil, err
}
func makeTestDB(t *testing.T) (kvdb.Backend, error) {
db, err := kvdb.Create(
kvdb.BoltBackendName, tempDirName+"/test.db", true,
kvdb.BoltBackendName, t.TempDir()+"/test.db", true,
kvdb.DefaultDBTimeout,
)
if err != nil {
return nil, nil, err
return nil, err
}
cleanUp := func() {
t.Cleanup(func() {
db.Close()
os.RemoveAll(tempDirName)
}
})
return db, cleanUp, nil
return db, nil
}
func newTestBoltArbLog(chainhash chainhash.Hash,
op wire.OutPoint) (ArbitratorLog, func(), error) {
func newTestBoltArbLog(t *testing.T, chainhash chainhash.Hash,
op wire.OutPoint) (ArbitratorLog, error) {
testDB, cleanUp, err := makeTestDB()
testDB, err := makeTestDB(t)
if err != nil {
return nil, nil, err
return nil, err
}
testArbCfg := ChannelArbitratorConfig{
@ -187,10 +177,10 @@ func newTestBoltArbLog(chainhash chainhash.Hash,
}
testLog, err := newBoltArbitratorLog(testDB, testArbCfg, chainhash, op)
if err != nil {
return nil, nil, err
return nil, err
}
return testLog, cleanUp, err
return testLog, err
}
func randOutPoint() wire.OutPoint {
@ -304,11 +294,10 @@ func TestContractInsertionRetrieval(t *testing.T) {
// First, we'll create a test instance of the ArbitratorLog
// implementation backed by boltdb.
testLog, cleanUp, err := newTestBoltArbLog(
testChainHash, testChanPoint1,
testLog, err := newTestBoltArbLog(
t, testChainHash, testChanPoint1,
)
require.NoError(t, err, "unable to create test log")
defer cleanUp()
// The log created, we'll create a series of resolvers, each properly
// implementing the ContractResolver interface.
@ -432,11 +421,10 @@ func TestContractResolution(t *testing.T) {
// First, we'll create a test instance of the ArbitratorLog
// implementation backed by boltdb.
testLog, cleanUp, err := newTestBoltArbLog(
testChainHash, testChanPoint1,
testLog, err := newTestBoltArbLog(
t, testChainHash, testChanPoint1,
)
require.NoError(t, err, "unable to create test log")
defer cleanUp()
// We'll now create a timeout resolver that we'll be using for the
// duration of this test.
@ -486,11 +474,10 @@ func TestContractSwapping(t *testing.T) {
// First, we'll create a test instance of the ArbitratorLog
// implementation backed by boltdb.
testLog, cleanUp, err := newTestBoltArbLog(
testChainHash, testChanPoint1,
testLog, err := newTestBoltArbLog(
t, testChainHash, testChanPoint1,
)
require.NoError(t, err, "unable to create test log")
defer cleanUp()
// We'll create two resolvers, a regular timeout resolver, and the
// contest resolver that eventually turns into the timeout resolver.
@ -543,11 +530,10 @@ func TestContractResolutionsStorage(t *testing.T) {
// First, we'll create a test instance of the ArbitratorLog
// implementation backed by boltdb.
testLog, cleanUp, err := newTestBoltArbLog(
testChainHash, testChanPoint1,
testLog, err := newTestBoltArbLog(
t, testChainHash, testChanPoint1,
)
require.NoError(t, err, "unable to create test log")
defer cleanUp()
// With the test log created, we'll now craft a contact resolution that
// will be using for the duration of this test.
@ -659,11 +645,10 @@ func TestContractResolutionsStorage(t *testing.T) {
func TestStateMutation(t *testing.T) {
t.Parallel()
testLog, cleanUp, err := newTestBoltArbLog(
testChainHash, testChanPoint1,
testLog, err := newTestBoltArbLog(
t, testChainHash, testChanPoint1,
)
require.NoError(t, err, "unable to create test log")
defer cleanUp()
// The default state of an arbitrator should be StateDefault.
arbState, err := testLog.CurrentState(nil)
@ -707,17 +692,15 @@ func TestScopeIsolation(t *testing.T) {
// We'll create two distinct test logs. Each log will have a unique
// scope key, and therefore should be isolated from the other on disk.
testLog1, cleanUp1, err := newTestBoltArbLog(
testChainHash, testChanPoint1,
testLog1, err := newTestBoltArbLog(
t, testChainHash, testChanPoint1,
)
require.NoError(t, err, "unable to create test log")
defer cleanUp1()
testLog2, cleanUp2, err := newTestBoltArbLog(
testChainHash, testChanPoint2,
testLog2, err := newTestBoltArbLog(
t, testChainHash, testChanPoint2,
)
require.NoError(t, err, "unable to create test log")
defer cleanUp2()
// We'll now update the current state of both the logs to a unique
// state.
@ -754,11 +737,10 @@ func TestScopeIsolation(t *testing.T) {
func TestCommitSetStorage(t *testing.T) {
t.Parallel()
testLog, cleanUp, err := newTestBoltArbLog(
testChainHash, testChanPoint1,
testLog, err := newTestBoltArbLog(
t, testChainHash, testChanPoint1,
)
require.NoError(t, err, "unable to create test log")
defer cleanUp()
activeHTLCs := []channeldb.HTLC{
{

View File

@ -1,9 +1,7 @@
package contractcourt
import (
"io/ioutil"
"net"
"os"
"testing"
"github.com/btcsuite/btcd/chaincfg/chainhash"
@ -22,13 +20,7 @@ import (
func TestChainArbitratorRepublishCloses(t *testing.T) {
t.Parallel()
tempPath, err := ioutil.TempDir("", "testdb")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempPath)
db, err := channeldb.Open(tempPath)
db, err := channeldb.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
@ -143,12 +135,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) {
func TestResolveContract(t *testing.T) {
t.Parallel()
// To start with, we'll create a new temp DB for the duration of this
// test.
tempPath, err := ioutil.TempDir("", "testdb")
require.NoError(t, err, "unable to make temp dir")
defer os.RemoveAll(tempPath)
db, err := channeldb.Open(tempPath)
db, err := channeldb.Open(t.TempDir())
require.NoError(t, err, "unable to open db")
defer db.Close()

View File

@ -200,27 +200,19 @@ type dlpTestCase struct {
// state) are returned.
func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi,
aliceChannel, bobChannel *lnwallet.LightningChannel,
numUpdates uint8) ([]*channeldb.OpenChannel, func(), error) {
numUpdates uint8) ([]*channeldb.OpenChannel, error) {
// We'll make a copy of the channel state before each transition.
var (
chanStates []*channeldb.OpenChannel
cleanupFuncs []func()
chanStates []*channeldb.OpenChannel
)
cleanAll := func() {
for _, f := range cleanupFuncs {
f()
}
}
state, f, err := copyChannelState(aliceChannel.State())
state, err := copyChannelState(t, aliceChannel.State())
if err != nil {
return nil, nil, err
return nil, err
}
chanStates = append(chanStates, state)
cleanupFuncs = append(cleanupFuncs, f)
for i := 0; i < int(numUpdates); i++ {
addFakeHTLC(
@ -229,21 +221,18 @@ func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi,
err := lnwallet.ForceStateTransition(aliceChannel, bobChannel)
if err != nil {
cleanAll()
return nil, nil, err
return nil, err
}
state, f, err := copyChannelState(aliceChannel.State())
state, err := copyChannelState(t, aliceChannel.State())
if err != nil {
cleanAll()
return nil, nil, err
return nil, err
}
chanStates = append(chanStates, state)
cleanupFuncs = append(cleanupFuncs, f)
}
return chanStates, cleanAll, nil
return chanStates, nil
}
// TestChainWatcherDataLossProtect tests that if we've lost data (and are
@ -278,7 +267,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
// new HTLC to add to the commitment, and then lock in a state
// transition.
const htlcAmt = 1000
states, cleanStates, err := executeStateTransitions(
states, err := executeStateTransitions(
t, htlcAmt, aliceChannel, bobChannel,
testCase.BroadcastStateNum,
)
@ -287,7 +276,6 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
"transition: %v", err)
return false
}
defer cleanStates()
// We'll use the state this test case wants Alice to start at.
aliceChanState := states[testCase.NumUpdates]
@ -455,7 +443,7 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
// get more coverage of various state hint encodings beyond 0
// and 1.
const htlcAmt = 1000
states, cleanStates, err := executeStateTransitions(
states, err := executeStateTransitions(
t, htlcAmt, aliceChannel, bobChannel, numUpdates,
)
if err != nil {
@ -463,7 +451,6 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
"transition: %v", err)
return false
}
defer cleanStates()
// We'll use the state this test case wants Alice to start at.
aliceChanState := states[localState]

View File

@ -3,8 +3,6 @@ package contractcourt
import (
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"sort"
@ -406,11 +404,7 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog,
var cleanUp func()
if log == nil {
dbDir, err := ioutil.TempDir("", "chanArb")
if err != nil {
return nil, err
}
dbPath := filepath.Join(dbDir, "testdb")
dbPath := filepath.Join(t.TempDir(), "testdb")
db, err := kvdb.Create(
kvdb.BoltBackendName, dbPath, true,
kvdb.DefaultDBTimeout,
@ -427,7 +421,6 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog,
}
cleanUp = func() {
db.Close()
os.RemoveAll(dbDir)
}
log = &testArbLog{

View File

@ -3,7 +3,6 @@ package contractcourt
import (
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"runtime/pprof"
@ -52,47 +51,35 @@ func copyFile(dest, src string) error {
}
// copyChannelState copies the OpenChannel state by copying the database and
// creating a new struct from it. The copied state and a cleanup function are
// returned.
func copyChannelState(state *channeldb.OpenChannel) (
*channeldb.OpenChannel, func(), error) {
// creating a new struct from it. The copied state is returned.
func copyChannelState(t *testing.T, state *channeldb.OpenChannel) (
*channeldb.OpenChannel, error) {
// Make a copy of the DB.
dbFile := filepath.Join(state.Db.GetParentDB().Path(), "channel.db")
tempDbPath, err := ioutil.TempDir("", "past-state")
if err != nil {
return nil, nil, err
}
cleanup := func() {
os.RemoveAll(tempDbPath)
}
tempDbPath := t.TempDir()
tempDbFile := filepath.Join(tempDbPath, "channel.db")
err = copyFile(tempDbFile, dbFile)
err := copyFile(tempDbFile, dbFile)
if err != nil {
cleanup()
return nil, nil, err
return nil, err
}
newDb, err := channeldb.Open(tempDbPath)
if err != nil {
cleanup()
return nil, nil, err
return nil, err
}
chans, err := newDb.ChannelStateDB().FetchAllChannels()
if err != nil {
cleanup()
return nil, nil, err
return nil, err
}
// We only support DBs with a single channel, for now.
if len(chans) != 1 {
cleanup()
return nil, nil, fmt.Errorf("found %d chans in the db",
return nil, fmt.Errorf("found %d chans in the db",
len(chans))
}
return chans[0], cleanup, nil
return chans[0], nil
}