watchtower/wtdb: update tests to use require package

In this commit, all the tests in the wtdb package are updated in order
to make use of the `require` package where appropriate.
This commit is contained in:
Elle Mouton
2022-10-10 12:47:08 +02:00
parent f815c88ee4
commit 5dabf7cb3e
3 changed files with 129 additions and 321 deletions

View File

@@ -1,11 +1,9 @@
package wtdb_test
import (
"bytes"
crand "crypto/rand"
"io"
"net"
"reflect"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
@@ -16,6 +14,7 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/stretchr/testify/require"
)
// clientDBInit is a closure used to initialize a wtclient.DB instance.
@@ -43,10 +42,7 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
h.t.Helper()
err := h.db.CreateClientSession(session)
if err != expErr {
h.t.Fatalf("expected create client session error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
func (h *clientDBHarness) listSessions(
@@ -55,9 +51,7 @@ func (h *clientDBHarness) listSessions(
h.t.Helper()
sessions, err := h.db.ListClientSessions(id)
if err != nil {
h.t.Fatalf("unable to list client sessions: %v", err)
}
require.NoError(h.t, err, "unable to list client sessions")
return sessions
}
@@ -68,13 +62,8 @@ func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID,
h.t.Helper()
index, err := h.db.NextSessionKeyIndex(id, blobType)
if err != nil {
h.t.Fatalf("unable to create next session key index: %v", err)
}
if index == 0 {
h.t.Fatalf("next key index should never be 0")
}
require.NoError(h.t, err, "unable to create next session key index")
require.NotZero(h.t, index, "next key index should never be 0")
return index
}
@@ -85,21 +74,11 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
h.t.Helper()
tower, err := h.db.CreateTower(lnAddr)
if err != expErr {
h.t.Fatalf("expected create tower error: %v, got: %v", expErr,
err)
}
if tower.ID == 0 {
h.t.Fatalf("tower id should never be 0")
}
require.ErrorIs(h.t, err, expErr)
require.NotZero(h.t, tower.ID, "tower id should never be 0")
for _, session := range h.listSessions(&tower.ID) {
if session.Status != wtdb.CSessionActive {
h.t.Fatalf("expected status for session %v to be %v, "+
"got %v", session.ID, wtdb.CSessionActive,
session.Status)
}
require.Equal(h.t, wtdb.CSessionActive, session.Status)
}
return tower
@@ -110,10 +89,9 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
h.t.Helper()
if err := h.db.RemoveTower(pubKey, addr); err != expErr {
h.t.Fatalf("expected remove tower error: %v, got %v", expErr,
err)
}
err := h.db.RemoveTower(pubKey, addr)
require.ErrorIs(h.t, err, expErr)
if expErr != nil {
return
}
@@ -122,37 +100,31 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
if addr != nil {
tower, err := h.db.LoadTower(pubKey)
if err != nil {
h.t.Fatalf("expected tower %x to still exist",
pubKeyStr)
}
require.NoErrorf(h.t, err, "expected tower %x to still exist",
pubKeyStr)
removedAddr := addr.String()
for _, towerAddr := range tower.Addresses {
if towerAddr.String() == removedAddr {
h.t.Fatalf("address %v not removed for tower "+
"%x", removedAddr, pubKeyStr)
}
require.NotEqualf(h.t, removedAddr, towerAddr,
"address %v not removed for tower %x",
removedAddr, pubKeyStr)
}
} else {
tower, err := h.db.LoadTower(pubKey)
if hasSessions && err != nil {
h.t.Fatalf("expected tower %x with sessions to still "+
"exist", pubKeyStr)
}
if !hasSessions && err == nil {
h.t.Fatalf("expected tower %x with no sessions to not "+
"exist", pubKeyStr)
}
if !hasSessions {
if hasSessions {
require.NoError(h.t, err, "expected tower %x with "+
"sessions to still exist", pubKeyStr)
} else {
require.Errorf(h.t, err, "expected tower %x with no "+
"sessions to not exist", pubKeyStr)
return
}
for _, session := range h.listSessions(&tower.ID) {
if session.Status != wtdb.CSessionInactive {
h.t.Fatalf("expected status for session %v to "+
"be %v, got %v", session.ID,
wtdb.CSessionInactive, session.Status)
}
require.Equal(h.t, wtdb.CSessionInactive,
session.Status, "expected status for session "+
"%v to be %v, got %v", session.ID,
wtdb.CSessionInactive, session.Status)
}
}
}
@@ -163,10 +135,7 @@ func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey,
h.t.Helper()
tower, err := h.db.LoadTower(pubKey)
if err != expErr {
h.t.Fatalf("expected load tower error: %v, got: %v", expErr,
err)
}
require.ErrorIs(h.t, err, expErr)
return tower
}
@@ -177,10 +146,7 @@ func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID,
h.t.Helper()
tower, err := h.db.LoadTowerByID(id)
if err != expErr {
h.t.Fatalf("expected load tower error: %v, got: %v", expErr,
err)
}
require.ErrorIs(h.t, err, expErr)
return tower
}
@@ -189,9 +155,7 @@ func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientC
h.t.Helper()
summaries, err := h.db.FetchChanSummaries()
if err != nil {
h.t.Fatalf("unable to fetch chan summaries: %v", err)
}
require.NoError(h.t, err)
return summaries
}
@@ -202,10 +166,7 @@ func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID,
h.t.Helper()
err := h.db.RegisterChannel(chanID, sweepPkScript)
if err != expErr {
h.t.Fatalf("expected register channel error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
@@ -214,10 +175,7 @@ func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
h.t.Helper()
lastApplied, err := h.db.CommitUpdate(id, update)
if err != expErr {
h.t.Fatalf("expected commit update error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
return lastApplied
}
@@ -228,10 +186,7 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
h.t.Helper()
err := h.db.AckUpdate(id, seqNum, lastApplied)
if err != expErr {
h.t.Fatalf("expected commit update error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
// testCreateClientSession asserts various conditions regarding the creation of
@@ -259,9 +214,9 @@ func testCreateClientSession(h *clientDBHarness) {
// First, assert that this session is not already present in the
// database.
if _, ok := h.listSessions(nil)[session.ID]; ok {
h.t.Fatalf("session for id %x should not exist yet", session.ID)
}
_, ok := h.listSessions(nil)[session.ID]
require.Falsef(h.t, ok, "session for id %x should not exist yet",
session.ID)
// Attempting to insert the client session without reserving a session
// key index should fail.
@@ -278,10 +233,8 @@ func testCreateClientSession(h *clientDBHarness) {
// successfully created, it should return the same index to maintain
// idempotency across restarts.
keyIndex2 := h.nextKeyIndex(session.TowerID, blobType)
if keyIndex != keyIndex2 {
h.t.Fatalf("next key index should be idempotent: want: %v, "+
"got %v", keyIndex, keyIndex2)
}
require.Equalf(h.t, keyIndex, keyIndex2, "next key index should "+
"be idempotent: want: %v, got %v", keyIndex, keyIndex2)
// Now, set the client session's key index so that it is proper and
// insert it. This should succeed.
@@ -289,9 +242,8 @@ func testCreateClientSession(h *clientDBHarness) {
h.insertSession(session, nil)
// Verify that the session now exists in the database.
if _, ok := h.listSessions(nil)[session.ID]; !ok {
h.t.Fatalf("session for id %x should exist now", session.ID)
}
_, ok = h.listSessions(nil)[session.ID]
require.Truef(h.t, ok, "session for id %x should exist now", session.ID)
// Attempt to insert the session again, which should fail due to the
// session already existing.
@@ -300,9 +252,8 @@ func testCreateClientSession(h *clientDBHarness) {
// Finally, assert that reserving another key index succeeds with a
// different key index, now that the first one has been finalized.
keyIndex3 := h.nextKeyIndex(session.TowerID, blobType)
if keyIndex == keyIndex3 {
h.t.Fatalf("key index still reserved after creating session")
}
require.NotEqualf(h.t, keyIndex, keyIndex3, "key index still "+
"reserved after creating session")
}
// testFilterClientSessions asserts that we can correctly filter client sessions
@@ -343,15 +294,12 @@ func testFilterClientSessions(h *clientDBHarness) {
// them.
for towerID, expectedSessions := range towerSessions {
sessions := h.listSessions(&towerID)
if len(sessions) != len(expectedSessions) {
h.t.Fatalf("expected %v sessions for tower %v, got %v",
len(expectedSessions), towerID, len(sessions))
}
require.Len(h.t, sessions, len(expectedSessions))
for _, expectedSession := range expectedSessions {
if _, ok := sessions[expectedSession]; !ok {
h.t.Fatalf("expected session %v for tower %v",
expectedSession, towerID)
}
_, ok := sessions[expectedSession]
require.Truef(h.t, ok, "expected session %v for "+
"tower %v", expectedSession, towerID)
}
}
}
@@ -380,26 +328,18 @@ func testCreateTower(h *clientDBHarness) {
// Load the tower from the database and assert that it matches the tower
// we created.
tower2 := h.loadTowerByID(tower.ID, nil)
if !reflect.DeepEqual(tower, tower2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
tower, tower2)
}
tower2 = h.loadTower(pk, err)
if !reflect.DeepEqual(tower, tower2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
tower, tower2)
}
require.Equal(h.t, tower, tower2)
tower2 = h.loadTower(tower.IdentityKey, nil)
require.Equal(h.t, tower, tower2)
// Insert the address again into the database. Since the address is the
// same, this should result in an unmodified tower record.
towerDupAddr := h.createTower(lnAddr, nil)
if len(towerDupAddr.Addresses) != 1 {
h.t.Fatalf("duplicate address should be deduped")
}
if !reflect.DeepEqual(tower, towerDupAddr) {
h.t.Fatalf("mismatch towers, want: %v, got: %v",
tower, towerDupAddr)
}
require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+
"should be deduped")
require.Equal(h.t, tower, towerDupAddr)
// Generate a new address for this tower.
addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
@@ -416,26 +356,18 @@ func testCreateTower(h *clientDBHarness) {
// Load the tower from the database, and assert that it matches the
// tower returned from creation.
towerNewAddr2 := h.loadTowerByID(tower.ID, nil)
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
towerNewAddr, towerNewAddr2)
}
towerNewAddr2 = h.loadTower(pk, nil)
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
towerNewAddr, towerNewAddr2)
}
require.Equal(h.t, towerNewAddr, towerNewAddr2)
towerNewAddr2 = h.loadTower(tower.IdentityKey, nil)
require.Equal(h.t, towerNewAddr, towerNewAddr2)
// Assert that there are now two addresses on the tower object.
if len(towerNewAddr.Addresses) != 2 {
h.t.Fatalf("new address should be added")
}
require.Lenf(h.t, towerNewAddr.Addresses, 2, "new address should be "+
"added")
// Finally, assert that the new address was prepended since it is deemed
// fresher.
if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) {
h.t.Fatalf("new address should be prepended")
}
require.Equal(h.t, tower.Addresses, towerNewAddr.Addresses[1:])
}
// testRemoveTower asserts the behavior of removing Tower objects as a whole and
@@ -443,9 +375,7 @@ func testCreateTower(h *clientDBHarness) {
func testRemoveTower(h *clientDBHarness) {
// Generate a random public key we'll use for our tower.
pk, err := randPubKey()
if err != nil {
h.t.Fatalf("unable to generate pubkey: %v", err)
}
require.NoError(h.t, err)
// Removing a tower that does not exist within the database should
// result in a NOP.
@@ -523,28 +453,23 @@ func testRemoveTower(h *clientDBHarness) {
func testChanSummaries(h *clientDBHarness) {
// First, assert that this channel is not already registered.
var chanID lnwire.ChannelID
if _, ok := h.fetchChanSummaries()[chanID]; ok {
h.t.Fatalf("pkscript for channel %x should not exist yet",
chanID)
}
_, ok := h.fetchChanSummaries()[chanID]
require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet",
chanID)
// Generate a random sweep pkscript and register it for this channel.
expPkScript := make([]byte, 22)
if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil {
h.t.Fatalf("unable to generate pkscript: %v", err)
}
_, err := io.ReadFull(crand.Reader, expPkScript)
require.NoError(h.t, err)
h.registerChan(chanID, expPkScript, nil)
// Assert that the channel exists and that its sweep pkscript matches
// the one we registered.
summary, ok := h.fetchChanSummaries()[chanID]
if !ok {
h.t.Fatalf("pkscript for channel %x should not exist yet",
chanID)
} else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 {
h.t.Fatalf("pkscript mismatch, want: %x, got: %x",
expPkScript, summary.SweepPkScript)
}
require.Truef(h.t, ok, "pkscript for channel %x should not exist yet",
chanID)
require.Equal(h.t, expPkScript, summary.SweepPkScript)
// Finally, assert that re-registering the same channel produces a
// failure.
@@ -581,10 +506,7 @@ func testCommitUpdate(h *clientDBHarness) {
// succeed. The lastApplied value should be 0 since we have not received
// an ack from the tower.
lastApplied := h.commitUpdate(&session.ID, update1, nil)
if lastApplied != 0 {
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
lastApplied)
}
require.Zero(h.t, lastApplied)
// Assert that the committed update appears in the client session's
// CommittedUpdates map when loaded from disk and that there are no
@@ -600,10 +522,7 @@ func testCommitUpdate(h *clientDBHarness) {
// the on-disk update's hint). The lastApplied value should remain
// unchanged.
lastApplied2 := h.commitUpdate(&session.ID, update1, nil)
if lastApplied2 != lastApplied {
h.t.Fatalf("last applied should not have changed, got %v",
lastApplied2)
}
require.Equal(h.t, lastApplied, lastApplied2)
// Assert that the loaded ClientSession is the same as before.
dbSession = h.listSessions(nil)[session.ID]
@@ -621,10 +540,7 @@ func testCommitUpdate(h *clientDBHarness) {
// which should succeed.
update2.SeqNum = 2
lastApplied3 := h.commitUpdate(&session.ID, update2, nil)
if lastApplied3 != lastApplied {
h.t.Fatalf("last applied should not have changed, got %v",
lastApplied3)
}
require.Equal(h.t, lastApplied, lastApplied3)
// Check that both updates now appear as committed on the ClientSession
// loaded from disk.
@@ -684,10 +600,7 @@ func testAckUpdate(h *clientDBHarness) {
// Commit to a random update at seqnum 1.
update1 := randCommittedUpdate(h.t, 1)
lastApplied := h.commitUpdate(&session.ID, update1, nil)
if lastApplied != 0 {
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
lastApplied)
}
require.Zero(h.t, lastApplied)
// Acking seqnum 1 should succeed.
h.ackUpdate(&session.ID, 1, 1, nil)
@@ -715,10 +628,7 @@ func testAckUpdate(h *clientDBHarness) {
// ack.
update2 := randCommittedUpdate(h.t, 2)
lastApplied = h.commitUpdate(&session.ID, update2, nil)
if lastApplied != 1 {
h.t.Fatalf("last applied mismatch, want: 1, got: %v",
lastApplied)
}
require.EqualValues(h.t, 1, lastApplied)
// Ack seqnum 2.
h.ackUpdate(&session.ID, 2, 2, nil)
@@ -756,10 +666,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
expUpdates = make([]wtdb.CommittedUpdate, 0)
}
if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) {
t.Fatalf("committed updates mismatch, want: %v, got: %v",
expUpdates, session.CommittedUpdates)
}
require.Equal(t, expUpdates, session.CommittedUpdates)
}
// checkAckedUpdates asserts that the AckedUpdates on a session match the
@@ -774,10 +681,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
expUpdates = make(map[uint16]wtdb.BackupID)
}
if !reflect.DeepEqual(session.AckedUpdates, expUpdates) {
t.Fatalf("acked updates mismatch, want: %v, got: %v",
expUpdates, session.AckedUpdates)
}
require.Equal(t, expUpdates, session.AckedUpdates)
}
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
@@ -795,14 +699,10 @@ func TestClientDB(t *testing.T) {
bdb, err := wtdb.NewBoltBackendCreator(
true, t.TempDir(), "wtclient.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err := wtdb.OpenClientDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
@@ -819,27 +719,19 @@ func TestClientDB(t *testing.T) {
bdb, err := wtdb.NewBoltBackendCreator(
true, path, "wtclient.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err := wtdb.OpenClientDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db.Close()
bdb, err = wtdb.NewBoltBackendCreator(
true, path, "wtclient.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err = wtdb.OpenClientDB(bdb)
if err != nil {
t.Fatalf("unable to reopen db: %v", err)
}
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
@@ -909,19 +801,16 @@ func TestClientDB(t *testing.T) {
// randCommittedUpdate generates a random committed update.
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
var chanID lnwire.ChannelID
if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil {
t.Fatalf("unable to generate chan id: %v", err)
}
_, err := io.ReadFull(crand.Reader, chanID[:])
require.NoError(t, err)
var hint blob.BreachHint
if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil {
t.Fatalf("unable to generate breach hint: %v", err)
}
_, err = io.ReadFull(crand.Reader, hint[:])
require.NoError(t, err)
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
if _, err := io.ReadFull(crand.Reader, encBlob); err != nil {
t.Fatalf("unable to generate encrypted blob: %v", err)
}
_, err = io.ReadFull(crand.Reader, encBlob)
require.NoError(t, err)
return &wtdb.CommittedUpdate{
SeqNum: seqNum,

View File

@@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/stretchr/testify/require"
)
func randPubKey() (*btcec.PublicKey, error) {
@@ -134,10 +135,7 @@ func TestCodec(tt *testing.T) {
// Ensure encoding the object succeeds.
var b bytes.Buffer
err := obj.Encode(&b)
if err != nil {
t.Fatalf("unable to encode: %v", err)
return false
}
require.NoError(t, err)
var obj2 dbObject
switch obj.(type) {
@@ -162,17 +160,10 @@ func TestCodec(tt *testing.T) {
// Ensure decoding the object succeeds.
err = obj2.Decode(bytes.NewReader(b.Bytes()))
if err != nil {
t.Fatalf("unable to decode: %v", err)
return false
}
require.NoError(t, err)
// Assert the original and decoded object match.
if !reflect.DeepEqual(obj, obj2) {
t.Fatalf("encode/decode mismatch, want: %v, "+
"got: %v", obj, obj2)
return false
}
require.Equal(t, obj, obj2)
return true
}
@@ -180,16 +171,10 @@ func TestCodec(tt *testing.T) {
customTypeGen := map[string]func([]reflect.Value, *rand.Rand){
"Tower": func(v []reflect.Value, r *rand.Rand) {
pk, err := randPubKey()
if err != nil {
t.Fatalf("unable to generate pubkey: %v", err)
return
}
require.NoError(t, err)
addrs, err := randAddrs(r)
if err != nil {
t.Fatalf("unable to generate addrs: %v", err)
return
}
require.NoError(t, err)
obj := wtdb.Tower{
IdentityKey: pk,
@@ -260,10 +245,7 @@ func TestCodec(tt *testing.T) {
}
err := quick.Check(test.scenario, config)
if err != nil {
t.Fatalf("fuzz checks for msg=%s failed: %v",
test.name, err)
}
require.NoError(h, err)
})
}
}

View File

@@ -3,7 +3,6 @@ package wtdb_test
import (
"bytes"
"encoding/binary"
"reflect"
"testing"
"github.com/btcsuite/btcd/chaincfg/chainhash"
@@ -14,6 +13,7 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/stretchr/testify/require"
)
var (
@@ -48,10 +48,7 @@ func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) {
h.t.Helper()
err := h.db.InsertSessionInfo(s)
if err != expErr {
h.t.Fatalf("expected insert session error: %v, got : %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
// getSession retrieves the session identified by id, asserting that the call
@@ -62,10 +59,7 @@ func (h *towerDBHarness) getSession(id *wtdb.SessionID,
h.t.Helper()
session, err := h.db.GetSessionInfo(id)
if err != expErr {
h.t.Fatalf("expected get session error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
return session
}
@@ -79,10 +73,7 @@ func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate,
h.t.Helper()
lastApplied, err := h.db.InsertStateUpdate(s)
if err != expErr {
h.t.Fatalf("expected insert update error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
return lastApplied
}
@@ -93,10 +84,7 @@ func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
h.t.Helper()
err := h.db.DeleteSession(id)
if err != expErr {
h.t.Fatalf("expected deletion error: %v, got: %v",
expErr, err)
}
require.ErrorIs(h.t, err, expErr)
}
// queryMatches queries that database for the passed breach hint, returning all
@@ -105,9 +93,7 @@ func (h *towerDBHarness) queryMatches(hint blob.BreachHint) []wtdb.Match {
h.t.Helper()
matches, err := h.db.QueryMatches([]blob.BreachHint{hint})
if err != nil {
h.t.Fatalf("unable to query matches: %v", err)
}
require.NoError(h.t, err)
return matches
}
@@ -119,14 +105,10 @@ func (h *towerDBHarness) hasUpdate(hint blob.BreachHint) wtdb.Match {
h.t.Helper()
matches := h.queryMatches(hint)
if len(matches) != 1 {
h.t.Fatalf("expected 1 match, found: %d", len(matches))
}
require.Len(h.t, matches, 1)
match := matches[0]
if match.Hint != hint {
h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint)
}
require.Equal(h.t, hint, match.Hint)
return match
}
@@ -158,11 +140,7 @@ func testInsertSession(h *towerDBHarness) {
h.insertSession(session, nil)
session2 := h.getSession(&id, nil)
if !reflect.DeepEqual(session, session2) {
h.t.Fatalf("expected session: %v, got %v",
session, session2)
}
require.Equal(h.t, session, session2)
h.insertSession(session, nil)
@@ -211,28 +189,21 @@ func testMultipleMatches(h *towerDBHarness) {
// Query the db for matches on the chosen hint.
matches := h.queryMatches(hint)
if len(matches) != numUpdates {
h.t.Fatalf("num updates mismatch, want: %d, got: %d",
numUpdates, len(matches))
}
require.Len(h.t, matches, numUpdates)
// Assert that the hints are what we asked for, and compute the set of
// sessions returned.
sessions := make(map[wtdb.SessionID]struct{})
for _, match := range matches {
if match.Hint != hint {
h.t.Fatalf("hint mismatch, want: %v, got: %v",
hint, match.Hint)
}
require.Equal(h.t, hint, match.Hint)
sessions[match.ID] = struct{}{}
}
// Assert that the sessions returned match the session ids of the
// sessions we initially created.
for i := 0; i < numUpdates; i++ {
if _, ok := sessions[*id(i)]; !ok {
h.t.Fatalf("match for session %v not found", *id(i))
}
_, ok := sessions[*id(i)]
require.Truef(h.t, ok, "match for session %v not found", *id(i))
}
}
@@ -242,33 +213,22 @@ func testMultipleMatches(h *towerDBHarness) {
func testLookoutTip(h *towerDBHarness) {
// Retrieve lookout tip on fresh db.
epoch, err := h.db.GetLookoutTip()
if err != nil {
h.t.Fatalf("unable to fetch lookout tip: %v", err)
}
require.NoError(h.t, err)
// Assert that the epoch is nil.
if epoch != nil {
h.t.Fatalf("lookout tip should not be set, found: %v", epoch)
}
require.Nil(h.t, epoch)
// Create a closure that inserts an epoch, retrieves it, and asserts
// that the returned epoch matches what was inserted.
setAndCheck := func(i int) {
expEpoch := epochFromInt(1)
err = h.db.SetLookoutTip(expEpoch)
if err != nil {
h.t.Fatalf("unable to set lookout tip: %v", err)
}
require.NoError(h.t, err)
epoch, err = h.db.GetLookoutTip()
if err != nil {
h.t.Fatalf("unable to fetch lookout tip: %v", err)
}
require.NoError(h.t, err)
if !reflect.DeepEqual(epoch, expEpoch) {
h.t.Fatalf("lookout tip mismatch, want: %v, got: %v",
expEpoch, epoch)
}
require.Equal(h.t, expEpoch, epoch)
}
// Set and assert the lookout tip.
@@ -348,15 +308,10 @@ func testDeleteSession(h *towerDBHarness) {
// Assert that only one update is still present.
matches := h.queryMatches(hint)
if len(matches) != 1 {
h.t.Fatalf("expected one update, found: %d", len(matches))
}
require.Len(h.t, matches, 1)
// Assert that the update belongs to the first session.
if matches[0].ID != *id0 {
h.t.Fatalf("expected match for %v, instead is for: %v",
*id0, matches[0].ID)
}
require.Equal(h.t, *id0, matches[0].ID)
// Finally, remove the first session added.
h.deleteSession(*id0, nil)
@@ -366,9 +321,7 @@ func testDeleteSession(h *towerDBHarness) {
// No matches should exist for this hint.
matches = h.queryMatches(hint)
if len(matches) != 0 {
h.t.Fatalf("expected zero updates, found: %d", len(matches))
}
require.Zero(h.t, len(matches))
}
type stateUpdateTest struct {
@@ -403,10 +356,9 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
*expSession = *test.session
}
if len(test.updates) != len(test.updateErrs) {
h.t.Fatalf("malformed test case, num updates " +
"should match num errors")
}
require.Lenf(h.t, test.updates, len(test.updateErrs),
"malformed test case, num updates should match num "+
"errors")
// Send any updates provided in the test.
for i, update := range test.updates {
@@ -430,10 +382,7 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
expSession.ClientLastApplied = update.LastApplied
match := h.hasUpdate(update.Hint)
if !reflect.DeepEqual(match.SessionInfo, expSession) {
h.t.Fatalf("expected session: %v, got: %v",
expSession, match.SessionInfo)
}
require.Equal(h.t, expSession, match.SessionInfo)
}
}
}
@@ -640,14 +589,10 @@ func TestTowerDB(t *testing.T) {
bdb, err := wtdb.NewBoltBackendCreator(
true, path, "watchtower.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err := wtdb.OpenTowerDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
@@ -664,14 +609,10 @@ func TestTowerDB(t *testing.T) {
bdb, err := wtdb.NewBoltBackendCreator(
true, path, "watchtower.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err := wtdb.OpenTowerDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db.Close()
// Open the db again, ensuring we test a
@@ -680,14 +621,10 @@ func TestTowerDB(t *testing.T) {
bdb, err = wtdb.NewBoltBackendCreator(
true, path, "watchtower.db",
)(dbCfg)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
db, err = wtdb.OpenTowerDB(bdb)
if err != nil {
t.Fatalf("unable to open db: %v", err)
}
require.NoError(t, err)
t.Cleanup(func() {
db.Close()