diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 8cd2b7e7d..32622b7da 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -42,12 +42,24 @@ const ( DefaultForceQuitDelay = 10 * time.Second ) -// genActiveSessionFilter generates a filter that selects active sessions that -// also match the desired channel type, either legacy or anchor. -func genActiveSessionFilter(anchor bool) func(*ClientSession) bool { - return func(s *ClientSession) bool { - return s.Status == wtdb.CSessionActive && - anchor == s.Policy.IsAnchorChannel() +// genSessionFilter constructs a filter that can be used to select sessions only +// if they match the policy of the client (namely anchor vs legacy). If +// activeOnly is set, then only active sessions will be returned. +func (c *TowerClient) genSessionFilter( + activeOnly bool) wtdb.ClientSessionFilterFn { + + return func(session *wtdb.ClientSession) bool { + if c.cfg.Policy.IsAnchorChannel() != + session.Policy.IsAnchorChannel() { + + return false + } + + if !activeOnly { + return true + } + + return session.Status == wtdb.CSessionActive } } @@ -344,13 +356,6 @@ func New(config *Config) (*TowerClient, error) { perUpdate(s.Policy, u.BackupID.ChanID, u.BackupID.CommitHeight) } - // Load all candidate sessions and towers from the database into the - // client. We will use any of these sessions if their policies match the - // current policy of the client, otherwise they will be ignored and new - // sessions will be requested. - isAnchorClient := cfg.Policy.IsAnchorChannel() - activeSessionFilter := genActiveSessionFilter(isAnchorClient) - candidateTowers := newTowerListIterator() perActiveTower := func(tower *Tower) { // If the tower has already been marked as active, then there is @@ -366,9 +371,13 @@ func New(config *Config) (*TowerClient, error) { candidateTowers.AddCandidate(tower) } + // Load all candidate sessions and towers from the database into the + // client. We will use any of these sessions if their policies match the + // current policy of the client, otherwise they will be ignored and new + // sessions will be requested. candidateSessions, err := getTowerAndSessionCandidates( - cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, - wtdb.WithPerMaxHeight(perMaxHeight), + cfg.DB, cfg.SecretKeyRing, c.genSessionFilter(true), + perActiveTower, wtdb.WithPerMaxHeight(perMaxHeight), wtdb.WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { @@ -401,7 +410,7 @@ func New(config *Config) (*TowerClient, error) { // sessionFilter check then the perActiveTower call-back will be called on that // tower. func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, - sessionFilter func(*ClientSession) bool, + sessionFilter wtdb.ClientSessionFilterFn, perActiveTower func(tower *Tower), opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*ClientSession, error) { @@ -418,7 +427,9 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, return nil, err } - sessions, err := db.ListClientSessions(&tower.ID, nil, opts...) + sessions, err := db.ListClientSessions( + &tower.ID, sessionFilter, opts..., + ) if err != nil { return nil, err } @@ -438,19 +449,14 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, towerKeyDesc, keyRing, ) - cs := &ClientSession{ + // Add the session to the set of candidate sessions. + candidateSessions[s.ID] = &ClientSession{ ID: s.ID, ClientSessionBody: s.ClientSessionBody, Tower: tower, SessionKeyECDH: sessionKeyECDH, } - if !sessionFilter(cs) { - continue - } - - // Add the session to the set of candidate sessions. - candidateSessions[s.ID] = cs perActiveTower(tower) } } @@ -466,11 +472,13 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, // ClientSession's SessionPrivKey field is desired, otherwise, the existing // ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, - passesFilter func(*ClientSession) bool, + sessionFilter wtdb.ClientSessionFilterFn, opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*ClientSession, error) { - dbSessions, err := db.ListClientSessions(forTower, nil, opts...) + dbSessions, err := db.ListClientSessions( + forTower, sessionFilter, opts..., + ) if err != nil { return nil, err } @@ -494,6 +502,7 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, if err != nil { return nil, err } + sessionKeyECDH := keychain.NewPubKeyECDH(towerKeyDesc, keyRing) tower, err := NewTowerFromDBTower(dbTower) @@ -501,20 +510,12 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, return nil, err } - cs := &ClientSession{ + sessions[s.ID] = &ClientSession{ ID: s.ID, ClientSessionBody: s.ClientSessionBody, Tower: tower, SessionKeyECDH: sessionKeyECDH, } - - // If an optional filter was provided, use it to filter out any - // undesired sessions. - if passesFilter != nil && !passesFilter(cs) { - continue - } - - sessions[s.ID] = cs } return sessions, nil @@ -1200,10 +1201,9 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { c.candidateTowers.AddCandidate(tower) // Include all of its corresponding sessions to our set of candidates. - isAnchorClient := c.cfg.Policy.IsAnchorChannel() - activeSessionFilter := genActiveSessionFilter(isAnchorClient) sessions, err := getClientSessions( - c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID, activeSessionFilter, + c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID, + c.genSessionFilter(true), ) if err != nil { return fmt.Errorf("unable to determine sessions for tower %x: "+ @@ -1320,7 +1320,9 @@ func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( if err != nil { return nil, err } - clientSessions, err := c.cfg.DB.ListClientSessions(nil, nil, opts...) + clientSessions, err := c.cfg.DB.ListClientSessions( + nil, c.genSessionFilter(false), opts..., + ) if err != nil { return nil, err } @@ -1362,7 +1364,7 @@ func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey, } towerSessions, err := c.cfg.DB.ListClientSessions( - &tower.ID, nil, opts..., + &tower.ID, c.genSessionFilter(false), opts..., ) if err != nil { return nil, err