From 105c44df9b0c35a182d094a523868580abf499b6 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 16:26:36 +0200 Subject: [PATCH] watchtower: use more efficient session query on startup In this commit, the functions used to fetch candidate sessions and towers on creation of the watchtower Client are changed to make use of the more efficient lookup functions. Previously, all sessions were listed from the DB and then these were used to collect the active towers which in certain situations lead to some users getting the "tower not found" error on start up. With this commit, we instead first list all Towers in the DB and then we fetch the sessions for each of those towers. --- watchtower/wtclient/client.go | 80 +++++++++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 436905f76..3d23f0b82 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -287,26 +287,33 @@ func New(config *Config) (*TowerClient, error) { } plog := build.NewPrefixLog(prefix, log) - // Next, load all candidate sessions and towers from the database into - // the client. We will use any of these session if their policies match + // Next, load all candidate towers and sessions from the database into + // the client. We will use any of these sessions if their policies match // the current policy of the client, otherwise they will be ignored and // new sessions will be requested. isAnchorClient := cfg.Policy.IsAnchorChannel() activeSessionFilter := genActiveSessionFilter(isAnchorClient) - candidateSessions, err := getClientSessions( - cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter, + candidateTowers := newTowerListIterator() + perActiveTower := func(tower *wtdb.Tower) { + // If the tower has already been marked as active, then there is + // no need to add it to the iterator again. + if candidateTowers.IsActive(tower.ID) { + return + } + + log.Infof("Using private watchtower %s, offering policy %s", + tower, cfg.Policy) + + // Add the tower to the set of candidate towers. + candidateTowers.AddCandidate(tower) + } + candidateSessions, err := getTowerAndSessionCandidates( + cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, ) if err != nil { return nil, err } - var candidateTowers []*wtdb.Tower - for _, s := range candidateSessions { - plog.Infof("Using private watchtower %s, offering policy %s", - s.Tower, cfg.Policy) - candidateTowers = append(candidateTowers, s.Tower) - } - // Load the sweep pkscripts that have been generated for all previously // registered channels. chanSummaries, err := cfg.DB.FetchChanSummaries() @@ -318,7 +325,7 @@ func New(config *Config) (*TowerClient, error) { cfg: cfg, log: plog, pipeline: newTaskPipeline(plog), - candidateTowers: newTowerListIterator(candidateTowers...), + candidateTowers: candidateTowers, candidateSessions: candidateSessions, activeSessions: make(sessionQueueSet), summaries: chanSummaries, @@ -349,6 +356,55 @@ func New(config *Config) (*TowerClient, error) { return c, nil } +// getTowerAndSessionCandidates loads all the towers from the DB and then +// fetches the sessions for each of tower. Sessions are only collected if they +// pass the sessionFilter check. If a tower has a session that does pass the +// sessionFilter check then the perActiveTower call-back will be called on that +// tower. +func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, + sessionFilter func(*wtdb.ClientSession) bool, + perActiveTower func(tower *wtdb.Tower)) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { + + towers, err := db.ListTowers() + if err != nil { + return nil, err + } + + candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) + for _, tower := range towers { + sessions, err := db.ListClientSessions(&tower.ID) + if err != nil { + return nil, err + } + + for _, s := range sessions { + towerKeyDesc, err := keyRing.DeriveKey( + keychain.KeyLocator{ + Family: keychain.KeyFamilyTowerSession, + Index: s.KeyIndex, + }, + ) + if err != nil { + return nil, err + } + s.SessionKeyECDH = keychain.NewPubKeyECDH( + towerKeyDesc, keyRing, + ) + + if !sessionFilter(s) { + continue + } + + // Add the session to the set of candidate sessions. + candidateSessions[s.ID] = s + perActiveTower(tower) + } + } + + return candidateSessions, nil +} + // getClientSessions retrieves the client sessions for a particular tower if // specified, otherwise all client sessions for all towers are retrieved. An // optional filter can be provided to filter out any undesired client sessions.