watchtower: add ListClientSessions functional options

This commit adds functional options to the ListClientSessions call that
can be used to perform a variety of extra operations during the DB
query. These functional options are not yet used in this commit.
This commit is contained in:
Elle Mouton
2022-10-13 13:46:52 +02:00
parent 3ac3b6a90d
commit 40e0ebf417
4 changed files with 113 additions and 36 deletions

View File

@@ -83,10 +83,12 @@ type Client interface {
// RegisteredTowers retrieves the list of watchtowers registered with
// the client.
RegisteredTowers() ([]*RegisteredTower, error)
RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower,
error)
// LookupTower retrieves a registered watchtower through its public key.
LookupTower(*btcec.PublicKey) (*RegisteredTower, error)
LookupTower(*btcec.PublicKey,
...wtdb.ClientSessionListOption) (*RegisteredTower, error)
// Stats returns the in-memory statistics of the client since startup.
Stats() ClientStats
@@ -363,7 +365,8 @@ func New(config *Config) (*TowerClient, error) {
// tower.
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
sessionFilter func(*wtdb.ClientSession) bool,
perActiveTower func(tower *wtdb.Tower)) (
perActiveTower func(tower *wtdb.Tower),
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
towers, err := db.ListTowers()
@@ -373,7 +376,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, tower := range towers {
sessions, err := db.ListClientSessions(&tower.ID)
sessions, err := db.ListClientSessions(&tower.ID, opts...)
if err != nil {
return nil, err
}
@@ -413,10 +416,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
// ListClientSessions method should be used.
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
passesFilter func(*wtdb.ClientSession) bool) (
passesFilter func(*wtdb.ClientSession) bool,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
sessions, err := db.ListClientSessions(forTower)
sessions, err := db.ListClientSessions(forTower, opts...)
if err != nil {
return nil, err
}
@@ -1233,13 +1237,15 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// RegisteredTowers retrieves the list of watchtowers registered with the
// client.
func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
[]*RegisteredTower, error) {
// Retrieve all of our towers along with all of our sessions.
towers, err := c.cfg.DB.ListTowers()
if err != nil {
return nil, err
}
clientSessions, err := c.cfg.DB.ListClientSessions(nil)
clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
if err != nil {
return nil, err
}
@@ -1272,13 +1278,15 @@ func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
}
// LookupTower retrieves a registered watchtower through its public key.
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey) (*RegisteredTower, error) {
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) {
tower, err := c.cfg.DB.LoadTower(pubKey)
if err != nil {
return nil, err
}
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID)
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
if err != nil {
return nil, err
}

View File

