From 40e0ebf4171bfe018ff81eeb9956b634d82f73e9 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 13 Oct 2022 13:46:52 +0200 Subject: [PATCH] watchtower: add ListClientSessions functional options This commit adds functional options to the ListClientSessions call that can be used to perform a variety of extra operations during the DB query. These functional options are not yet used in this commit. --- watchtower/wtclient/client.go | 28 +++++--- watchtower/wtclient/interface.go | 2 +- watchtower/wtdb/client_db.go | 107 +++++++++++++++++++++++++------ watchtower/wtmock/client_db.go | 12 ++-- 4 files changed, 113 insertions(+), 36 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 3d23f0b82..0003c2b10 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -83,10 +83,12 @@ type Client interface { // RegisteredTowers retrieves the list of watchtowers registered with // the client. - RegisteredTowers() ([]*RegisteredTower, error) + RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower, + error) // LookupTower retrieves a registered watchtower through its public key. - LookupTower(*btcec.PublicKey) (*RegisteredTower, error) + LookupTower(*btcec.PublicKey, + ...wtdb.ClientSessionListOption) (*RegisteredTower, error) // Stats returns the in-memory statistics of the client since startup. Stats() ClientStats @@ -363,7 +365,8 @@ func New(config *Config) (*TowerClient, error) { // tower. func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, sessionFilter func(*wtdb.ClientSession) bool, - perActiveTower func(tower *wtdb.Tower)) ( + perActiveTower func(tower *wtdb.Tower), + opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { towers, err := db.ListTowers() @@ -373,7 +376,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, tower := range towers { - sessions, err := db.ListClientSessions(&tower.ID) + sessions, err := db.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err } @@ -413,10 +416,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, // ClientSession's SessionPrivKey field is desired, otherwise, the existing // ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, - passesFilter func(*wtdb.ClientSession) bool) ( + passesFilter func(*wtdb.ClientSession) bool, + opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { - sessions, err := db.ListClientSessions(forTower) + sessions, err := db.ListClientSessions(forTower, opts...) if err != nil { return nil, err } @@ -1233,13 +1237,15 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // RegisteredTowers retrieves the list of watchtowers registered with the // client. -func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) { +func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( + []*RegisteredTower, error) { + // Retrieve all of our towers along with all of our sessions. towers, err := c.cfg.DB.ListTowers() if err != nil { return nil, err } - clientSessions, err := c.cfg.DB.ListClientSessions(nil) + clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...) if err != nil { return nil, err } @@ -1272,13 +1278,15 @@ func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) { } // LookupTower retrieves a registered watchtower through its public key. -func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey) (*RegisteredTower, error) { +func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey, + opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) { + tower, err := c.cfg.DB.LoadTower(pubKey) if err != nil { return nil, err } - towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID) + towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err } diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index dbf2faf71..eb4a450a2 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -62,7 +62,7 @@ type DB interface { // still be able to accept state updates. An optional tower ID can be // used to filter out any client sessions in the response that do not // correspond to this tower. - ListClientSessions(*wtdb.TowerID) ( + ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) // FetchChanSummaries loads a mapping from all registered channels to diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index f862b186c..9e33a1df3 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -736,8 +736,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID, // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (c *ClientDB) ListClientSessions(id *TowerID) ( - map[SessionID]*ClientSession, error) { +func (c *ClientDB) ListClientSessions(id *TowerID, + opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { var clientSessions map[SessionID]*ClientSession err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -757,7 +757,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( // known to the db. if id == nil { clientSessions, err = listClientAllSessions( - sessions, towers, + sessions, towers, opts..., ) return err } @@ -769,7 +769,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( } clientSessions, err = listTowerSessions( - *id, sessions, towers, towerToSessionIndex, + *id, sessions, towers, towerToSessionIndex, opts..., ) return err }, func() { @@ -783,8 +783,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( } // listClientAllSessions returns the set of all client sessions known to the db. -func listClientAllSessions(sessions, - towers kvdb.RBucket) (map[SessionID]*ClientSession, error) { +func listClientAllSessions(sessions, towers kvdb.RBucket, + opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) err := sessions.ForEach(func(k, _ []byte) error { @@ -792,7 +792,7 @@ func listClientAllSessions(sessions, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessions, towers, k) + session, err := getClientSession(sessions, towers, k, opts...) if err != nil { return err } @@ -811,8 +811,8 @@ func listClientAllSessions(sessions, // listTowerSessions returns the set of all client sessions known to the db // that are associated with the given tower id. func listTowerSessions(id TowerID, sessionsBkt, towersBkt, - towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession, - error) { + towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) ( + map[SessionID]*ClientSession, error) { towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes()) if towerIndexBkt == nil { @@ -825,7 +825,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessionsBkt, towersBkt, k) + session, err := getClientSession( + sessionsBkt, towersBkt, k, opts..., + ) if err != nil { return err } @@ -1157,11 +1159,63 @@ func getClientSessionBody(sessions kvdb.RBucket, return &session, nil } +// PerAckedUpdateCB describes the signature of a callback function that can be +// called for each of a session's acked updates. +type PerAckedUpdateCB func(*ClientSession, uint16, BackupID) + +// PerCommittedUpdateCB describes the signature of a callback function that can +// be called for each of a session's committed updates (updates that the client +// has not yet received an ACK for). +type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate) + +// ClientSessionListOption describes the signature of a functional option that +// can be used when listing client sessions in order to provide any extra +// instruction to the query. +type ClientSessionListOption func(cfg *ClientSessionListCfg) + +// ClientSessionListCfg defines various query parameters that will be used when +// querying the DB for client sessions. +type ClientSessionListCfg struct { + // PerAckedUpdate will, if set, be called for each of the session's + // acked updates. + PerAckedUpdate PerAckedUpdateCB + + // PerCommittedUpdate will, if set, be called for each of the session's + // committed (un-acked) updates. + PerCommittedUpdate PerCommittedUpdateCB +} + +// NewClientSessionCfg constructs a new ClientSessionListCfg. +func NewClientSessionCfg() *ClientSessionListCfg { + return &ClientSessionListCfg{} +} + +// WithPerAckedUpdate constructs a functional option that will set a call-back +// function to be called for each of a client's acked updates. +func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerAckedUpdate = cb + } +} + +// WithPerCommittedUpdate constructs a functional option that will set a +// call-back function to be called for each of a client's un-acked updates. +func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerCommittedUpdate = cb + } +} + // getClientSession loads the full ClientSession associated with the serialized // session id. This method populates the CommittedUpdates, AckUpdates and Tower // in addition to the ClientSession's body. -func getClientSession(sessions, towers kvdb.RBucket, - idBytes []byte) (*ClientSession, error) { +func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, + opts ...ClientSessionListOption) (*ClientSession, error) { + + cfg := NewClientSessionCfg() + for _, o := range opts { + o(cfg) + } session, err := getClientSessionBody(sessions, idBytes) if err != nil { @@ -1178,13 +1232,17 @@ func getClientSession(sessions, towers kvdb.RBucket, sessionBkt := sessions.NestedReadBucket(idBytes) // Fetch the committed updates for this session. - commitedUpdates, err := getClientSessionCommits(sessionBkt) + commitedUpdates, err := getClientSessionCommits( + sessionBkt, session, cfg.PerCommittedUpdate, + ) if err != nil { return nil, err } // Fetch the acked updates for this session. - ackedUpdates, err := getClientSessionAcks(sessionBkt) + ackedUpdates, err := getClientSessionAcks( + sessionBkt, session, cfg.PerAckedUpdate, + ) if err != nil { return nil, err } @@ -1197,11 +1255,12 @@ func getClientSession(sessions, towers kvdb.RBucket, } // getClientSessionCommits retrieves all committed updates for the session -// identified by the serialized session id. -func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate, - error) { +// identified by the serialized session id. If a PerCommittedUpdateCB is +// provided, then it will be called for each of the session's committed updates. +func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, + cb PerCommittedUpdateCB) ([]CommittedUpdate, error) { - // Initialize commitedUpdates so that we can return an initialized map + // Initialize committedUpdates so that we can return an initialized map // if no committed updates exist. committedUpdates := make([]CommittedUpdate, 0) @@ -1220,6 +1279,10 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate, committedUpdates = append(committedUpdates, committedUpdate) + if cb != nil { + cb(s, &committedUpdate) + } + return nil }) if err != nil { @@ -1231,8 +1294,8 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate, // getClientSessionAcks retrieves all acked updates for the session identified // by the serialized session id. -func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID, - error) { +func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession, + cb PerAckedUpdateCB) (map[uint16]BackupID, error) { // Initialize ackedUpdates so that we can return an initialized map if // no acked updates exist. @@ -1254,6 +1317,10 @@ func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID, ackedUpdates[seqNum] = backupID + if cb != nil { + cb(s, seqNum, backupID) + } + return nil }) if err != nil { diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 2a3825e87..f569991fc 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -200,19 +200,21 @@ func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight ui // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (m *ClientDB) ListClientSessions( - tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { +func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID, + opts ...wtdb.ClientSessionListOption) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { m.mu.Lock() defer m.mu.Unlock() - return m.listClientSessions(tower) + return m.listClientSessions(tower, opts...) } // listClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (m *ClientDB) listClientSessions( - tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { +func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, + _ ...wtdb.ClientSessionListOption) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, session := range m.activeSessions {