wtclient: allow multiplie callback dial functions

This commit is a step towards prepping the watchtower client test
harness to be able to handle the case where the client connects to
multiple mock servers.
This commit is contained in:
Elle Mouton
2022-09-07 11:47:54 +02:00
parent ab4d4a19be
commit 4828fd902d

View File

@@ -2,6 +2,7 @@ package wtclient_test
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"net" "net"
"sync" "sync"
"testing" "testing"
@@ -76,37 +77,34 @@ func randPrivKey(t *testing.T) *btcec.PrivateKey {
} }
type mockNet struct { type mockNet struct {
mu sync.RWMutex mu sync.RWMutex
connCallback func(wtserver.Peer) connCallbacks map[string]func(wtserver.Peer)
} }
func newMockNet(cb func(wtserver.Peer)) *mockNet { func newMockNet() *mockNet {
return &mockNet{ return &mockNet{
connCallback: cb, connCallbacks: make(map[string]func(peer wtserver.Peer)),
} }
} }
func (m *mockNet) Dial(network string, address string, func (m *mockNet) Dial(_, _ string, _ time.Duration) (net.Conn, error) {
timeout time.Duration) (net.Conn, error) {
return nil, nil return nil, nil
} }
func (m *mockNet) LookupHost(host string) ([]string, error) { func (m *mockNet) LookupHost(_ string) ([]string, error) {
panic("not implemented") panic("not implemented")
} }
func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) { func (m *mockNet) LookupSRV(_, _, _ string) (string, []*net.SRV, error) {
panic("not implemented") panic("not implemented")
} }
func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) { func (m *mockNet) ResolveTCPAddr(_, _ string) (*net.TCPAddr, error) {
panic("not implemented") panic("not implemented")
} }
func (m *mockNet) AuthDial(local keychain.SingleKeyECDH, func (m *mockNet) AuthDial(local keychain.SingleKeyECDH,
netAddr *lnwire.NetAddress, netAddr *lnwire.NetAddress, _ tor.DialFunc) (wtserver.Peer, error) {
dialer tor.DialFunc) (wtserver.Peer, error) {
localPk := local.PubKey() localPk := local.PubKey()
localAddr := &net.TCPAddr{ localAddr := &net.TCPAddr{
@@ -119,16 +117,31 @@ func (m *mockNet) AuthDial(local keychain.SingleKeyECDH,
) )
m.mu.RLock() m.mu.RLock()
m.connCallback(remotePeer) defer m.mu.RUnlock()
m.mu.RUnlock() cb, ok := m.connCallbacks[netAddr.String()]
if !ok {
return nil, fmt.Errorf("no callback registered for this peer")
}
cb(remotePeer)
return localPeer, nil return localPeer, nil
} }
func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) { func (m *mockNet) registerConnCallback(netAddr *lnwire.NetAddress,
cb func(wtserver.Peer)) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.connCallback = cb
m.connCallbacks[netAddr.String()] = cb
}
func (m *mockNet) removeConnCallback(netAddr *lnwire.NetAddress) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.connCallbacks, netAddr.String())
} }
type mockChannel struct { type mockChannel struct {
@@ -416,11 +429,8 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
NoAckCreateSession: cfg.noAckCreateSession, NoAckCreateSession: cfg.noAckCreateSession,
} }
server, err := wtserver.New(serverCfg)
require.NoError(t, err, "unable to create wtserver")
signer := wtmock.NewMockSigner() signer := wtmock.NewMockSigner()
mockNet := newMockNet(server.InboundPeerConnected) mockNet := newMockNet()
clientDB := wtmock.NewClientDB() clientDB := wtmock.NewClientDB()
clientCfg := &wtclient.Config{ clientCfg := &wtclient.Config{
@@ -442,19 +452,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
client, err := wtclient.New(clientCfg) client, err := wtclient.New(clientCfg)
require.NoError(t, err, "Unable to create wtclient") require.NoError(t, err, "Unable to create wtclient")
err = server.Start()
require.NoError(t, err)
t.Cleanup(func() {
_ = server.Stop()
})
err = client.Start()
require.NoError(t, err)
t.Cleanup(client.ForceQuit)
err = client.AddTower(towerAddr)
require.NoError(t, err)
h := &testHarness{ h := &testHarness{
t: t, t: t,
cfg: cfg, cfg: cfg,
@@ -466,11 +463,20 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
serverAddr: towerAddr, serverAddr: towerAddr,
serverDB: serverDB, serverDB: serverDB,
serverCfg: serverCfg, serverCfg: serverCfg,
server: server,
net: mockNet, net: mockNet,
channels: make(map[lnwire.ChannelID]*mockChannel), channels: make(map[lnwire.ChannelID]*mockChannel),
} }
h.startServer()
t.Cleanup(h.stopServer)
err = client.Start()
require.NoError(t, err)
t.Cleanup(client.ForceQuit)
err = client.AddTower(towerAddr)
require.NoError(t, err)
h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
if !cfg.noRegisterChan0 { if !cfg.noRegisterChan0 {
h.registerChannel(0) h.registerChannel(0)
@@ -488,11 +494,20 @@ func (h *testHarness) startServer() {
h.server, err = wtserver.New(h.serverCfg) h.server, err = wtserver.New(h.serverCfg)
require.NoError(h.t, err) require.NoError(h.t, err)
h.net.setConnCallback(h.server.InboundPeerConnected) h.net.registerConnCallback(h.serverAddr, h.server.InboundPeerConnected)
require.NoError(h.t, h.server.Start()) require.NoError(h.t, h.server.Start())
} }
// stopServer stops the main harness server.
func (h *testHarness) stopServer() {
h.t.Helper()
h.net.removeConnCallback(h.serverAddr)
require.NoError(h.t, h.server.Stop())
}
// startClient creates a new server using the harness's current clientCf and // startClient creates a new server using the harness's current clientCf and
// starts it. // starts it.
func (h *testHarness) startClient() { func (h *testHarness) startClient() {
@@ -932,7 +947,7 @@ var clientTests = []clientTest{
// Now, restart the server and prevent it from acking // Now, restart the server and prevent it from acking
// state updates. // state updates.
h.server.Stop() h.stopServer()
h.serverCfg.NoAckUpdates = true h.serverCfg.NoAckUpdates = true
h.startServer() h.startServer()
@@ -952,7 +967,7 @@ var clientTests = []clientTest{
// Restart the server and allow it to ack the updates // Restart the server and allow it to ack the updates
// after the client retransmits the unacked update. // after the client retransmits the unacked update.
h.server.Stop() h.stopServer()
h.serverCfg.NoAckUpdates = false h.serverCfg.NoAckUpdates = false
h.startServer() h.startServer()
@@ -1002,7 +1017,7 @@ var clientTests = []clientTest{
// Restart the server and prevent it from acking state // Restart the server and prevent it from acking state
// updates. // updates.
h.server.Stop() h.stopServer()
h.serverCfg.NoAckUpdates = true h.serverCfg.NoAckUpdates = true
h.startServer() h.startServer()
@@ -1020,7 +1035,7 @@ var clientTests = []clientTest{
// Restart the server and allow it to ack the updates // Restart the server and allow it to ack the updates
// after the client retransmits the unacked updates. // after the client retransmits the unacked updates.
h.server.Stop() h.stopServer()
h.serverCfg.NoAckUpdates = false h.serverCfg.NoAckUpdates = false
h.startServer() h.startServer()
@@ -1163,7 +1178,7 @@ var clientTests = []clientTest{
// Restart the server and allow it to ack session // Restart the server and allow it to ack session
// creation. // creation.
h.server.Stop() h.stopServer()
h.serverCfg.NoAckCreateSession = false h.serverCfg.NoAckCreateSession = false
h.startServer() h.startServer()
@@ -1219,7 +1234,7 @@ var clientTests = []clientTest{
// Restart the server and allow it to ack session // Restart the server and allow it to ack session
// creation. // creation.
h.server.Stop() h.stopServer()
h.serverCfg.NoAckCreateSession = false h.serverCfg.NoAckCreateSession = false
h.startServer() h.startServer()
@@ -1390,8 +1405,7 @@ var clientTests = []clientTest{
// Re-add the tower. We prevent the tower from acking // Re-add the tower. We prevent the tower from acking
// session creation to ensure the inactive sessions are // session creation to ensure the inactive sessions are
// not used. // not used.
err := h.server.Stop() h.stopServer()
require.Nil(h.t, err)
h.serverCfg.NoAckCreateSession = true h.serverCfg.NoAckCreateSession = true
h.startServer() h.startServer()
h.addTower(h.serverAddr) h.addTower(h.serverAddr)
@@ -1400,8 +1414,7 @@ var clientTests = []clientTest{
// Finally, allow the tower to ack session creation, // Finally, allow the tower to ack session creation,
// allowing the state updates to be sent through the new // allowing the state updates to be sent through the new
// session. // session.
err = h.server.Stop() h.stopServer()
require.Nil(h.t, err)
h.serverCfg.NoAckCreateSession = false h.serverCfg.NoAckCreateSession = false
h.startServer() h.startServer()
h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second) h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second)
@@ -1440,8 +1453,7 @@ var clientTests = []clientTest{
// Now, restart the tower and prevent it from acking any // Now, restart the tower and prevent it from acking any
// new sessions. We do this here as once the last slot // new sessions. We do this here as once the last slot
// is exhausted the client will attempt to renegotiate. // is exhausted the client will attempt to renegotiate.
err := h.server.Stop() h.stopServer()
require.Nil(h.t, err)
h.serverCfg.NoAckCreateSession = true h.serverCfg.NoAckCreateSession = true
h.startServer() h.startServer()
@@ -1458,8 +1470,7 @@ var clientTests = []clientTest{
// state to process. After the force quite delay // state to process. After the force quite delay
// expires, the client should force quite itself and // expires, the client should force quite itself and
// allow the test to complete. // allow the test to complete.
err = h.client.Stop() h.stopServer()
require.Nil(h.t, err)
}, },
}, },
} }