mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-11-10 14:17:56 +01:00
watchtower: add ListClientSessions functional options
This commit adds functional options to the ListClientSessions call that can be used to perform a variety of extra operations during the DB query. These functional options are not yet used in this commit.
This commit is contained in:
@@ -83,10 +83,12 @@ type Client interface {
|
|||||||
|
|
||||||
// RegisteredTowers retrieves the list of watchtowers registered with
|
// RegisteredTowers retrieves the list of watchtowers registered with
|
||||||
// the client.
|
// the client.
|
||||||
RegisteredTowers() ([]*RegisteredTower, error)
|
RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower,
|
||||||
|
error)
|
||||||
|
|
||||||
// LookupTower retrieves a registered watchtower through its public key.
|
// LookupTower retrieves a registered watchtower through its public key.
|
||||||
LookupTower(*btcec.PublicKey) (*RegisteredTower, error)
|
LookupTower(*btcec.PublicKey,
|
||||||
|
...wtdb.ClientSessionListOption) (*RegisteredTower, error)
|
||||||
|
|
||||||
// Stats returns the in-memory statistics of the client since startup.
|
// Stats returns the in-memory statistics of the client since startup.
|
||||||
Stats() ClientStats
|
Stats() ClientStats
|
||||||
@@ -363,7 +365,8 @@ func New(config *Config) (*TowerClient, error) {
|
|||||||
// tower.
|
// tower.
|
||||||
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||||
sessionFilter func(*wtdb.ClientSession) bool,
|
sessionFilter func(*wtdb.ClientSession) bool,
|
||||||
perActiveTower func(tower *wtdb.Tower)) (
|
perActiveTower func(tower *wtdb.Tower),
|
||||||
|
opts ...wtdb.ClientSessionListOption) (
|
||||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||||
|
|
||||||
towers, err := db.ListTowers()
|
towers, err := db.ListTowers()
|
||||||
@@ -373,7 +376,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
|||||||
|
|
||||||
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||||
for _, tower := range towers {
|
for _, tower := range towers {
|
||||||
sessions, err := db.ListClientSessions(&tower.ID)
|
sessions, err := db.ListClientSessions(&tower.ID, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -413,10 +416,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
|||||||
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
||||||
// ListClientSessions method should be used.
|
// ListClientSessions method should be used.
|
||||||
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||||
passesFilter func(*wtdb.ClientSession) bool) (
|
passesFilter func(*wtdb.ClientSession) bool,
|
||||||
|
opts ...wtdb.ClientSessionListOption) (
|
||||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||||
|
|
||||||
sessions, err := db.ListClientSessions(forTower)
|
sessions, err := db.ListClientSessions(forTower, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1233,13 +1237,15 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
|||||||
|
|
||||||
// RegisteredTowers retrieves the list of watchtowers registered with the
|
// RegisteredTowers retrieves the list of watchtowers registered with the
|
||||||
// client.
|
// client.
|
||||||
func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
|
func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
|
||||||
|
[]*RegisteredTower, error) {
|
||||||
|
|
||||||
// Retrieve all of our towers along with all of our sessions.
|
// Retrieve all of our towers along with all of our sessions.
|
||||||
towers, err := c.cfg.DB.ListTowers()
|
towers, err := c.cfg.DB.ListTowers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
clientSessions, err := c.cfg.DB.ListClientSessions(nil)
|
clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1272,13 +1278,15 @@ func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LookupTower retrieves a registered watchtower through its public key.
|
// LookupTower retrieves a registered watchtower through its public key.
|
||||||
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey) (*RegisteredTower, error) {
|
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
|
||||||
|
opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) {
|
||||||
|
|
||||||
tower, err := c.cfg.DB.LoadTower(pubKey)
|
tower, err := c.cfg.DB.LoadTower(pubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID)
|
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ type DB interface {
|
|||||||
// still be able to accept state updates. An optional tower ID can be
|
// still be able to accept state updates. An optional tower ID can be
|
||||||
// used to filter out any client sessions in the response that do not
|
// used to filter out any client sessions in the response that do not
|
||||||
// correspond to this tower.
|
// correspond to this tower.
|
||||||
ListClientSessions(*wtdb.TowerID) (
|
ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
|
||||||
map[wtdb.SessionID]*wtdb.ClientSession, error)
|
map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||||
|
|
||||||
// FetchChanSummaries loads a mapping from all registered channels to
|
// FetchChanSummaries loads a mapping from all registered channels to
|
||||||
|
|||||||
@@ -736,8 +736,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
|
|||||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||||
// optional tower ID can be used to filter out any client sessions in the
|
// optional tower ID can be used to filter out any client sessions in the
|
||||||
// response that do not correspond to this tower.
|
// response that do not correspond to this tower.
|
||||||
func (c *ClientDB) ListClientSessions(id *TowerID) (
|
func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||||
map[SessionID]*ClientSession, error) {
|
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
var clientSessions map[SessionID]*ClientSession
|
var clientSessions map[SessionID]*ClientSession
|
||||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||||
@@ -757,7 +757,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
|||||||
// known to the db.
|
// known to the db.
|
||||||
if id == nil {
|
if id == nil {
|
||||||
clientSessions, err = listClientAllSessions(
|
clientSessions, err = listClientAllSessions(
|
||||||
sessions, towers,
|
sessions, towers, opts...,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -769,7 +769,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
clientSessions, err = listTowerSessions(
|
clientSessions, err = listTowerSessions(
|
||||||
*id, sessions, towers, towerToSessionIndex,
|
*id, sessions, towers, towerToSessionIndex, opts...,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}, func() {
|
}, func() {
|
||||||
@@ -783,8 +783,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||||
func listClientAllSessions(sessions,
|
func listClientAllSessions(sessions, towers kvdb.RBucket,
|
||||||
towers kvdb.RBucket) (map[SessionID]*ClientSession, error) {
|
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
clientSessions := make(map[SessionID]*ClientSession)
|
clientSessions := make(map[SessionID]*ClientSession)
|
||||||
err := sessions.ForEach(func(k, _ []byte) error {
|
err := sessions.ForEach(func(k, _ []byte) error {
|
||||||
@@ -792,7 +792,7 @@ func listClientAllSessions(sessions,
|
|||||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||||
// committed updates and compute the highest known commit height
|
// committed updates and compute the highest known commit height
|
||||||
// for each channel.
|
// for each channel.
|
||||||
session, err := getClientSession(sessions, towers, k)
|
session, err := getClientSession(sessions, towers, k, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -811,8 +811,8 @@ func listClientAllSessions(sessions,
|
|||||||
// listTowerSessions returns the set of all client sessions known to the db
|
// listTowerSessions returns the set of all client sessions known to the db
|
||||||
// that are associated with the given tower id.
|
// that are associated with the given tower id.
|
||||||
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||||
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession,
|
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
||||||
error) {
|
map[SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
|
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
|
||||||
if towerIndexBkt == nil {
|
if towerIndexBkt == nil {
|
||||||
@@ -825,7 +825,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
|||||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||||
// committed updates and compute the highest known commit height
|
// committed updates and compute the highest known commit height
|
||||||
// for each channel.
|
// for each channel.
|
||||||
session, err := getClientSession(sessionsBkt, towersBkt, k)
|
session, err := getClientSession(
|
||||||
|
sessionsBkt, towersBkt, k, opts...,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1157,11 +1159,63 @@ func getClientSessionBody(sessions kvdb.RBucket,
|
|||||||
return &session, nil
|
return &session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PerAckedUpdateCB describes the signature of a callback function that can be
|
||||||
|
// called for each of a session's acked updates.
|
||||||
|
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
|
||||||
|
|
||||||
|
// PerCommittedUpdateCB describes the signature of a callback function that can
|
||||||
|
// be called for each of a session's committed updates (updates that the client
|
||||||
|
// has not yet received an ACK for).
|
||||||
|
type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate)
|
||||||
|
|
||||||
|
// ClientSessionListOption describes the signature of a functional option that
|
||||||
|
// can be used when listing client sessions in order to provide any extra
|
||||||
|
// instruction to the query.
|
||||||
|
type ClientSessionListOption func(cfg *ClientSessionListCfg)
|
||||||
|
|
||||||
|
// ClientSessionListCfg defines various query parameters that will be used when
|
||||||
|
// querying the DB for client sessions.
|
||||||
|
type ClientSessionListCfg struct {
|
||||||
|
// PerAckedUpdate will, if set, be called for each of the session's
|
||||||
|
// acked updates.
|
||||||
|
PerAckedUpdate PerAckedUpdateCB
|
||||||
|
|
||||||
|
// PerCommittedUpdate will, if set, be called for each of the session's
|
||||||
|
// committed (un-acked) updates.
|
||||||
|
PerCommittedUpdate PerCommittedUpdateCB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientSessionCfg constructs a new ClientSessionListCfg.
|
||||||
|
func NewClientSessionCfg() *ClientSessionListCfg {
|
||||||
|
return &ClientSessionListCfg{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPerAckedUpdate constructs a functional option that will set a call-back
|
||||||
|
// function to be called for each of a client's acked updates.
|
||||||
|
func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption {
|
||||||
|
return func(cfg *ClientSessionListCfg) {
|
||||||
|
cfg.PerAckedUpdate = cb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPerCommittedUpdate constructs a functional option that will set a
|
||||||
|
// call-back function to be called for each of a client's un-acked updates.
|
||||||
|
func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
|
||||||
|
return func(cfg *ClientSessionListCfg) {
|
||||||
|
cfg.PerCommittedUpdate = cb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getClientSession loads the full ClientSession associated with the serialized
|
// getClientSession loads the full ClientSession associated with the serialized
|
||||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||||
// in addition to the ClientSession's body.
|
// in addition to the ClientSession's body.
|
||||||
func getClientSession(sessions, towers kvdb.RBucket,
|
func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
|
||||||
idBytes []byte) (*ClientSession, error) {
|
opts ...ClientSessionListOption) (*ClientSession, error) {
|
||||||
|
|
||||||
|
cfg := NewClientSessionCfg()
|
||||||
|
for _, o := range opts {
|
||||||
|
o(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
session, err := getClientSessionBody(sessions, idBytes)
|
session, err := getClientSessionBody(sessions, idBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1178,13 +1232,17 @@ func getClientSession(sessions, towers kvdb.RBucket,
|
|||||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
sessionBkt := sessions.NestedReadBucket(idBytes)
|
||||||
|
|
||||||
// Fetch the committed updates for this session.
|
// Fetch the committed updates for this session.
|
||||||
commitedUpdates, err := getClientSessionCommits(sessionBkt)
|
commitedUpdates, err := getClientSessionCommits(
|
||||||
|
sessionBkt, session, cfg.PerCommittedUpdate,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the acked updates for this session.
|
// Fetch the acked updates for this session.
|
||||||
ackedUpdates, err := getClientSessionAcks(sessionBkt)
|
ackedUpdates, err := getClientSessionAcks(
|
||||||
|
sessionBkt, session, cfg.PerAckedUpdate,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1197,11 +1255,12 @@ func getClientSession(sessions, towers kvdb.RBucket,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getClientSessionCommits retrieves all committed updates for the session
|
// getClientSessionCommits retrieves all committed updates for the session
|
||||||
// identified by the serialized session id.
|
// identified by the serialized session id. If a PerCommittedUpdateCB is
|
||||||
func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
|
// provided, then it will be called for each of the session's committed updates.
|
||||||
error) {
|
func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||||
|
cb PerCommittedUpdateCB) ([]CommittedUpdate, error) {
|
||||||
|
|
||||||
// Initialize commitedUpdates so that we can return an initialized map
|
// Initialize committedUpdates so that we can return an initialized map
|
||||||
// if no committed updates exist.
|
// if no committed updates exist.
|
||||||
committedUpdates := make([]CommittedUpdate, 0)
|
committedUpdates := make([]CommittedUpdate, 0)
|
||||||
|
|
||||||
@@ -1220,6 +1279,10 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
|
|||||||
|
|
||||||
committedUpdates = append(committedUpdates, committedUpdate)
|
committedUpdates = append(committedUpdates, committedUpdate)
|
||||||
|
|
||||||
|
if cb != nil {
|
||||||
|
cb(s, &committedUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1231,8 +1294,8 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
|
|||||||
|
|
||||||
// getClientSessionAcks retrieves all acked updates for the session identified
|
// getClientSessionAcks retrieves all acked updates for the session identified
|
||||||
// by the serialized session id.
|
// by the serialized session id.
|
||||||
func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID,
|
func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||||
error) {
|
cb PerAckedUpdateCB) (map[uint16]BackupID, error) {
|
||||||
|
|
||||||
// Initialize ackedUpdates so that we can return an initialized map if
|
// Initialize ackedUpdates so that we can return an initialized map if
|
||||||
// no acked updates exist.
|
// no acked updates exist.
|
||||||
@@ -1254,6 +1317,10 @@ func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID,
|
|||||||
|
|
||||||
ackedUpdates[seqNum] = backupID
|
ackedUpdates[seqNum] = backupID
|
||||||
|
|
||||||
|
if cb != nil {
|
||||||
|
cb(s, seqNum, backupID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -200,19 +200,21 @@ func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight ui
|
|||||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||||
// optional tower ID can be used to filter out any client sessions in the
|
// optional tower ID can be used to filter out any client sessions in the
|
||||||
// response that do not correspond to this tower.
|
// response that do not correspond to this tower.
|
||||||
func (m *ClientDB) ListClientSessions(
|
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
|
||||||
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
opts ...wtdb.ClientSessionListOption) (
|
||||||
|
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
return m.listClientSessions(tower)
|
return m.listClientSessions(tower, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// listClientSessions returns the set of all client sessions known to the db. An
|
// listClientSessions returns the set of all client sessions known to the db. An
|
||||||
// optional tower ID can be used to filter out any client sessions in the
|
// optional tower ID can be used to filter out any client sessions in the
|
||||||
// response that do not correspond to this tower.
|
// response that do not correspond to this tower.
|
||||||
func (m *ClientDB) listClientSessions(
|
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||||
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
_ ...wtdb.ClientSessionListOption) (
|
||||||
|
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||||
|
|
||||||
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||||
for _, session := range m.activeSessions {
|
for _, session := range m.activeSessions {
|
||||||
|
|||||||
Reference in New Issue
Block a user