diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index a9c4a330b..312a7bc1b 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -2,6 +2,7 @@ package wtclient_test import ( "encoding/binary" + "fmt" "net" "sync" "testing" @@ -76,37 +77,34 @@ func randPrivKey(t *testing.T) *btcec.PrivateKey { } type mockNet struct { - mu sync.RWMutex - connCallback func(wtserver.Peer) + mu sync.RWMutex + connCallbacks map[string]func(wtserver.Peer) } -func newMockNet(cb func(wtserver.Peer)) *mockNet { +func newMockNet() *mockNet { return &mockNet{ - connCallback: cb, + connCallbacks: make(map[string]func(peer wtserver.Peer)), } } -func (m *mockNet) Dial(network string, address string, - timeout time.Duration) (net.Conn, error) { - +func (m *mockNet) Dial(_, _ string, _ time.Duration) (net.Conn, error) { return nil, nil } -func (m *mockNet) LookupHost(host string) ([]string, error) { +func (m *mockNet) LookupHost(_ string) ([]string, error) { panic("not implemented") } -func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) { +func (m *mockNet) LookupSRV(_, _, _ string) (string, []*net.SRV, error) { panic("not implemented") } -func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) { +func (m *mockNet) ResolveTCPAddr(_, _ string) (*net.TCPAddr, error) { panic("not implemented") } func (m *mockNet) AuthDial(local keychain.SingleKeyECDH, - netAddr *lnwire.NetAddress, - dialer tor.DialFunc) (wtserver.Peer, error) { + netAddr *lnwire.NetAddress, _ tor.DialFunc) (wtserver.Peer, error) { localPk := local.PubKey() localAddr := &net.TCPAddr{ @@ -119,16 +117,31 @@ func (m *mockNet) AuthDial(local keychain.SingleKeyECDH, ) m.mu.RLock() - m.connCallback(remotePeer) - m.mu.RUnlock() + defer m.mu.RUnlock() + cb, ok := m.connCallbacks[netAddr.String()] + if !ok { + return nil, fmt.Errorf("no callback registered for this peer") + } + + cb(remotePeer) return localPeer, nil } -func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) { +func (m *mockNet) registerConnCallback(netAddr *lnwire.NetAddress, + cb func(wtserver.Peer)) { + m.mu.Lock() defer m.mu.Unlock() - m.connCallback = cb + + m.connCallbacks[netAddr.String()] = cb +} + +func (m *mockNet) removeConnCallback(netAddr *lnwire.NetAddress) { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.connCallbacks, netAddr.String()) } type mockChannel struct { @@ -416,11 +429,8 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { NoAckCreateSession: cfg.noAckCreateSession, } - server, err := wtserver.New(serverCfg) - require.NoError(t, err, "unable to create wtserver") - signer := wtmock.NewMockSigner() - mockNet := newMockNet(server.InboundPeerConnected) + mockNet := newMockNet() clientDB := wtmock.NewClientDB() clientCfg := &wtclient.Config{ @@ -442,19 +452,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { client, err := wtclient.New(clientCfg) require.NoError(t, err, "Unable to create wtclient") - err = server.Start() - require.NoError(t, err) - t.Cleanup(func() { - _ = server.Stop() - }) - - err = client.Start() - require.NoError(t, err) - t.Cleanup(client.ForceQuit) - - err = client.AddTower(towerAddr) - require.NoError(t, err) - h := &testHarness{ t: t, cfg: cfg, @@ -466,11 +463,20 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { serverAddr: towerAddr, serverDB: serverDB, serverCfg: serverCfg, - server: server, net: mockNet, channels: make(map[lnwire.ChannelID]*mockChannel), } + h.startServer() + t.Cleanup(h.stopServer) + + err = client.Start() + require.NoError(t, err) + t.Cleanup(client.ForceQuit) + + err = client.AddTower(towerAddr) + require.NoError(t, err) + h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) if !cfg.noRegisterChan0 { h.registerChannel(0) @@ -488,11 +494,20 @@ func (h *testHarness) startServer() { h.server, err = wtserver.New(h.serverCfg) require.NoError(h.t, err) - h.net.setConnCallback(h.server.InboundPeerConnected) + h.net.registerConnCallback(h.serverAddr, h.server.InboundPeerConnected) require.NoError(h.t, h.server.Start()) } +// stopServer stops the main harness server. +func (h *testHarness) stopServer() { + h.t.Helper() + + h.net.removeConnCallback(h.serverAddr) + + require.NoError(h.t, h.server.Stop()) +} + // startClient creates a new server using the harness's current clientCf and // starts it. func (h *testHarness) startClient() { @@ -932,7 +947,7 @@ var clientTests = []clientTest{ // Now, restart the server and prevent it from acking // state updates. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = true h.startServer() @@ -952,7 +967,7 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack the updates // after the client retransmits the unacked update. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = false h.startServer() @@ -1002,7 +1017,7 @@ var clientTests = []clientTest{ // Restart the server and prevent it from acking state // updates. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = true h.startServer() @@ -1020,7 +1035,7 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack the updates // after the client retransmits the unacked updates. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckUpdates = false h.startServer() @@ -1163,7 +1178,7 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack session // creation. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() @@ -1219,7 +1234,7 @@ var clientTests = []clientTest{ // Restart the server and allow it to ack session // creation. - h.server.Stop() + h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() @@ -1390,8 +1405,7 @@ var clientTests = []clientTest{ // Re-add the tower. We prevent the tower from acking // session creation to ensure the inactive sessions are // not used. - err := h.server.Stop() - require.Nil(h.t, err) + h.stopServer() h.serverCfg.NoAckCreateSession = true h.startServer() h.addTower(h.serverAddr) @@ -1400,8 +1414,7 @@ var clientTests = []clientTest{ // Finally, allow the tower to ack session creation, // allowing the state updates to be sent through the new // session. - err = h.server.Stop() - require.Nil(h.t, err) + h.stopServer() h.serverCfg.NoAckCreateSession = false h.startServer() h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second) @@ -1440,8 +1453,7 @@ var clientTests = []clientTest{ // Now, restart the tower and prevent it from acking any // new sessions. We do this here as once the last slot // is exhausted the client will attempt to renegotiate. - err := h.server.Stop() - require.Nil(h.t, err) + h.stopServer() h.serverCfg.NoAckCreateSession = true h.startServer() @@ -1458,8 +1470,7 @@ var clientTests = []clientTest{ // state to process. After the force quite delay // expires, the client should force quite itself and // allow the test to complete. - err = h.client.Stop() - require.Nil(h.t, err) + h.stopServer() }, }, }