From ecd2eb965a75fd1a0dc55fe533f5ac8ebe529209 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 16:06:53 +0200 Subject: [PATCH] watchtower: make use of the new tower-to-session index In this commit, the towerID-to-sessionID index added in the previous commit is put to use in order to make session lookup more efficient in certain places. In the process, 2 TODO's are also removed from the code. --- watchtower/wtdb/client_db.go | 102 +++++++++++++++++++++++++---------- 1 file changed, 73 insertions(+), 29 deletions(-) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 537c8cc73..3cb5a8c70 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -293,27 +293,32 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { // If there are any client sessions that correspond to // this tower, we'll mark them as active to ensure we // load them upon restarts. - // - // TODO(wilmer): with an index of tower -> sessions we - // can avoid the linear lookup. + towerSessIndex := towerToSessionIndex.NestedReadBucket( + tower.ID.Bytes(), + ) + if towerSessIndex == nil { + return ErrTowerNotFound + } + sessions := tx.ReadWriteBucket(cSessionBkt) if sessions == nil { return ErrUninitializedDB } - towerID := TowerIDFromBytes(towerIDBytes) - towerSessions, err := listClientSessions( - sessions, towers, &towerID, - ) - if err != nil { - return err - } - for _, session := range towerSessions { - err := markSessionStatus( - sessions, session, CSessionActive, + + err = towerSessIndex.ForEach(func(k, _ []byte) error { + session, err := getClientSessionBody( + sessions, k, ) if err != nil { return err } + + return markSessionStatus( + sessions, session, CSessionActive, + ) + }) + if err != nil { + return err } } else { // No such tower exists, create a new tower id for our @@ -410,16 +415,13 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { // Otherwise, we should attempt to mark the tower's sessions as // inactive. - // - // TODO(wilmer): with an index of tower -> sessions we can avoid - // the linear lookup. sessions := tx.ReadWriteBucket(cSessionBkt) if sessions == nil { return ErrUninitializedDB } towerID := TowerIDFromBytes(towerIDBytes) - towerSessions, err := listClientSessions( - sessions, towers, &towerID, + towerSessions, err := listTowerSessions( + towerID, sessions, towers, towersToSessionsIndex, ) if err != nil { return err @@ -750,7 +752,25 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( } var err error - clientSessions, err = listClientSessions(sessions, towers, id) + + // If no tower ID is specified, then fetch all the sessions + // known to the db. + if id == nil { + clientSessions, err = listClientAllSessions( + sessions, towers, + ) + return err + } + + // Otherwise, fetch the sessions for the given tower. + towerToSessionIndex := tx.ReadBucket(cTowerToSessionIndexBkt) + if towerToSessionIndex == nil { + return ErrUninitializedDB + } + + clientSessions, err = listTowerSessions( + *id, sessions, towers, towerToSessionIndex, + ) return err }, func() { clientSessions = nil @@ -762,11 +782,9 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( return clientSessions, nil } -// listClientSessions returns the set of all client sessions known to the db. An -// optional tower ID can be used to filter out any client sessions in the -// response that do not correspond to this tower. -func listClientSessions(sessions, towers kvdb.RBucket, - id *TowerID) (map[SessionID]*ClientSession, error) { +// listClientAllSessions returns the set of all client sessions known to the db. +func listClientAllSessions(sessions, + towers kvdb.RBucket) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) err := sessions.ForEach(func(k, _ []byte) error { @@ -779,14 +797,40 @@ func listClientSessions(sessions, towers kvdb.RBucket, return err } - // Filter out any sessions that don't correspond to the given - // tower if one was set. - if id != nil && session.TowerID != *id { - return nil + clientSessions[session.ID] = session + + return nil + }) + if err != nil { + return nil, err + } + + return clientSessions, nil +} + +// listTowerSessions returns the set of all client sessions known to the db +// that are associated with the given tower id. +func listTowerSessions(id TowerID, sessionsBkt, towersBkt, + towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession, + error) { + + towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes()) + if towerIndexBkt == nil { + return nil, ErrTowerNotFound + } + + clientSessions := make(map[SessionID]*ClientSession) + err := towerIndexBkt.ForEach(func(k, _ []byte) error { + // We'll load the full client session since the client will need + // the CommittedUpdates and AckedUpdates on startup to resume + // committed updates and compute the highest known commit height + // for each channel. + session, err := getClientSession(sessionsBkt, towersBkt, k) + if err != nil { + return err } clientSessions[session.ID] = session - return nil }) if err != nil {