diff --git a/lnpeer/mock_peer.go b/lnpeer/mock_peer.go new file mode 100644 index 000000000..a8953092c --- /dev/null +++ b/lnpeer/mock_peer.go @@ -0,0 +1,82 @@ +package lnpeer + +import ( + "net" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/mock" +) + +// MockPeer implements the `lnpeer.Peer` interface. +type MockPeer struct { + mock.Mock +} + +// Compile time assertion that MockPeer implements lnpeer.Peer. +var _ Peer = (*MockPeer)(nil) + +func (m *MockPeer) SendMessage(sync bool, msgs ...lnwire.Message) error { + args := m.Called(sync, msgs) + return args.Error(0) +} + +func (m *MockPeer) SendMessageLazy(sync bool, msgs ...lnwire.Message) error { + args := m.Called(sync, msgs) + return args.Error(0) +} + +func (m *MockPeer) AddNewChannel(channel *channeldb.OpenChannel, + cancel <-chan struct{}) error { + + args := m.Called(channel, cancel) + return args.Error(0) +} + +func (m *MockPeer) AddPendingChannel(cid lnwire.ChannelID, + cancel <-chan struct{}) error { + + args := m.Called(cid, cancel) + return args.Error(0) +} + +func (m *MockPeer) RemovePendingChannel(cid lnwire.ChannelID) error { + args := m.Called(cid) + return args.Error(0) +} + +func (m *MockPeer) WipeChannel(op *wire.OutPoint) { + m.Called(op) +} + +func (m *MockPeer) PubKey() [33]byte { + args := m.Called() + return args.Get(0).([33]byte) +} + +func (m *MockPeer) IdentityKey() *btcec.PublicKey { + args := m.Called() + return args.Get(0).(*btcec.PublicKey) +} + +func (m *MockPeer) Address() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} + +func (m *MockPeer) QuitSignal() <-chan struct{} { + args := m.Called() + return args.Get(0).(<-chan struct{}) +} + +func (m *MockPeer) LocalFeatures() *lnwire.FeatureVector { + args := m.Called() + return args.Get(0).(*lnwire.FeatureVector) +} + +func (m *MockPeer) RemoteFeatures() *lnwire.FeatureVector { + args := m.Called() + return args.Get(0).(*lnwire.FeatureVector) +} diff --git a/peer/brontide_test.go b/peer/brontide_test.go index 4b0929574..428ea934e 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chancloser" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/pool" @@ -1171,3 +1172,171 @@ func TestUpdateNextRevocation(t *testing.T) { // TODO(yy): add test for `addActiveChannel` and `handleNewActiveChannel` once // we have interfaced `lnwallet.LightningChannel` and // `*contractcourt.ChainArbitrator`. + +// TestHandleNewPendingChannel checks the method `handleNewPendingChannel` +// behaves as expected. +func TestHandleNewPendingChannel(t *testing.T) { + t.Parallel() + + // Create three channel IDs for testing. + chanIDActive := lnwire.ChannelID{0} + chanIDNotExist := lnwire.ChannelID{1} + chanIDPending := lnwire.ChannelID{2} + + // Create a test brontide. + dummyConfig := Config{} + peer := NewBrontide(dummyConfig) + + // Create the test state. + peer.activeChannels.Store(chanIDActive, &lnwallet.LightningChannel{}) + peer.activeChannels.Store(chanIDPending, nil) + + // Assert test state, we should have two channels store, one active and + // one pending. + require.Equal(t, 2, peer.activeChannels.Len()) + + testCases := []struct { + name string + chanID lnwire.ChannelID + + // expectChanAdded specifies whether this chanID will be added + // to the peer's state. + expectChanAdded bool + }{ + { + name: "noop on active channel", + chanID: chanIDActive, + expectChanAdded: false, + }, + { + name: "noop on pending channel", + chanID: chanIDPending, + expectChanAdded: false, + }, + { + name: "new channel should be added", + chanID: chanIDNotExist, + expectChanAdded: true, + }, + } + + for _, tc := range testCases { + tc := tc + + // Create a request for testing. + errChan := make(chan error, 1) + req := &newChannelMsg{ + channelID: tc.chanID, + err: errChan, + } + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Get the number of channels before mutating the + // state. + numChans := peer.activeChannels.Len() + + // Call the method. + peer.handleNewPendingChannel(req) + + // Add one if we expect this channel to be added. + if tc.expectChanAdded { + numChans++ + } + + // Assert the number of channels is correct. + require.Equal(numChans, peer.activeChannels.Len()) + + // Assert the request's error chan is closed. + err, ok := <-req.err + require.False(ok, "expect err chan to be closed") + require.NoError(err, "expect no error") + }) + } +} + +// TestHandleRemovePendingChannel checks the method +// `handleRemovePendingChannel` behaves as expected. +func TestHandleRemovePendingChannel(t *testing.T) { + t.Parallel() + + // Create three channel IDs for testing. + chanIDActive := lnwire.ChannelID{0} + chanIDNotExist := lnwire.ChannelID{1} + chanIDPending := lnwire.ChannelID{2} + + // Create a test brontide. + dummyConfig := Config{} + peer := NewBrontide(dummyConfig) + + // Create the test state. + peer.activeChannels.Store(chanIDActive, &lnwallet.LightningChannel{}) + peer.activeChannels.Store(chanIDPending, nil) + + // Assert test state, we should have two channels store, one active and + // one pending. + require.Equal(t, 2, peer.activeChannels.Len()) + + testCases := []struct { + name string + chanID lnwire.ChannelID + + // expectDeleted specifies whether this chanID will be removed + // from the peer's state. + expectDeleted bool + }{ + { + name: "noop on active channel", + chanID: chanIDActive, + expectDeleted: false, + }, + { + name: "pending channel should be removed", + chanID: chanIDPending, + expectDeleted: true, + }, + { + name: "noop on non-exist channel", + chanID: chanIDNotExist, + expectDeleted: false, + }, + } + + for _, tc := range testCases { + tc := tc + + // Create a request for testing. + errChan := make(chan error, 1) + req := &newChannelMsg{ + channelID: tc.chanID, + err: errChan, + } + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Get the number of channels before mutating the + // state. + numChans := peer.activeChannels.Len() + + // Call the method. + peer.handleRemovePendingChannel(req) + + // Minus one if we expect this channel to be removed. + if tc.expectDeleted { + numChans-- + } + + // Assert the number of channels is correct. + require.Equal(numChans, peer.activeChannels.Len()) + + // Assert the request's error chan is closed. + err, ok := <-req.err + require.False(ok, "expect err chan to be closed") + require.NoError(err, "expect no error") + }) + } +}