mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-04-05 02:28:11 +02:00
contractcourt: use T.TempDir
to create temporary test directory
Signed-off-by: Eng Zer Jun <engzerjun@gmail.com>
This commit is contained in:
parent
5d5ad9ce65
commit
d1541d6628
@ -10,10 +10,8 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
@ -639,25 +637,13 @@ func TestMockRetributionStore(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestChannelDB() (*channeldb.DB, func(), error) {
|
||||
// First, create a temporary directory to be used for the duration of
|
||||
// this test.
|
||||
tempDirName, err := ioutil.TempDir("", "channeldb")
|
||||
func makeTestChannelDB(t *testing.T) (*channeldb.DB, error) {
|
||||
db, err := channeldb.Open(t.TempDir())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cleanUp := func() {
|
||||
os.RemoveAll(tempDirName)
|
||||
}
|
||||
|
||||
db, err := channeldb.Open(tempDirName)
|
||||
if err != nil {
|
||||
cleanUp()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return db, cleanUp, nil
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// TestChannelDBRetributionStore instantiates a retributionStore backed by a
|
||||
@ -670,12 +656,11 @@ func TestChannelDBRetributionStore(t *testing.T) {
|
||||
t.Run(
|
||||
"channeldbDBRetributionStore."+test.name,
|
||||
func(tt *testing.T) {
|
||||
db, cleanUp, err := makeTestChannelDB()
|
||||
db, err := makeTestChannelDB(t)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open channeldb: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
defer cleanUp()
|
||||
|
||||
restartDb := func() RetributionStorer {
|
||||
// Close and reopen channeldb
|
||||
@ -976,7 +961,7 @@ func initBreachedState(t *testing.T) (*BreachArbiter,
|
||||
// Create a pair of channels using a notifier that allows us to signal
|
||||
// a spend of the funding transaction. Alice's channel will be the on
|
||||
// observing a breach.
|
||||
alice, bob, cleanUpChans, err := createInitChannels(1)
|
||||
alice, bob, cleanUpChans, err := createInitChannels(t, 1)
|
||||
require.NoError(t, err, "unable to create test channels")
|
||||
|
||||
// Instantiate a breach arbiter to handle the breach of alice's channel.
|
||||
@ -2177,8 +2162,7 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent,
|
||||
// createInitChannels creates two initialized test channels funded with 10 BTC,
|
||||
// with 5 BTC allocated to each side. Within the channel, Alice is the
|
||||
// initiator.
|
||||
func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(), error) {
|
||||
|
||||
func createInitChannels(t *testing.T, revocationWindow int) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(), error) {
|
||||
aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(
|
||||
channels.AlicesPrivKey,
|
||||
)
|
||||
@ -2285,22 +2269,12 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
alicePath, err := ioutil.TempDir("", "alicedb")
|
||||
dbAlice, err := channeldb.Open(t.TempDir())
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
dbAlice, err := channeldb.Open(alicePath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
bobPath, err := ioutil.TempDir("", "bobdb")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
dbBob, err := channeldb.Open(bobPath)
|
||||
dbBob, err := channeldb.Open(t.TempDir())
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -2418,8 +2392,6 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
|
||||
cleanUpFunc := func() {
|
||||
dbBob.Close()
|
||||
dbAlice.Close()
|
||||
os.RemoveAll(bobPath)
|
||||
os.RemoveAll(alicePath)
|
||||
}
|
||||
|
||||
// Now that the channel are open, simulate the start of a session by
|
||||
|
@ -2,9 +2,7 @@ package contractcourt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io/ioutil"
|
||||
prand "math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
@ -146,36 +144,28 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func makeTestDB() (kvdb.Backend, func(), error) {
|
||||
// First, create a temporary directory to be used for the duration of
|
||||
// this test.
|
||||
tempDirName, err := ioutil.TempDir("", "arblog")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func makeTestDB(t *testing.T) (kvdb.Backend, error) {
|
||||
db, err := kvdb.Create(
|
||||
kvdb.BoltBackendName, tempDirName+"/test.db", true,
|
||||
kvdb.BoltBackendName, t.TempDir()+"/test.db", true,
|
||||
kvdb.DefaultDBTimeout,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cleanUp := func() {
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
os.RemoveAll(tempDirName)
|
||||
}
|
||||
})
|
||||
|
||||
return db, cleanUp, nil
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func newTestBoltArbLog(chainhash chainhash.Hash,
|
||||
op wire.OutPoint) (ArbitratorLog, func(), error) {
|
||||
func newTestBoltArbLog(t *testing.T, chainhash chainhash.Hash,
|
||||
op wire.OutPoint) (ArbitratorLog, error) {
|
||||
|
||||
testDB, cleanUp, err := makeTestDB()
|
||||
testDB, err := makeTestDB(t)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
testArbCfg := ChannelArbitratorConfig{
|
||||
@ -187,10 +177,10 @@ func newTestBoltArbLog(chainhash chainhash.Hash,
|
||||
}
|
||||
testLog, err := newBoltArbitratorLog(testDB, testArbCfg, chainhash, op)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return testLog, cleanUp, err
|
||||
return testLog, err
|
||||
}
|
||||
|
||||
func randOutPoint() wire.OutPoint {
|
||||
@ -304,11 +294,10 @@ func TestContractInsertionRetrieval(t *testing.T) {
|
||||
|
||||
// First, we'll create a test instance of the ArbitratorLog
|
||||
// implementation backed by boltdb.
|
||||
testLog, cleanUp, err := newTestBoltArbLog(
|
||||
testChainHash, testChanPoint1,
|
||||
testLog, err := newTestBoltArbLog(
|
||||
t, testChainHash, testChanPoint1,
|
||||
)
|
||||
require.NoError(t, err, "unable to create test log")
|
||||
defer cleanUp()
|
||||
|
||||
// The log created, we'll create a series of resolvers, each properly
|
||||
// implementing the ContractResolver interface.
|
||||
@ -432,11 +421,10 @@ func TestContractResolution(t *testing.T) {
|
||||
|
||||
// First, we'll create a test instance of the ArbitratorLog
|
||||
// implementation backed by boltdb.
|
||||
testLog, cleanUp, err := newTestBoltArbLog(
|
||||
testChainHash, testChanPoint1,
|
||||
testLog, err := newTestBoltArbLog(
|
||||
t, testChainHash, testChanPoint1,
|
||||
)
|
||||
require.NoError(t, err, "unable to create test log")
|
||||
defer cleanUp()
|
||||
|
||||
// We'll now create a timeout resolver that we'll be using for the
|
||||
// duration of this test.
|
||||
@ -486,11 +474,10 @@ func TestContractSwapping(t *testing.T) {
|
||||
|
||||
// First, we'll create a test instance of the ArbitratorLog
|
||||
// implementation backed by boltdb.
|
||||
testLog, cleanUp, err := newTestBoltArbLog(
|
||||
testChainHash, testChanPoint1,
|
||||
testLog, err := newTestBoltArbLog(
|
||||
t, testChainHash, testChanPoint1,
|
||||
)
|
||||
require.NoError(t, err, "unable to create test log")
|
||||
defer cleanUp()
|
||||
|
||||
// We'll create two resolvers, a regular timeout resolver, and the
|
||||
// contest resolver that eventually turns into the timeout resolver.
|
||||
@ -543,11 +530,10 @@ func TestContractResolutionsStorage(t *testing.T) {
|
||||
|
||||
// First, we'll create a test instance of the ArbitratorLog
|
||||
// implementation backed by boltdb.
|
||||
testLog, cleanUp, err := newTestBoltArbLog(
|
||||
testChainHash, testChanPoint1,
|
||||
testLog, err := newTestBoltArbLog(
|
||||
t, testChainHash, testChanPoint1,
|
||||
)
|
||||
require.NoError(t, err, "unable to create test log")
|
||||
defer cleanUp()
|
||||
|
||||
// With the test log created, we'll now craft a contact resolution that
|
||||
// will be using for the duration of this test.
|
||||
@ -659,11 +645,10 @@ func TestContractResolutionsStorage(t *testing.T) {
|
||||
func TestStateMutation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testLog, cleanUp, err := newTestBoltArbLog(
|
||||
testChainHash, testChanPoint1,
|
||||
testLog, err := newTestBoltArbLog(
|
||||
t, testChainHash, testChanPoint1,
|
||||
)
|
||||
require.NoError(t, err, "unable to create test log")
|
||||
defer cleanUp()
|
||||
|
||||
// The default state of an arbitrator should be StateDefault.
|
||||
arbState, err := testLog.CurrentState(nil)
|
||||
@ -707,17 +692,15 @@ func TestScopeIsolation(t *testing.T) {
|
||||
|
||||
// We'll create two distinct test logs. Each log will have a unique
|
||||
// scope key, and therefore should be isolated from the other on disk.
|
||||
testLog1, cleanUp1, err := newTestBoltArbLog(
|
||||
testChainHash, testChanPoint1,
|
||||
testLog1, err := newTestBoltArbLog(
|
||||
t, testChainHash, testChanPoint1,
|
||||
)
|
||||
require.NoError(t, err, "unable to create test log")
|
||||
defer cleanUp1()
|
||||
|
||||
testLog2, cleanUp2, err := newTestBoltArbLog(
|
||||
testChainHash, testChanPoint2,
|
||||
testLog2, err := newTestBoltArbLog(
|
||||
t, testChainHash, testChanPoint2,
|
||||
)
|
||||
require.NoError(t, err, "unable to create test log")
|
||||
defer cleanUp2()
|
||||
|
||||
// We'll now update the current state of both the logs to a unique
|
||||
// state.
|
||||
@ -754,11 +737,10 @@ func TestScopeIsolation(t *testing.T) {
|
||||
func TestCommitSetStorage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testLog, cleanUp, err := newTestBoltArbLog(
|
||||
testChainHash, testChanPoint1,
|
||||
testLog, err := newTestBoltArbLog(
|
||||
t, testChainHash, testChanPoint1,
|
||||
)
|
||||
require.NoError(t, err, "unable to create test log")
|
||||
defer cleanUp()
|
||||
|
||||
activeHTLCs := []channeldb.HTLC{
|
||||
{
|
||||
|
@ -1,9 +1,7 @@
|
||||
package contractcourt
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
@ -22,13 +20,7 @@ import (
|
||||
func TestChainArbitratorRepublishCloses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempPath, err := ioutil.TempDir("", "testdb")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tempPath)
|
||||
|
||||
db, err := channeldb.Open(tempPath)
|
||||
db, err := channeldb.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -143,12 +135,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) {
|
||||
func TestResolveContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// To start with, we'll create a new temp DB for the duration of this
|
||||
// test.
|
||||
tempPath, err := ioutil.TempDir("", "testdb")
|
||||
require.NoError(t, err, "unable to make temp dir")
|
||||
defer os.RemoveAll(tempPath)
|
||||
db, err := channeldb.Open(tempPath)
|
||||
db, err := channeldb.Open(t.TempDir())
|
||||
require.NoError(t, err, "unable to open db")
|
||||
defer db.Close()
|
||||
|
||||
|
@ -200,27 +200,19 @@ type dlpTestCase struct {
|
||||
// state) are returned.
|
||||
func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi,
|
||||
aliceChannel, bobChannel *lnwallet.LightningChannel,
|
||||
numUpdates uint8) ([]*channeldb.OpenChannel, func(), error) {
|
||||
numUpdates uint8) ([]*channeldb.OpenChannel, error) {
|
||||
|
||||
// We'll make a copy of the channel state before each transition.
|
||||
var (
|
||||
chanStates []*channeldb.OpenChannel
|
||||
cleanupFuncs []func()
|
||||
chanStates []*channeldb.OpenChannel
|
||||
)
|
||||
|
||||
cleanAll := func() {
|
||||
for _, f := range cleanupFuncs {
|
||||
f()
|
||||
}
|
||||
}
|
||||
|
||||
state, f, err := copyChannelState(aliceChannel.State())
|
||||
state, err := copyChannelState(t, aliceChannel.State())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chanStates = append(chanStates, state)
|
||||
cleanupFuncs = append(cleanupFuncs, f)
|
||||
|
||||
for i := 0; i < int(numUpdates); i++ {
|
||||
addFakeHTLC(
|
||||
@ -229,21 +221,18 @@ func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi,
|
||||
|
||||
err := lnwallet.ForceStateTransition(aliceChannel, bobChannel)
|
||||
if err != nil {
|
||||
cleanAll()
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state, f, err := copyChannelState(aliceChannel.State())
|
||||
state, err := copyChannelState(t, aliceChannel.State())
|
||||
if err != nil {
|
||||
cleanAll()
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chanStates = append(chanStates, state)
|
||||
cleanupFuncs = append(cleanupFuncs, f)
|
||||
}
|
||||
|
||||
return chanStates, cleanAll, nil
|
||||
return chanStates, nil
|
||||
}
|
||||
|
||||
// TestChainWatcherDataLossProtect tests that if we've lost data (and are
|
||||
@ -278,7 +267,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
|
||||
// new HTLC to add to the commitment, and then lock in a state
|
||||
// transition.
|
||||
const htlcAmt = 1000
|
||||
states, cleanStates, err := executeStateTransitions(
|
||||
states, err := executeStateTransitions(
|
||||
t, htlcAmt, aliceChannel, bobChannel,
|
||||
testCase.BroadcastStateNum,
|
||||
)
|
||||
@ -287,7 +276,6 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
|
||||
"transition: %v", err)
|
||||
return false
|
||||
}
|
||||
defer cleanStates()
|
||||
|
||||
// We'll use the state this test case wants Alice to start at.
|
||||
aliceChanState := states[testCase.NumUpdates]
|
||||
@ -455,7 +443,7 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
|
||||
// get more coverage of various state hint encodings beyond 0
|
||||
// and 1.
|
||||
const htlcAmt = 1000
|
||||
states, cleanStates, err := executeStateTransitions(
|
||||
states, err := executeStateTransitions(
|
||||
t, htlcAmt, aliceChannel, bobChannel, numUpdates,
|
||||
)
|
||||
if err != nil {
|
||||
@ -463,7 +451,6 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
|
||||
"transition: %v", err)
|
||||
return false
|
||||
}
|
||||
defer cleanStates()
|
||||
|
||||
// We'll use the state this test case wants Alice to start at.
|
||||
aliceChanState := states[localState]
|
||||
|
@ -3,8 +3,6 @@ package contractcourt
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
@ -406,11 +404,7 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog,
|
||||
|
||||
var cleanUp func()
|
||||
if log == nil {
|
||||
dbDir, err := ioutil.TempDir("", "chanArb")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbPath := filepath.Join(dbDir, "testdb")
|
||||
dbPath := filepath.Join(t.TempDir(), "testdb")
|
||||
db, err := kvdb.Create(
|
||||
kvdb.BoltBackendName, dbPath, true,
|
||||
kvdb.DefaultDBTimeout,
|
||||
@ -427,7 +421,6 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog,
|
||||
}
|
||||
cleanUp = func() {
|
||||
db.Close()
|
||||
os.RemoveAll(dbDir)
|
||||
}
|
||||
|
||||
log = &testArbLog{
|
||||
|
@ -3,7 +3,6 @@ package contractcourt
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
@ -52,47 +51,35 @@ func copyFile(dest, src string) error {
|
||||
}
|
||||
|
||||
// copyChannelState copies the OpenChannel state by copying the database and
|
||||
// creating a new struct from it. The copied state and a cleanup function are
|
||||
// returned.
|
||||
func copyChannelState(state *channeldb.OpenChannel) (
|
||||
*channeldb.OpenChannel, func(), error) {
|
||||
// creating a new struct from it. The copied state is returned.
|
||||
func copyChannelState(t *testing.T, state *channeldb.OpenChannel) (
|
||||
*channeldb.OpenChannel, error) {
|
||||
|
||||
// Make a copy of the DB.
|
||||
dbFile := filepath.Join(state.Db.GetParentDB().Path(), "channel.db")
|
||||
tempDbPath, err := ioutil.TempDir("", "past-state")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(tempDbPath)
|
||||
}
|
||||
tempDbPath := t.TempDir()
|
||||
|
||||
tempDbFile := filepath.Join(tempDbPath, "channel.db")
|
||||
err = copyFile(tempDbFile, dbFile)
|
||||
err := copyFile(tempDbFile, dbFile)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newDb, err := channeldb.Open(tempDbPath)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chans, err := newDb.ChannelStateDB().FetchAllChannels()
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We only support DBs with a single channel, for now.
|
||||
if len(chans) != 1 {
|
||||
cleanup()
|
||||
return nil, nil, fmt.Errorf("found %d chans in the db",
|
||||
return nil, fmt.Errorf("found %d chans in the db",
|
||||
len(chans))
|
||||
}
|
||||
|
||||
return chans[0], cleanup, nil
|
||||
return chans[0], nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user