diff --git a/peer/brontide_test.go b/peer/brontide_test.go index 5772e9679..270085674 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -48,11 +48,14 @@ func TestPeerChannelClosureShutdownResponseLinkRemoved(t *testing.T) { mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, err := createTestPeerWithChannel( + harness, err := createTestPeerWithChannel( t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channels") + alicePeer := harness.peer + bobChan := harness.channel + chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -96,11 +99,14 @@ func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) { mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, err := createTestPeerWithChannel( + harness, err := createTestPeerWithChannel( t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channels") + alicePeer := harness.peer + bobChan := harness.channel + chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -201,11 +207,14 @@ func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) { mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, err := createTestPeerWithChannel( + harness, err := createTestPeerWithChannel( t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channels") + bobChan := harness.channel + alicePeer := harness.peer + chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) mockLink := newMockUpdateHandler(chanID) @@ -325,11 +334,14 @@ func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, err := createTestPeerWithChannel( + harness, err := createTestPeerWithChannel( t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channels") + bobChan := harness.channel + alicePeer := harness.peer + chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -512,11 +524,14 @@ func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) { mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, err := createTestPeerWithChannel( + harness, err := createTestPeerWithChannel( t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channels") + alicePeer := harness.peer + bobChan := harness.channel + chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) mockLink := newMockUpdateHandler(chanID) @@ -840,7 +855,7 @@ func TestCustomShutdownScript(t *testing.T) { mockSwitch := &mockMessageSwitch{} // Open a channel. - alicePeer, bobChan, err := createTestPeerWithChannel( + harness, err := createTestPeerWithChannel( t, notifier, broadcastTxChan, test.update, mockSwitch, ) @@ -848,13 +863,16 @@ func TestCustomShutdownScript(t *testing.T) { t.Fatalf("unable to create test channels: %v", err) } + alicePeer := harness.peer + bobChan := harness.channel + chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) mockLink := newMockUpdateHandler(chanID) mockSwitch.links = append(mockSwitch.links, mockLink) - // Request initiator to cooperatively close the channel, with - // a specified delivery address. + // Request initiator to cooperatively close the channel, + // with a specified delivery address. updateChan := make(chan interface{}, 1) errChan := make(chan error, 1) closeCommand := htlcswitch.ChanClose{ @@ -1195,11 +1213,14 @@ func TestUpdateNextRevocation(t *testing.T) { broadcastTxChan := make(chan *wire.MsgTx) mockSwitch := &mockMessageSwitch{} - alicePeer, bobChan, err := createTestPeerWithChannel( + harness, err := createTestPeerWithChannel( t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(err, "unable to create test channels") + bobChan := harness.channel + alicePeer := harness.peer + // testChannel is used to test the updateNextRevocation function. testChannel := bobChan.State() @@ -1432,11 +1453,13 @@ func TestStartupWriteMessageRace(t *testing.T) { // createTestPeerWithChannel creates a peer and a channel with that // peer. - peer, _, err := createTestPeerWithChannel( + harness, err := createTestPeerWithChannel( t, notifier, broadcastTxChan, getChannels, mockSwitch, ) require.NoError(t, err, "unable to create test channel") + peer := harness.peer + // Avoid the need to mock the channel graph by marking the channel // borked. Borked channels still get a reestablish message sent on // reconnect, while skipping channel graph checks and link creation. @@ -1528,11 +1551,13 @@ func TestRemovePendingChannel(t *testing.T) { // createTestPeerWithChannel creates a peer and a channel with that // peer. - peer, _, err := createTestPeerWithChannel( + harness, err := createTestPeerWithChannel( t, notifier, broadcastTxChan, noUpdate, mockSwitch, ) require.NoError(t, err, "unable to create test channel") + peer := harness.peer + // Add a pending channel to the peer Alice. errChan := make(chan error, 1) pendingChanID := lnwire.ChannelID{1} diff --git a/peer/test_utils.go b/peer/test_utils.go index 5080043a2..2f8f17bb7 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -53,6 +53,11 @@ var ( // the channels set up. var noUpdate = func(a, b *channeldb.OpenChannel) {} +type peerTestCtx struct { + peer *Brontide + channel *lnwallet.LightningChannel +} + // createTestPeerWithChannel creates a channel between two nodes, and returns a // peer for one of the nodes, together with the channel seen from both nodes. // It takes an updateChan function which can be used to modify the default @@ -60,7 +65,7 @@ var noUpdate = func(a, b *channeldb.OpenChannel) {} func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, publTx chan *wire.MsgTx, updateChan func(a, b *channeldb.OpenChannel), mockSwitch *mockMessageSwitch) ( - *Brontide, *lnwallet.LightningChannel, error) { + *peerTestCtx, error) { nodeKeyLocator := keychain.KeyLocator{ Family: keychain.KeyFamilyNodeKey, @@ -142,23 +147,23 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize()) if err != nil { - return nil, nil, err + return nil, err } bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot) bobFirstRevoke, err := bobPreimageProducer.AtIndex(0) if err != nil { - return nil, nil, err + return nil, err } bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:]) aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize()) if err != nil { - return nil, nil, err + return nil, err } alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot) aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0) if err != nil { - return nil, nil, err + return nil, err } aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:]) @@ -168,12 +173,12 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, isAliceInitiator, 0, ) if err != nil { - return nil, nil, err + return nil, err } dbAlice, err := channeldb.Open(t.TempDir()) if err != nil { - return nil, nil, err + return nil, err } t.Cleanup(func() { require.NoError(t, dbAlice.Close()) @@ -181,7 +186,7 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, dbBob, err := channeldb.Open(t.TempDir()) if err != nil { - return nil, nil, err + return nil, err } t.Cleanup(func() { require.NoError(t, dbBob.Close()) @@ -190,7 +195,7 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, estimator := chainfee.NewStaticEstimator(12500, 0) feePerKw, err := estimator.EstimateFeePerKW(1) if err != nil { - return nil, nil, err + return nil, err } // TODO(roasbeef): need to factor in commit fee? @@ -215,7 +220,7 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, var chanIDBytes [8]byte if _, err := io.ReadFull(crand.Reader, chanIDBytes[:]); err != nil { - return nil, nil, err + return nil, err } shortChanID := lnwire.NewShortChanIDFromInt( @@ -266,7 +271,7 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, } if err := aliceChannelState.SyncPending(aliceAddr, 0); err != nil { - return nil, nil, err + return nil, err } bobAddr := &net.TCPAddr{ @@ -275,7 +280,7 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, } if err := bobChannelState.SyncPending(bobAddr, 0); err != nil { - return nil, nil, err + return nil, err } aliceSigner := input.NewMockSigner( @@ -290,7 +295,7 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, aliceSigner, aliceChannelState, alicePool, ) if err != nil { - return nil, nil, err + return nil, err } _ = alicePool.Start() t.Cleanup(func() { @@ -302,7 +307,7 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, bobSigner, bobChannelState, bobPool, ) if err != nil { - return nil, nil, err + return nil, err } _ = bobPool.Start() t.Cleanup(func() { @@ -346,15 +351,15 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, }, }) if err != nil { - return nil, nil, err + return nil, err } if err = chanStatusMgr.Start(); err != nil { - return nil, nil, err + return nil, err } errBuffer, err := queue.NewCircularBuffer(ErrorBufferSize) if err != nil { - return nil, nil, err + return nil, err } var pubKey [33]byte @@ -381,7 +386,7 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, }, ) if err != nil { - return nil, nil, err + return nil, err } // TODO(yy): change ChannelNotifier to be an interface. @@ -419,7 +424,10 @@ func createTestPeerWithChannel(t *testing.T, notifier chainntnfs.ChainNotifier, alicePeer.wg.Add(1) go alicePeer.channelManager() - return alicePeer, channelBob, nil + return &peerTestCtx{ + peer: alicePeer, + channel: channelBob, + }, nil } // mockMessageSwitch is a mock implementation of the messageSwitch interface