From c60ecaccbf921b86c58a6b27b3f97e39e85d2b1f Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 15:18:40 +0200 Subject: [PATCH] watchtower: always populate Tower in ClientSession In this commit, we make sure to always populate the Tower member of a ClientSession. This is done for consistency. --- watchtower/wtclient/client.go | 10 ++-------- watchtower/wtdb/client_db.go | 35 ++++++++++++++++++++++++---------- watchtower/wtmock/client_db.go | 1 + 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 720bdcca7..436905f76 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -354,8 +354,8 @@ func New(config *Config) (*TowerClient, error) { // optional filter can be provided to filter out any undesired client sessions. // // NOTE: This method should only be used when deserialization of a -// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the -// existing ListClientSessions method should be used. +// ClientSession's SessionPrivKey field is desired, otherwise, the existing +// ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, passesFilter func(*wtdb.ClientSession) bool) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { @@ -371,12 +371,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, // requests. This prevents us from having to store the private keys on // disk. for _, s := range sessions { - tower, err := db.LoadTowerByID(s.TowerID) - if err != nil { - return nil, err - } - s.Tower = tower - towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, Index: s.KeyIndex, diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 7f4d7d6f5..9d8383cb5 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -288,7 +288,7 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { } towerID := TowerIDFromBytes(towerIDBytes) towerSessions, err := listClientSessions( - sessions, &towerID, + sessions, towers, &towerID, ) if err != nil { return err @@ -389,7 +389,9 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { return ErrUninitializedDB } towerID := TowerIDFromBytes(towerIDBytes) - towerSessions, err := listClientSessions(sessions, &towerID) + towerSessions, err := listClientSessions( + sessions, towers, &towerID, + ) if err != nil { return err } @@ -685,8 +687,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( if sessions == nil { return ErrUninitializedDB } + + towers := tx.ReadBucket(cTowerBkt) + if towers == nil { + return ErrUninitializedDB + } + var err error - clientSessions, err = listClientSessions(sessions, id) + clientSessions, err = listClientSessions(sessions, towers, id) return err }, func() { clientSessions = nil @@ -701,7 +709,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( // listClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func listClientSessions(sessions kvdb.RBucket, +func listClientSessions(sessions, towers kvdb.RBucket, id *TowerID) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) @@ -710,7 +718,7 @@ func listClientSessions(sessions kvdb.RBucket, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessions, k) + session, err := getClientSession(sessions, towers, k) if err != nil { return err } @@ -1022,8 +1030,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, // getClientSessionBody loads the body of a ClientSession from the sessions // bucket corresponding to the serialized session id. This does not deserialize -// the CommittedUpdates or AckUpdates associated with the session. If the caller -// requires this info, use getClientSession. +// the CommittedUpdates, AckUpdates or the Tower associated with the session. +// If the caller requires this info, use getClientSession. func getClientSessionBody(sessions kvdb.RBucket, idBytes []byte) (*ClientSession, error) { @@ -1050,9 +1058,9 @@ func getClientSessionBody(sessions kvdb.RBucket, } // getClientSession loads the full ClientSession associated with the serialized -// session id. This method populates the CommittedUpdates and AckUpdates in -// addition to the ClientSession's body. -func getClientSession(sessions kvdb.RBucket, +// session id. This method populates the CommittedUpdates, AckUpdates and Tower +// in addition to the ClientSession's body. +func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte) (*ClientSession, error) { session, err := getClientSessionBody(sessions, idBytes) @@ -1060,6 +1068,12 @@ func getClientSession(sessions kvdb.RBucket, return nil, err } + // Fetch the tower associated with this session. + tower, err := getTower(towers, session.TowerID.Bytes()) + if err != nil { + return nil, err + } + // Fetch the committed updates for this session. commitedUpdates, err := getClientSessionCommits(sessions, idBytes) if err != nil { @@ -1072,6 +1086,7 @@ func getClientSession(sessions kvdb.RBucket, return nil, err } + session.Tower = tower session.CommittedUpdates = commitedUpdates session.AckedUpdates = ackedUpdates diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 28dafd04c..2a3825e87 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -220,6 +220,7 @@ func (m *ClientDB) listClientSessions( if tower != nil && *tower != session.TowerID { continue } + session.Tower = m.towers[session.TowerID] sessions[session.ID] = &session }