watchtower: make use of the new tower-to-session index

In this commit, the towerID-to-sessionID index added in the previous
commit is put to use in order to make session lookup more efficient in
certain places. In the process, 2 TODO's are also removed from the code.
This commit is contained in:
Elle Mouton
2022-10-04 16:06:53 +02:00
parent 354a3b16bd
commit ecd2eb965a

View File

@ -293,28 +293,33 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
// If there are any client sessions that correspond to // If there are any client sessions that correspond to
// this tower, we'll mark them as active to ensure we // this tower, we'll mark them as active to ensure we
// load them upon restarts. // load them upon restarts.
// towerSessIndex := towerToSessionIndex.NestedReadBucket(
// TODO(wilmer): with an index of tower -> sessions we tower.ID.Bytes(),
// can avoid the linear lookup. )
if towerSessIndex == nil {
return ErrTowerNotFound
}
sessions := tx.ReadWriteBucket(cSessionBkt) sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions( err = towerSessIndex.ForEach(func(k, _ []byte) error {
sessions, towers, &towerID, session, err := getClientSessionBody(
sessions, k,
) )
if err != nil { if err != nil {
return err return err
} }
for _, session := range towerSessions {
err := markSessionStatus( return markSessionStatus(
sessions, session, CSessionActive, sessions, session, CSessionActive,
) )
})
if err != nil { if err != nil {
return err return err
} }
}
} else { } else {
// No such tower exists, create a new tower id for our // No such tower exists, create a new tower id for our
// new tower. The error is unhandled since NextSequence // new tower. The error is unhandled since NextSequence
@ -410,16 +415,13 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
// Otherwise, we should attempt to mark the tower's sessions as // Otherwise, we should attempt to mark the tower's sessions as
// inactive. // inactive.
//
// TODO(wilmer): with an index of tower -> sessions we can avoid
// the linear lookup.
sessions := tx.ReadWriteBucket(cSessionBkt) sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
towerID := TowerIDFromBytes(towerIDBytes) towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions( towerSessions, err := listTowerSessions(
sessions, towers, &towerID, towerID, sessions, towers, towersToSessionsIndex,
) )
if err != nil { if err != nil {
return err return err
@ -750,7 +752,25 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
} }
var err error var err error
clientSessions, err = listClientSessions(sessions, towers, id)
// If no tower ID is specified, then fetch all the sessions
// known to the db.
if id == nil {
clientSessions, err = listClientAllSessions(
sessions, towers,
)
return err
}
// Otherwise, fetch the sessions for the given tower.
towerToSessionIndex := tx.ReadBucket(cTowerToSessionIndexBkt)
if towerToSessionIndex == nil {
return ErrUninitializedDB
}
clientSessions, err = listTowerSessions(
*id, sessions, towers, towerToSessionIndex,
)
return err return err
}, func() { }, func() {
clientSessions = nil clientSessions = nil
@ -762,11 +782,9 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
return clientSessions, nil return clientSessions, nil
} }
// listClientSessions returns the set of all client sessions known to the db. An // listClientAllSessions returns the set of all client sessions known to the db.
// optional tower ID can be used to filter out any client sessions in the func listClientAllSessions(sessions,
// response that do not correspond to this tower. towers kvdb.RBucket) (map[SessionID]*ClientSession, error) {
func listClientSessions(sessions, towers kvdb.RBucket,
id *TowerID) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession) clientSessions := make(map[SessionID]*ClientSession)
err := sessions.ForEach(func(k, _ []byte) error { err := sessions.ForEach(func(k, _ []byte) error {
@ -779,14 +797,40 @@ func listClientSessions(sessions, towers kvdb.RBucket,
return err return err
} }
// Filter out any sessions that don't correspond to the given clientSessions[session.ID] = session
// tower if one was set.
if id != nil && session.TowerID != *id {
return nil return nil
})
if err != nil {
return nil, err
}
return clientSessions, nil
}
// listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id.
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession,
error) {
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
if towerIndexBkt == nil {
return nil, ErrTowerNotFound
}
clientSessions := make(map[SessionID]*ClientSession)
err := towerIndexBkt.ForEach(func(k, _ []byte) error {
// We'll load the full client session since the client will need
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := getClientSession(sessionsBkt, towersBkt, k)
if err != nil {
return err
} }
clientSessions[session.ID] = session clientSessions[session.ID] = session
return nil return nil
}) })
if err != nil { if err != nil {