diff --git a/peer/brontide_test.go b/peer/brontide_test.go index a0a354c1c..44e63267b 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -43,11 +43,10 @@ func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) { mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, cleanUp, err := createTestPeer( - notifier, broadcastTxChan, noUpdate, mockSwitch, + alicePeer, bobChan, err := createTestPeer( + t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channels") - defer cleanUp() chanID := lnwire.NewChanIDFromOutPoint(bobChan.ChannelPoint()) @@ -147,11 +146,10 @@ func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) { mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, cleanUp, err := createTestPeer( - notifier, broadcastTxChan, noUpdate, mockSwitch, + alicePeer, bobChan, err := createTestPeer( + t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channels") - defer cleanUp() chanID := lnwire.NewChanIDFromOutPoint(bobChan.ChannelPoint()) mockLink := newMockUpdateHandler(chanID) @@ -270,11 +268,10 @@ func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, cleanUp, err := createTestPeer( - notifier, broadcastTxChan, noUpdate, mockSwitch, + alicePeer, bobChan, err := createTestPeer( + t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channels") - defer cleanUp() chanID := lnwire.NewChanIDFromOutPoint(bobChan.ChannelPoint()) @@ -456,11 +453,10 @@ func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) { mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, cleanUp, err := createTestPeer( - notifier, broadcastTxChan, noUpdate, mockSwitch, + alicePeer, bobChan, err := createTestPeer( + t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channels") - defer cleanUp() chanID := lnwire.NewChanIDFromOutPoint(bobChan.ChannelPoint()) mockLink := newMockUpdateHandler(chanID) @@ -784,14 +780,13 @@ func TestCustomShutdownScript(t *testing.T) { mockSwitch := &mockMessageSwitch{} // Open a channel. - alicePeer, bobChan, cleanUp, err := createTestPeer( - notifier, broadcastTxChan, test.update, + alicePeer, bobChan, err := createTestPeer( + t, notifier, broadcastTxChan, test.update, mockSwitch, ) if err != nil { t.Fatalf("unable to create test channels: %v", err) } - defer cleanUp() chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) diff --git a/peer/test_utils.go b/peer/test_utils.go index abebf5260..938c3e51c 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -5,10 +5,8 @@ import ( crand "crypto/rand" "encoding/binary" "io" - "io/ioutil" "math/rand" "net" - "os" "testing" "time" @@ -56,10 +54,10 @@ var noUpdate = func(a, b *channeldb.OpenChannel) {} // one of the nodes, together with the channel seen from both nodes. It takes // an updateChan function which can be used to modify the default values on // the channel states for each peer. -func createTestPeer(notifier chainntnfs.ChainNotifier, +func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, publTx chan *wire.MsgTx, updateChan func(a, b *channeldb.OpenChannel), mockSwitch *mockMessageSwitch) ( - *Brontide, *lnwallet.LightningChannel, func(), error) { + *Brontide, *lnwallet.LightningChannel, error) { nodeKeyLocator := keychain.KeyLocator{ Family: keychain.KeyFamilyNodeKey, @@ -141,23 +139,23 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize()) if err != nil { - return nil, nil, nil, err + return nil, nil, err } bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot) bobFirstRevoke, err := bobPreimageProducer.AtIndex(0) if err != nil { - return nil, nil, nil, err + return nil, nil, err } bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:]) aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize()) if err != nil { - return nil, nil, nil, err + return nil, nil, err } alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot) aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0) if err != nil { - return nil, nil, nil, err + return nil, nil, err } aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:]) @@ -167,33 +165,29 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, isAliceInitiator, 0, ) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - alicePath, err := ioutil.TempDir("", "alicedb") + dbAlice, err := channeldb.Open(t.TempDir()) if err != nil { - return nil, nil, nil, err + return nil, nil, err } + t.Cleanup(func() { + require.NoError(t, dbAlice.Close()) + }) - dbAlice, err := channeldb.Open(alicePath) + dbBob, err := channeldb.Open(t.TempDir()) if err != nil { - return nil, nil, nil, err - } - - bobPath, err := ioutil.TempDir("", "bobdb") - if err != nil { - return nil, nil, nil, err - } - - dbBob, err := channeldb.Open(bobPath) - if err != nil { - return nil, nil, nil, err + return nil, nil, err } + t.Cleanup(func() { + require.NoError(t, dbBob.Close()) + }) estimator := chainfee.NewStaticEstimator(12500, 0) feePerKw, err := estimator.EstimateFeePerKW(1) if err != nil { - return nil, nil, nil, err + return nil, nil, err } // TODO(roasbeef): need to factor in commit fee? @@ -218,7 +212,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, var chanIDBytes [8]byte if _, err := io.ReadFull(crand.Reader, chanIDBytes[:]); err != nil { - return nil, nil, nil, err + return nil, nil, err } shortChanID := lnwire.NewShortChanIDFromInt( @@ -269,7 +263,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, } if err := aliceChannelState.SyncPending(aliceAddr, 0); err != nil { - return nil, nil, nil, err + return nil, nil, err } bobAddr := &net.TCPAddr{ @@ -278,12 +272,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, } if err := bobChannelState.SyncPending(bobAddr, 0); err != nil { - return nil, nil, nil, err - } - - cleanUpFunc := func() { - os.RemoveAll(bobPath) - os.RemoveAll(alicePath) + return nil, nil, err } aliceSigner := &mock.SingleSigner{Privkey: aliceKeyPriv} @@ -294,18 +283,24 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, aliceSigner, aliceChannelState, alicePool, ) if err != nil { - return nil, nil, nil, err + return nil, nil, err } _ = alicePool.Start() + t.Cleanup(func() { + require.NoError(t, alicePool.Stop()) + }) bobPool := lnwallet.NewSigPool(1, bobSigner) channelBob, err := lnwallet.NewLightningChannel( bobSigner, bobChannelState, bobPool, ) if err != nil { - return nil, nil, nil, err + return nil, nil, err } _ = bobPool.Start() + t.Cleanup(func() { + require.NoError(t, bobPool.Stop()) + }) chainIO := &mock.ChainIO{ BestHeight: broadcastHeight, @@ -344,15 +339,15 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, }, }) if err != nil { - return nil, nil, nil, err + return nil, nil, err } if err = chanStatusMgr.Start(); err != nil { - return nil, nil, nil, err + return nil, nil, err } errBuffer, err := queue.NewCircularBuffer(ErrorBufferSize) if err != nil { - return nil, nil, nil, err + return nil, nil, err } var pubKey [33]byte @@ -392,7 +387,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, alicePeer.wg.Add(1) go alicePeer.channelManager() - return alicePeer, channelBob, cleanUpFunc, nil + return alicePeer, channelBob, nil } // mockMessageSwitch is a mock implementation of the messageSwitch interface