watchtower: always populate Tower in ClientSession

In this commit, we make sure to always populate the Tower member of a
ClientSession. This is done for consistency.
This commit is contained in:
Elle Mouton 2022-10-04 15:18:40 +02:00
parent e150bb83d1
commit c60ecaccbf
No known key found for this signature in database
GPG Key ID: D7D916376026F177
3 changed files with 28 additions and 18 deletions

View File

@ -354,8 +354,8 @@ func New(config *Config) (*TowerClient, error) {
// optional filter can be provided to filter out any undesired client sessions. // optional filter can be provided to filter out any undesired client sessions.
// //
// NOTE: This method should only be used when deserialization of a // NOTE: This method should only be used when deserialization of a
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the // ClientSession's SessionPrivKey field is desired, otherwise, the existing
// existing ListClientSessions method should be used. // ListClientSessions method should be used.
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
passesFilter func(*wtdb.ClientSession) bool) ( passesFilter func(*wtdb.ClientSession) bool) (
map[wtdb.SessionID]*wtdb.ClientSession, error) { map[wtdb.SessionID]*wtdb.ClientSession, error) {
@ -371,12 +371,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
// requests. This prevents us from having to store the private keys on // requests. This prevents us from having to store the private keys on
// disk. // disk.
for _, s := range sessions { for _, s := range sessions {
tower, err := db.LoadTowerByID(s.TowerID)
if err != nil {
return nil, err
}
s.Tower = tower
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{ towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession, Family: keychain.KeyFamilyTowerSession,
Index: s.KeyIndex, Index: s.KeyIndex,

View File

@ -288,7 +288,7 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
} }
towerID := TowerIDFromBytes(towerIDBytes) towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions( towerSessions, err := listClientSessions(
sessions, &towerID, sessions, towers, &towerID,
) )
if err != nil { if err != nil {
return err return err
@ -389,7 +389,9 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return ErrUninitializedDB return ErrUninitializedDB
} }
towerID := TowerIDFromBytes(towerIDBytes) towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions(sessions, &towerID) towerSessions, err := listClientSessions(
sessions, towers, &towerID,
)
if err != nil { if err != nil {
return err return err
} }
@ -685,8 +687,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
towers := tx.ReadBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
var err error var err error
clientSessions, err = listClientSessions(sessions, id) clientSessions, err = listClientSessions(sessions, towers, id)
return err return err
}, func() { }, func() {
clientSessions = nil clientSessions = nil
@ -701,7 +709,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
// listClientSessions returns the set of all client sessions known to the db. An // listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower. // response that do not correspond to this tower.
func listClientSessions(sessions kvdb.RBucket, func listClientSessions(sessions, towers kvdb.RBucket,
id *TowerID) (map[SessionID]*ClientSession, error) { id *TowerID) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession) clientSessions := make(map[SessionID]*ClientSession)
@ -710,7 +718,7 @@ func listClientSessions(sessions kvdb.RBucket,
// the CommittedUpdates and AckedUpdates on startup to resume // the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height // committed updates and compute the highest known commit height
// for each channel. // for each channel.
session, err := getClientSession(sessions, k) session, err := getClientSession(sessions, towers, k)
if err != nil { if err != nil {
return err return err
} }
@ -1022,8 +1030,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
// getClientSessionBody loads the body of a ClientSession from the sessions // getClientSessionBody loads the body of a ClientSession from the sessions
// bucket corresponding to the serialized session id. This does not deserialize // bucket corresponding to the serialized session id. This does not deserialize
// the CommittedUpdates or AckUpdates associated with the session. If the caller // the CommittedUpdates, AckUpdates or the Tower associated with the session.
// requires this info, use getClientSession. // If the caller requires this info, use getClientSession.
func getClientSessionBody(sessions kvdb.RBucket, func getClientSessionBody(sessions kvdb.RBucket,
idBytes []byte) (*ClientSession, error) { idBytes []byte) (*ClientSession, error) {
@ -1050,9 +1058,9 @@ func getClientSessionBody(sessions kvdb.RBucket,
} }
// getClientSession loads the full ClientSession associated with the serialized // getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates and AckUpdates in // session id. This method populates the CommittedUpdates, AckUpdates and Tower
// addition to the ClientSession's body. // in addition to the ClientSession's body.
func getClientSession(sessions kvdb.RBucket, func getClientSession(sessions, towers kvdb.RBucket,
idBytes []byte) (*ClientSession, error) { idBytes []byte) (*ClientSession, error) {
session, err := getClientSessionBody(sessions, idBytes) session, err := getClientSessionBody(sessions, idBytes)
@ -1060,6 +1068,12 @@ func getClientSession(sessions kvdb.RBucket,
return nil, err return nil, err
} }
// Fetch the tower associated with this session.
tower, err := getTower(towers, session.TowerID.Bytes())
if err != nil {
return nil, err
}
// Fetch the committed updates for this session. // Fetch the committed updates for this session.
commitedUpdates, err := getClientSessionCommits(sessions, idBytes) commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
if err != nil { if err != nil {
@ -1072,6 +1086,7 @@ func getClientSession(sessions kvdb.RBucket,
return nil, err return nil, err
} }
session.Tower = tower
session.CommittedUpdates = commitedUpdates session.CommittedUpdates = commitedUpdates
session.AckedUpdates = ackedUpdates session.AckedUpdates = ackedUpdates

View File

@ -220,6 +220,7 @@ func (m *ClientDB) listClientSessions(
if tower != nil && *tower != session.TowerID { if tower != nil && *tower != session.TowerID {
continue continue
} }
session.Tower = m.towers[session.TowerID]
sessions[session.ID] = &session sessions[session.ID] = &session
} }