wtclient: allow multiplie callback dial functions

This commit is a step towards prepping the watchtower client test
harness to be able to handle the case where the client connects to
multiple mock servers.
This commit is contained in:
Elle Mouton
2022-09-07 11:47:54 +02:00
parent ab4d4a19be
commit 4828fd902d

View File

@@ -2,6 +2,7 @@ package wtclient_test
import (
"encoding/binary"
"fmt"
"net"
"sync"
"testing"
@@ -76,37 +77,34 @@ func randPrivKey(t *testing.T) *btcec.PrivateKey {
}
type mockNet struct {
mu sync.RWMutex
connCallback func(wtserver.Peer)
mu sync.RWMutex
connCallbacks map[string]func(wtserver.Peer)
}
func newMockNet(cb func(wtserver.Peer)) *mockNet {
func newMockNet() *mockNet {
return &mockNet{
connCallback: cb,
connCallbacks: make(map[string]func(peer wtserver.Peer)),
}
}
func (m *mockNet) Dial(network string, address string,
timeout time.Duration) (net.Conn, error) {
func (m *mockNet) Dial(_, _ string, _ time.Duration) (net.Conn, error) {
return nil, nil
}
func (m *mockNet) LookupHost(host string) ([]string, error) {
func (m *mockNet) LookupHost(_ string) ([]string, error) {
panic("not implemented")
}
func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) {
func (m *mockNet) LookupSRV(_, _, _ string) (string, []*net.SRV, error) {
panic("not implemented")
}
func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) {
func (m *mockNet) ResolveTCPAddr(_, _ string) (*net.TCPAddr, error) {
panic("not implemented")
}
func (m *mockNet) AuthDial(local keychain.SingleKeyECDH,
netAddr *lnwire.NetAddress,
dialer tor.DialFunc) (wtserver.Peer, error) {
netAddr *lnwire.NetAddress, _ tor.DialFunc) (wtserver.Peer, error) {
localPk := local.PubKey()
localAddr := &net.TCPAddr{
@@ -119,16 +117,31 @@ func (m *mockNet) AuthDial(local keychain.SingleKeyECDH,
)
m.mu.RLock()
m.connCallback(remotePeer)
m.mu.RUnlock()
defer m.mu.RUnlock()
cb, ok := m.connCallbacks[netAddr.String()]
if !ok {
return nil, fmt.Errorf("no callback registered for this peer")
}
cb(remotePeer)
return localPeer, nil
}
func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) {
func (m *mockNet) registerConnCallback(netAddr *lnwire.NetAddress,
cb func(wtserver.Peer)) {
m.mu.Lock()
defer m.mu.Unlock()
m.connCallback = cb
m.connCallbacks[netAddr.String()] = cb
}
func (m *mockNet) removeConnCallback(netAddr *lnwire.NetAddress) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.connCallbacks, netAddr.String())
}
type mockChannel struct {
@@ -416,11 +429,8 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
NoAckCreateSession: cfg.noAckCreateSession,
}
server, err := wtserver.New(serverCfg)
require.NoError(t, err, "unable to create wtserver")
signer := wtmock.NewMockSigner()
mockNet := newMockNet(server.InboundPeerConnected)
mockNet := newMockNet()
clientDB := wtmock.NewClientDB()
clientCfg := &wtclient.Config{
@@ -442,19 +452,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
client, err := wtclient.New(clientCfg)
require.NoError(t, err, "Unable to create wtclient")
err = server.Start()
require.NoError(t, err)
t.Cleanup(func() {
_ = server.Stop()
})
err = client.Start()
require.NoError(t, err)
t.Cleanup(client.ForceQuit)
err = client.AddTower(towerAddr)
require.NoError(t, err)
h := &testHarness{
t: t,
cfg: cfg,
@@ -466,11 +463,20 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
serverAddr: towerAddr,
serverDB: serverDB,
serverCfg: serverCfg,
server: server,
net: mockNet,
channels: make(map[lnwire.ChannelID]*mockChannel),
}
h.startServer()
t.Cleanup(h.stopServer)
err = client.Start()
require.NoError(t, err)
t.Cleanup(client.ForceQuit)
err = client.AddTower(towerAddr)
require.NoError(t, err)
h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
if !cfg.noRegisterChan0 {
h.registerChannel(0)
@@ -488,11 +494,20 @@ func (h *testHarness) startServer() {
h.server, err = wtserver.New(h.serverCfg)
require.NoError(h.t, err)
h.net.setConnCallback(h.server.InboundPeerConnected)
h.net.registerConnCallback(h.serverAddr, h.server.InboundPeerConnected)
require.NoError(h.t, h.server.Start())
}
// stopServer stops the main harness server.
func (h *testHarness) stopServer() {
h.t.Helper()
h.net.removeConnCallback(h.serverAddr)
require.NoError(h.t, h.server.Stop())
}
// startClient creates a new server using the harness's current clientCf and
// starts it.
func (h *testHarness) startClient() {
@@ -932,7 +947,7 @@ var clientTests = []clientTest{
// Now, restart the server and prevent it from acking
// state updates.
h.server.Stop()
h.stopServer()
h.serverCfg.NoAckUpdates = true
h.startServer()
@@ -952,7 +967,7 @@ var clientTests = []clientTest{
// Restart the server and allow it to ack the updates
// after the client retransmits the unacked update.
h.server.Stop()
h.stopServer()
h.serverCfg.NoAckUpdates = false
h.startServer()
@@ -1002,7 +1017,7 @@ var clientTests = []clientTest{
// Restart the server and prevent it from acking state
// updates.
h.server.Stop()
h.stopServer()
h.serverCfg.NoAckUpdates = true
h.startServer()
@@ -1020,7 +1035,7 @@ var clientTests = []clientTest{
// Restart the server and allow it to ack the updates
// after the client retransmits the unacked updates.
h.server.Stop()
h.stopServer()
h.serverCfg.NoAckUpdates = false
h.startServer()
@@ -1163,7 +1178,7 @@ var clientTests = []clientTest{
// Restart the server and allow it to ack session
// creation.
h.server.Stop()
h.stopServer()
h.serverCfg.NoAckCreateSession = false
h.startServer()
@@ -1219,7 +1234,7 @@ var clientTests = []clientTest{
// Restart the server and allow it to ack session
// creation.
h.server.Stop()
h.stopServer()
h.serverCfg.NoAckCreateSession = false
h.startServer()
@@ -1390,8 +1405,7 @@ var clientTests = []clientTest{
// Re-add the tower. We prevent the tower from acking
// session creation to ensure the inactive sessions are
// not used.
err := h.server.Stop()
require.Nil(h.t, err)
h.stopServer()
h.serverCfg.NoAckCreateSession = true
h.startServer()
h.addTower(h.serverAddr)
@@ -1400,8 +1414,7 @@ var clientTests = []clientTest{
// Finally, allow the tower to ack session creation,
// allowing the state updates to be sent through the new
// session.
err = h.server.Stop()
require.Nil(h.t, err)
h.stopServer()
h.serverCfg.NoAckCreateSession = false
h.startServer()
h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second)
@@ -1440,8 +1453,7 @@ var clientTests = []clientTest{
// Now, restart the tower and prevent it from acking any
// new sessions. We do this here as once the last slot
// is exhausted the client will attempt to renegotiate.
err := h.server.Stop()
require.Nil(h.t, err)
h.stopServer()
h.serverCfg.NoAckCreateSession = true
h.startServer()
@@ -1458,8 +1470,7 @@ var clientTests = []clientTest{
// state to process. After the force quite delay
// expires, the client should force quite itself and
// allow the test to complete.
err = h.client.Stop()
require.Nil(h.t, err)
h.stopServer()
},
},
}