@@ -62,7 +62,7 @@ type DB interface {
// still be able to accept state updates. An optional tower ID can be
// used to filter out any client sessions in the response that do not
// correspond to this tower.
ListClientSessions(*wtdb.TowerID) (
ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error)
// FetchChanSummaries loads a mapping from all registered channels to

View File

@@ -736,8 +736,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
// ListClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID) (
map[SessionID]*ClientSession, error) {
func (c *ClientDB) ListClientSessions(id *TowerID,
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
@@ -757,7 +757,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
// known to the db.
if id == nil {
clientSessions, err = listClientAllSessions(
sessions, towers,
sessions, towers, opts...,
)
return err
}
@@ -769,7 +769,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
}
clientSessions, err = listTowerSessions(
*id, sessions, towers, towerToSessionIndex,
*id, sessions, towers, towerToSessionIndex, opts...,
)
return err
}, func() {
@@ -783,8 +783,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
}
// listClientAllSessions returns the set of all client sessions known to the db.
func listClientAllSessions(sessions,
towers kvdb.RBucket) (map[SessionID]*ClientSession, error) {
func listClientAllSessions(sessions, towers kvdb.RBucket,
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
err := sessions.ForEach(func(k, _ []byte) error {
@@ -792,7 +792,7 @@ func listClientAllSessions(sessions,
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := getClientSession(sessions, towers, k)
session, err := getClientSession(sessions, towers, k, opts...)
if err != nil {
return err
}
@@ -811,8 +811,8 @@ func listClientAllSessions(sessions,
// listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id.
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession,
error) {
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
if towerIndexBkt == nil {
@@ -825,7 +825,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := getClientSession(sessionsBkt, towersBkt, k)
session, err := getClientSession(
sessionsBkt, towersBkt, k, opts...,
)
if err != nil {
return err
}
@@ -1157,11 +1159,63 @@ func getClientSessionBody(sessions kvdb.RBucket,
return &session, nil
}
// PerAckedUpdateCB describes the signature of a callback function that can be
// called for each of a session's acked updates.
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
// PerCommittedUpdateCB describes the signature of a callback function that can
// be called for each of a session's committed updates (updates that the client
// has not yet received an ACK for).
type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate)
// ClientSessionListOption describes the signature of a functional option that
// can be used when listing client sessions in order to provide any extra
// instruction to the query.
type ClientSessionListOption func(cfg *ClientSessionListCfg)
// ClientSessionListCfg defines various query parameters that will be used when
// querying the DB for client sessions.
type ClientSessionListCfg struct {
// PerAckedUpdate will, if set, be called for each of the session's
// acked updates.
PerAckedUpdate PerAckedUpdateCB
// PerCommittedUpdate will, if set, be called for each of the session's
// committed (un-acked) updates.
PerCommittedUpdate PerCommittedUpdateCB
}
// NewClientSessionCfg constructs a new ClientSessionListCfg.
func NewClientSessionCfg() *ClientSessionListCfg {
return &ClientSessionListCfg{}
}
// WithPerAckedUpdate constructs a functional option that will set a call-back
// function to be called for each of a client's acked updates.
func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PerAckedUpdate = cb
}
}
// WithPerCommittedUpdate constructs a functional option that will set a
// call-back function to be called for each of a client's un-acked updates.
func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PerCommittedUpdate = cb
}
}
// getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body.
func getClientSession(sessions, towers kvdb.RBucket,
idBytes []byte) (*ClientSession, error) {
func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
opts ...ClientSessionListOption) (*ClientSession, error) {
cfg := NewClientSessionCfg()
for _, o := range opts {
o(cfg)
}
session, err := getClientSessionBody(sessions, idBytes)
if err != nil {
@@ -1178,13 +1232,17 @@ func getClientSession(sessions, towers kvdb.RBucket,
sessionBkt := sessions.NestedReadBucket(idBytes)
// Fetch the committed updates for this session.
commitedUpdates, err := getClientSessionCommits(sessionBkt)
commitedUpdates, err := getClientSessionCommits(
sessionBkt, session, cfg.PerCommittedUpdate,
)
if err != nil {
return nil, err
}
// Fetch the acked updates for this session.
ackedUpdates, err := getClientSessionAcks(sessionBkt)
ackedUpdates, err := getClientSessionAcks(
sessionBkt, session, cfg.PerAckedUpdate,
)
if err != nil {
return nil, err
}
@@ -1197,11 +1255,12 @@ func getClientSession(sessions, towers kvdb.RBucket,
}
// getClientSessionCommits retrieves all committed updates for the session
// identified by the serialized session id.
func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
error) {
// identified by the serialized session id. If a PerCommittedUpdateCB is
// provided, then it will be called for each of the session's committed updates.
func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerCommittedUpdateCB) ([]CommittedUpdate, error) {
// Initialize commitedUpdates so that we can return an initialized map
// Initialize committedUpdates so that we can return an initialized map
// if no committed updates exist.
committedUpdates := make([]CommittedUpdate, 0)
@@ -1220,6 +1279,10 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
committedUpdates = append(committedUpdates, committedUpdate)
if cb != nil {
cb(s, &committedUpdate)
}
return nil
})
if err != nil {
@@ -1231,8 +1294,8 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
// getClientSessionAcks retrieves all acked updates for the session identified
// by the serialized session id.
func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID,
error) {
func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerAckedUpdateCB) (map[uint16]BackupID, error) {
// Initialize ackedUpdates so that we can return an initialized map if
// no acked updates exist.
@@ -1254,6 +1317,10 @@ func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID,
ackedUpdates[seqNum] = backupID
if cb != nil {
cb(s, seqNum, backupID)
}
return nil
})
if err != nil {

View File

@@ -200,19 +200,21 @@ func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight ui
// ListClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (m *ClientDB) ListClientSessions(
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.listClientSessions(tower)
return m.listClientSessions(tower, opts...)
}
// listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (m *ClientDB) listClientSessions(
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
_ ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, session := range m.activeSessions {