watchtower: add filter function to ListTowers

And then only load active towers on client start up.
This commit is contained in:
Elle Mouton
2023-11-24 12:14:04 +02:00
parent ffd355c6c4
commit 0bb1816fff
4 changed files with 22 additions and 6 deletions

View File

@@ -261,7 +261,10 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
opts ...wtdb.ClientSessionListOption) ( opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*ClientSession, error) { map[wtdb.SessionID]*ClientSession, error) {
towers, err := db.ListTowers() // Fetch all active towers from the DB.
towers, err := db.ListTowers(func(tower *wtdb.Tower) bool {
return tower.Status == wtdb.TowerStatusActive
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -41,8 +41,9 @@ type DB interface {
LoadTowerByID(wtdb.TowerID) (*wtdb.Tower, error) LoadTowerByID(wtdb.TowerID) (*wtdb.Tower, error)
// ListTowers retrieves the list of towers available within the // ListTowers retrieves the list of towers available within the
// database. // database. The filter function may be set in order to filter out the
ListTowers() ([]*wtdb.Tower, error) // towers to be returned.
ListTowers(filter wtdb.TowerFilterFn) ([]*wtdb.Tower, error)
// NextSessionKeyIndex reserves a new session key derivation index for a // NextSessionKeyIndex reserves a new session key derivation index for a
// particular tower id and blob type. The index is reserved for that // particular tower id and blob type. The index is reserved for that

View File

@@ -455,7 +455,7 @@ func (m *Manager) Stats() ClientStats {
func (m *Manager) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( func (m *Manager) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
map[blob.Type][]*RegisteredTower, error) { map[blob.Type][]*RegisteredTower, error) {
towers, err := m.cfg.DB.ListTowers() towers, err := m.cfg.DB.ListTowers(nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -632,8 +632,14 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
return tower, nil return tower, nil
} }
// ListTowers retrieves the list of towers available within the database. // TowerFilterFn is the signature of a call-back function that can be used to
func (c *ClientDB) ListTowers() ([]*Tower, error) { // skip certain towers in the ListTowers method.
type TowerFilterFn func(*Tower) bool
// ListTowers retrieves the list of towers available within the database that
// have a status matching the given status. The filter function may be set in
// order to filter out the towers to be returned.
func (c *ClientDB) ListTowers(filter TowerFilterFn) ([]*Tower, error) {
var towers []*Tower var towers []*Tower
err := kvdb.View(c.db, func(tx kvdb.RTx) error { err := kvdb.View(c.db, func(tx kvdb.RTx) error {
towerBucket := tx.ReadBucket(cTowerBkt) towerBucket := tx.ReadBucket(cTowerBkt)
@@ -646,7 +652,13 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) {
if err != nil { if err != nil {
return err return err
} }
if filter != nil && !filter(tower) {
return nil
}
towers = append(towers, tower) towers = append(towers, tower)
return nil return nil
}) })
}, func() { }, func() {