mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-29 15:11:09 +02:00
watchtower/wtdb: update tests to use require package
In this commit, all the tests in the wtdb package are updated in order to make use of the `require` package where appropriate.
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
package wtdb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
@@ -16,6 +14,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// clientDBInit is a closure used to initialize a wtclient.DB instance.
|
||||
@@ -43,10 +42,7 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.CreateClientSession(session)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected create client session error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) listSessions(
|
||||
@@ -55,9 +51,7 @@ func (h *clientDBHarness) listSessions(
|
||||
h.t.Helper()
|
||||
|
||||
sessions, err := h.db.ListClientSessions(id)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to list client sessions: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err, "unable to list client sessions")
|
||||
|
||||
return sessions
|
||||
}
|
||||
@@ -68,13 +62,8 @@ func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID,
|
||||
h.t.Helper()
|
||||
|
||||
index, err := h.db.NextSessionKeyIndex(id, blobType)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to create next session key index: %v", err)
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
h.t.Fatalf("next key index should never be 0")
|
||||
}
|
||||
require.NoError(h.t, err, "unable to create next session key index")
|
||||
require.NotZero(h.t, index, "next key index should never be 0")
|
||||
|
||||
return index
|
||||
}
|
||||
@@ -85,21 +74,11 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.CreateTower(lnAddr)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected create tower error: %v, got: %v", expErr,
|
||||
err)
|
||||
}
|
||||
|
||||
if tower.ID == 0 {
|
||||
h.t.Fatalf("tower id should never be 0")
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
require.NotZero(h.t, tower.ID, "tower id should never be 0")
|
||||
|
||||
for _, session := range h.listSessions(&tower.ID) {
|
||||
if session.Status != wtdb.CSessionActive {
|
||||
h.t.Fatalf("expected status for session %v to be %v, "+
|
||||
"got %v", session.ID, wtdb.CSessionActive,
|
||||
session.Status)
|
||||
}
|
||||
require.Equal(h.t, wtdb.CSessionActive, session.Status)
|
||||
}
|
||||
|
||||
return tower
|
||||
@@ -110,10 +89,9 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
if err := h.db.RemoveTower(pubKey, addr); err != expErr {
|
||||
h.t.Fatalf("expected remove tower error: %v, got %v", expErr,
|
||||
err)
|
||||
}
|
||||
err := h.db.RemoveTower(pubKey, addr)
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
if expErr != nil {
|
||||
return
|
||||
}
|
||||
@@ -122,37 +100,31 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
|
||||
|
||||
if addr != nil {
|
||||
tower, err := h.db.LoadTower(pubKey)
|
||||
if err != nil {
|
||||
h.t.Fatalf("expected tower %x to still exist",
|
||||
pubKeyStr)
|
||||
}
|
||||
require.NoErrorf(h.t, err, "expected tower %x to still exist",
|
||||
pubKeyStr)
|
||||
|
||||
removedAddr := addr.String()
|
||||
for _, towerAddr := range tower.Addresses {
|
||||
if towerAddr.String() == removedAddr {
|
||||
h.t.Fatalf("address %v not removed for tower "+
|
||||
"%x", removedAddr, pubKeyStr)
|
||||
}
|
||||
require.NotEqualf(h.t, removedAddr, towerAddr,
|
||||
"address %v not removed for tower %x",
|
||||
removedAddr, pubKeyStr)
|
||||
}
|
||||
} else {
|
||||
tower, err := h.db.LoadTower(pubKey)
|
||||
if hasSessions && err != nil {
|
||||
h.t.Fatalf("expected tower %x with sessions to still "+
|
||||
"exist", pubKeyStr)
|
||||
}
|
||||
if !hasSessions && err == nil {
|
||||
h.t.Fatalf("expected tower %x with no sessions to not "+
|
||||
"exist", pubKeyStr)
|
||||
}
|
||||
if !hasSessions {
|
||||
if hasSessions {
|
||||
require.NoError(h.t, err, "expected tower %x with "+
|
||||
"sessions to still exist", pubKeyStr)
|
||||
} else {
|
||||
require.Errorf(h.t, err, "expected tower %x with no "+
|
||||
"sessions to not exist", pubKeyStr)
|
||||
return
|
||||
}
|
||||
|
||||
for _, session := range h.listSessions(&tower.ID) {
|
||||
if session.Status != wtdb.CSessionInactive {
|
||||
h.t.Fatalf("expected status for session %v to "+
|
||||
"be %v, got %v", session.ID,
|
||||
wtdb.CSessionInactive, session.Status)
|
||||
}
|
||||
require.Equal(h.t, wtdb.CSessionInactive,
|
||||
session.Status, "expected status for session "+
|
||||
"%v to be %v, got %v", session.ID,
|
||||
wtdb.CSessionInactive, session.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -163,10 +135,7 @@ func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey,
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.LoadTower(pubKey)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected load tower error: %v, got: %v", expErr,
|
||||
err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return tower
|
||||
}
|
||||
@@ -177,10 +146,7 @@ func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID,
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.LoadTowerByID(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected load tower error: %v, got: %v", expErr,
|
||||
err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return tower
|
||||
}
|
||||
@@ -189,9 +155,7 @@ func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientC
|
||||
h.t.Helper()
|
||||
|
||||
summaries, err := h.db.FetchChanSummaries()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch chan summaries: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
return summaries
|
||||
}
|
||||
@@ -202,10 +166,7 @@ func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID,
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.RegisterChannel(chanID, sweepPkScript)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected register channel error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
|
||||
@@ -214,10 +175,7 @@ func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
|
||||
h.t.Helper()
|
||||
|
||||
lastApplied, err := h.db.CommitUpdate(id, update)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected commit update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return lastApplied
|
||||
}
|
||||
@@ -228,10 +186,7 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.AckUpdate(id, seqNum, lastApplied)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected commit update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
// testCreateClientSession asserts various conditions regarding the creation of
|
||||
@@ -259,9 +214,9 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
|
||||
// First, assert that this session is not already present in the
|
||||
// database.
|
||||
if _, ok := h.listSessions(nil)[session.ID]; ok {
|
||||
h.t.Fatalf("session for id %x should not exist yet", session.ID)
|
||||
}
|
||||
_, ok := h.listSessions(nil)[session.ID]
|
||||
require.Falsef(h.t, ok, "session for id %x should not exist yet",
|
||||
session.ID)
|
||||
|
||||
// Attempting to insert the client session without reserving a session
|
||||
// key index should fail.
|
||||
@@ -278,10 +233,8 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
// successfully created, it should return the same index to maintain
|
||||
// idempotency across restarts.
|
||||
keyIndex2 := h.nextKeyIndex(session.TowerID, blobType)
|
||||
if keyIndex != keyIndex2 {
|
||||
h.t.Fatalf("next key index should be idempotent: want: %v, "+
|
||||
"got %v", keyIndex, keyIndex2)
|
||||
}
|
||||
require.Equalf(h.t, keyIndex, keyIndex2, "next key index should "+
|
||||
"be idempotent: want: %v, got %v", keyIndex, keyIndex2)
|
||||
|
||||
// Now, set the client session's key index so that it is proper and
|
||||
// insert it. This should succeed.
|
||||
@@ -289,9 +242,8 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Verify that the session now exists in the database.
|
||||
if _, ok := h.listSessions(nil)[session.ID]; !ok {
|
||||
h.t.Fatalf("session for id %x should exist now", session.ID)
|
||||
}
|
||||
_, ok = h.listSessions(nil)[session.ID]
|
||||
require.Truef(h.t, ok, "session for id %x should exist now", session.ID)
|
||||
|
||||
// Attempt to insert the session again, which should fail due to the
|
||||
// session already existing.
|
||||
@@ -300,9 +252,8 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
// Finally, assert that reserving another key index succeeds with a
|
||||
// different key index, now that the first one has been finalized.
|
||||
keyIndex3 := h.nextKeyIndex(session.TowerID, blobType)
|
||||
if keyIndex == keyIndex3 {
|
||||
h.t.Fatalf("key index still reserved after creating session")
|
||||
}
|
||||
require.NotEqualf(h.t, keyIndex, keyIndex3, "key index still "+
|
||||
"reserved after creating session")
|
||||
}
|
||||
|
||||
// testFilterClientSessions asserts that we can correctly filter client sessions
|
||||
@@ -343,15 +294,12 @@ func testFilterClientSessions(h *clientDBHarness) {
|
||||
// them.
|
||||
for towerID, expectedSessions := range towerSessions {
|
||||
sessions := h.listSessions(&towerID)
|
||||
if len(sessions) != len(expectedSessions) {
|
||||
h.t.Fatalf("expected %v sessions for tower %v, got %v",
|
||||
len(expectedSessions), towerID, len(sessions))
|
||||
}
|
||||
require.Len(h.t, sessions, len(expectedSessions))
|
||||
|
||||
for _, expectedSession := range expectedSessions {
|
||||
if _, ok := sessions[expectedSession]; !ok {
|
||||
h.t.Fatalf("expected session %v for tower %v",
|
||||
expectedSession, towerID)
|
||||
}
|
||||
_, ok := sessions[expectedSession]
|
||||
require.Truef(h.t, ok, "expected session %v for "+
|
||||
"tower %v", expectedSession, towerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -380,26 +328,18 @@ func testCreateTower(h *clientDBHarness) {
|
||||
// Load the tower from the database and assert that it matches the tower
|
||||
// we created.
|
||||
tower2 := h.loadTowerByID(tower.ID, nil)
|
||||
if !reflect.DeepEqual(tower, tower2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
tower, tower2)
|
||||
}
|
||||
tower2 = h.loadTower(pk, err)
|
||||
if !reflect.DeepEqual(tower, tower2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
tower, tower2)
|
||||
}
|
||||
require.Equal(h.t, tower, tower2)
|
||||
|
||||
tower2 = h.loadTower(tower.IdentityKey, nil)
|
||||
require.Equal(h.t, tower, tower2)
|
||||
|
||||
// Insert the address again into the database. Since the address is the
|
||||
// same, this should result in an unmodified tower record.
|
||||
towerDupAddr := h.createTower(lnAddr, nil)
|
||||
if len(towerDupAddr.Addresses) != 1 {
|
||||
h.t.Fatalf("duplicate address should be deduped")
|
||||
}
|
||||
if !reflect.DeepEqual(tower, towerDupAddr) {
|
||||
h.t.Fatalf("mismatch towers, want: %v, got: %v",
|
||||
tower, towerDupAddr)
|
||||
}
|
||||
require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+
|
||||
"should be deduped")
|
||||
|
||||
require.Equal(h.t, tower, towerDupAddr)
|
||||
|
||||
// Generate a new address for this tower.
|
||||
addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
@@ -416,26 +356,18 @@ func testCreateTower(h *clientDBHarness) {
|
||||
// Load the tower from the database, and assert that it matches the
|
||||
// tower returned from creation.
|
||||
towerNewAddr2 := h.loadTowerByID(tower.ID, nil)
|
||||
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
towerNewAddr, towerNewAddr2)
|
||||
}
|
||||
towerNewAddr2 = h.loadTower(pk, nil)
|
||||
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
towerNewAddr, towerNewAddr2)
|
||||
}
|
||||
require.Equal(h.t, towerNewAddr, towerNewAddr2)
|
||||
|
||||
towerNewAddr2 = h.loadTower(tower.IdentityKey, nil)
|
||||
require.Equal(h.t, towerNewAddr, towerNewAddr2)
|
||||
|
||||
// Assert that there are now two addresses on the tower object.
|
||||
if len(towerNewAddr.Addresses) != 2 {
|
||||
h.t.Fatalf("new address should be added")
|
||||
}
|
||||
require.Lenf(h.t, towerNewAddr.Addresses, 2, "new address should be "+
|
||||
"added")
|
||||
|
||||
// Finally, assert that the new address was prepended since it is deemed
|
||||
// fresher.
|
||||
if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) {
|
||||
h.t.Fatalf("new address should be prepended")
|
||||
}
|
||||
require.Equal(h.t, tower.Addresses, towerNewAddr.Addresses[1:])
|
||||
}
|
||||
|
||||
// testRemoveTower asserts the behavior of removing Tower objects as a whole and
|
||||
@@ -443,9 +375,7 @@ func testCreateTower(h *clientDBHarness) {
|
||||
func testRemoveTower(h *clientDBHarness) {
|
||||
// Generate a random public key we'll use for our tower.
|
||||
pk, err := randPubKey()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to generate pubkey: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Removing a tower that does not exist within the database should
|
||||
// result in a NOP.
|
||||
@@ -523,28 +453,23 @@ func testRemoveTower(h *clientDBHarness) {
|
||||
func testChanSummaries(h *clientDBHarness) {
|
||||
// First, assert that this channel is not already registered.
|
||||
var chanID lnwire.ChannelID
|
||||
if _, ok := h.fetchChanSummaries()[chanID]; ok {
|
||||
h.t.Fatalf("pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
}
|
||||
_, ok := h.fetchChanSummaries()[chanID]
|
||||
require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
|
||||
// Generate a random sweep pkscript and register it for this channel.
|
||||
expPkScript := make([]byte, 22)
|
||||
if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil {
|
||||
h.t.Fatalf("unable to generate pkscript: %v", err)
|
||||
}
|
||||
_, err := io.ReadFull(crand.Reader, expPkScript)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
h.registerChan(chanID, expPkScript, nil)
|
||||
|
||||
// Assert that the channel exists and that its sweep pkscript matches
|
||||
// the one we registered.
|
||||
summary, ok := h.fetchChanSummaries()[chanID]
|
||||
if !ok {
|
||||
h.t.Fatalf("pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
} else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 {
|
||||
h.t.Fatalf("pkscript mismatch, want: %x, got: %x",
|
||||
expPkScript, summary.SweepPkScript)
|
||||
}
|
||||
require.Truef(h.t, ok, "pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
require.Equal(h.t, expPkScript, summary.SweepPkScript)
|
||||
|
||||
// Finally, assert that re-registering the same channel produces a
|
||||
// failure.
|
||||
@@ -581,10 +506,7 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
// succeed. The lastApplied value should be 0 since we have not received
|
||||
// an ack from the tower.
|
||||
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied != 0 {
|
||||
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
require.Zero(h.t, lastApplied)
|
||||
|
||||
// Assert that the committed update appears in the client session's
|
||||
// CommittedUpdates map when loaded from disk and that there are no
|
||||
@@ -600,10 +522,7 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
// the on-disk update's hint). The lastApplied value should remain
|
||||
// unchanged.
|
||||
lastApplied2 := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied2 != lastApplied {
|
||||
h.t.Fatalf("last applied should not have changed, got %v",
|
||||
lastApplied2)
|
||||
}
|
||||
require.Equal(h.t, lastApplied, lastApplied2)
|
||||
|
||||
// Assert that the loaded ClientSession is the same as before.
|
||||
dbSession = h.listSessions(nil)[session.ID]
|
||||
@@ -621,10 +540,7 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
// which should succeed.
|
||||
update2.SeqNum = 2
|
||||
lastApplied3 := h.commitUpdate(&session.ID, update2, nil)
|
||||
if lastApplied3 != lastApplied {
|
||||
h.t.Fatalf("last applied should not have changed, got %v",
|
||||
lastApplied3)
|
||||
}
|
||||
require.Equal(h.t, lastApplied, lastApplied3)
|
||||
|
||||
// Check that both updates now appear as committed on the ClientSession
|
||||
// loaded from disk.
|
||||
@@ -684,10 +600,7 @@ func testAckUpdate(h *clientDBHarness) {
|
||||
// Commit to a random update at seqnum 1.
|
||||
update1 := randCommittedUpdate(h.t, 1)
|
||||
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied != 0 {
|
||||
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
require.Zero(h.t, lastApplied)
|
||||
|
||||
// Acking seqnum 1 should succeed.
|
||||
h.ackUpdate(&session.ID, 1, 1, nil)
|
||||
@@ -715,10 +628,7 @@ func testAckUpdate(h *clientDBHarness) {
|
||||
// ack.
|
||||
update2 := randCommittedUpdate(h.t, 2)
|
||||
lastApplied = h.commitUpdate(&session.ID, update2, nil)
|
||||
if lastApplied != 1 {
|
||||
h.t.Fatalf("last applied mismatch, want: 1, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
require.EqualValues(h.t, 1, lastApplied)
|
||||
|
||||
// Ack seqnum 2.
|
||||
h.ackUpdate(&session.ID, 2, 2, nil)
|
||||
@@ -756,10 +666,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates = make([]wtdb.CommittedUpdate, 0)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) {
|
||||
t.Fatalf("committed updates mismatch, want: %v, got: %v",
|
||||
expUpdates, session.CommittedUpdates)
|
||||
}
|
||||
require.Equal(t, expUpdates, session.CommittedUpdates)
|
||||
}
|
||||
|
||||
// checkAckedUpdates asserts that the AckedUpdates on a session match the
|
||||
@@ -774,10 +681,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates = make(map[uint16]wtdb.BackupID)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(session.AckedUpdates, expUpdates) {
|
||||
t.Fatalf("acked updates mismatch, want: %v, got: %v",
|
||||
expUpdates, session.AckedUpdates)
|
||||
}
|
||||
require.Equal(t, expUpdates, session.AckedUpdates)
|
||||
}
|
||||
|
||||
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
|
||||
@@ -795,14 +699,10 @@ func TestClientDB(t *testing.T) {
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, t.TempDir(), "wtclient.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
@@ -819,27 +719,19 @@ func TestClientDB(t *testing.T) {
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, path, "wtclient.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
db.Close()
|
||||
|
||||
bdb, err = wtdb.NewBoltBackendCreator(
|
||||
true, path, "wtclient.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err = wtdb.OpenClientDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to reopen db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
@@ -909,19 +801,16 @@ func TestClientDB(t *testing.T) {
|
||||
// randCommittedUpdate generates a random committed update.
|
||||
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
|
||||
var chanID lnwire.ChannelID
|
||||
if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil {
|
||||
t.Fatalf("unable to generate chan id: %v", err)
|
||||
}
|
||||
_, err := io.ReadFull(crand.Reader, chanID[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var hint blob.BreachHint
|
||||
if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil {
|
||||
t.Fatalf("unable to generate breach hint: %v", err)
|
||||
}
|
||||
_, err = io.ReadFull(crand.Reader, hint[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
|
||||
if _, err := io.ReadFull(crand.Reader, encBlob); err != nil {
|
||||
t.Fatalf("unable to generate encrypted blob: %v", err)
|
||||
}
|
||||
_, err = io.ReadFull(crand.Reader, encBlob)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &wtdb.CommittedUpdate{
|
||||
SeqNum: seqNum,
|
||||
|
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
"github.com/lightningnetwork/lnd/tor"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func randPubKey() (*btcec.PublicKey, error) {
|
||||
@@ -134,10 +135,7 @@ func TestCodec(tt *testing.T) {
|
||||
// Ensure encoding the object succeeds.
|
||||
var b bytes.Buffer
|
||||
err := obj.Encode(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode: %v", err)
|
||||
return false
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
var obj2 dbObject
|
||||
switch obj.(type) {
|
||||
@@ -162,17 +160,10 @@ func TestCodec(tt *testing.T) {
|
||||
|
||||
// Ensure decoding the object succeeds.
|
||||
err = obj2.Decode(bytes.NewReader(b.Bytes()))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode: %v", err)
|
||||
return false
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert the original and decoded object match.
|
||||
if !reflect.DeepEqual(obj, obj2) {
|
||||
t.Fatalf("encode/decode mismatch, want: %v, "+
|
||||
"got: %v", obj, obj2)
|
||||
return false
|
||||
}
|
||||
require.Equal(t, obj, obj2)
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -180,16 +171,10 @@ func TestCodec(tt *testing.T) {
|
||||
customTypeGen := map[string]func([]reflect.Value, *rand.Rand){
|
||||
"Tower": func(v []reflect.Value, r *rand.Rand) {
|
||||
pk, err := randPubKey()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate pubkey: %v", err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
addrs, err := randAddrs(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate addrs: %v", err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
obj := wtdb.Tower{
|
||||
IdentityKey: pk,
|
||||
@@ -260,10 +245,7 @@ func TestCodec(tt *testing.T) {
|
||||
}
|
||||
|
||||
err := quick.Check(test.scenario, config)
|
||||
if err != nil {
|
||||
t.Fatalf("fuzz checks for msg=%s failed: %v",
|
||||
test.name, err)
|
||||
}
|
||||
require.NoError(h, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -3,7 +3,6 @@ package wtdb_test
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -48,10 +48,7 @@ func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.InsertSessionInfo(s)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected insert session error: %v, got : %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
// getSession retrieves the session identified by id, asserting that the call
|
||||
@@ -62,10 +59,7 @@ func (h *towerDBHarness) getSession(id *wtdb.SessionID,
|
||||
h.t.Helper()
|
||||
|
||||
session, err := h.db.GetSessionInfo(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected get session error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return session
|
||||
}
|
||||
@@ -79,10 +73,7 @@ func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate,
|
||||
h.t.Helper()
|
||||
|
||||
lastApplied, err := h.db.InsertStateUpdate(s)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected insert update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
|
||||
return lastApplied
|
||||
}
|
||||
@@ -93,10 +84,7 @@ func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.DeleteSession(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected deletion error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
// queryMatches queries that database for the passed breach hint, returning all
|
||||
@@ -105,9 +93,7 @@ func (h *towerDBHarness) queryMatches(hint blob.BreachHint) []wtdb.Match {
|
||||
h.t.Helper()
|
||||
|
||||
matches, err := h.db.QueryMatches([]blob.BreachHint{hint})
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to query matches: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
return matches
|
||||
}
|
||||
@@ -119,14 +105,10 @@ func (h *towerDBHarness) hasUpdate(hint blob.BreachHint) wtdb.Match {
|
||||
h.t.Helper()
|
||||
|
||||
matches := h.queryMatches(hint)
|
||||
if len(matches) != 1 {
|
||||
h.t.Fatalf("expected 1 match, found: %d", len(matches))
|
||||
}
|
||||
require.Len(h.t, matches, 1)
|
||||
|
||||
match := matches[0]
|
||||
if match.Hint != hint {
|
||||
h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint)
|
||||
}
|
||||
require.Equal(h.t, hint, match.Hint)
|
||||
|
||||
return match
|
||||
}
|
||||
@@ -158,11 +140,7 @@ func testInsertSession(h *towerDBHarness) {
|
||||
h.insertSession(session, nil)
|
||||
|
||||
session2 := h.getSession(&id, nil)
|
||||
|
||||
if !reflect.DeepEqual(session, session2) {
|
||||
h.t.Fatalf("expected session: %v, got %v",
|
||||
session, session2)
|
||||
}
|
||||
require.Equal(h.t, session, session2)
|
||||
|
||||
h.insertSession(session, nil)
|
||||
|
||||
@@ -211,28 +189,21 @@ func testMultipleMatches(h *towerDBHarness) {
|
||||
|
||||
// Query the db for matches on the chosen hint.
|
||||
matches := h.queryMatches(hint)
|
||||
if len(matches) != numUpdates {
|
||||
h.t.Fatalf("num updates mismatch, want: %d, got: %d",
|
||||
numUpdates, len(matches))
|
||||
}
|
||||
require.Len(h.t, matches, numUpdates)
|
||||
|
||||
// Assert that the hints are what we asked for, and compute the set of
|
||||
// sessions returned.
|
||||
sessions := make(map[wtdb.SessionID]struct{})
|
||||
for _, match := range matches {
|
||||
if match.Hint != hint {
|
||||
h.t.Fatalf("hint mismatch, want: %v, got: %v",
|
||||
hint, match.Hint)
|
||||
}
|
||||
require.Equal(h.t, hint, match.Hint)
|
||||
sessions[match.ID] = struct{}{}
|
||||
}
|
||||
|
||||
// Assert that the sessions returned match the session ids of the
|
||||
// sessions we initially created.
|
||||
for i := 0; i < numUpdates; i++ {
|
||||
if _, ok := sessions[*id(i)]; !ok {
|
||||
h.t.Fatalf("match for session %v not found", *id(i))
|
||||
}
|
||||
_, ok := sessions[*id(i)]
|
||||
require.Truef(h.t, ok, "match for session %v not found", *id(i))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,33 +213,22 @@ func testMultipleMatches(h *towerDBHarness) {
|
||||
func testLookoutTip(h *towerDBHarness) {
|
||||
// Retrieve lookout tip on fresh db.
|
||||
epoch, err := h.db.GetLookoutTip()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch lookout tip: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Assert that the epoch is nil.
|
||||
if epoch != nil {
|
||||
h.t.Fatalf("lookout tip should not be set, found: %v", epoch)
|
||||
}
|
||||
require.Nil(h.t, epoch)
|
||||
|
||||
// Create a closure that inserts an epoch, retrieves it, and asserts
|
||||
// that the returned epoch matches what was inserted.
|
||||
setAndCheck := func(i int) {
|
||||
expEpoch := epochFromInt(1)
|
||||
err = h.db.SetLookoutTip(expEpoch)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to set lookout tip: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
epoch, err = h.db.GetLookoutTip()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch lookout tip: %v", err)
|
||||
}
|
||||
require.NoError(h.t, err)
|
||||
|
||||
if !reflect.DeepEqual(epoch, expEpoch) {
|
||||
h.t.Fatalf("lookout tip mismatch, want: %v, got: %v",
|
||||
expEpoch, epoch)
|
||||
}
|
||||
require.Equal(h.t, expEpoch, epoch)
|
||||
}
|
||||
|
||||
// Set and assert the lookout tip.
|
||||
@@ -348,15 +308,10 @@ func testDeleteSession(h *towerDBHarness) {
|
||||
|
||||
// Assert that only one update is still present.
|
||||
matches := h.queryMatches(hint)
|
||||
if len(matches) != 1 {
|
||||
h.t.Fatalf("expected one update, found: %d", len(matches))
|
||||
}
|
||||
require.Len(h.t, matches, 1)
|
||||
|
||||
// Assert that the update belongs to the first session.
|
||||
if matches[0].ID != *id0 {
|
||||
h.t.Fatalf("expected match for %v, instead is for: %v",
|
||||
*id0, matches[0].ID)
|
||||
}
|
||||
require.Equal(h.t, *id0, matches[0].ID)
|
||||
|
||||
// Finally, remove the first session added.
|
||||
h.deleteSession(*id0, nil)
|
||||
@@ -366,9 +321,7 @@ func testDeleteSession(h *towerDBHarness) {
|
||||
|
||||
// No matches should exist for this hint.
|
||||
matches = h.queryMatches(hint)
|
||||
if len(matches) != 0 {
|
||||
h.t.Fatalf("expected zero updates, found: %d", len(matches))
|
||||
}
|
||||
require.Zero(h.t, len(matches))
|
||||
}
|
||||
|
||||
type stateUpdateTest struct {
|
||||
@@ -403,10 +356,9 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
|
||||
*expSession = *test.session
|
||||
}
|
||||
|
||||
if len(test.updates) != len(test.updateErrs) {
|
||||
h.t.Fatalf("malformed test case, num updates " +
|
||||
"should match num errors")
|
||||
}
|
||||
require.Lenf(h.t, test.updates, len(test.updateErrs),
|
||||
"malformed test case, num updates should match num "+
|
||||
"errors")
|
||||
|
||||
// Send any updates provided in the test.
|
||||
for i, update := range test.updates {
|
||||
@@ -430,10 +382,7 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
|
||||
expSession.ClientLastApplied = update.LastApplied
|
||||
|
||||
match := h.hasUpdate(update.Hint)
|
||||
if !reflect.DeepEqual(match.SessionInfo, expSession) {
|
||||
h.t.Fatalf("expected session: %v, got: %v",
|
||||
expSession, match.SessionInfo)
|
||||
}
|
||||
require.Equal(h.t, expSession, match.SessionInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -640,14 +589,10 @@ func TestTowerDB(t *testing.T) {
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, path, "watchtower.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenTowerDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
@@ -664,14 +609,10 @@ func TestTowerDB(t *testing.T) {
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, path, "watchtower.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenTowerDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
db.Close()
|
||||
|
||||
// Open the db again, ensuring we test a
|
||||
@@ -680,14 +621,10 @@ func TestTowerDB(t *testing.T) {
|
||||
bdb, err = wtdb.NewBoltBackendCreator(
|
||||
true, path, "watchtower.db",
|
||||
)(dbCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err = wtdb.OpenTowerDB(bdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
|
Reference in New Issue
Block a user