mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-29 23:21:12 +02:00
watchtower: add ListClientSessions functional options
This commit adds functional options to the ListClientSessions call that can be used to perform a variety of extra operations during the DB query. These functional options are not yet used in this commit.
This commit is contained in:
@@ -83,10 +83,12 @@ type Client interface {
|
||||
|
||||
// RegisteredTowers retrieves the list of watchtowers registered with
|
||||
// the client.
|
||||
RegisteredTowers() ([]*RegisteredTower, error)
|
||||
RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower,
|
||||
error)
|
||||
|
||||
// LookupTower retrieves a registered watchtower through its public key.
|
||||
LookupTower(*btcec.PublicKey) (*RegisteredTower, error)
|
||||
LookupTower(*btcec.PublicKey,
|
||||
...wtdb.ClientSessionListOption) (*RegisteredTower, error)
|
||||
|
||||
// Stats returns the in-memory statistics of the client since startup.
|
||||
Stats() ClientStats
|
||||
@@ -363,7 +365,8 @@ func New(config *Config) (*TowerClient, error) {
|
||||
// tower.
|
||||
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
sessionFilter func(*wtdb.ClientSession) bool,
|
||||
perActiveTower func(tower *wtdb.Tower)) (
|
||||
perActiveTower func(tower *wtdb.Tower),
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
towers, err := db.ListTowers()
|
||||
@@ -373,7 +376,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
|
||||
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
for _, tower := range towers {
|
||||
sessions, err := db.ListClientSessions(&tower.ID)
|
||||
sessions, err := db.ListClientSessions(&tower.ID, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -413,10 +416,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
||||
// ListClientSessions method should be used.
|
||||
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
passesFilter func(*wtdb.ClientSession) bool) (
|
||||
passesFilter func(*wtdb.ClientSession) bool,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
sessions, err := db.ListClientSessions(forTower)
|
||||
sessions, err := db.ListClientSessions(forTower, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1233,13 +1237,15 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
||||
|
||||
// RegisteredTowers retrieves the list of watchtowers registered with the
|
||||
// client.
|
||||
func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
|
||||
func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
|
||||
[]*RegisteredTower, error) {
|
||||
|
||||
// Retrieve all of our towers along with all of our sessions.
|
||||
towers, err := c.cfg.DB.ListTowers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientSessions, err := c.cfg.DB.ListClientSessions(nil)
|
||||
clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1272,13 +1278,15 @@ func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
|
||||
}
|
||||
|
||||
// LookupTower retrieves a registered watchtower through its public key.
|
||||
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey) (*RegisteredTower, error) {
|
||||
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
|
||||
opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) {
|
||||
|
||||
tower, err := c.cfg.DB.LoadTower(pubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID)
|
||||
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -62,7 +62,7 @@ type DB interface {
|
||||
// still be able to accept state updates. An optional tower ID can be
|
||||
// used to filter out any client sessions in the response that do not
|
||||
// correspond to this tower.
|
||||
ListClientSessions(*wtdb.TowerID) (
|
||||
ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to
|
||||
|
@@ -736,8 +736,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
|
||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
var clientSessions map[SessionID]*ClientSession
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
@@ -757,7 +757,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
// known to the db.
|
||||
if id == nil {
|
||||
clientSessions, err = listClientAllSessions(
|
||||
sessions, towers,
|
||||
sessions, towers, opts...,
|
||||
)
|
||||
return err
|
||||
}
|
||||
@@ -769,7 +769,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
}
|
||||
|
||||
clientSessions, err = listTowerSessions(
|
||||
*id, sessions, towers, towerToSessionIndex,
|
||||
*id, sessions, towers, towerToSessionIndex, opts...,
|
||||
)
|
||||
return err
|
||||
}, func() {
|
||||
@@ -783,8 +783,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
}
|
||||
|
||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||
func listClientAllSessions(sessions,
|
||||
towers kvdb.RBucket) (map[SessionID]*ClientSession, error) {
|
||||
func listClientAllSessions(sessions, towers kvdb.RBucket,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := sessions.ForEach(func(k, _ []byte) error {
|
||||
@@ -792,7 +792,7 @@ func listClientAllSessions(sessions,
|
||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := getClientSession(sessions, towers, k)
|
||||
session, err := getClientSession(sessions, towers, k, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -811,8 +811,8 @@ func listClientAllSessions(sessions,
|
||||
// listTowerSessions returns the set of all client sessions known to the db
|
||||
// that are associated with the given tower id.
|
||||
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession,
|
||||
error) {
|
||||
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
|
||||
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
|
||||
if towerIndexBkt == nil {
|
||||
@@ -825,7 +825,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := getClientSession(sessionsBkt, towersBkt, k)
|
||||
session, err := getClientSession(
|
||||
sessionsBkt, towersBkt, k, opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1157,11 +1159,63 @@ func getClientSessionBody(sessions kvdb.RBucket,
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// PerAckedUpdateCB describes the signature of a callback function that can be
|
||||
// called for each of a session's acked updates.
|
||||
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
|
||||
|
||||
// PerCommittedUpdateCB describes the signature of a callback function that can
|
||||
// be called for each of a session's committed updates (updates that the client
|
||||
// has not yet received an ACK for).
|
||||
type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate)
|
||||
|
||||
// ClientSessionListOption describes the signature of a functional option that
|
||||
// can be used when listing client sessions in order to provide any extra
|
||||
// instruction to the query.
|
||||
type ClientSessionListOption func(cfg *ClientSessionListCfg)
|
||||
|
||||
// ClientSessionListCfg defines various query parameters that will be used when
|
||||
// querying the DB for client sessions.
|
||||
type ClientSessionListCfg struct {
|
||||
// PerAckedUpdate will, if set, be called for each of the session's
|
||||
// acked updates.
|
||||
PerAckedUpdate PerAckedUpdateCB
|
||||
|
||||
// PerCommittedUpdate will, if set, be called for each of the session's
|
||||
// committed (un-acked) updates.
|
||||
PerCommittedUpdate PerCommittedUpdateCB
|
||||
}
|
||||
|
||||
// NewClientSessionCfg constructs a new ClientSessionListCfg.
|
||||
func NewClientSessionCfg() *ClientSessionListCfg {
|
||||
return &ClientSessionListCfg{}
|
||||
}
|
||||
|
||||
// WithPerAckedUpdate constructs a functional option that will set a call-back
|
||||
// function to be called for each of a client's acked updates.
|
||||
func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption {
|
||||
return func(cfg *ClientSessionListCfg) {
|
||||
cfg.PerAckedUpdate = cb
|
||||
}
|
||||
}
|
||||
|
||||
// WithPerCommittedUpdate constructs a functional option that will set a
|
||||
// call-back function to be called for each of a client's un-acked updates.
|
||||
func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
|
||||
return func(cfg *ClientSessionListCfg) {
|
||||
cfg.PerCommittedUpdate = cb
|
||||
}
|
||||
}
|
||||
|
||||
// getClientSession loads the full ClientSession associated with the serialized
|
||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||
// in addition to the ClientSession's body.
|
||||
func getClientSession(sessions, towers kvdb.RBucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
|
||||
opts ...ClientSessionListOption) (*ClientSession, error) {
|
||||
|
||||
cfg := NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
o(cfg)
|
||||
}
|
||||
|
||||
session, err := getClientSessionBody(sessions, idBytes)
|
||||
if err != nil {
|
||||
@@ -1178,13 +1232,17 @@ func getClientSession(sessions, towers kvdb.RBucket,
|
||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
||||
|
||||
// Fetch the committed updates for this session.
|
||||
commitedUpdates, err := getClientSessionCommits(sessionBkt)
|
||||
commitedUpdates, err := getClientSessionCommits(
|
||||
sessionBkt, session, cfg.PerCommittedUpdate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the acked updates for this session.
|
||||
ackedUpdates, err := getClientSessionAcks(sessionBkt)
|
||||
ackedUpdates, err := getClientSessionAcks(
|
||||
sessionBkt, session, cfg.PerAckedUpdate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1197,11 +1255,12 @@ func getClientSession(sessions, towers kvdb.RBucket,
|
||||
}
|
||||
|
||||
// getClientSessionCommits retrieves all committed updates for the session
|
||||
// identified by the serialized session id.
|
||||
func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
|
||||
error) {
|
||||
// identified by the serialized session id. If a PerCommittedUpdateCB is
|
||||
// provided, then it will be called for each of the session's committed updates.
|
||||
func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||
cb PerCommittedUpdateCB) ([]CommittedUpdate, error) {
|
||||
|
||||
// Initialize commitedUpdates so that we can return an initialized map
|
||||
// Initialize committedUpdates so that we can return an initialized map
|
||||
// if no committed updates exist.
|
||||
committedUpdates := make([]CommittedUpdate, 0)
|
||||
|
||||
@@ -1220,6 +1279,10 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
|
||||
|
||||
committedUpdates = append(committedUpdates, committedUpdate)
|
||||
|
||||
if cb != nil {
|
||||
cb(s, &committedUpdate)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1231,8 +1294,8 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
|
||||
|
||||
// getClientSessionAcks retrieves all acked updates for the session identified
|
||||
// by the serialized session id.
|
||||
func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID,
|
||||
error) {
|
||||
func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||
cb PerAckedUpdateCB) (map[uint16]BackupID, error) {
|
||||
|
||||
// Initialize ackedUpdates so that we can return an initialized map if
|
||||
// no acked updates exist.
|
||||
@@ -1254,6 +1317,10 @@ func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID,
|
||||
|
||||
ackedUpdates[seqNum] = backupID
|
||||
|
||||
if cb != nil {
|
||||
cb(s, seqNum, backupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
|
@@ -200,19 +200,21 @@ func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight ui
|
||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) ListClientSessions(
|
||||
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.listClientSessions(tower)
|
||||
return m.listClientSessions(tower, opts...)
|
||||
}
|
||||
|
||||
// listClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) listClientSessions(
|
||||
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||
_ ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
for _, session := range m.activeSessions {
|
||||
|
Reference in New Issue
Block a user