From f815c88ee4a6dd59c9979aef910f6713cbad27a4 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 14:59:11 +0200 Subject: [PATCH 1/8] watchtower: fix formatting In order to make upcoming commits in the PR easier to parse, this commit makes some basic formatting changes to some of the watchtower files. --- watchtower/wtclient/interface.go | 7 ++--- watchtower/wtdb/client_db.go | 15 ++++++++--- watchtower/wtdb/client_db_test.go | 44 +++++++++++++++++++++---------- 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 69f367293..dbf2faf71 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -62,7 +62,8 @@ type DB interface { // still be able to accept state updates. An optional tower ID can be // used to filter out any client sessions in the response that do not // correspond to this tower. - ListClientSessions(*wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) + ListClientSessions(*wtdb.TowerID) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) // FetchChanSummaries loads a mapping from all registered channels to // their channel summaries. @@ -96,8 +97,8 @@ type DB interface { AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error } -// AuthDialer connects to a remote node using an authenticated transport, such as -// brontide. The dialer argument is used to specify a resolver, which allows +// AuthDialer connects to a remote node using an authenticated transport, such +// as brontide. The dialer argument is used to specify a resolver, which allows // this method to be used over Tor or clear net connections. type AuthDialer func(localKey keychain.SingleKeyECDH, netAddr *lnwire.NetAddress, diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 91df574d2..c868b949f 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -113,7 +113,8 @@ var ( // NewBoltBackendCreator returns a function that creates a new bbolt backend for // the watchtower database. func NewBoltBackendCreator(active bool, dbPath, - dbFileName string) func(boltCfg *kvdb.BoltConfig) (kvdb.Backend, error) { + dbFileName string) func(boltCfg *kvdb.BoltConfig) (kvdb.Backend, + error) { // If the watchtower client isn't active, we return a function that // always returns a nil DB to make sure we don't create empty database @@ -575,7 +576,9 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { // Check that client session with this session id doesn't // already exist. - existingSessionBytes := sessions.NestedReadWriteBucket(session.ID[:]) + existingSessionBytes := sessions.NestedReadWriteBucket( + session.ID[:], + ) if existingSessionBytes != nil { return ErrClientSessionAlreadyExists } @@ -662,7 +665,9 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID, // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession, error) { +func (c *ClientDB) ListClientSessions(id *TowerID) ( + map[SessionID]*ClientSession, error) { + var clientSessions map[SessionID]*ClientSession err := kvdb.View(c.db, func(tx kvdb.RTx) error { sessions := tx.ReadBucket(cSessionBkt) @@ -951,7 +956,9 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, // If the commits sub-bucket doesn't exist, there can't possibly // be a corresponding committed update to remove. - sessionCommits := sessionBkt.NestedReadWriteBucket(cSessionCommits) + sessionCommits := sessionBkt.NestedReadWriteBucket( + cSessionCommits, + ) if sessionCommits == nil { return ErrCommittedUpdateNotFound } diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index f694b746f..42189df41 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -37,7 +37,9 @@ func newClientDBHarness(t *testing.T, init clientDBInit) *clientDBHarness { return h } -func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) { +func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, + expErr error) { + h.t.Helper() err := h.db.CreateClientSession(session) @@ -47,7 +49,9 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr erro } } -func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { +func (h *clientDBHarness) listSessions( + id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { + h.t.Helper() sessions, err := h.db.ListClientSessions(id) @@ -82,7 +86,8 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress, tower, err := h.db.CreateTower(lnAddr) if err != expErr { - h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err) + h.t.Fatalf("expected create tower error: %v, got: %v", expErr, + err) } if tower.ID == 0 { @@ -106,35 +111,38 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr, h.t.Helper() if err := h.db.RemoveTower(pubKey, addr); err != expErr { - h.t.Fatalf("expected remove tower error: %v, got %v", expErr, err) + h.t.Fatalf("expected remove tower error: %v, got %v", expErr, + err) } if expErr != nil { return } + pubKeyStr := pubKey.SerializeCompressed() + if addr != nil { tower, err := h.db.LoadTower(pubKey) if err != nil { h.t.Fatalf("expected tower %x to still exist", - pubKey.SerializeCompressed()) + pubKeyStr) } removedAddr := addr.String() for _, towerAddr := range tower.Addresses { if towerAddr.String() == removedAddr { - h.t.Fatalf("address %v not removed for tower %x", - removedAddr, pubKey.SerializeCompressed()) + h.t.Fatalf("address %v not removed for tower "+ + "%x", removedAddr, pubKeyStr) } } } else { tower, err := h.db.LoadTower(pubKey) if hasSessions && err != nil { h.t.Fatalf("expected tower %x with sessions to still "+ - "exist", pubKey.SerializeCompressed()) + "exist", pubKeyStr) } if !hasSessions && err == nil { h.t.Fatalf("expected tower %x with no sessions to not "+ - "exist", pubKey.SerializeCompressed()) + "exist", pubKeyStr) } if !hasSessions { return @@ -149,23 +157,29 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr, } } -func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey, expErr error) *wtdb.Tower { +func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey, + expErr error) *wtdb.Tower { + h.t.Helper() tower, err := h.db.LoadTower(pubKey) if err != expErr { - h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err) + h.t.Fatalf("expected load tower error: %v, got: %v", expErr, + err) } return tower } -func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, expErr error) *wtdb.Tower { +func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, + expErr error) *wtdb.Tower { + h.t.Helper() tower, err := h.db.LoadTowerByID(id) if err != expErr { - h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err) + h.t.Fatalf("expected load tower error: %v, got: %v", expErr, + err) } return tower @@ -320,7 +334,9 @@ func testFilterClientSessions(h *clientDBHarness) { }, ID: sessionID, }, nil) - towerSessions[towerID] = append(towerSessions[towerID], sessionID) + towerSessions[towerID] = append( + towerSessions[towerID], sessionID, + ) } // We should see the expected sessions for each tower when filtering From 5dabf7cb3e7e2720c3a8ed9a2c8490db9866d87b Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 10 Oct 2022 12:47:08 +0200 Subject: [PATCH 2/8] watchtower/wtdb: update tests to use require package In this commit, all the tests in the wtdb package are updated in order to make use of the `require` package where appropriate. --- watchtower/wtdb/client_db_test.go | 293 ++++++++++-------------------- watchtower/wtdb/codec_test.go | 32 +--- watchtower/wtdb/tower_db_test.go | 125 ++++--------- 3 files changed, 129 insertions(+), 321 deletions(-) diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 42189df41..b0e4af6a3 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -1,11 +1,9 @@ package wtdb_test import ( - "bytes" crand "crypto/rand" "io" "net" - "reflect" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -16,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/stretchr/testify/require" ) // clientDBInit is a closure used to initialize a wtclient.DB instance. @@ -43,10 +42,7 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, h.t.Helper() err := h.db.CreateClientSession(session) - if err != expErr { - h.t.Fatalf("expected create client session error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } func (h *clientDBHarness) listSessions( @@ -55,9 +51,7 @@ func (h *clientDBHarness) listSessions( h.t.Helper() sessions, err := h.db.ListClientSessions(id) - if err != nil { - h.t.Fatalf("unable to list client sessions: %v", err) - } + require.NoError(h.t, err, "unable to list client sessions") return sessions } @@ -68,13 +62,8 @@ func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, h.t.Helper() index, err := h.db.NextSessionKeyIndex(id, blobType) - if err != nil { - h.t.Fatalf("unable to create next session key index: %v", err) - } - - if index == 0 { - h.t.Fatalf("next key index should never be 0") - } + require.NoError(h.t, err, "unable to create next session key index") + require.NotZero(h.t, index, "next key index should never be 0") return index } @@ -85,21 +74,11 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress, h.t.Helper() tower, err := h.db.CreateTower(lnAddr) - if err != expErr { - h.t.Fatalf("expected create tower error: %v, got: %v", expErr, - err) - } - - if tower.ID == 0 { - h.t.Fatalf("tower id should never be 0") - } + require.ErrorIs(h.t, err, expErr) + require.NotZero(h.t, tower.ID, "tower id should never be 0") for _, session := range h.listSessions(&tower.ID) { - if session.Status != wtdb.CSessionActive { - h.t.Fatalf("expected status for session %v to be %v, "+ - "got %v", session.ID, wtdb.CSessionActive, - session.Status) - } + require.Equal(h.t, wtdb.CSessionActive, session.Status) } return tower @@ -110,10 +89,9 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr, h.t.Helper() - if err := h.db.RemoveTower(pubKey, addr); err != expErr { - h.t.Fatalf("expected remove tower error: %v, got %v", expErr, - err) - } + err := h.db.RemoveTower(pubKey, addr) + require.ErrorIs(h.t, err, expErr) + if expErr != nil { return } @@ -122,37 +100,31 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr, if addr != nil { tower, err := h.db.LoadTower(pubKey) - if err != nil { - h.t.Fatalf("expected tower %x to still exist", - pubKeyStr) - } + require.NoErrorf(h.t, err, "expected tower %x to still exist", + pubKeyStr) removedAddr := addr.String() for _, towerAddr := range tower.Addresses { - if towerAddr.String() == removedAddr { - h.t.Fatalf("address %v not removed for tower "+ - "%x", removedAddr, pubKeyStr) - } + require.NotEqualf(h.t, removedAddr, towerAddr, + "address %v not removed for tower %x", + removedAddr, pubKeyStr) } } else { tower, err := h.db.LoadTower(pubKey) - if hasSessions && err != nil { - h.t.Fatalf("expected tower %x with sessions to still "+ - "exist", pubKeyStr) - } - if !hasSessions && err == nil { - h.t.Fatalf("expected tower %x with no sessions to not "+ - "exist", pubKeyStr) - } - if !hasSessions { + if hasSessions { + require.NoError(h.t, err, "expected tower %x with "+ + "sessions to still exist", pubKeyStr) + } else { + require.Errorf(h.t, err, "expected tower %x with no "+ + "sessions to not exist", pubKeyStr) return } + for _, session := range h.listSessions(&tower.ID) { - if session.Status != wtdb.CSessionInactive { - h.t.Fatalf("expected status for session %v to "+ - "be %v, got %v", session.ID, - wtdb.CSessionInactive, session.Status) - } + require.Equal(h.t, wtdb.CSessionInactive, + session.Status, "expected status for session "+ + "%v to be %v, got %v", session.ID, + wtdb.CSessionInactive, session.Status) } } } @@ -163,10 +135,7 @@ func (h *clientDBHarness) loadTower(pubKey *btcec.PublicKey, h.t.Helper() tower, err := h.db.LoadTower(pubKey) - if err != expErr { - h.t.Fatalf("expected load tower error: %v, got: %v", expErr, - err) - } + require.ErrorIs(h.t, err, expErr) return tower } @@ -177,10 +146,7 @@ func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, h.t.Helper() tower, err := h.db.LoadTowerByID(id) - if err != expErr { - h.t.Fatalf("expected load tower error: %v, got: %v", expErr, - err) - } + require.ErrorIs(h.t, err, expErr) return tower } @@ -189,9 +155,7 @@ func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientC h.t.Helper() summaries, err := h.db.FetchChanSummaries() - if err != nil { - h.t.Fatalf("unable to fetch chan summaries: %v", err) - } + require.NoError(h.t, err) return summaries } @@ -202,10 +166,7 @@ func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID, h.t.Helper() err := h.db.RegisterChannel(chanID, sweepPkScript) - if err != expErr { - h.t.Fatalf("expected register channel error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID, @@ -214,10 +175,7 @@ func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID, h.t.Helper() lastApplied, err := h.db.CommitUpdate(id, update) - if err != expErr { - h.t.Fatalf("expected commit update error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) return lastApplied } @@ -228,10 +186,7 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, h.t.Helper() err := h.db.AckUpdate(id, seqNum, lastApplied) - if err != expErr { - h.t.Fatalf("expected commit update error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } // testCreateClientSession asserts various conditions regarding the creation of @@ -259,9 +214,9 @@ func testCreateClientSession(h *clientDBHarness) { // First, assert that this session is not already present in the // database. - if _, ok := h.listSessions(nil)[session.ID]; ok { - h.t.Fatalf("session for id %x should not exist yet", session.ID) - } + _, ok := h.listSessions(nil)[session.ID] + require.Falsef(h.t, ok, "session for id %x should not exist yet", + session.ID) // Attempting to insert the client session without reserving a session // key index should fail. @@ -278,10 +233,8 @@ func testCreateClientSession(h *clientDBHarness) { // successfully created, it should return the same index to maintain // idempotency across restarts. keyIndex2 := h.nextKeyIndex(session.TowerID, blobType) - if keyIndex != keyIndex2 { - h.t.Fatalf("next key index should be idempotent: want: %v, "+ - "got %v", keyIndex, keyIndex2) - } + require.Equalf(h.t, keyIndex, keyIndex2, "next key index should "+ + "be idempotent: want: %v, got %v", keyIndex, keyIndex2) // Now, set the client session's key index so that it is proper and // insert it. This should succeed. @@ -289,9 +242,8 @@ func testCreateClientSession(h *clientDBHarness) { h.insertSession(session, nil) // Verify that the session now exists in the database. - if _, ok := h.listSessions(nil)[session.ID]; !ok { - h.t.Fatalf("session for id %x should exist now", session.ID) - } + _, ok = h.listSessions(nil)[session.ID] + require.Truef(h.t, ok, "session for id %x should exist now", session.ID) // Attempt to insert the session again, which should fail due to the // session already existing. @@ -300,9 +252,8 @@ func testCreateClientSession(h *clientDBHarness) { // Finally, assert that reserving another key index succeeds with a // different key index, now that the first one has been finalized. keyIndex3 := h.nextKeyIndex(session.TowerID, blobType) - if keyIndex == keyIndex3 { - h.t.Fatalf("key index still reserved after creating session") - } + require.NotEqualf(h.t, keyIndex, keyIndex3, "key index still "+ + "reserved after creating session") } // testFilterClientSessions asserts that we can correctly filter client sessions @@ -343,15 +294,12 @@ func testFilterClientSessions(h *clientDBHarness) { // them. for towerID, expectedSessions := range towerSessions { sessions := h.listSessions(&towerID) - if len(sessions) != len(expectedSessions) { - h.t.Fatalf("expected %v sessions for tower %v, got %v", - len(expectedSessions), towerID, len(sessions)) - } + require.Len(h.t, sessions, len(expectedSessions)) + for _, expectedSession := range expectedSessions { - if _, ok := sessions[expectedSession]; !ok { - h.t.Fatalf("expected session %v for tower %v", - expectedSession, towerID) - } + _, ok := sessions[expectedSession] + require.Truef(h.t, ok, "expected session %v for "+ + "tower %v", expectedSession, towerID) } } } @@ -380,26 +328,18 @@ func testCreateTower(h *clientDBHarness) { // Load the tower from the database and assert that it matches the tower // we created. tower2 := h.loadTowerByID(tower.ID, nil) - if !reflect.DeepEqual(tower, tower2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - tower, tower2) - } - tower2 = h.loadTower(pk, err) - if !reflect.DeepEqual(tower, tower2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - tower, tower2) - } + require.Equal(h.t, tower, tower2) + + tower2 = h.loadTower(tower.IdentityKey, nil) + require.Equal(h.t, tower, tower2) // Insert the address again into the database. Since the address is the // same, this should result in an unmodified tower record. towerDupAddr := h.createTower(lnAddr, nil) - if len(towerDupAddr.Addresses) != 1 { - h.t.Fatalf("duplicate address should be deduped") - } - if !reflect.DeepEqual(tower, towerDupAddr) { - h.t.Fatalf("mismatch towers, want: %v, got: %v", - tower, towerDupAddr) - } + require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+ + "should be deduped") + + require.Equal(h.t, tower, towerDupAddr) // Generate a new address for this tower. addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911} @@ -416,26 +356,18 @@ func testCreateTower(h *clientDBHarness) { // Load the tower from the database, and assert that it matches the // tower returned from creation. towerNewAddr2 := h.loadTowerByID(tower.ID, nil) - if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - towerNewAddr, towerNewAddr2) - } - towerNewAddr2 = h.loadTower(pk, nil) - if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) { - h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", - towerNewAddr, towerNewAddr2) - } + require.Equal(h.t, towerNewAddr, towerNewAddr2) + + towerNewAddr2 = h.loadTower(tower.IdentityKey, nil) + require.Equal(h.t, towerNewAddr, towerNewAddr2) // Assert that there are now two addresses on the tower object. - if len(towerNewAddr.Addresses) != 2 { - h.t.Fatalf("new address should be added") - } + require.Lenf(h.t, towerNewAddr.Addresses, 2, "new address should be "+ + "added") // Finally, assert that the new address was prepended since it is deemed // fresher. - if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) { - h.t.Fatalf("new address should be prepended") - } + require.Equal(h.t, tower.Addresses, towerNewAddr.Addresses[1:]) } // testRemoveTower asserts the behavior of removing Tower objects as a whole and @@ -443,9 +375,7 @@ func testCreateTower(h *clientDBHarness) { func testRemoveTower(h *clientDBHarness) { // Generate a random public key we'll use for our tower. pk, err := randPubKey() - if err != nil { - h.t.Fatalf("unable to generate pubkey: %v", err) - } + require.NoError(h.t, err) // Removing a tower that does not exist within the database should // result in a NOP. @@ -523,28 +453,23 @@ func testRemoveTower(h *clientDBHarness) { func testChanSummaries(h *clientDBHarness) { // First, assert that this channel is not already registered. var chanID lnwire.ChannelID - if _, ok := h.fetchChanSummaries()[chanID]; ok { - h.t.Fatalf("pkscript for channel %x should not exist yet", - chanID) - } + _, ok := h.fetchChanSummaries()[chanID] + require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet", + chanID) // Generate a random sweep pkscript and register it for this channel. expPkScript := make([]byte, 22) - if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil { - h.t.Fatalf("unable to generate pkscript: %v", err) - } + _, err := io.ReadFull(crand.Reader, expPkScript) + require.NoError(h.t, err) + h.registerChan(chanID, expPkScript, nil) // Assert that the channel exists and that its sweep pkscript matches // the one we registered. summary, ok := h.fetchChanSummaries()[chanID] - if !ok { - h.t.Fatalf("pkscript for channel %x should not exist yet", - chanID) - } else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 { - h.t.Fatalf("pkscript mismatch, want: %x, got: %x", - expPkScript, summary.SweepPkScript) - } + require.Truef(h.t, ok, "pkscript for channel %x should not exist yet", + chanID) + require.Equal(h.t, expPkScript, summary.SweepPkScript) // Finally, assert that re-registering the same channel produces a // failure. @@ -581,10 +506,7 @@ func testCommitUpdate(h *clientDBHarness) { // succeed. The lastApplied value should be 0 since we have not received // an ack from the tower. lastApplied := h.commitUpdate(&session.ID, update1, nil) - if lastApplied != 0 { - h.t.Fatalf("last applied mismatch, want: 0, got: %v", - lastApplied) - } + require.Zero(h.t, lastApplied) // Assert that the committed update appears in the client session's // CommittedUpdates map when loaded from disk and that there are no @@ -600,10 +522,7 @@ func testCommitUpdate(h *clientDBHarness) { // the on-disk update's hint). The lastApplied value should remain // unchanged. lastApplied2 := h.commitUpdate(&session.ID, update1, nil) - if lastApplied2 != lastApplied { - h.t.Fatalf("last applied should not have changed, got %v", - lastApplied2) - } + require.Equal(h.t, lastApplied, lastApplied2) // Assert that the loaded ClientSession is the same as before. dbSession = h.listSessions(nil)[session.ID] @@ -621,10 +540,7 @@ func testCommitUpdate(h *clientDBHarness) { // which should succeed. update2.SeqNum = 2 lastApplied3 := h.commitUpdate(&session.ID, update2, nil) - if lastApplied3 != lastApplied { - h.t.Fatalf("last applied should not have changed, got %v", - lastApplied3) - } + require.Equal(h.t, lastApplied, lastApplied3) // Check that both updates now appear as committed on the ClientSession // loaded from disk. @@ -684,10 +600,7 @@ func testAckUpdate(h *clientDBHarness) { // Commit to a random update at seqnum 1. update1 := randCommittedUpdate(h.t, 1) lastApplied := h.commitUpdate(&session.ID, update1, nil) - if lastApplied != 0 { - h.t.Fatalf("last applied mismatch, want: 0, got: %v", - lastApplied) - } + require.Zero(h.t, lastApplied) // Acking seqnum 1 should succeed. h.ackUpdate(&session.ID, 1, 1, nil) @@ -715,10 +628,7 @@ func testAckUpdate(h *clientDBHarness) { // ack. update2 := randCommittedUpdate(h.t, 2) lastApplied = h.commitUpdate(&session.ID, update2, nil) - if lastApplied != 1 { - h.t.Fatalf("last applied mismatch, want: 1, got: %v", - lastApplied) - } + require.EqualValues(h.t, 1, lastApplied) // Ack seqnum 2. h.ackUpdate(&session.ID, 2, 2, nil) @@ -756,10 +666,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make([]wtdb.CommittedUpdate, 0) } - if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) { - t.Fatalf("committed updates mismatch, want: %v, got: %v", - expUpdates, session.CommittedUpdates) - } + require.Equal(t, expUpdates, session.CommittedUpdates) } // checkAckedUpdates asserts that the AckedUpdates on a session match the @@ -774,10 +681,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make(map[uint16]wtdb.BackupID) } - if !reflect.DeepEqual(session.AckedUpdates, expUpdates) { - t.Fatalf("acked updates mismatch, want: %v, got: %v", - expUpdates, session.AckedUpdates) - } + require.Equal(t, expUpdates, session.AckedUpdates) } // TestClientDB asserts the behavior of a fresh client db, a reopened client db, @@ -795,14 +699,10 @@ func TestClientDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, t.TempDir(), "wtclient.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenClientDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() @@ -819,27 +719,19 @@ func TestClientDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, path, "wtclient.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenClientDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db.Close() bdb, err = wtdb.NewBoltBackendCreator( true, path, "wtclient.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err = wtdb.OpenClientDB(bdb) - if err != nil { - t.Fatalf("unable to reopen db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() @@ -909,19 +801,16 @@ func TestClientDB(t *testing.T) { // randCommittedUpdate generates a random committed update. func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { var chanID lnwire.ChannelID - if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - } + _, err := io.ReadFull(crand.Reader, chanID[:]) + require.NoError(t, err) var hint blob.BreachHint - if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil { - t.Fatalf("unable to generate breach hint: %v", err) - } + _, err = io.ReadFull(crand.Reader, hint[:]) + require.NoError(t, err) encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) - if _, err := io.ReadFull(crand.Reader, encBlob); err != nil { - t.Fatalf("unable to generate encrypted blob: %v", err) - } + _, err = io.ReadFull(crand.Reader, encBlob) + require.NoError(t, err) return &wtdb.CommittedUpdate{ SeqNum: seqNum, diff --git a/watchtower/wtdb/codec_test.go b/watchtower/wtdb/codec_test.go index 7842b13bc..c2628b86a 100644 --- a/watchtower/wtdb/codec_test.go +++ b/watchtower/wtdb/codec_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/stretchr/testify/require" ) func randPubKey() (*btcec.PublicKey, error) { @@ -134,10 +135,7 @@ func TestCodec(tt *testing.T) { // Ensure encoding the object succeeds. var b bytes.Buffer err := obj.Encode(&b) - if err != nil { - t.Fatalf("unable to encode: %v", err) - return false - } + require.NoError(t, err) var obj2 dbObject switch obj.(type) { @@ -162,17 +160,10 @@ func TestCodec(tt *testing.T) { // Ensure decoding the object succeeds. err = obj2.Decode(bytes.NewReader(b.Bytes())) - if err != nil { - t.Fatalf("unable to decode: %v", err) - return false - } + require.NoError(t, err) // Assert the original and decoded object match. - if !reflect.DeepEqual(obj, obj2) { - t.Fatalf("encode/decode mismatch, want: %v, "+ - "got: %v", obj, obj2) - return false - } + require.Equal(t, obj, obj2) return true } @@ -180,16 +171,10 @@ func TestCodec(tt *testing.T) { customTypeGen := map[string]func([]reflect.Value, *rand.Rand){ "Tower": func(v []reflect.Value, r *rand.Rand) { pk, err := randPubKey() - if err != nil { - t.Fatalf("unable to generate pubkey: %v", err) - return - } + require.NoError(t, err) addrs, err := randAddrs(r) - if err != nil { - t.Fatalf("unable to generate addrs: %v", err) - return - } + require.NoError(t, err) obj := wtdb.Tower{ IdentityKey: pk, @@ -260,10 +245,7 @@ func TestCodec(tt *testing.T) { } err := quick.Check(test.scenario, config) - if err != nil { - t.Fatalf("fuzz checks for msg=%s failed: %v", - test.name, err) - } + require.NoError(h, err) }) } } diff --git a/watchtower/wtdb/tower_db_test.go b/watchtower/wtdb/tower_db_test.go index 177dbd233..9459f34d3 100644 --- a/watchtower/wtdb/tower_db_test.go +++ b/watchtower/wtdb/tower_db_test.go @@ -3,7 +3,6 @@ package wtdb_test import ( "bytes" "encoding/binary" - "reflect" "testing" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -14,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/stretchr/testify/require" ) var ( @@ -48,10 +48,7 @@ func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) { h.t.Helper() err := h.db.InsertSessionInfo(s) - if err != expErr { - h.t.Fatalf("expected insert session error: %v, got : %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } // getSession retrieves the session identified by id, asserting that the call @@ -62,10 +59,7 @@ func (h *towerDBHarness) getSession(id *wtdb.SessionID, h.t.Helper() session, err := h.db.GetSessionInfo(id) - if err != expErr { - h.t.Fatalf("expected get session error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) return session } @@ -79,10 +73,7 @@ func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate, h.t.Helper() lastApplied, err := h.db.InsertStateUpdate(s) - if err != expErr { - h.t.Fatalf("expected insert update error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) return lastApplied } @@ -93,10 +84,7 @@ func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) { h.t.Helper() err := h.db.DeleteSession(id) - if err != expErr { - h.t.Fatalf("expected deletion error: %v, got: %v", - expErr, err) - } + require.ErrorIs(h.t, err, expErr) } // queryMatches queries that database for the passed breach hint, returning all @@ -105,9 +93,7 @@ func (h *towerDBHarness) queryMatches(hint blob.BreachHint) []wtdb.Match { h.t.Helper() matches, err := h.db.QueryMatches([]blob.BreachHint{hint}) - if err != nil { - h.t.Fatalf("unable to query matches: %v", err) - } + require.NoError(h.t, err) return matches } @@ -119,14 +105,10 @@ func (h *towerDBHarness) hasUpdate(hint blob.BreachHint) wtdb.Match { h.t.Helper() matches := h.queryMatches(hint) - if len(matches) != 1 { - h.t.Fatalf("expected 1 match, found: %d", len(matches)) - } + require.Len(h.t, matches, 1) match := matches[0] - if match.Hint != hint { - h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint) - } + require.Equal(h.t, hint, match.Hint) return match } @@ -158,11 +140,7 @@ func testInsertSession(h *towerDBHarness) { h.insertSession(session, nil) session2 := h.getSession(&id, nil) - - if !reflect.DeepEqual(session, session2) { - h.t.Fatalf("expected session: %v, got %v", - session, session2) - } + require.Equal(h.t, session, session2) h.insertSession(session, nil) @@ -211,28 +189,21 @@ func testMultipleMatches(h *towerDBHarness) { // Query the db for matches on the chosen hint. matches := h.queryMatches(hint) - if len(matches) != numUpdates { - h.t.Fatalf("num updates mismatch, want: %d, got: %d", - numUpdates, len(matches)) - } + require.Len(h.t, matches, numUpdates) // Assert that the hints are what we asked for, and compute the set of // sessions returned. sessions := make(map[wtdb.SessionID]struct{}) for _, match := range matches { - if match.Hint != hint { - h.t.Fatalf("hint mismatch, want: %v, got: %v", - hint, match.Hint) - } + require.Equal(h.t, hint, match.Hint) sessions[match.ID] = struct{}{} } // Assert that the sessions returned match the session ids of the // sessions we initially created. for i := 0; i < numUpdates; i++ { - if _, ok := sessions[*id(i)]; !ok { - h.t.Fatalf("match for session %v not found", *id(i)) - } + _, ok := sessions[*id(i)] + require.Truef(h.t, ok, "match for session %v not found", *id(i)) } } @@ -242,33 +213,22 @@ func testMultipleMatches(h *towerDBHarness) { func testLookoutTip(h *towerDBHarness) { // Retrieve lookout tip on fresh db. epoch, err := h.db.GetLookoutTip() - if err != nil { - h.t.Fatalf("unable to fetch lookout tip: %v", err) - } + require.NoError(h.t, err) // Assert that the epoch is nil. - if epoch != nil { - h.t.Fatalf("lookout tip should not be set, found: %v", epoch) - } + require.Nil(h.t, epoch) // Create a closure that inserts an epoch, retrieves it, and asserts // that the returned epoch matches what was inserted. setAndCheck := func(i int) { expEpoch := epochFromInt(1) err = h.db.SetLookoutTip(expEpoch) - if err != nil { - h.t.Fatalf("unable to set lookout tip: %v", err) - } + require.NoError(h.t, err) epoch, err = h.db.GetLookoutTip() - if err != nil { - h.t.Fatalf("unable to fetch lookout tip: %v", err) - } + require.NoError(h.t, err) - if !reflect.DeepEqual(epoch, expEpoch) { - h.t.Fatalf("lookout tip mismatch, want: %v, got: %v", - expEpoch, epoch) - } + require.Equal(h.t, expEpoch, epoch) } // Set and assert the lookout tip. @@ -348,15 +308,10 @@ func testDeleteSession(h *towerDBHarness) { // Assert that only one update is still present. matches := h.queryMatches(hint) - if len(matches) != 1 { - h.t.Fatalf("expected one update, found: %d", len(matches)) - } + require.Len(h.t, matches, 1) // Assert that the update belongs to the first session. - if matches[0].ID != *id0 { - h.t.Fatalf("expected match for %v, instead is for: %v", - *id0, matches[0].ID) - } + require.Equal(h.t, *id0, matches[0].ID) // Finally, remove the first session added. h.deleteSession(*id0, nil) @@ -366,9 +321,7 @@ func testDeleteSession(h *towerDBHarness) { // No matches should exist for this hint. matches = h.queryMatches(hint) - if len(matches) != 0 { - h.t.Fatalf("expected zero updates, found: %d", len(matches)) - } + require.Zero(h.t, len(matches)) } type stateUpdateTest struct { @@ -403,10 +356,9 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) { *expSession = *test.session } - if len(test.updates) != len(test.updateErrs) { - h.t.Fatalf("malformed test case, num updates " + - "should match num errors") - } + require.Lenf(h.t, test.updates, len(test.updateErrs), + "malformed test case, num updates should match num "+ + "errors") // Send any updates provided in the test. for i, update := range test.updates { @@ -430,10 +382,7 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) { expSession.ClientLastApplied = update.LastApplied match := h.hasUpdate(update.Hint) - if !reflect.DeepEqual(match.SessionInfo, expSession) { - h.t.Fatalf("expected session: %v, got: %v", - expSession, match.SessionInfo) - } + require.Equal(h.t, expSession, match.SessionInfo) } } } @@ -640,14 +589,10 @@ func TestTowerDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, path, "watchtower.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenTowerDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() @@ -664,14 +609,10 @@ func TestTowerDB(t *testing.T) { bdb, err := wtdb.NewBoltBackendCreator( true, path, "watchtower.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err := wtdb.OpenTowerDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db.Close() // Open the db again, ensuring we test a @@ -680,14 +621,10 @@ func TestTowerDB(t *testing.T) { bdb, err = wtdb.NewBoltBackendCreator( true, path, "watchtower.db", )(dbCfg) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) db, err = wtdb.OpenTowerDB(bdb) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } + require.NoError(t, err) t.Cleanup(func() { db.Close() From e150bb83d1d94928fc2ad8b9fd23ee8a5c4f337e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 15:08:18 +0200 Subject: [PATCH 3/8] watchtower/wtdb: check tower exists on session create Before creating a new session, first check that the TowerID that the ClientSession is referencing refers to an existing tower. This is done to prevent the situation where RemoveTower is called right before CreateClientSession is called which would, before this commit, lead to the session being created with a tower ID that does not refer to any existing tower. --- watchtower/wtdb/client_db.go | 11 ++++++ watchtower/wtdb/client_db_test.go | 63 ++++++++++++++++++------------- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index c868b949f..7f4d7d6f5 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -574,6 +574,11 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { return ErrUninitializedDB } + towers := tx.ReadBucket(cTowerBkt) + if towers == nil { + return ErrUninitializedDB + } + // Check that client session with this session id doesn't // already exist. existingSessionBytes := sessions.NestedReadWriteBucket( @@ -583,7 +588,13 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { return ErrClientSessionAlreadyExists } + // Ensure that a tower with the given ID actually exists in the + // DB. towerID := session.TowerID + if _, err := getTower(towers, towerID.Bytes()); err != nil { + return err + } + blobType := session.Policy.BlobType // Check that this tower has a reserved key index. diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index b0e4af6a3..d4f1699c9 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -17,6 +17,9 @@ import ( "github.com/stretchr/testify/require" ) +// pseudoAddr is a fake network address to be used for testing purposes. +var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} + // clientDBInit is a closure used to initialize a wtclient.DB instance. type clientDBInit func(t *testing.T) wtclient.DB @@ -189,6 +192,21 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, require.ErrorIs(h.t, err, expErr) } +// newTower is a helper function that creates a new tower with a randomly +// generated public key and inserts it into the client DB. +func (h *clientDBHarness) newTower() *wtdb.Tower { + h.t.Helper() + + pk, err := randPubKey() + require.NoError(h.t, err) + + // Insert a random tower into the database. + return h.createTower(&lnwire.NetAddress{ + IdentityKey: pk, + Address: pseudoAddr, + }, nil) +} + // testCreateClientSession asserts various conditions regarding the creation of // a new ClientSession. The test asserts: // - client sessions can only be created if a session key index is reserved. @@ -197,10 +215,12 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, func testCreateClientSession(h *clientDBHarness) { const blobType = blob.TypeAltruistAnchorCommit + tower := h.newTower() + // Create a test client session to insert. session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: wtdb.TowerID(3), + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -265,15 +285,12 @@ func testFilterClientSessions(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID) for i := 0; i < numSessions; i++ { - towerID := wtdb.TowerID(1) - if i == numSessions-1 { - towerID = wtdb.TowerID(2) - } - keyIndex := h.nextKeyIndex(towerID, blobType) + tower := h.newTower() + keyIndex := h.nextKeyIndex(tower.ID, blobType) sessionID := wtdb.SessionID([33]byte{byte(i)}) h.insertSession(&wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: towerID, + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -285,8 +302,8 @@ func testFilterClientSessions(h *clientDBHarness) { }, ID: sessionID, }, nil) - towerSessions[towerID] = append( - towerSessions[towerID], sessionID, + towerSessions[tower.ID] = append( + towerSessions[tower.ID], sessionID, ) } @@ -311,19 +328,9 @@ func testCreateTower(h *clientDBHarness) { // Test that loading a tower with an arbitrary tower id fails. h.loadTowerByID(20, wtdb.ErrTowerNotFound) - pk, err := randPubKey() - if err != nil { - h.t.Fatalf("unable to generate pubkey: %v", err) - } - - addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} - lnAddr := &lnwire.NetAddress{ - IdentityKey: pk, - Address: addr1, - } - - // Insert a random tower into the database. - tower := h.createTower(lnAddr, nil) + tower := h.newTower() + require.Len(h.t, tower.LNAddrs(), 1) + towerAddr := tower.LNAddrs()[0] // Load the tower from the database and assert that it matches the tower // we created. @@ -335,7 +342,7 @@ func testCreateTower(h *clientDBHarness) { // Insert the address again into the database. Since the address is the // same, this should result in an unmodified tower record. - towerDupAddr := h.createTower(lnAddr, nil) + towerDupAddr := h.createTower(towerAddr, nil) require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+ "should be deduped") @@ -345,7 +352,7 @@ func testCreateTower(h *clientDBHarness) { addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911} lnAddr2 := &lnwire.NetAddress{ - IdentityKey: pk, + IdentityKey: tower.IdentityKey, Address: addr2, } @@ -479,9 +486,11 @@ func testChanSummaries(h *clientDBHarness) { // testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can func testCommitUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit + + tower := h.newTower() session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: wtdb.TowerID(3), + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -570,10 +579,12 @@ func testCommitUpdate(h *clientDBHarness) { func testAckUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit + tower := h.newTower() + // Create a new session that the updates in this will be tied to. session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: wtdb.TowerID(3), + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, From c60ecaccbf921b86c58a6b27b3f97e39e85d2b1f Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 15:18:40 +0200 Subject: [PATCH 4/8] watchtower: always populate Tower in ClientSession In this commit, we make sure to always populate the Tower member of a ClientSession. This is done for consistency. --- watchtower/wtclient/client.go | 10 ++-------- watchtower/wtdb/client_db.go | 35 ++++++++++++++++++++++++---------- watchtower/wtmock/client_db.go | 1 + 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 720bdcca7..436905f76 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -354,8 +354,8 @@ func New(config *Config) (*TowerClient, error) { // optional filter can be provided to filter out any undesired client sessions. // // NOTE: This method should only be used when deserialization of a -// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the -// existing ListClientSessions method should be used. +// ClientSession's SessionPrivKey field is desired, otherwise, the existing +// ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, passesFilter func(*wtdb.ClientSession) bool) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { @@ -371,12 +371,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, // requests. This prevents us from having to store the private keys on // disk. for _, s := range sessions { - tower, err := db.LoadTowerByID(s.TowerID) - if err != nil { - return nil, err - } - s.Tower = tower - towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, Index: s.KeyIndex, diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 7f4d7d6f5..9d8383cb5 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -288,7 +288,7 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { } towerID := TowerIDFromBytes(towerIDBytes) towerSessions, err := listClientSessions( - sessions, &towerID, + sessions, towers, &towerID, ) if err != nil { return err @@ -389,7 +389,9 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { return ErrUninitializedDB } towerID := TowerIDFromBytes(towerIDBytes) - towerSessions, err := listClientSessions(sessions, &towerID) + towerSessions, err := listClientSessions( + sessions, towers, &towerID, + ) if err != nil { return err } @@ -685,8 +687,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( if sessions == nil { return ErrUninitializedDB } + + towers := tx.ReadBucket(cTowerBkt) + if towers == nil { + return ErrUninitializedDB + } + var err error - clientSessions, err = listClientSessions(sessions, id) + clientSessions, err = listClientSessions(sessions, towers, id) return err }, func() { clientSessions = nil @@ -701,7 +709,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( // listClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func listClientSessions(sessions kvdb.RBucket, +func listClientSessions(sessions, towers kvdb.RBucket, id *TowerID) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) @@ -710,7 +718,7 @@ func listClientSessions(sessions kvdb.RBucket, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessions, k) + session, err := getClientSession(sessions, towers, k) if err != nil { return err } @@ -1022,8 +1030,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, // getClientSessionBody loads the body of a ClientSession from the sessions // bucket corresponding to the serialized session id. This does not deserialize -// the CommittedUpdates or AckUpdates associated with the session. If the caller -// requires this info, use getClientSession. +// the CommittedUpdates, AckUpdates or the Tower associated with the session. +// If the caller requires this info, use getClientSession. func getClientSessionBody(sessions kvdb.RBucket, idBytes []byte) (*ClientSession, error) { @@ -1050,9 +1058,9 @@ func getClientSessionBody(sessions kvdb.RBucket, } // getClientSession loads the full ClientSession associated with the serialized -// session id. This method populates the CommittedUpdates and AckUpdates in -// addition to the ClientSession's body. -func getClientSession(sessions kvdb.RBucket, +// session id. This method populates the CommittedUpdates, AckUpdates and Tower +// in addition to the ClientSession's body. +func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte) (*ClientSession, error) { session, err := getClientSessionBody(sessions, idBytes) @@ -1060,6 +1068,12 @@ func getClientSession(sessions kvdb.RBucket, return nil, err } + // Fetch the tower associated with this session. + tower, err := getTower(towers, session.TowerID.Bytes()) + if err != nil { + return nil, err + } + // Fetch the committed updates for this session. commitedUpdates, err := getClientSessionCommits(sessions, idBytes) if err != nil { @@ -1072,6 +1086,7 @@ func getClientSession(sessions kvdb.RBucket, return nil, err } + session.Tower = tower session.CommittedUpdates = commitedUpdates session.AckedUpdates = ackedUpdates diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 28dafd04c..2a3825e87 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -220,6 +220,7 @@ func (m *ClientDB) listClientSessions( if tower != nil && *tower != session.TowerID { continue } + session.Tower = m.towers[session.TowerID] sessions[session.ID] = &session } From 354a3b16bd12eb7f7577b002a50b6583785d8eec Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 15:49:17 +0200 Subject: [PATCH 5/8] watchtower/wtdb: add new towerID-to-sessionID index This commit adds a new towerID-to-sessionID index to the wtclient DB. The commit also contains the necessary migration required in order to build the index for an existing client. This index will greatly improve the lookup of sessions for a given tower ID. --- watchtower/wtdb/client_db.go | 58 ++++- watchtower/wtdb/log.go | 2 + watchtower/wtdb/migration1/client_db.go | 145 +++++++++++ watchtower/wtdb/migration1/client_db_test.go | 155 ++++++++++++ watchtower/wtdb/migration1/codec.go | 241 +++++++++++++++++++ watchtower/wtdb/migration1/log.go | 14 ++ watchtower/wtdb/version.go | 7 +- 7 files changed, 620 insertions(+), 2 deletions(-) create mode 100644 watchtower/wtdb/migration1/client_db.go create mode 100644 watchtower/wtdb/migration1/client_db_test.go create mode 100644 watchtower/wtdb/migration1/codec.go create mode 100644 watchtower/wtdb/migration1/log.go diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 9d8383cb5..537c8cc73 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -48,6 +48,12 @@ var ( // tower-pubkey -> tower-id. cTowerIndexBkt = []byte("client-tower-index-bucket") + // cTowerToSessionIndexBkt is a top-level bucket storing: + // tower-id -> session-id -> 1 + cTowerToSessionIndexBkt = []byte( + "client-tower-to-session-index-bucket", + ) + // ErrTowerNotFound signals that the target tower was not found in the // database. ErrTowerNotFound = errors.New("tower not found") @@ -196,6 +202,7 @@ func initClientDBBuckets(tx kvdb.RwTx) error { cSessionBkt, cTowerBkt, cTowerIndexBkt, + cTowerToSessionIndexBkt, } for _, bucket := range buckets { @@ -260,6 +267,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { return ErrUninitializedDB } + towerToSessionIndex := tx.ReadWriteBucket( + cTowerToSessionIndexBkt, + ) + if towerToSessionIndex == nil { + return ErrUninitializedDB + } + // Check if the tower index already knows of this pubkey. towerIDBytes := towerIndex.Get(towerPubKey[:]) if len(towerIDBytes) == 8 { @@ -321,6 +335,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { if err != nil { return err } + + // Create a new bucket for this tower in the + // tower-to-sessions index. + _, err = towerToSessionIndex.CreateBucket(towerIDBytes) + if err != nil { + return err + } } // Store the new or updated tower under its tower id. @@ -349,11 +370,19 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { if towers == nil { return ErrUninitializedDB } + towerIndex := tx.ReadWriteBucket(cTowerIndexBkt) if towerIndex == nil { return ErrUninitializedDB } + towersToSessionsIndex := tx.ReadWriteBucket( + cTowerToSessionIndexBkt, + ) + if towersToSessionsIndex == nil { + return ErrUninitializedDB + } + // Don't return an error if the watchtower doesn't exist to act // as a NOP. pubKeyBytes := pubKey.SerializeCompressed() @@ -402,7 +431,14 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { if err := towerIndex.Delete(pubKeyBytes); err != nil { return err } - return towers.Delete(towerIDBytes) + + if err := towers.Delete(towerIDBytes); err != nil { + return err + } + + return towersToSessionsIndex.DeleteNestedBucket( + towerIDBytes, + ) } // We'll mark its sessions as inactive as long as they don't @@ -581,6 +617,13 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { return ErrUninitializedDB } + towerToSessionIndex := tx.ReadWriteBucket( + cTowerToSessionIndexBkt, + ) + if towerToSessionIndex == nil { + return ErrUninitializedDB + } + // Check that client session with this session id doesn't // already exist. existingSessionBytes := sessions.NestedReadWriteBucket( @@ -625,6 +668,19 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { } } + // Add the new entry to the towerID-to-SessionID index. + indexBkt := towerToSessionIndex.NestedReadWriteBucket( + towerID.Bytes(), + ) + if indexBkt == nil { + return ErrTowerNotFound + } + + err = indexBkt.Put(session.ID[:], []byte{1}) + if err != nil { + return err + } + // Finally, write the client session's body in the sessions // bucket. return putClientSessionBody(sessions, session) diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index 0e14ea996..6ddb6c35f 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -3,6 +3,7 @@ package wtdb import ( "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration1" ) // log is a logger that is initialized with no output filters. This @@ -26,6 +27,7 @@ func DisableLog() { // using btclog. func UseLogger(logger btclog.Logger) { log = logger + migration1.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration1/client_db.go b/watchtower/wtdb/migration1/client_db.go new file mode 100644 index 000000000..d09ef6ef7 --- /dev/null +++ b/watchtower/wtdb/migration1/client_db.go @@ -0,0 +1,145 @@ +package migration1 + +import ( + "bytes" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAcks => seqnum -> encoded BackupID + cSessionBkt = []byte("client-session-bucket") + + // cSessionBody is a sub-bucket of cSessionBkt storing only the body of + // the ClientSession. + cSessionBody = []byte("client-session-body") + + // cTowerIDToSessionIDIndexBkt is a top-level bucket storing: + // tower-id -> session-id -> 1 + cTowerIDToSessionIDIndexBkt = []byte( + "client-tower-to-session-index-bucket", + ) + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrClientSessionNotFound signals that the requested client session + // was not found in the database. + ErrClientSessionNotFound = errors.New("client session not found") + + // ErrCorruptClientSession signals that the client session's on-disk + // structure deviates from what is expected. + ErrCorruptClientSession = errors.New("client session corrupted") +) + +// MigrateTowerToSessionIndex constructs a new towerID-to-sessionID for the +// watchtower client DB. +func MigrateTowerToSessionIndex(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client db to add a " + + "towerID-to-sessionID index") + + // First, we collect all the entries we want to add to the index. + entries, err := getIndexEntries(tx) + if err != nil { + return err + } + + // Then we create a new top-level bucket for the index. + indexBkt, err := tx.CreateTopLevelBucket(cTowerIDToSessionIDIndexBkt) + if err != nil { + return err + } + + // Finally, we add all the collected entries to the index. + for towerID, sessions := range entries { + // Create a sub-bucket using the tower ID. + towerBkt, err := indexBkt.CreateBucketIfNotExists( + towerID.Bytes(), + ) + if err != nil { + return err + } + + for sessionID := range sessions { + err := addIndex(towerBkt, sessionID) + if err != nil { + return err + } + } + } + + return nil +} + +// addIndex adds a new towerID-sessionID pair to the given bucket. The +// session ID is used as a key within the bucket and a value of []byte{1} is +// used for each session ID key. +func addIndex(towerBkt kvdb.RwBucket, sessionID SessionID) error { + session := towerBkt.Get(sessionID[:]) + if session != nil { + return fmt.Errorf("session %x duplicated", sessionID) + } + + return towerBkt.Put(sessionID[:], []byte{1}) +} + +// getIndexEntries collects all the towerID-sessionID entries that need to be +// added to the new index. +func getIndexEntries(tx kvdb.RwTx) (map[TowerID]map[SessionID]bool, error) { + sessions := tx.ReadBucket(cSessionBkt) + if sessions == nil { + return nil, ErrUninitializedDB + } + + index := make(map[TowerID]map[SessionID]bool) + err := sessions.ForEach(func(k, _ []byte) error { + session, err := getClientSession(sessions, k) + if err != nil { + return err + } + + if index[session.TowerID] == nil { + index[session.TowerID] = make(map[SessionID]bool) + } + + index[session.TowerID][session.ID] = true + return nil + }) + if err != nil { + return nil, err + } + + return index, nil +} + +// getClientSession fetches the session with the given ID from the db. +func getClientSession(sessions kvdb.RBucket, idBytes []byte) (*ClientSession, + error) { + + sessionBkt := sessions.NestedReadBucket(idBytes) + if sessionBkt == nil { + return nil, ErrClientSessionNotFound + } + + // Should never have a sessionBkt without also having its body. + sessionBody := sessionBkt.Get(cSessionBody) + if sessionBody == nil { + return nil, ErrCorruptClientSession + } + + var session ClientSession + copy(session.ID[:], idBytes) + + err := session.Decode(bytes.NewReader(sessionBody)) + if err != nil { + return nil, err + } + + return &session, nil +} diff --git a/watchtower/wtdb/migration1/client_db_test.go b/watchtower/wtdb/migration1/client_db_test.go new file mode 100644 index 000000000..acae177ad --- /dev/null +++ b/watchtower/wtdb/migration1/client_db_test.go @@ -0,0 +1,155 @@ +package migration1 + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + s1 = &ClientSessionBody{ + TowerID: TowerID(1), + } + s2 = &ClientSessionBody{ + TowerID: TowerID(3), + } + s3 = &ClientSessionBody{ + TowerID: TowerID(6), + } + + // pre is the expected data in the DB before the migration. + pre = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s3), + }, + sessionIDString("3"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("4"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("5"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s2), + }, + } + + // preFailNoSessionBody should fail the migration due to there being a + // session without an associated session body. + preFailNoSessionBody = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("2"): map[string]interface{}{}, + } + + // post is the expected data after migration. + post = map[string]interface{}{ + towerIDString(1): map[string]interface{}{ + sessionIDString("1"): string([]byte{1}), + sessionIDString("3"): string([]byte{1}), + sessionIDString("4"): string([]byte{1}), + }, + towerIDString(3): map[string]interface{}{ + sessionIDString("5"): string([]byte{1}), + }, + towerIDString(6): map[string]interface{}{ + sessionIDString("2"): string([]byte{1}), + }, + } +) + +// TestMigrateTowerToSessionIndex tests that the TestMigrateTowerToSessionIndex +// function correctly adds a new towerID-to-sessionID index to the tower client +// db. +func TestMigrateTowerToSessionIndex(t *testing.T) { + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + post map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: pre, + post: post, + }, + { + name: "fail due to corrupt db", + shouldFail: true, + pre: preFailNoSessionBody, + post: nil, + }, + { + name: "no sessions", + shouldFail: false, + pre: nil, + post: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + // Before the migration we have a sessions bucket. + before := func(tx kvdb.RwTx) error { + return migtest.RestoreDB( + tx, cSessionBkt, test.pre, + ) + } + + // After the migration, we should have an untouched + // sessions bucket and a new index bucket. + after := func(tx kvdb.RwTx) error { + if err := migtest.VerifyDB( + tx, cSessionBkt, test.pre, + ); err != nil { + return err + } + + // If we expect our migration to fail, we don't + // expect an index bucket. + if test.shouldFail { + return nil + } + + return migtest.VerifyDB( + tx, cTowerIDToSessionIDIndexBkt, + test.post, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateTowerToSessionIndex, + test.shouldFail, + ) + }) + } +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return string(sessID[:]) +} + +func clientSessionString(s *ClientSessionBody) string { + var b bytes.Buffer + err := s.Encode(&b) + if err != nil { + panic(err) + } + + return b.String() +} + +func towerIDString(id int) string { + towerID := TowerID(id) + return string(towerID.Bytes()) +} diff --git a/watchtower/wtdb/migration1/codec.go b/watchtower/wtdb/migration1/codec.go new file mode 100644 index 000000000..8c5a2299c --- /dev/null +++ b/watchtower/wtdb/migration1/codec.go @@ -0,0 +1,241 @@ +package migration1 + +import ( + "encoding/binary" + "io" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" +) + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// UnknownElementType is an alias for channeldb.UnknownElementType. +type UnknownElementType = channeldb.UnknownElementType + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// TowerID is a unique 64-bit identifier allocated to each unique watchtower. +// This allows the client to conserve on-disk space by not needing to always +// reference towers by their pubkey. +type TowerID uint64 + +// Bytes encodes a TowerID into an 8-byte slice in big-endian byte order. +func (id TowerID) Bytes() []byte { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], uint64(id)) + return buf[:] +} + +// ClientSession encapsulates a SessionInfo returned from a successful +// session negotiation, and also records the tower and ephemeral secret used for +// communicating with the tower. +type ClientSession struct { + // ID is the client's public key used when authenticating with the + // tower. + ID SessionID + ClientSessionBody +} + +// CSessionStatus is a bit-field representing the possible statuses of +// ClientSessions. +type CSessionStatus uint8 + +type ClientSessionBody struct { + // SeqNum is the next unallocated sequence number that can be sent to + // the tower. + SeqNum uint16 + + // TowerLastApplied the last last-applied the tower has echoed back. + TowerLastApplied uint16 + + // TowerID is the unique, db-assigned identifier that references the + // Tower with which the session is negotiated. + TowerID TowerID + + // KeyIndex is the index of key locator used to derive the client's + // session key so that it can authenticate with the tower to update its + // session. In order to rederive the private key, the key locator should + // use the keychain.KeyFamilyTowerSession key family. + KeyIndex uint32 + + // Policy holds the negotiated session parameters. + Policy wtpolicy.Policy + + // Status indicates the current state of the ClientSession. + Status CSessionStatus + + // RewardPkScript is the pkscript that the tower's reward will be + // deposited to if a sweep transaction confirms and the sessions + // specifies a reward output. + RewardPkScript []byte +} + +// Encode writes a ClientSessionBody to the passed io.Writer. +func (s *ClientSessionBody) Encode(w io.Writer) error { + return WriteElements(w, + s.SeqNum, + s.TowerLastApplied, + uint64(s.TowerID), + s.KeyIndex, + uint8(s.Status), + s.Policy, + s.RewardPkScript, + ) +} + +// Decode reads a ClientSessionBody from the passed io.Reader. +func (s *ClientSessionBody) Decode(r io.Reader) error { + var ( + towerID uint64 + status uint8 + ) + err := ReadElements(r, + &s.SeqNum, + &s.TowerLastApplied, + &towerID, + &s.KeyIndex, + &status, + &s.Policy, + &s.RewardPkScript, + ) + if err != nil { + return err + } + + s.TowerID = TowerID(towerID) + s.Status = CSessionStatus(status) + + return nil +} + +// WriteElements serializes a variadic list of elements into the given +// io.Writer. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + if err := WriteElement(w, element); err != nil { + return err + } + } + + return nil +} + +// WriteElement serializes a single element into the provided io.Writer. +func WriteElement(w io.Writer, element interface{}) error { + err := channeldb.WriteElement(w, element) + switch { + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + default: + if _, ok := err.(UnknownElementType); !ok { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + case SessionID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case blob.BreachHint: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wtpolicy.Policy: + return channeldb.WriteElements(w, + uint16(e.BlobType), + e.MaxUpdates, + e.RewardBase, + e.RewardRate, + uint64(e.SweepFeeRate), + ) + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "WriteElement", element, + ) + } + + return nil +} + +// ReadElements deserializes the provided io.Reader into a variadic list of +// target elements. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + if err := ReadElement(r, element); err != nil { + return err + } + } + + return nil +} + +// ReadElement deserializes a single element from the provided io.Reader. +func ReadElement(r io.Reader, element interface{}) error { + err := channeldb.ReadElement(r, element) + switch { + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + default: + if _, ok := err.(UnknownElementType); !ok { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + case *SessionID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *blob.BreachHint: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wtpolicy.Policy: + var ( + blobType uint16 + sweepFeeRate uint64 + ) + err := channeldb.ReadElements(r, + &blobType, + &e.MaxUpdates, + &e.RewardBase, + &e.RewardRate, + &sweepFeeRate, + ) + if err != nil { + return err + } + + e.BlobType = blob.Type(blobType) + e.SweepFeeRate = chainfee.SatPerKWeight(sweepFeeRate) + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "ReadElement", element, + ) + } + + return nil +} diff --git a/watchtower/wtdb/migration1/log.go b/watchtower/wtdb/migration1/log.go new file mode 100644 index 000000000..1dc105280 --- /dev/null +++ b/watchtower/wtdb/migration1/log.go @@ -0,0 +1,14 @@ +package migration1 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index 229b8a9dd..4785b0ae2 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -3,6 +3,7 @@ package wtdb import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration1" ) // migration is a function which takes a prior outdated version of the database @@ -24,7 +25,11 @@ var towerDBVersions = []version{} // clientDBVersions stores all versions and migrations of the client database. // This list will be used when opening the database to determine if any // migrations must be applied. -var clientDBVersions = []version{} +var clientDBVersions = []version{ + { + migration: migration1.MigrateTowerToSessionIndex, + }, +} // getLatestDBVersion returns the last known database version. func getLatestDBVersion(versions []version) uint32 { From ecd2eb965a75fd1a0dc55fe533f5ac8ebe529209 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 16:06:53 +0200 Subject: [PATCH 6/8] watchtower: make use of the new tower-to-session index In this commit, the towerID-to-sessionID index added in the previous commit is put to use in order to make session lookup more efficient in certain places. In the process, 2 TODO's are also removed from the code. --- watchtower/wtdb/client_db.go | 102 +++++++++++++++++++++++++---------- 1 file changed, 73 insertions(+), 29 deletions(-) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 537c8cc73..3cb5a8c70 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -293,27 +293,32 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { // If there are any client sessions that correspond to // this tower, we'll mark them as active to ensure we // load them upon restarts. - // - // TODO(wilmer): with an index of tower -> sessions we - // can avoid the linear lookup. + towerSessIndex := towerToSessionIndex.NestedReadBucket( + tower.ID.Bytes(), + ) + if towerSessIndex == nil { + return ErrTowerNotFound + } + sessions := tx.ReadWriteBucket(cSessionBkt) if sessions == nil { return ErrUninitializedDB } - towerID := TowerIDFromBytes(towerIDBytes) - towerSessions, err := listClientSessions( - sessions, towers, &towerID, - ) - if err != nil { - return err - } - for _, session := range towerSessions { - err := markSessionStatus( - sessions, session, CSessionActive, + + err = towerSessIndex.ForEach(func(k, _ []byte) error { + session, err := getClientSessionBody( + sessions, k, ) if err != nil { return err } + + return markSessionStatus( + sessions, session, CSessionActive, + ) + }) + if err != nil { + return err } } else { // No such tower exists, create a new tower id for our @@ -410,16 +415,13 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { // Otherwise, we should attempt to mark the tower's sessions as // inactive. - // - // TODO(wilmer): with an index of tower -> sessions we can avoid - // the linear lookup. sessions := tx.ReadWriteBucket(cSessionBkt) if sessions == nil { return ErrUninitializedDB } towerID := TowerIDFromBytes(towerIDBytes) - towerSessions, err := listClientSessions( - sessions, towers, &towerID, + towerSessions, err := listTowerSessions( + towerID, sessions, towers, towersToSessionsIndex, ) if err != nil { return err @@ -750,7 +752,25 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( } var err error - clientSessions, err = listClientSessions(sessions, towers, id) + + // If no tower ID is specified, then fetch all the sessions + // known to the db. + if id == nil { + clientSessions, err = listClientAllSessions( + sessions, towers, + ) + return err + } + + // Otherwise, fetch the sessions for the given tower. + towerToSessionIndex := tx.ReadBucket(cTowerToSessionIndexBkt) + if towerToSessionIndex == nil { + return ErrUninitializedDB + } + + clientSessions, err = listTowerSessions( + *id, sessions, towers, towerToSessionIndex, + ) return err }, func() { clientSessions = nil @@ -762,11 +782,9 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( return clientSessions, nil } -// listClientSessions returns the set of all client sessions known to the db. An -// optional tower ID can be used to filter out any client sessions in the -// response that do not correspond to this tower. -func listClientSessions(sessions, towers kvdb.RBucket, - id *TowerID) (map[SessionID]*ClientSession, error) { +// listClientAllSessions returns the set of all client sessions known to the db. +func listClientAllSessions(sessions, + towers kvdb.RBucket) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) err := sessions.ForEach(func(k, _ []byte) error { @@ -779,14 +797,40 @@ func listClientSessions(sessions, towers kvdb.RBucket, return err } - // Filter out any sessions that don't correspond to the given - // tower if one was set. - if id != nil && session.TowerID != *id { - return nil + clientSessions[session.ID] = session + + return nil + }) + if err != nil { + return nil, err + } + + return clientSessions, nil +} + +// listTowerSessions returns the set of all client sessions known to the db +// that are associated with the given tower id. +func listTowerSessions(id TowerID, sessionsBkt, towersBkt, + towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession, + error) { + + towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes()) + if towerIndexBkt == nil { + return nil, ErrTowerNotFound + } + + clientSessions := make(map[SessionID]*ClientSession) + err := towerIndexBkt.ForEach(func(k, _ []byte) error { + // We'll load the full client session since the client will need + // the CommittedUpdates and AckedUpdates on startup to resume + // committed updates and compute the highest known commit height + // for each channel. + session, err := getClientSession(sessionsBkt, towersBkt, k) + if err != nil { + return err } clientSessions[session.ID] = session - return nil }) if err != nil { From 105c44df9b0c35a182d094a523868580abf499b6 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 16:26:36 +0200 Subject: [PATCH 7/8] watchtower: use more efficient session query on startup In this commit, the functions used to fetch candidate sessions and towers on creation of the watchtower Client are changed to make use of the more efficient lookup functions. Previously, all sessions were listed from the DB and then these were used to collect the active towers which in certain situations lead to some users getting the "tower not found" error on start up. With this commit, we instead first list all Towers in the DB and then we fetch the sessions for each of those towers. --- watchtower/wtclient/client.go | 80 +++++++++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 436905f76..3d23f0b82 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -287,26 +287,33 @@ func New(config *Config) (*TowerClient, error) { } plog := build.NewPrefixLog(prefix, log) - // Next, load all candidate sessions and towers from the database into - // the client. We will use any of these session if their policies match + // Next, load all candidate towers and sessions from the database into + // the client. We will use any of these sessions if their policies match // the current policy of the client, otherwise they will be ignored and // new sessions will be requested. isAnchorClient := cfg.Policy.IsAnchorChannel() activeSessionFilter := genActiveSessionFilter(isAnchorClient) - candidateSessions, err := getClientSessions( - cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter, + candidateTowers := newTowerListIterator() + perActiveTower := func(tower *wtdb.Tower) { + // If the tower has already been marked as active, then there is + // no need to add it to the iterator again. + if candidateTowers.IsActive(tower.ID) { + return + } + + log.Infof("Using private watchtower %s, offering policy %s", + tower, cfg.Policy) + + // Add the tower to the set of candidate towers. + candidateTowers.AddCandidate(tower) + } + candidateSessions, err := getTowerAndSessionCandidates( + cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, ) if err != nil { return nil, err } - var candidateTowers []*wtdb.Tower - for _, s := range candidateSessions { - plog.Infof("Using private watchtower %s, offering policy %s", - s.Tower, cfg.Policy) - candidateTowers = append(candidateTowers, s.Tower) - } - // Load the sweep pkscripts that have been generated for all previously // registered channels. chanSummaries, err := cfg.DB.FetchChanSummaries() @@ -318,7 +325,7 @@ func New(config *Config) (*TowerClient, error) { cfg: cfg, log: plog, pipeline: newTaskPipeline(plog), - candidateTowers: newTowerListIterator(candidateTowers...), + candidateTowers: candidateTowers, candidateSessions: candidateSessions, activeSessions: make(sessionQueueSet), summaries: chanSummaries, @@ -349,6 +356,55 @@ func New(config *Config) (*TowerClient, error) { return c, nil } +// getTowerAndSessionCandidates loads all the towers from the DB and then +// fetches the sessions for each of tower. Sessions are only collected if they +// pass the sessionFilter check. If a tower has a session that does pass the +// sessionFilter check then the perActiveTower call-back will be called on that +// tower. +func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, + sessionFilter func(*wtdb.ClientSession) bool, + perActiveTower func(tower *wtdb.Tower)) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { + + towers, err := db.ListTowers() + if err != nil { + return nil, err + } + + candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) + for _, tower := range towers { + sessions, err := db.ListClientSessions(&tower.ID) + if err != nil { + return nil, err + } + + for _, s := range sessions { + towerKeyDesc, err := keyRing.DeriveKey( + keychain.KeyLocator{ + Family: keychain.KeyFamilyTowerSession, + Index: s.KeyIndex, + }, + ) + if err != nil { + return nil, err + } + s.SessionKeyECDH = keychain.NewPubKeyECDH( + towerKeyDesc, keyRing, + ) + + if !sessionFilter(s) { + continue + } + + // Add the session to the set of candidate sessions. + candidateSessions[s.ID] = s + perActiveTower(tower) + } + } + + return candidateSessions, nil +} + // getClientSessions retrieves the client sessions for a particular tower if // specified, otherwise all client sessions for all towers are retrieved. An // optional filter can be provided to filter out any undesired client sessions. From 2294e11b7a068152de5e545f24bdd27685ee9aea Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 16:58:36 +0200 Subject: [PATCH 8/8] docs: update release notes --- docs/release-notes/release-notes-0.16.0.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/release-notes/release-notes-0.16.0.md b/docs/release-notes/release-notes-0.16.0.md index 17df7f4d1..402c8ae35 100644 --- a/docs/release-notes/release-notes-0.16.0.md +++ b/docs/release-notes/release-notes-0.16.0.md @@ -102,6 +102,14 @@ crash](https://github.com/lightningnetwork/lnd/pull/7019). * [The `tlv` package now allows decoding records larger than 65535 bytes. The caller is expected to know that doing so with untrusted input is unsafe.](https://github.com/lightningnetwork/lnd/pull/6779) + +## Watchtowers + +* [Create a towerID-to-sessionID index in the wtclient DB to improve the + speed of listing sessions for a particular tower ID]( + https://github.com/lightningnetwork/lnd/pull/6972). This PR also ensures a + closer coupling of Towers and Sessions and ensures that a session cannot be + added if the tower it is referring to does not exist. ### Tooling and documentation