mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-07 19:30:46 +02:00
watchtower: add ListClientSessions functional options
This commit adds functional options to the ListClientSessions call that can be used to perform a variety of extra operations during the DB query. These functional options are not yet used in this commit.
This commit is contained in:
@@ -736,8 +736,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
|
||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
var clientSessions map[SessionID]*ClientSession
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
@@ -757,7 +757,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
// known to the db.
|
||||
if id == nil {
|
||||
clientSessions, err = listClientAllSessions(
|
||||
sessions, towers,
|
||||
sessions, towers, opts...,
|
||||
)
|
||||
return err
|
||||
}
|
||||
@@ -769,7 +769,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
}
|
||||
|
||||
clientSessions, err = listTowerSessions(
|
||||
*id, sessions, towers, towerToSessionIndex,
|
||||
*id, sessions, towers, towerToSessionIndex, opts...,
|
||||
)
|
||||
return err
|
||||
}, func() {
|
||||
@@ -783,8 +783,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
}
|
||||
|
||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||
func listClientAllSessions(sessions,
|
||||
towers kvdb.RBucket) (map[SessionID]*ClientSession, error) {
|
||||
func listClientAllSessions(sessions, towers kvdb.RBucket,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := sessions.ForEach(func(k, _ []byte) error {
|
||||
@@ -792,7 +792,7 @@ func listClientAllSessions(sessions,
|
||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := getClientSession(sessions, towers, k)
|
||||
session, err := getClientSession(sessions, towers, k, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -811,8 +811,8 @@ func listClientAllSessions(sessions,
|
||||
// listTowerSessions returns the set of all client sessions known to the db
|
||||
// that are associated with the given tower id.
|
||||
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession,
|
||||
error) {
|
||||
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
|
||||
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
|
||||
if towerIndexBkt == nil {
|
||||
@@ -825,7 +825,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := getClientSession(sessionsBkt, towersBkt, k)
|
||||
session, err := getClientSession(
|
||||
sessionsBkt, towersBkt, k, opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1157,11 +1159,63 @@ func getClientSessionBody(sessions kvdb.RBucket,
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// PerAckedUpdateCB describes the signature of a callback function that can be
|
||||
// called for each of a session's acked updates.
|
||||
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
|
||||
|
||||
// PerCommittedUpdateCB describes the signature of a callback function that can
|
||||
// be called for each of a session's committed updates (updates that the client
|
||||
// has not yet received an ACK for).
|
||||
type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate)
|
||||
|
||||
// ClientSessionListOption describes the signature of a functional option that
|
||||
// can be used when listing client sessions in order to provide any extra
|
||||
// instruction to the query.
|
||||
type ClientSessionListOption func(cfg *ClientSessionListCfg)
|
||||
|
||||
// ClientSessionListCfg defines various query parameters that will be used when
|
||||
// querying the DB for client sessions.
|
||||
type ClientSessionListCfg struct {
|
||||
// PerAckedUpdate will, if set, be called for each of the session's
|
||||
// acked updates.
|
||||
PerAckedUpdate PerAckedUpdateCB
|
||||
|
||||
// PerCommittedUpdate will, if set, be called for each of the session's
|
||||
// committed (un-acked) updates.
|
||||
PerCommittedUpdate PerCommittedUpdateCB
|
||||
}
|
||||
|
||||
// NewClientSessionCfg constructs a new ClientSessionListCfg.
|
||||
func NewClientSessionCfg() *ClientSessionListCfg {
|
||||
return &ClientSessionListCfg{}
|
||||
}
|
||||
|
||||
// WithPerAckedUpdate constructs a functional option that will set a call-back
|
||||
// function to be called for each of a client's acked updates.
|
||||
func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption {
|
||||
return func(cfg *ClientSessionListCfg) {
|
||||
cfg.PerAckedUpdate = cb
|
||||
}
|
||||
}
|
||||
|
||||
// WithPerCommittedUpdate constructs a functional option that will set a
|
||||
// call-back function to be called for each of a client's un-acked updates.
|
||||
func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
|
||||
return func(cfg *ClientSessionListCfg) {
|
||||
cfg.PerCommittedUpdate = cb
|
||||
}
|
||||
}
|
||||
|
||||
// getClientSession loads the full ClientSession associated with the serialized
|
||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||
// in addition to the ClientSession's body.
|
||||
func getClientSession(sessions, towers kvdb.RBucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
|
||||
opts ...ClientSessionListOption) (*ClientSession, error) {
|
||||
|
||||
cfg := NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
o(cfg)
|
||||
}
|
||||
|
||||
session, err := getClientSessionBody(sessions, idBytes)
|
||||
if err != nil {
|
||||
@@ -1178,13 +1232,17 @@ func getClientSession(sessions, towers kvdb.RBucket,
|
||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
||||
|
||||
// Fetch the committed updates for this session.
|
||||
commitedUpdates, err := getClientSessionCommits(sessionBkt)
|
||||
commitedUpdates, err := getClientSessionCommits(
|
||||
sessionBkt, session, cfg.PerCommittedUpdate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the acked updates for this session.
|
||||
ackedUpdates, err := getClientSessionAcks(sessionBkt)
|
||||
ackedUpdates, err := getClientSessionAcks(
|
||||
sessionBkt, session, cfg.PerAckedUpdate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1197,11 +1255,12 @@ func getClientSession(sessions, towers kvdb.RBucket,
|
||||
}
|
||||
|
||||
// getClientSessionCommits retrieves all committed updates for the session
|
||||
// identified by the serialized session id.
|
||||
func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
|
||||
error) {
|
||||
// identified by the serialized session id. If a PerCommittedUpdateCB is
|
||||
// provided, then it will be called for each of the session's committed updates.
|
||||
func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||
cb PerCommittedUpdateCB) ([]CommittedUpdate, error) {
|
||||
|
||||
// Initialize commitedUpdates so that we can return an initialized map
|
||||
// Initialize committedUpdates so that we can return an initialized map
|
||||
// if no committed updates exist.
|
||||
committedUpdates := make([]CommittedUpdate, 0)
|
||||
|
||||
@@ -1220,6 +1279,10 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
|
||||
|
||||
committedUpdates = append(committedUpdates, committedUpdate)
|
||||
|
||||
if cb != nil {
|
||||
cb(s, &committedUpdate)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1231,8 +1294,8 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate,
|
||||
|
||||
// getClientSessionAcks retrieves all acked updates for the session identified
|
||||
// by the serialized session id.
|
||||
func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID,
|
||||
error) {
|
||||
func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||
cb PerAckedUpdateCB) (map[uint16]BackupID, error) {
|
||||
|
||||
// Initialize ackedUpdates so that we can return an initialized map if
|
||||
// no acked updates exist.
|
||||
@@ -1254,6 +1317,10 @@ func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID,
|
||||
|
||||
ackedUpdates[seqNum] = backupID
|
||||
|
||||
if cb != nil {
|
||||
cb(s, seqNum, backupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user