diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 7d3178f3e..c536c433b 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -2,9 +2,6 @@ package wtclient import ( "bytes" - "crypto/rand" - "io" - "reflect" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -12,7 +9,6 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -54,14 +50,6 @@ var ( } ) -func makeAddrSlice(size int) []byte { - addr := make([]byte, size) - if _, err := io.ReadFull(rand.Reader, addr); err != nil { - panic("cannot make addr") - } - return addr -} - type backupTaskTest struct { name string chanID lnwire.ChannelID @@ -502,35 +490,12 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that all parameters set during initialization are properly // populated. - if task.id.ChanID != test.chanID { - t.Fatalf("channel id mismatch, want: %s, got: %s", - test.chanID, task.id.ChanID) - } - - if task.id.CommitHeight != test.breachInfo.RevokedStateNum { - t.Fatalf("commit height mismatch, want: %d, got: %d", - test.breachInfo.RevokedStateNum, task.id.CommitHeight) - } - - if task.totalAmt != test.expTotalAmt { - t.Fatalf("total amount mismatch, want: %d, got: %v", - test.expTotalAmt, task.totalAmt) - } - - if !reflect.DeepEqual(task.breachInfo, test.breachInfo) { - t.Fatalf("breach info mismatch, want: %v, got: %v", - test.breachInfo, task.breachInfo) - } - - if !reflect.DeepEqual(task.toLocalInput, test.expToLocalInput) { - t.Fatalf("to-local input mismatch, want: %v, got: %v", - test.expToLocalInput, task.toLocalInput) - } - - if !reflect.DeepEqual(task.toRemoteInput, test.expToRemoteInput) { - t.Fatalf("to-local input mismatch, want: %v, got: %v", - test.expToRemoteInput, task.toRemoteInput) - } + require.Equal(t, test.chanID, task.id.ChanID) + require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight) + require.Equal(t, test.expTotalAmt, task.totalAmt) + require.Equal(t, test.breachInfo, task.breachInfo) + require.Equal(t, test.expToLocalInput, task.toLocalInput) + require.Equal(t, test.expToRemoteInput, task.toRemoteInput) // Reconstruct the expected input.Inputs that will be returned by the // task's inputs() method. @@ -545,34 +510,24 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that the inputs method returns the correct slice of // input.Inputs. inputs := task.inputs() - if !reflect.DeepEqual(expInputs, inputs) { - t.Fatalf("inputs mismatch, want: %v, got: %v", - expInputs, inputs) - } + require.Equal(t, expInputs, inputs) // Now, bind the session to the task. If successful, this locks in the // session's negotiated parameters and allows the backup task to derive // the final free variables in the justice transaction. err := task.bindSession(test.session) - if err != test.bindErr { - t.Fatalf("expected: %v when binding session, got: %v", - test.bindErr, err) - } + require.ErrorIs(t, err, test.bindErr) // Exit early if the bind was supposed to fail. But first, we check that // all fields set during a bind are still unset. This ensure that a // failed bind doesn't have side-effects if the task is retried with a // different session. if test.bindErr != nil { - if task.blobType != 0 { - t.Fatalf("blob type should not be set on failed bind, "+ - "found: %s", task.blobType) - } + require.Zerof(t, task.blobType, "blob type should not be set "+ + "on failed bind, found: %s", task.blobType) - if task.outputs != nil { - t.Fatalf("justice outputs should not be set on failed bind, "+ - "found: %v", task.outputs) - } + require.Nilf(t, task.outputs, "justice outputs should not be "+ + " set on failed bind, found: %v", task.outputs) return } @@ -580,10 +535,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Otherwise, the binding succeeded. Assert that all values set during // the bind are properly populated. policy := test.session.Policy - if task.blobType != policy.BlobType { - t.Fatalf("blob type mismatch, want: %s, got %s", - policy.BlobType, task.blobType) - } + require.Equal(t, policy.BlobType, task.blobType) // Compute the expected outputs on the justice transaction. var expOutputs = []*wire.TxOut{ @@ -603,10 +555,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { } // Assert that the computed outputs match our expected outputs. - if !reflect.DeepEqual(expOutputs, task.outputs) { - t.Fatalf("justice txn output mismatch, want: %v,\ngot: %v", - spew.Sdump(expOutputs), spew.Sdump(task.outputs)) - } + require.Equal(t, expOutputs, task.outputs) // Now, we'll construct, sign, and encrypt the blob containing the parts // needed to reconstruct the justice transaction. @@ -616,10 +565,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Verify that the breach hint matches the breach txid's prefix. breachTxID := test.breachInfo.BreachTxHash expHint := blob.NewBreachHintFromHash(&breachTxID) - if hint != expHint { - t.Fatalf("breach hint mismatch, want: %x, got: %v", - expHint, hint) - } + require.Equal(t, expHint, hint) // Decrypt the return blob to obtain the JusticeKit containing its // contents. @@ -634,14 +580,8 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that the blob contained the serialized revocation and to-local // pubkeys. - if !bytes.Equal(jKit.RevocationPubKey[:], expRevPK) { - t.Fatalf("revocation pk mismatch, want: %x, got: %x", - expRevPK, jKit.RevocationPubKey[:]) - } - if !bytes.Equal(jKit.LocalDelayPubKey[:], expToLocalPK) { - t.Fatalf("revocation pk mismatch, want: %x, got: %x", - expToLocalPK, jKit.LocalDelayPubKey[:]) - } + require.Equal(t, expRevPK, jKit.RevocationPubKey[:]) + require.Equal(t, expToLocalPK, jKit.LocalDelayPubKey[:]) // Determine if the breach transaction has a to-remote output and/or // to-local output to spend from. Note the seemingly-reversed @@ -650,32 +590,19 @@ func testBackupTask(t *testing.T, test backupTaskTest) { hasToLocal := test.breachInfo.RemoteOutputSignDesc != nil // If the to-remote output is present, assert that the to-remote public - // key was included in the blob. - if hasToRemote && - !bytes.Equal(jKit.CommitToRemotePubKey[:], expToRemotePK) { - t.Fatalf("mismatch to-remote pubkey, want: %x, got: %x", - expToRemotePK, jKit.CommitToRemotePubKey) - } - - // Otherwise if the to-local output is not present, assert that a blank - // public key was inserted. - if !hasToRemote && - !bytes.Equal(jKit.CommitToRemotePubKey[:], zeroPK[:]) { - t.Fatalf("mismatch to-remote pubkey, want: %x, got: %x", - zeroPK, jKit.CommitToRemotePubKey) + // key was included in the blob. Otherwise assert that a blank public + // key was inserted. + if hasToRemote { + require.Equal(t, expToRemotePK, jKit.CommitToRemotePubKey[:]) + } else { + require.Equal(t, zeroPK[:], jKit.CommitToRemotePubKey[:]) } // Assert that the CSV is encoded in the blob. - if jKit.CSVDelay != test.breachInfo.RemoteDelay { - t.Fatalf("mismatch remote delay, want: %d, got: %v", - test.breachInfo.RemoteDelay, jKit.CSVDelay) - } + require.Equal(t, test.breachInfo.RemoteDelay, jKit.CSVDelay) // Assert that the sweep pkscript is included. - if !bytes.Equal(jKit.SweepAddress, test.expSweepScript) { - t.Fatalf("sweep pkscript mismatch, want: %x, got: %x", - test.expSweepScript, jKit.SweepAddress) - } + require.Equal(t, test.expSweepScript, jKit.SweepAddress) // Finally, verify that the signatures are encoded in the justice kit. // We don't validate the actual signatures produced here, since at the @@ -684,18 +611,20 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // TODO(conner): include signature validation checks emptyToLocalSig := bytes.Equal(jKit.CommitToLocalSig[:], zeroSig[:]) - switch { - case hasToLocal && emptyToLocalSig: - t.Fatalf("to-local signature should not be empty") - case !hasToLocal && !emptyToLocalSig: - t.Fatalf("to-local signature should be empty") + if hasToLocal { + require.False(t, emptyToLocalSig, "to-local signature should "+ + "not be empty") + } else { + require.True(t, emptyToLocalSig, "to-local signature should "+ + "be empty") } emptyToRemoteSig := bytes.Equal(jKit.CommitToRemoteSig[:], zeroSig[:]) - switch { - case hasToRemote && emptyToRemoteSig: - t.Fatalf("to-remote signature should not be empty") - case !hasToRemote && !emptyToRemoteSig: - t.Fatalf("to-remote signature should be empty") + if hasToRemote { + require.False(t, emptyToRemoteSig, "to-remote signature "+ + "should not be empty") + } else { + require.True(t, emptyToRemoteSig, "to-remote signature "+ + "should be empty") } } diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go index 99547d794..9a919e103 100644 --- a/watchtower/wtclient/candidate_iterator_test.go +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -4,12 +4,10 @@ import ( "encoding/binary" "math/rand" "net" - "reflect" "testing" "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/stretchr/testify/require" ) @@ -19,15 +17,16 @@ func init() { } func randAddr(t *testing.T) net.Addr { - var ip [4]byte - if _, err := rand.Read(ip[:]); err != nil { - t.Fatal(err) - } - var port [2]byte - if _, err := rand.Read(port[:]); err != nil { - t.Fatal(err) + t.Helper() + + var ip [4]byte + _, err := rand.Read(ip[:]) + require.NoError(t, err) + + var port [2]byte + _, err = rand.Read(port[:]) + require.NoError(t, err) - } return &net.TCPAddr{ IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(port[:])), @@ -35,6 +34,8 @@ func randAddr(t *testing.T) net.Addr { } func randTower(t *testing.T) *wtdb.Tower { + t.Helper() + priv, err := btcec.NewPrivateKey() require.NoError(t, err, "unable to create private key") pubKey := priv.PubKey() @@ -58,27 +59,24 @@ func copyTower(tower *wtdb.Tower) *wtdb.Tower { func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower, active bool) { + t.Helper() + isCandidate := i.IsActive(c.ID) - if isCandidate && !active { - t.Fatalf("expected tower %v to no longer be an active candidate", - c.ID) - } - if !isCandidate && active { - t.Fatalf("expected tower %v to be an active candidate", c.ID) + if isCandidate { + require.Truef(t, active, "expected tower %v to no longer be "+ + "an active candidate", c.ID) + return } + require.Falsef(t, active, "expected tower %v to be an active candidate", + c.ID) } func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower) { t.Helper() tower, err := i.Next() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(tower, c) { - t.Fatalf("expected tower: %v\ngot: %v", spew.Sdump(c), - spew.Sdump(tower)) - } + require.NoError(t, err) + require.Equal(t, c, tower) } // TestTowerCandidateIterator asserts the internal state of a @@ -104,18 +102,13 @@ func TestTowerCandidateIterator(t *testing.T) { // were added. for _, expTower := range towers { tower, err := towerIterator.Next() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(tower, expTower) { - t.Fatalf("expected tower: %v\ngot: %v", - spew.Sdump(expTower), spew.Sdump(tower)) - } + require.NoError(t, err) + require.Equal(t, expTower, tower) } - if _, err := towerIterator.Next(); err != ErrTowerCandidatesExhausted { - t.Fatalf("expected ErrTowerCandidatesExhausted, got %v", err) - } + _, err := towerIterator.Next() + require.ErrorIs(t, err, ErrTowerCandidatesExhausted) + towerIterator.Reset() // We'll then attempt to test the RemoveCandidate behavior of the diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 53ef15d15..a9c4a330b 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -325,10 +325,8 @@ func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliSatoshi) { c.mu.Lock() defer c.mu.Unlock() - if c.localBalance < amt { - t.Fatalf("insufficient funds to send, need: %v, have: %v", - amt, c.localBalance) - } + require.GreaterOrEqualf(t, c.localBalance, amt, "insufficient funds "+ + "to send, need: %v, have: %v", amt, c.localBalance) c.localBalance -= amt c.remoteBalance += amt @@ -343,10 +341,8 @@ func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliSatoshi) { c.mu.Lock() defer c.mu.Unlock() - if c.remoteBalance < amt { - t.Fatalf("insufficient funds to recv, need: %v, have: %v", - amt, c.remoteBalance) - } + require.GreaterOrEqualf(t, c.remoteBalance, amt, "insufficient funds "+ + "to recv, need: %v, have: %v", amt, c.remoteBalance) c.localBalance += amt c.remoteBalance -= amt @@ -446,21 +442,18 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { client, err := wtclient.New(clientCfg) require.NoError(t, err, "Unable to create wtclient") - if err := server.Start(); err != nil { - t.Fatalf("Unable to start wtserver: %v", err) - } + err = server.Start() + require.NoError(t, err) t.Cleanup(func() { _ = server.Stop() }) - if err = client.Start(); err != nil { - t.Fatalf("Unable to start wtclient: %v", err) - } + err = client.Start() + require.NoError(t, err) t.Cleanup(client.ForceQuit) - if err := client.AddTower(towerAddr); err != nil { - t.Fatalf("Unable to add tower to wtclient: %v", err) - } + err = client.AddTower(towerAddr) + require.NoError(t, err) h := &testHarness{ t: t, @@ -493,15 +486,11 @@ func (h *testHarness) startServer() { var err error h.server, err = wtserver.New(h.serverCfg) - if err != nil { - h.t.Fatalf("unable to create wtserver: %v", err) - } + require.NoError(h.t, err) h.net.setConnCallback(h.server.InboundPeerConnected) - if err := h.server.Start(); err != nil { - h.t.Fatalf("unable to start wtserver: %v", err) - } + require.NoError(h.t, h.server.Start()) } // startClient creates a new server using the harness's current clientCf and @@ -510,24 +499,16 @@ func (h *testHarness) startClient() { h.t.Helper() towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr) - if err != nil { - h.t.Fatalf("Unable to resolve tower TCP addr: %v", err) - } + require.NoError(h.t, err) towerAddr := &lnwire.NetAddress{ IdentityKey: h.serverCfg.NodeKeyECDH.PubKey(), Address: towerTCPAddr, } h.client, err = wtclient.New(h.clientCfg) - if err != nil { - h.t.Fatalf("unable to create wtclient: %v", err) - } - if err := h.client.Start(); err != nil { - h.t.Fatalf("unable to start wtclient: %v", err) - } - if err := h.client.AddTower(towerAddr); err != nil { - h.t.Fatalf("unable to add tower to wtclient: %v", err) - } + require.NoError(h.t, err) + require.NoError(h.t, h.client.Start()) + require.NoError(h.t, h.client.AddTower(towerAddr)) } // chanIDFromInt creates a unique channel id given a unique integral id. @@ -556,9 +537,7 @@ func (h *testHarness) makeChannel(id uint64, } c.mu.Unlock() - if ok { - h.t.Fatalf("channel %d already created", id) - } + require.Falsef(h.t, ok, "channel %d already created", id) } // channel retrieves the channel corresponding to id. @@ -570,9 +549,7 @@ func (h *testHarness) channel(id uint64) *mockChannel { h.mu.Lock() c, ok := h.channels[chanIDFromInt(id)] h.mu.Unlock() - if !ok { - h.t.Fatalf("unable to fetch channel %d", id) - } + require.Truef(h.t, ok, "unable to fetch channel %d", id) return c } @@ -583,9 +560,7 @@ func (h *testHarness) registerChannel(id uint64) { chanID := chanIDFromInt(id) err := h.client.RegisterChannel(chanID) - if err != nil { - h.t.Fatalf("unable to register channel %d: %v", id, err) - } + require.NoError(h.t, err) } // advanceChannelN calls advanceState on the channel identified by id the number @@ -624,11 +599,10 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { _, retribution := h.channel(id).getState(i) chanID := chanIDFromInt(id) - err := h.client.BackupState(&chanID, retribution, channeldb.SingleFunderBit) - if err != expErr { - h.t.Fatalf("back error mismatch, want: %v, got: %v", - expErr, err) - } + err := h.client.BackupState( + &chanID, retribution, channeldb.SingleFunderBit, + ) + require.ErrorIs(h.t, expErr, err) } // sendPayments instructs the channel identified by id to send amt to the remote @@ -688,10 +662,8 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, hintSet[hint] = struct{}{} } - if len(hints) != len(hintSet) { - h.t.Fatalf("breach hints are not unique, list-len: %d "+ - "set-len: %d", len(hints), len(hintSet)) - } + require.Lenf(h.t, hints, len(hintSet), "breach hints are not unique, "+ + "list-len: %d set-len: %d", len(hints), len(hintSet)) // Closure to assert the server's matches are consistent with the hint // set. @@ -701,12 +673,9 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, } for _, match := range matches { - if _, ok := hintSet[match.Hint]; ok { - continue - } - - h.t.Fatalf("match %v in db is not in hint set", - match.Hint) + _, ok := hintSet[match.Hint] + require.Truef(h.t, ok, "match %v in db is not in "+ + "hint set", match.Hint) } return true @@ -717,31 +686,24 @@ func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, select { case <-time.After(time.Second): matches, err := h.serverDB.QueryMatches(hints) - switch { - case err != nil: - h.t.Fatalf("unable to query for hints: %v", err) + require.NoError(h.t, err, "unable to query for hints") - case wantUpdates && serverHasHints(matches): + if wantUpdates && serverHasHints(matches) { return + } - case wantUpdates: + if wantUpdates { h.t.Logf("Received %d/%d\n", len(matches), len(hints)) } case <-failTimeout: matches, err := h.serverDB.QueryMatches(hints) - switch { - case err != nil: - h.t.Fatalf("unable to query for hints: %v", err) - - case serverHasHints(matches): - return - - default: - h.t.Fatalf("breach hints not received, only "+ - "got %d/%d", len(matches), len(hints)) - } + require.NoError(h.t, err, "unable to query for hints") + require.Truef(h.t, serverHasHints(matches), "breach "+ + "hints not received, only got %d/%d", + len(matches), len(hints)) + return } } } @@ -754,25 +716,18 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, // Query for matches on the provided hints. matches, err := h.serverDB.QueryMatches(hints) - if err != nil { - h.t.Fatalf("unable to query for matches: %v", err) - } + require.NoError(h.t, err) // Assert that the number of matches is exactly the number of provided // hints. - if len(matches) != len(hints) { - h.t.Fatalf("expected: %d matches, got: %d", len(hints), - len(matches)) - } + require.Lenf(h.t, matches, len(hints), "expected: %d matches, got: %d", + len(hints), len(matches)) // Assert that all of the matches correspond to a session with the // expected policy. for _, match := range matches { matchPolicy := match.SessionInfo.Policy - if expPolicy != matchPolicy { - h.t.Fatalf("expected session to have policy: %v, "+ - "got: %v", expPolicy, matchPolicy) - } + require.Equal(h.t, expPolicy, matchPolicy) } } @@ -780,9 +735,8 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, func (h *testHarness) addTower(addr *lnwire.NetAddress) { h.t.Helper() - if err := h.client.AddTower(addr); err != nil { - h.t.Fatalf("unable to add tower: %v", err) - } + err := h.client.AddTower(addr) + require.NoError(h.t, err) } // removeTower removes a tower from the client. If `addr` is specified, then the @@ -790,9 +744,8 @@ func (h *testHarness) addTower(addr *lnwire.NetAddress) { func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { h.t.Helper() - if err := h.client.RemoveTower(pubKey, addr); err != nil { - h.t.Fatalf("unable to remove tower: %v", err) - } + err := h.client.RemoveTower(pubKey, addr) + require.NoError(h.t, err) } const (