watchtower/wtclient: upgrade pkg to use require

Upgrade all the tests in the wtclient package to make use of the
`require` package.
This commit is contained in:
Elle Mouton
2022-10-11 17:31:12 +02:00
parent d29a55bbb5
commit ab4d4a19be
3 changed files with 108 additions and 233 deletions

View File

@@ -325,10 +325,8 @@ func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliSatoshi) {
c.mu.Lock()
defer c.mu.Unlock()
if c.localBalance < amt {
t.Fatalf("insufficient funds to send, need: %v, have: %v",
amt, c.localBalance)
}
require.GreaterOrEqualf(t, c.localBalance, amt, "insufficient funds "+
"to send, need: %v, have: %v", amt, c.localBalance)
c.localBalance -= amt
c.remoteBalance += amt
@@ -343,10 +341,8 @@ func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliSatoshi) {
c.mu.Lock()
defer c.mu.Unlock()
if c.remoteBalance < amt {
t.Fatalf("insufficient funds to recv, need: %v, have: %v",
amt, c.remoteBalance)
}
require.GreaterOrEqualf(t, c.remoteBalance, amt, "insufficient funds "+
"to recv, need: %v, have: %v", amt, c.remoteBalance)
c.localBalance += amt
c.remoteBalance -= amt
@@ -446,21 +442,18 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
client, err := wtclient.New(clientCfg)
require.NoError(t, err, "Unable to create wtclient")
if err := server.Start(); err != nil {
t.Fatalf("Unable to start wtserver: %v", err)
}
err = server.Start()
require.NoError(t, err)
t.Cleanup(func() {
_ = server.Stop()
})
if err = client.Start(); err != nil {
t.Fatalf("Unable to start wtclient: %v", err)
}
err = client.Start()
require.NoError(t, err)
t.Cleanup(client.ForceQuit)
if err := client.AddTower(towerAddr); err != nil {
t.Fatalf("Unable to add tower to wtclient: %v", err)
}
err = client.AddTower(towerAddr)
require.NoError(t, err)
h := &testHarness{
t: t,
@@ -493,15 +486,11 @@ func (h *testHarness) startServer() {
var err error
h.server, err = wtserver.New(h.serverCfg)
if err != nil {
h.t.Fatalf("unable to create wtserver: %v", err)
}
require.NoError(h.t, err)
h.net.setConnCallback(h.server.InboundPeerConnected)
if err := h.server.Start(); err != nil {
h.t.Fatalf("unable to start wtserver: %v", err)
}
require.NoError(h.t, h.server.Start())
}
// startClient creates a new server using the harness's current clientCf and
@@ -510,24 +499,16 @@ func (h *testHarness) startClient() {
h.t.Helper()
towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr)
if err != nil {
h.t.Fatalf("Unable to resolve tower TCP addr: %v", err)
}
require.NoError(h.t, err)
towerAddr := &lnwire.NetAddress{
IdentityKey: h.serverCfg.NodeKeyECDH.PubKey(),
Address: towerTCPAddr,
}
h.client, err = wtclient.New(h.clientCfg)
if err != nil {
h.t.Fatalf("unable to create wtclient: %v", err)
}
if err := h.client.Start(); err != nil {
h.t.Fatalf("unable to start wtclient: %v", err)
}
if err := h.client.AddTower(towerAddr); err != nil {
h.t.Fatalf("unable to add tower to wtclient: %v", err)
}
require.NoError(h.t, err)
require.NoError(h.t, h.client.Start())
require.NoError(h.t, h.client.AddTower(towerAddr))
}
// chanIDFromInt creates a unique channel id given a unique integral id.
@@ -556,9 +537,7 @@ func (h *testHarness) makeChannel(id uint64,
}
c.mu.Unlock()
if ok {
h.t.Fatalf("channel %d already created", id)
}
require.Falsef(h.t, ok, "channel %d already created", id)
}
// channel retrieves the channel corresponding to id.
@@ -570,9 +549,7 @@ func (h *testHarness) channel(id uint64) *mockChannel {
h.mu.Lock()
c, ok := h.channels[chanIDFromInt(id)]
h.mu.Unlock()
if !ok {
h.t.Fatalf("unable to fetch channel %d", id)
}
require.Truef(h.t, ok, "unable to fetch channel %d", id)
return c
}
@@ -583,9 +560,7 @@ func (h *testHarness) registerChannel(id uint64) {
chanID := chanIDFromInt(id)
err := h.client.RegisterChannel(chanID)
if err != nil {
h.t.Fatalf("unable to register channel %d: %v", id, err)
}
require.NoError(h.t, err)
}
// advanceChannelN calls advanceState on the channel identified by id the number
@@ -624,11 +599,10 @@ func (h *testHarness) backupState(id, i uint64, expErr error) {
_, retribution := h.channel(id).getState(i)
chanID := chanIDFromInt(id)
err := h.client.BackupState(&chanID, retribution, channeldb.SingleFunderBit)
if err != expErr {
h.t.Fatalf("back error mismatch, want: %v, got: %v",
expErr, err)
}
err := h.client.BackupState(
&chanID, retribution, channeldb.SingleFunderBit,
)
require.ErrorIs(h.t, expErr, err)
}
// sendPayments instructs the channel identified by id to send amt to the remote
@@ -688,10 +662,8 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint,
hintSet[hint] = struct{}{}
}
if len(hints) != len(hintSet) {
h.t.Fatalf("breach hints are not unique, list-len: %d "+
"set-len: %d", len(hints), len(hintSet))
}
require.Lenf(h.t, hints, len(hintSet), "breach hints are not unique, "+
"list-len: %d set-len: %d", len(hints), len(hintSet))
// Closure to assert the server's matches are consistent with the hint
// set.
@@ -701,12 +673,9 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint,
}
for _, match := range matches {
if _, ok := hintSet[match.Hint]; ok {
continue
}
h.t.Fatalf("match %v in db is not in hint set",
match.Hint)
_, ok := hintSet[match.Hint]
require.Truef(h.t, ok, "match %v in db is not in "+
"hint set", match.Hint)
}
return true
@@ -717,31 +686,24 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint,
select {
case <-time.After(time.Second):
matches, err := h.serverDB.QueryMatches(hints)
switch {
case err != nil:
h.t.Fatalf("unable to query for hints: %v", err)
require.NoError(h.t, err, "unable to query for hints")
case wantUpdates && serverHasHints(matches):
if wantUpdates && serverHasHints(matches) {
return
}
case wantUpdates:
if wantUpdates {
h.t.Logf("Received %d/%d\n", len(matches),
len(hints))
}
case <-failTimeout:
matches, err := h.serverDB.QueryMatches(hints)
switch {
case err != nil:
h.t.Fatalf("unable to query for hints: %v", err)
case serverHasHints(matches):
return
default:
h.t.Fatalf("breach hints not received, only "+
"got %d/%d", len(matches), len(hints))
}
require.NoError(h.t, err, "unable to query for hints")
require.Truef(h.t, serverHasHints(matches), "breach "+
"hints not received, only got %d/%d",
len(matches), len(hints))
return
}
}
}
@@ -754,25 +716,18 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
// Query for matches on the provided hints.
matches, err := h.serverDB.QueryMatches(hints)
if err != nil {
h.t.Fatalf("unable to query for matches: %v", err)
}
require.NoError(h.t, err)
// Assert that the number of matches is exactly the number of provided
// hints.
if len(matches) != len(hints) {
h.t.Fatalf("expected: %d matches, got: %d", len(hints),
len(matches))
}
require.Lenf(h.t, matches, len(hints), "expected: %d matches, got: %d",
len(hints), len(matches))
// Assert that all of the matches correspond to a session with the
// expected policy.
for _, match := range matches {
matchPolicy := match.SessionInfo.Policy
if expPolicy != matchPolicy {
h.t.Fatalf("expected session to have policy: %v, "+
"got: %v", expPolicy, matchPolicy)
}
require.Equal(h.t, expPolicy, matchPolicy)
}
}
@@ -780,9 +735,8 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
func (h *testHarness) addTower(addr *lnwire.NetAddress) {
h.t.Helper()
if err := h.client.AddTower(addr); err != nil {
h.t.Fatalf("unable to add tower: %v", err)
}
err := h.client.AddTower(addr)
require.NoError(h.t, err)
}
// removeTower removes a tower from the client. If `addr` is specified, then the
@@ -790,9 +744,8 @@ func (h *testHarness) addTower(addr *lnwire.NetAddress) {
func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) {
h.t.Helper()
if err := h.client.RemoveTower(pubKey, addr); err != nil {
h.t.Fatalf("unable to remove tower: %v", err)
}
err := h.client.RemoveTower(pubKey, addr)
require.NoError(h.t, err)
}
const (