diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 42189df41..b0e4af6a3 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -1,11 +1,9 @@ package wtdb_test import ( - "bytes" crand "crypto/rand" "io" "net" - "reflect" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -16,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/stretchr/testify/require" ) // clientDBInit is a closure used to initialize a wtclient.DB instance. @@ -43,10 +42,7 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, h.t.Helper() err := h.db.CreateClientSession(session) - if err != expErr { - h.t.Fatalf("expected create client session error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } func (h *clientDBHarness) listSessions( @@ -55,9 +51,7 @@ func (h *clientDBHarness) listSessions( h.t.Helper() sessions, err := h.db.ListClientSessions(id) - if err != nil { - h.t.Fatalf("unable to list client sessions: %v", err) - } + require.NoError(h.t, err, "unable to list client sessions") return sessions } @@ -68,13 +62,8 @@ func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, h.t.Helper() index, err := h.db.NextSessionKeyIndex(id, blobType) - if err != nil { - h.t.Fatalf("unable to create next session key index: %v", err) - } - - if index == 0 { - h.t.Fatalf("next key index should never be 0") - } + require.NoError(h.t, err, "unable to create next session key index") + require.NotZero(h.t, index, "next key index should never be 0") return index } @@ -85,21 +74,11 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress, h.t.Helper() tower, err := h.db.CreateTower(lnAddr) - if err != expErr { - h.t.Fatalf("expected create tower error: %v, got: %v", expErr, - err) - } - - if tower.ID == 0 { - h.t.Fatalf("tower id should never be 0") - } + require.ErrorIs(h.t, err, expErr) + require.NotZero(h.t, tower.ID, "tower id should never be 0") for _, session := range h.listSessions(&tower.ID) { - if session.Status != wtdb.CSessionActive { - h.t.Fatalf("expected status for session %v to be %v, "+ - "got %v", session.ID, wtdb.CSessionActive, - session.Status) - } + require.Equal(h.t, wtdb.CSessionActive, session.Status) } return tower @@ -110,10 +89,9 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr, h.t.Helper() - if err := h.db.RemoveTower(pubKey, addr); err != expErr { - h.t.Fatalf("expected remove tower error: %v, got %v", expErr, - err) - } + err := h.db.RemoveTower(pubKey, addr) + require.ErrorIs(h.t, err, expErr) + if expErr != nil { return } @@ -122,37 +100,31 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr, if addr != nil { tower, err := h.db.LoadTower(pubKey) - if err != nil { - h.t.Fatalf("expected tower %x to still exist", - pubKeyStr) - } + require.NoErrorf(h.t, err, "expected tower %x to still exist", + pubKeyStr) removedAddr := addr.String() for _, towerAddr := range tower.Addresses { - if towerAddr.String() == removedAddr { - h.t.Fatalf("address %v not removed for tower "+ - "%x", removedAddr, pubKeyStr) - } + require.NotEqualf(h.t, removedAddr, towerAddr, + "address %v not removed for tower %x", + removedAddr, pubKeyStr) } } else { tower, err := h.db.LoadTower(pubKey) - if hasSessions && err != nil { - h.t.Fatalf("expected tower %x with sessions to still "+ - "exist", pubKeyStr) - } - if !hasSessions && err == nil { - h.t.Fatalf("expected tower %x with no sessions to not "+ - "exist", pubKeyStr) - } - if !hasSessions { + if hasSessions { + require.NoError(h.t, err, "expected tower %x with "+ + "sessions to still exist", pubKeyStr) + } else { + require.Errorf(h.t, err, "expected tower %x with no "+ + "sessions to not exist", pubKeyStr) return } + for _, session := range h.listSessions(&tower.ID) { - if session.Status != wtdb.CSessionInactive { - h.t.Fatalf("expected status for session %v to "+ - "be %v, got %v", session.ID, - wtdb.CSessionInactive, session.Status) - } + require.Equal(h.t, wtdb.CSessionInactive, + session.Status, "expected status for session "+ + "%v to be %v, got %v", session.ID, + wtdb.CSessionInactive, session.Status) } } } @@ -163,10 +135,7 @@ func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey, h.t.Helper() tower, err := h.db.LoadTower(pubKey) - if err != expErr { - h.t.Fatalf("expected load tower error: %v, got: %v", expErr, - err) - } + require.ErrorIs(h.t, err, expErr) return tower } @@ -177,10 +146,7 @@ func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, h.t.Helper() tower, err := h.db.LoadTowerByID(id) - if err != expErr { - h.t.Fatalf("expected load tower error: %v, got: %v", expErr, - err) - } + require.ErrorIs(h.t, err, expErr) return tower } @@ -189,9 +155,7 @@ func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientC h.t.Helper() summaries, err := h.db.FetchChanSummaries() - if err != nil { - h.t.Fatalf("unable to fetch chan summaries: %v", err) - } + require.NoError(h.t, err) return summaries } @@ -202,10 +166,7 @@ func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID, h.t.Helper() err := h.db.RegisterChannel(chanID, sweepPkScript) - if err != expErr { - h.t.Fatalf("expected register channel error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID, @@ -214,10 +175,7 @@ func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID, h.t.Helper() lastApplied, err := h.db.CommitUpdate(id, update) - if err != expErr { - h.t.Fatalf("expected commit update error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) return lastApplied } @@ -228,10 +186,7 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, h.t.Helper() err := h.db.AckUpdate(id, seqNum, lastApplied) - if err != expErr { - h.t.Fatalf("expected commit update error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } // testCreateClientSession asserts various conditions regarding the creation of @@ -259,9 +214,9 @@ func testCreateClientSession(h *clientDBHarness) { // First, assert that this session is not already present in the // database. - if _, ok := h.listSessions(nil)[session.ID]; ok { - h.t.Fatalf("session for id %x should not exist yet", session.ID) - } + _, ok := h.listSessions(nil)[session.ID] + require.Falsef(h.t, ok, "session for id %x should not exist yet", + session.ID) // Attempting to insert the client session without reserving a session // key index should fail. @@ -278,10 +233,8 @@ func testCreateClientSession(h *clientDBHarness) { // successfully created, it should return the same index to maintain // idempotency across restarts. keyIndex2 := h.nextKeyIndex(session.TowerID, blobType) - if keyIndex != keyIndex2 { - h.t.Fatalf("next key index should be idempotent: want: %v, "+ - "got %v", keyIndex, keyIndex2) - } + require.Equalf(h.t, keyIndex, keyIndex2, "next key index should "+ + "be idempotent: want: %v, got %v", keyIndex, keyIndex2) // Now, set the client session's key index so that it is proper and // insert it. This should succeed. @@ -289,9 +242,8 @@ func testCreateClientSession(h *clientDBHarness) { h.insertSession(session, nil) // Verify that the session now exists in the database. - if _, ok := h.listSessions(nil)[session.ID]; !ok { - h.t.Fatalf("session for id %x should exist now", session.ID) - } + _, ok = h.listSessions(nil)[session.ID] + require.Truef(h.t, ok, "session for id %x should exist now", session.ID) // Attempt to insert the session again, which should fail due to the // session already existing. @@ -300,9 +252,8 @@ func testCreateClientSession(h *clientDBHarness) { // Finally, assert that reserving another key index succeeds with a // different key index, now that the first one has been finalized. keyIndex3 := h.nextKeyIndex(session.TowerID, blobType) - if keyIndex == keyIndex3 { - h.t.Fatalf("key index still reserved after creating session") - } + require.NotEqualf(h.t, keyIndex, keyIndex3, "key index still "+ + "reserved after creating session") } // testFilterClientSessions asserts that we can correctly filter client sessions @@ -343,15 +294,12 @@ func testFilterClientSessions(h *clientDBHarness) { // them. for towerID, expectedSessions := range towerSessions { sessions := h.listSessions(&towerID) - if len(sessions) != len(expectedSessions) { - h.t.Fatalf("expected %v sessions for tower %v, got %v", - len(expectedSessions), towerID, len(sessions)) - } + require.Len(h.t, sessions, len(expectedSessions)) + for _, expectedSession := range expectedSessions { - if _, ok := sessions[expectedSession]; !ok { - h.t.Fatalf("expected session %v for tower %v", - expectedSession, towerID) - } + _, ok := sessions[expectedSession] + require.Truef(h.t, ok, "expected session %v for "+ + "tower %v", expectedSession, towerID) } } } @@ -380,26 +328,18 @@ func testCreateTower(h *clientDBHarness) { // Load the tower from the database and assert that it matches the tower // we created. tower2 := h.loadTowerByID(tower.ID, nil) - if !reflect.DeepEqual(tower, tower2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - tower, tower2) - } - tower2 = h.loadTower(pk, err) - if !reflect.DeepEqual(tower, tower2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - tower, tower2) - } + require.Equal(h.t, tower, tower2) + + tower2 = h.loadTower(tower.IdentityKey, nil) + require.Equal(h.t, tower, tower2) // Insert the address again into the database. Since the address is the // same, this should result in an unmodified tower record. towerDupAddr := h.createTower(lnAddr, nil) - if len(towerDupAddr.Addresses) != 1 { - h.t.Fatalf("duplicate address should be deduped") - } - if !reflect.DeepEqual(tower, towerDupAddr) { - h.t.Fatalf("mismatch towers, want: %v, got: %v", - tower, towerDupAddr) - } + require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+ + "should be deduped") + + require.Equal(h.t, tower, towerDupAddr) // Generate a new address for this tower. addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911} @@ -416,26 +356,18 @@ func testCreateTower(h *clientDBHarness) { // Load the tower from the database, and assert that it matches the // tower returned from creation. towerNewAddr2 := h.loadTowerByID(tower.ID, nil) - if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - towerNewAddr, towerNewAddr2) - } - towerNewAddr2 = h.loadTower(pk, nil) - if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - towerNewAddr, towerNewAddr2) - } + require.Equal(h.t, towerNewAddr, towerNewAddr2) + + towerNewAddr2 = h.loadTower(tower.IdentityKey, nil) + require.Equal(h.t, towerNewAddr, towerNewAddr2) // Assert that there are now two addresses on the tower object. - if len(towerNewAddr.Addresses) != 2 { - h.t.Fatalf("new address should be added") - } + require.Lenf(h.t, towerNewAddr.Addresses, 2, "new address should be "+ + "added") // Finally, assert that the new address was prepended since it is deemed // fresher. - if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) { - h.t.Fatalf("new address should be prepended") - } + require.Equal(h.t, tower.Addresses, towerNewAddr.Addresses[1:]) } // testRemoveTower asserts the behavior of removing Tower objects as a whole and @@ -443,9 +375,7 @@ func testCreateTower(h *clientDBHarness) { func testRemoveTower(h *clientDBHarness) { // Generate a random public key we'll use for our tower. pk, err := randPubKey() - if err != nil { - h.t.Fatalf("unable to generate pubkey: %v", err) - } + require.NoError(h.t, err) // Removing a tower that does not exist within the database should // result in a NOP. @@ -523,28 +453,23 @@ func testRemoveTower(h *clientDBHarness) { func testChanSummaries(h *clientDBHarness) { // First, assert that this channel is not already registered. var chanID lnwire.ChannelID - if _, ok := h.fetchChanSummaries()[chanID]; ok { - h.t.Fatalf("pkscript for channel %x should not exist yet", - chanID) - } + _, ok := h.fetchChanSummaries()[chanID] + require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet", + chanID) // Generate a random sweep pkscript and register it for this channel. expPkScript := make([]byte, 22) - if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil { - h.t.Fatalf("unable to generate pkscript: %v", err) - } + _, err := io.ReadFull(crand.Reader, expPkScript) + require.NoError(h.t, err) + h.registerChan(chanID, expPkScript, nil) // Assert that the channel exists and that its sweep pkscript matches // the one we registered. summary, ok := h.fetchChanSummaries()[chanID] - if !ok { - h.t.Fatalf("pkscript for channel %x should not exist yet", - chanID) - } else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 { - h.t.Fatalf("pkscript mismatch, want: %x, got: %x", - expPkScript, summary.SweepPkScript) - } + require.Truef(h.t, ok, "pkscript for channel %x should not exist yet", + chanID) + require.Equal(h.t, expPkScript, summary.SweepPkScript) // Finally, assert that re-registering the same channel produces a // failure. @@ -581,10 +506,7 @@ func testCommitUpdate(h *clientDBHarness) { // succeed. The lastApplied value should be 0 since we have not received // an ack from the tower. lastApplied := h.commitUpdate(&session.ID, update1, nil) - if lastApplied != 0 { - h.t.Fatalf("last applied mismatch, want: 0, got: %v", - lastApplied) - } + require.Zero(h.t, lastApplied) // Assert that the committed update appears in the client session's // CommittedUpdates map when loaded from disk and that there are no @@ -600,10 +522,7 @@ func testCommitUpdate(h *clientDBHarness) { // the on-disk update's hint). The lastApplied value should remain // unchanged. lastApplied2 := h.commitUpdate(&session.ID, update1, nil) - if lastApplied2 != lastApplied { - h.t.Fatalf("last applied should not have changed, got %v", - lastApplied2) - } + require.Equal(h.t, lastApplied, lastApplied2) // Assert that the loaded ClientSession is the same as before. dbSession = h.listSessions(nil)[session.ID] @@ -621,10 +540,7 @@ func testCommitUpdate(h *clientDBHarness) { // which should succeed. update2.SeqNum = 2 lastApplied3 := h.commitUpdate(&session.ID, update2, nil) - if lastApplied3 != lastApplied { - h.t.Fatalf("last applied should not have changed, got %v", - lastApplied3) - } + require.Equal(h.t, lastApplied, lastApplied3) // Check that both updates now appear as committed on the ClientSession // loaded from disk. @@ -684,10 +600,7 @@ func testAckUpdate(h *clientDBHarness) { // Commit to a random update at seqnum 1. update1 := randCommittedUpdate(h.t, 1) lastApplied := h.commitUpdate(&session.ID, update1, nil) - if lastApplied != 0 { - h.t.Fatalf("last applied mismatch, want: 0, got: %v", - lastApplied) - } + require.Zero(h.t, lastApplied) // Acking seqnum 1 should succeed. h.ackUpdate(&session.ID, 1, 1, nil) @@ -715,10 +628,7 @@ func testAckUpdate(h *clientDBHarness) { // ack. update2 := randCommittedUpdate(h.t, 2) lastApplied = h.commitUpdate(&session.ID, update2, nil) - if lastApplied != 1 { - h.t.Fatalf("last applied mismatch, want: 1, got: %v", - lastApplied) - } + require.EqualValues(h.t, 1, lastApplied) // Ack seqnum 2. h.ackUpdate(&session.ID, 2, 2, nil) @@ -756,10 +666,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make([]wtdb.CommittedUpdate, 0) } - if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) { - t.Fatalf("committed updates mismatch, want: %v, got: %v", - expUpdates, session.CommittedUpdates) - } + require.Equal(t, expUpdates, session.CommittedUpdates) } // checkAckedUpdates asserts that the AckedUpdates on a session match the @@ -774,10 +681,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make(map[uint16]wtdb.BackupID) } - if !reflect.DeepEqual(session.AckedUpdates, expUpdates) { - t.Fatalf("acked updates mismatch, want: %v, got: %v", - expUpdates, session.AckedUpdates) - } + require.Equal(t, expUpdates, session.AckedUpdates) } // TestClientDB asserts the behavior of a fresh client db, a reopened client db, @@ -795,14 +699,10 @@ func TestClientDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, t.TempDir(), "wtclient.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenClientDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() @@ -819,27 +719,19 @@ func TestClientDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, path, "wtclient.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenClientDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db.Close() bdb, err = wtdb.NewBoltBackendCreator( true, path, "wtclient.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err = wtdb.OpenClientDB(bdb) - if err != nil { - t.Fatalf("unable to reopen db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() @@ -909,19 +801,16 @@ func TestClientDB(t *testing.T) { // randCommittedUpdate generates a random committed update. func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { var chanID lnwire.ChannelID - if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - } + _, err := io.ReadFull(crand.Reader, chanID[:]) + require.NoError(t, err) var hint blob.BreachHint - if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil { - t.Fatalf("unable to generate breach hint: %v", err) - } + _, err = io.ReadFull(crand.Reader, hint[:]) + require.NoError(t, err) encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) - if _, err := io.ReadFull(crand.Reader, encBlob); err != nil { - t.Fatalf("unable to generate encrypted blob: %v", err) - } + _, err = io.ReadFull(crand.Reader, encBlob) + require.NoError(t, err) return &wtdb.CommittedUpdate{ SeqNum: seqNum, diff --git a/watchtower/wtdb/codec_test.go b/watchtower/wtdb/codec_test.go index 7842b13bc..c2628b86a 100644 --- a/watchtower/wtdb/codec_test.go +++ b/watchtower/wtdb/codec_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/stretchr/testify/require" ) func randPubKey() (*btcec.PublicKey, error) { @@ -134,10 +135,7 @@ func TestCodec(tt *testing.T) { // Ensure encoding the object succeeds. var b bytes.Buffer err := obj.Encode(&b) - if err != nil { - t.Fatalf("unable to encode: %v", err) - return false - } + require.NoError(t, err) var obj2 dbObject switch obj.(type) { @@ -162,17 +160,10 @@ func TestCodec(tt *testing.T) { // Ensure decoding the object succeeds. err = obj2.Decode(bytes.NewReader(b.Bytes())) - if err != nil { - t.Fatalf("unable to decode: %v", err) - return false - } + require.NoError(t, err) // Assert the original and decoded object match. - if !reflect.DeepEqual(obj, obj2) { - t.Fatalf("encode/decode mismatch, want: %v, "+ - "got: %v", obj, obj2) - return false - } + require.Equal(t, obj, obj2) return true } @@ -180,16 +171,10 @@ func TestCodec(tt *testing.T) { customTypeGen := map[string]func([]reflect.Value, *rand.Rand){ "Tower": func(v []reflect.Value, r *rand.Rand) { pk, err := randPubKey() - if err != nil { - t.Fatalf("unable to generate pubkey: %v", err) - return - } + require.NoError(t, err) addrs, err := randAddrs(r) - if err != nil { - t.Fatalf("unable to generate addrs: %v", err) - return - } + require.NoError(t, err) obj := wtdb.Tower{ IdentityKey: pk, @@ -260,10 +245,7 @@ func TestCodec(tt *testing.T) { } err := quick.Check(test.scenario, config) - if err != nil { - t.Fatalf("fuzz checks for msg=%s failed: %v", - test.name, err) - } + require.NoError(h, err) }) } } diff --git a/watchtower/wtdb/tower_db_test.go b/watchtower/wtdb/tower_db_test.go index 177dbd233..9459f34d3 100644 --- a/watchtower/wtdb/tower_db_test.go +++ b/watchtower/wtdb/tower_db_test.go @@ -3,7 +3,6 @@ package wtdb_test import ( "bytes" "encoding/binary" - "reflect" "testing" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -14,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/stretchr/testify/require" ) var ( @@ -48,10 +48,7 @@ func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) { h.t.Helper() err := h.db.InsertSessionInfo(s) - if err != expErr { - h.t.Fatalf("expected insert session error: %v, got : %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } // getSession retrieves the session identified by id, asserting that the call @@ -62,10 +59,7 @@ func (h *towerDBHarness) getSession(id *wtdb.SessionID, h.t.Helper() session, err := h.db.GetSessionInfo(id) - if err != expErr { - h.t.Fatalf("expected get session error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) return session } @@ -79,10 +73,7 @@ func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate, h.t.Helper() lastApplied, err := h.db.InsertStateUpdate(s) - if err != expErr { - h.t.Fatalf("expected insert update error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) return lastApplied } @@ -93,10 +84,7 @@ func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) { h.t.Helper() err := h.db.DeleteSession(id) - if err != expErr { - h.t.Fatalf("expected deletion error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } // queryMatches queries that database for the passed breach hint, returning all @@ -105,9 +93,7 @@ func (h *towerDBHarness) queryMatches(hint blob.BreachHint) []wtdb.Match { h.t.Helper() matches, err := h.db.QueryMatches([]blob.BreachHint{hint}) - if err != nil { - h.t.Fatalf("unable to query matches: %v", err) - } + require.NoError(h.t, err) return matches } @@ -119,14 +105,10 @@ func (h *towerDBHarness) hasUpdate(hint blob.BreachHint) wtdb.Match { h.t.Helper() matches := h.queryMatches(hint) - if len(matches) != 1 { - h.t.Fatalf("expected 1 match, found: %d", len(matches)) - } + require.Len(h.t, matches, 1) match := matches[0] - if match.Hint != hint { - h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint) - } + require.Equal(h.t, hint, match.Hint) return match } @@ -158,11 +140,7 @@ func testInsertSession(h *towerDBHarness) { h.insertSession(session, nil) session2 := h.getSession(&id, nil) - - if !reflect.DeepEqual(session, session2) { - h.t.Fatalf("expected session: %v, got %v", - session, session2) - } + require.Equal(h.t, session, session2) h.insertSession(session, nil) @@ -211,28 +189,21 @@ func testMultipleMatches(h *towerDBHarness) { // Query the db for matches on the chosen hint. matches := h.queryMatches(hint) - if len(matches) != numUpdates { - h.t.Fatalf("num updates mismatch, want: %d, got: %d", - numUpdates, len(matches)) - } + require.Len(h.t, matches, numUpdates) // Assert that the hints are what we asked for, and compute the set of // sessions returned. sessions := make(map[wtdb.SessionID]struct{}) for _, match := range matches { - if match.Hint != hint { - h.t.Fatalf("hint mismatch, want: %v, got: %v", - hint, match.Hint) - } + require.Equal(h.t, hint, match.Hint) sessions[match.ID] = struct{}{} } // Assert that the sessions returned match the session ids of the // sessions we initially created. for i := 0; i < numUpdates; i++ { - if _, ok := sessions[*id(i)]; !ok { - h.t.Fatalf("match for session %v not found", *id(i)) - } + _, ok := sessions[*id(i)] + require.Truef(h.t, ok, "match for session %v not found", *id(i)) } } @@ -242,33 +213,22 @@ func testMultipleMatches(h *towerDBHarness) { func testLookoutTip(h *towerDBHarness) { // Retrieve lookout tip on fresh db. epoch, err := h.db.GetLookoutTip() - if err != nil { - h.t.Fatalf("unable to fetch lookout tip: %v", err) - } + require.NoError(h.t, err) // Assert that the epoch is nil. - if epoch != nil { - h.t.Fatalf("lookout tip should not be set, found: %v", epoch) - } + require.Nil(h.t, epoch) // Create a closure that inserts an epoch, retrieves it, and asserts // that the returned epoch matches what was inserted. setAndCheck := func(i int) { expEpoch := epochFromInt(1) err = h.db.SetLookoutTip(expEpoch) - if err != nil { - h.t.Fatalf("unable to set lookout tip: %v", err) - } + require.NoError(h.t, err) epoch, err = h.db.GetLookoutTip() - if err != nil { - h.t.Fatalf("unable to fetch lookout tip: %v", err) - } + require.NoError(h.t, err) - if !reflect.DeepEqual(epoch, expEpoch) { - h.t.Fatalf("lookout tip mismatch, want: %v, got: %v", - expEpoch, epoch) - } + require.Equal(h.t, expEpoch, epoch) } // Set and assert the lookout tip. @@ -348,15 +308,10 @@ func testDeleteSession(h *towerDBHarness) { // Assert that only one update is still present. matches := h.queryMatches(hint) - if len(matches) != 1 { - h.t.Fatalf("expected one update, found: %d", len(matches)) - } + require.Len(h.t, matches, 1) // Assert that the update belongs to the first session. - if matches[0].ID != *id0 { - h.t.Fatalf("expected match for %v, instead is for: %v", - *id0, matches[0].ID) - } + require.Equal(h.t, *id0, matches[0].ID) // Finally, remove the first session added. h.deleteSession(*id0, nil) @@ -366,9 +321,7 @@ func testDeleteSession(h *towerDBHarness) { // No matches should exist for this hint. matches = h.queryMatches(hint) - if len(matches) != 0 { - h.t.Fatalf("expected zero updates, found: %d", len(matches)) - } + require.Zero(h.t, len(matches)) } type stateUpdateTest struct { @@ -403,10 +356,9 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) { *expSession = *test.session } - if len(test.updates) != len(test.updateErrs) { - h.t.Fatalf("malformed test case, num updates " + - "should match num errors") - } + require.Lenf(h.t, test.updates, len(test.updateErrs), + "malformed test case, num updates should match num "+ + "errors") // Send any updates provided in the test. for i, update := range test.updates { @@ -430,10 +382,7 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) { expSession.ClientLastApplied = update.LastApplied match := h.hasUpdate(update.Hint) - if !reflect.DeepEqual(match.SessionInfo, expSession) { - h.t.Fatalf("expected session: %v, got: %v", - expSession, match.SessionInfo) - } + require.Equal(h.t, expSession, match.SessionInfo) } } } @@ -640,14 +589,10 @@ func TestTowerDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, path, "watchtower.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenTowerDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() @@ -664,14 +609,10 @@ func TestTowerDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, path, "watchtower.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenTowerDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db.Close() // Open the db again, ensuring we test a @@ -680,14 +621,10 @@ func TestTowerDB(t *testing.T) { bdb, err = wtdb.NewBoltBackendCreator( true, path, "watchtower.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err = wtdb.OpenTowerDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close()