diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 436905f76..3d23f0b82 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -287,26 +287,33 @@ func New(config *Config) (*TowerClient, error) { } plog := build.NewPrefixLog(prefix, log) - // Next, load all candidate sessions and towers from the database into - // the client. We will use any of these session if their policies match + // Next, load all candidate towers and sessions from the database into + // the client. We will use any of these sessions if their policies match // the current policy of the client, otherwise they will be ignored and // new sessions will be requested. isAnchorClient := cfg.Policy.IsAnchorChannel() activeSessionFilter := genActiveSessionFilter(isAnchorClient) - candidateSessions, err := getClientSessions( - cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter, + candidateTowers := newTowerListIterator() + perActiveTower := func(tower *wtdb.Tower) { + // If the tower has already been marked as active, then there is + // no need to add it to the iterator again. + if candidateTowers.IsActive(tower.ID) { + return + } + + log.Infof("Using private watchtower %s, offering policy %s", + tower, cfg.Policy) + + // Add the tower to the set of candidate towers. + candidateTowers.AddCandidate(tower) + } + candidateSessions, err := getTowerAndSessionCandidates( + cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, ) if err != nil { return nil, err } - var candidateTowers []*wtdb.Tower - for _, s := range candidateSessions { - plog.Infof("Using private watchtower %s, offering policy %s", - s.Tower, cfg.Policy) - candidateTowers = append(candidateTowers, s.Tower) - } - // Load the sweep pkscripts that have been generated for all previously // registered channels. chanSummaries, err := cfg.DB.FetchChanSummaries() @@ -318,7 +325,7 @@ func New(config *Config) (*TowerClient, error) { cfg: cfg, log: plog, pipeline: newTaskPipeline(plog), - candidateTowers: newTowerListIterator(candidateTowers...), + candidateTowers: candidateTowers, candidateSessions: candidateSessions, activeSessions: make(sessionQueueSet), summaries: chanSummaries, @@ -349,6 +356,55 @@ func New(config *Config) (*TowerClient, error) { return c, nil } +// getTowerAndSessionCandidates loads all the towers from the DB and then +// fetches the sessions for each of tower. Sessions are only collected if they +// pass the sessionFilter check. If a tower has a session that does pass the +// sessionFilter check then the perActiveTower call-back will be called on that +// tower. +func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, + sessionFilter func(*wtdb.ClientSession) bool, + perActiveTower func(tower *wtdb.Tower)) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { + + towers, err := db.ListTowers() + if err != nil { + return nil, err + } + + candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) + for _, tower := range towers { + sessions, err := db.ListClientSessions(&tower.ID) + if err != nil { + return nil, err + } + + for _, s := range sessions { + towerKeyDesc, err := keyRing.DeriveKey( + keychain.KeyLocator{ + Family: keychain.KeyFamilyTowerSession, + Index: s.KeyIndex, + }, + ) + if err != nil { + return nil, err + } + s.SessionKeyECDH = keychain.NewPubKeyECDH( + towerKeyDesc, keyRing, + ) + + if !sessionFilter(s) { + continue + } + + // Add the session to the set of candidate sessions. + candidateSessions[s.ID] = s + perActiveTower(tower) + } + } + + return candidateSessions, nil +} + // getClientSessions retrieves the client sessions for a particular tower if // specified, otherwise all client sessions for all towers are retrieved. An // optional filter can be provided to filter out any undesired client sessions.