watchtower: add ListClientSessions functional options

This commit adds functional options to the ListClientSessions call that
can be used to perform a variety of extra operations during the DB
query. These functional options are not yet used in this commit.
This commit is contained in:
Elle Mouton
2022-10-13 13:46:52 +02:00
parent 3ac3b6a90d
commit 40e0ebf417
4 changed files with 113 additions and 36 deletions

View File

@@ -83,10 +83,12 @@ type Client interface {
// RegisteredTowers retrieves the list of watchtowers registered with // RegisteredTowers retrieves the list of watchtowers registered with
// the client. // the client.
RegisteredTowers() ([]*RegisteredTower, error) RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower,
error)
// LookupTower retrieves a registered watchtower through its public key. // LookupTower retrieves a registered watchtower through its public key.
LookupTower(*btcec.PublicKey) (*RegisteredTower, error) LookupTower(*btcec.PublicKey,
...wtdb.ClientSessionListOption) (*RegisteredTower, error)
// Stats returns the in-memory statistics of the client since startup. // Stats returns the in-memory statistics of the client since startup.
Stats() ClientStats Stats() ClientStats
@@ -363,7 +365,8 @@ func New(config *Config) (*TowerClient, error) {
// tower. // tower.
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
sessionFilter func(*wtdb.ClientSession) bool, sessionFilter func(*wtdb.ClientSession) bool,
perActiveTower func(tower *wtdb.Tower)) ( perActiveTower func(tower *wtdb.Tower),
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) { map[wtdb.SessionID]*wtdb.ClientSession, error) {
towers, err := db.ListTowers() towers, err := db.ListTowers()
@@ -373,7 +376,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, tower := range towers { for _, tower := range towers {
sessions, err := db.ListClientSessions(&tower.ID) sessions, err := db.ListClientSessions(&tower.ID, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -413,10 +416,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
// ClientSession's SessionPrivKey field is desired, otherwise, the existing // ClientSession's SessionPrivKey field is desired, otherwise, the existing
// ListClientSessions method should be used. // ListClientSessions method should be used.
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
passesFilter func(*wtdb.ClientSession) bool) ( passesFilter func(*wtdb.ClientSession) bool,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) { map[wtdb.SessionID]*wtdb.ClientSession, error) {
sessions, err := db.ListClientSessions(forTower) sessions, err := db.ListClientSessions(forTower, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1233,13 +1237,15 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// RegisteredTowers retrieves the list of watchtowers registered with the // RegisteredTowers retrieves the list of watchtowers registered with the
// client. // client.
func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) { func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
[]*RegisteredTower, error) {
// Retrieve all of our towers along with all of our sessions. // Retrieve all of our towers along with all of our sessions.
towers, err := c.cfg.DB.ListTowers() towers, err := c.cfg.DB.ListTowers()
if err != nil { if err != nil {
return nil, err return nil, err
} }
clientSessions, err := c.cfg.DB.ListClientSessions(nil) clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1272,13 +1278,15 @@ func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
} }
// LookupTower retrieves a registered watchtower through its public key. // LookupTower retrieves a registered watchtower through its public key.
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey) (*RegisteredTower, error) { func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) {
tower, err := c.cfg.DB.LoadTower(pubKey) tower, err := c.cfg.DB.LoadTower(pubKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID) towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -62,7 +62,7 @@ type DB interface {
// still be able to accept state updates. An optional tower ID can be // still be able to accept state updates. An optional tower ID can be
// used to filter out any client sessions in the response that do not // used to filter out any client sessions in the response that do not
// correspond to this tower. // correspond to this tower.
ListClientSessions(*wtdb.TowerID) ( ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) map[wtdb.SessionID]*wtdb.ClientSession, error)
// FetchChanSummaries loads a mapping from all registered channels to // FetchChanSummaries loads a mapping from all registered channels to

View File

@@ -736,8 +736,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
// ListClientSessions returns the set of all client sessions known to the db. An // ListClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower. // response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID) ( func (c *ClientDB) ListClientSessions(id *TowerID,
map[SessionID]*ClientSession, error) { opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession var clientSessions map[SessionID]*ClientSession
err := kvdb.View(c.db, func(tx kvdb.RTx) error { err := kvdb.View(c.db, func(tx kvdb.RTx) error {
@@ -757,7 +757,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
// known to the db. // known to the db.
if id == nil { if id == nil {
clientSessions, err = listClientAllSessions( clientSessions, err = listClientAllSessions(
sessions, towers, sessions, towers, opts...,
) )
return err return err
} }
@@ -769,7 +769,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
} }
clientSessions, err = listTowerSessions( clientSessions, err = listTowerSessions(
*id, sessions, towers, towerToSessionIndex, *id, sessions, towers, towerToSessionIndex, opts...,
) )
return err return err
}, func() { }, func() {
@@ -783,8 +783,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
} }
// listClientAllSessions returns the set of all client sessions known to the db. // listClientAllSessions returns the set of all client sessions known to the db.
func listClientAllSessions(sessions, func listClientAllSessions(sessions, towers kvdb.RBucket,
towers kvdb.RBucket) (map[SessionID]*ClientSession, error) { opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession) clientSessions := make(map[SessionID]*ClientSession)
err := sessions.ForEach(func(k, _ []byte) error { err := sessions.ForEach(func(k, _ []byte) error {
@@ -792,7 +792,7 @@ func listClientAllSessions(sessions,
// the CommittedUpdates and AckedUpdates on startup to resume // the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height // committed updates and compute the highest known commit height
// for each channel. // for each channel.
session, err := getClientSession(sessions, towers, k) session, err := getClientSession(sessions, towers, k, opts...)
if err != nil { if err != nil {
return err return err
} }
@@ -811,8 +811,8 @@ func listClientAllSessions(sessions,
// listTowerSessions returns the set of all client sessions known to the db // listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id. // that are associated with the given tower id.
func listTowerSessions(id TowerID, sessionsBkt, towersBkt, func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession, towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
error) { map[SessionID]*ClientSession, error) {
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes()) towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
if towerIndexBkt == nil { if towerIndexBkt == nil {
@@ -825,7 +825,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
// the CommittedUpdates and AckedUpdates on startup to resume // the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height // committed updates and compute the highest known commit height
// for each channel. // for each channel.
session, err := getClientSession(sessionsBkt, towersBkt, k) session, err := getClientSession(
sessionsBkt, towersBkt, k, opts...,
)
if err != nil { if err != nil {
return err return err
} }
@@ -1157,11 +1159,63 @@ func getClientSessionBody(sessions kvdb.RBucket,
return &session, nil return &session, nil
} }
// PerAckedUpdateCB describes the signature of a callback function that can be
// called for each of a session's acked updates.
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
// PerCommittedUpdateCB describes the signature of a callback function that can
// be called for each of a session's committed updates (updates that the client
// has not yet received an ACK for).
type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate)
// ClientSessionListOption describes the signature of a functional option that
// can be used when listing client sessions in order to provide any extra
// instruction to the query.
type ClientSessionListOption func(cfg *ClientSessionListCfg)
// ClientSessionListCfg defines various query parameters that will be used when
// querying the DB for client sessions.
type ClientSessionListCfg struct {
// PerAckedUpdate will, if set, be called for each of the session's
// acked updates.
PerAckedUpdate PerAckedUpdateCB
// PerCommittedUpdate will, if set, be called for each of the session's
// committed (un-acked) updates.
PerCommittedUpdate PerCommittedUpdateCB
}
// NewClientSessionCfg constructs a new ClientSessionListCfg.
func NewClientSessionCfg() *ClientSessionListCfg {
return &ClientSessionListCfg{}
}
// WithPerAckedUpdate constructs a functional option that will set a call-back
// function to be called for each of a client's acked updates.
func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PerAckedUpdate = cb
}
}
// WithPerCommittedUpdate constructs a functional option that will set a
// call-back function to be called for each of a client's un-acked updates.
func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PerCommittedUpdate = cb
}
}
// getClientSession loads the full ClientSession associated with the serialized // getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates, AckUpdates and Tower // session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body. // in addition to the ClientSession's body.
func getClientSession(sessions, towers kvdb.RBucket, func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
idBytes []byte) (*ClientSession, error) { opts ...ClientSessionListOption) (*ClientSession, error) {
cfg := NewClientSessionCfg()
for _, o := range opts {
o(cfg)
}
session, err := getClientSessionBody(sessions, idBytes) session, err := getClientSessionBody(sessions, idBytes)
if err != nil { if err != nil {
@@ -1178,13 +1232,17 @@ func getClientSession(sessions, towers kvdb.RBucket,
sessionBkt := sessions.NestedReadBucket(idBytes) sessionBkt := sessions.NestedReadBucket(idBytes)
// Fetch the committed updates for this session. // Fetch the committed updates for this session.
commitedUpdates, err := getClientSessionCommits(sessionBkt) commitedUpdates, err := getClientSessionCommits(
sessionBkt, session, cfg.PerCommittedUpdate,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Fetch the acked updates for this session. // Fetch the acked updates for this session.
ackedUpdates, err := getClientSessionAcks(sessionBkt) ackedUpdates, err := getClientSessionAcks(
sessionBkt, session, cfg.PerAckedUpdate,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1197,11 +1255,12 @@ func getClientSession(sessions, towers kvdb.RBucket,
} }
// getClientSessionCommits retrieves all committed updates for the session // getClientSessionCommits retrieves all committed updates for the session
// identified by the serialized session id. // identified by the serialized session id. If a PerCommittedUpdateCB is
func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate, // provided, then it will be called for each of the session's committed updates.
error) { func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerCommittedUpdateCB) ([]CommittedUpdate, error) {
// Initialize commitedUpdates so that we can return an initialized map // Initialize committedUpdates so that we can return an initialized map
// if no committed updates exist. // if no committed updates exist.
committedUpdates := make([]CommittedUpdate, 0) committedUpdates := make([]CommittedUpdate, 0)
@@ -1220,6 +1279,10 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
committedUpdates = append(committedUpdates, committedUpdate) committedUpdates = append(committedUpdates, committedUpdate)
if cb != nil {
cb(s, &committedUpdate)
}
return nil return nil
}) })
if err != nil { if err != nil {
@@ -1231,8 +1294,8 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
// getClientSessionAcks retrieves all acked updates for the session identified // getClientSessionAcks retrieves all acked updates for the session identified
// by the serialized session id. // by the serialized session id.
func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID, func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
error) { cb PerAckedUpdateCB) (map[uint16]BackupID, error) {
// Initialize ackedUpdates so that we can return an initialized map if // Initialize ackedUpdates so that we can return an initialized map if
// no acked updates exist. // no acked updates exist.
@@ -1254,6 +1317,10 @@ func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID,
ackedUpdates[seqNum] = backupID ackedUpdates[seqNum] = backupID
if cb != nil {
cb(s, seqNum, backupID)
}
return nil return nil
}) })
if err != nil { if err != nil {

View File

@@ -200,19 +200,21 @@ func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight ui
// ListClientSessions returns the set of all client sessions known to the db. An // ListClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower. // response that do not correspond to this tower.
func (m *ClientDB) ListClientSessions( func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
return m.listClientSessions(tower) return m.listClientSessions(tower, opts...)
} }
// listClientSessions returns the set of all client sessions known to the db. An // listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower. // response that do not correspond to this tower.
func (m *ClientDB) listClientSessions( func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { _ ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, session := range m.activeSessions { for _, session := range m.activeSessions {