watchtower: add filter function to ListTowers

And then only load active towers on client start up.
This commit is contained in:
Elle Mouton
2023-11-24 12:14:04 +02:00
parent ffd355c6c4
commit 0bb1816fff
4 changed files with 22 additions and 6 deletions

View File

@@ -261,7 +261,10 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*ClientSession, error) {
towers, err := db.ListTowers()
// Fetch all active towers from the DB.
towers, err := db.ListTowers(func(tower *wtdb.Tower) bool {
return tower.Status == wtdb.TowerStatusActive
})
if err != nil {
return nil, err
}

View File

@@ -41,8 +41,9 @@ type DB interface {
LoadTowerByID(wtdb.TowerID) (*wtdb.Tower, error)
// ListTowers retrieves the list of towers available within the
// database.
ListTowers() ([]*wtdb.Tower, error)
// database. The filter function may be set in order to filter out the
// towers to be returned.
ListTowers(filter wtdb.TowerFilterFn) ([]*wtdb.Tower, error)
// NextSessionKeyIndex reserves a new session key derivation index for a
// particular tower id and blob type. The index is reserved for that

View File

@@ -455,7 +455,7 @@ func (m *Manager) Stats() ClientStats {
func (m *Manager) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
map[blob.Type][]*RegisteredTower, error) {
towers, err := m.cfg.DB.ListTowers()
towers, err := m.cfg.DB.ListTowers(nil)
if err != nil {
return nil, err
}

View File

@@ -632,8 +632,14 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
return tower, nil
}
// ListTowers retrieves the list of towers available within the database.
func (c *ClientDB) ListTowers() ([]*Tower, error) {
// TowerFilterFn is the signature of a call-back function that can be used to
// skip certain towers in the ListTowers method.
type TowerFilterFn func(*Tower) bool
// ListTowers retrieves the list of towers available within the database that
// have a status matching the given status. The filter function may be set in
// order to filter out the towers to be returned.
func (c *ClientDB) ListTowers(filter TowerFilterFn) ([]*Tower, error) {
var towers []*Tower
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
towerBucket := tx.ReadBucket(cTowerBkt)
@@ -646,7 +652,13 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) {
if err != nil {
return err
}
if filter != nil && !filter(tower) {
return nil
}
towers = append(towers, tower)
return nil
})
}, func() {