mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-07 19:30:46 +02:00
watchtower: add PreEvaluateFilterFn callback
In this commit, a PreEvaluateFilterFn option is added to the wtdb.ClientSessionListCfg and it is used instead of a separate ClientSessionFilterFn parameter. This neatens quiet a few function signatures.
This commit is contained in:
@@ -517,7 +517,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
|
||||
towerSessions, err := c.listTowerSessions(
|
||||
towerID, sessions, chanIDIndexBkt,
|
||||
towersToSessionsIndex, nil,
|
||||
towersToSessionsIndex,
|
||||
WithPerCommittedUpdate(perCommittedUpdate),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1055,7 +1055,7 @@ func (c *ClientDB) GetClientSession(id SessionID,
|
||||
}
|
||||
|
||||
session, err := c.getClientSession(
|
||||
sessionsBkt, chanIDIndexBkt, id[:], nil, opts...,
|
||||
sessionsBkt, chanIDIndexBkt, id[:], opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1073,8 +1073,7 @@ func (c *ClientDB) GetClientSession(id SessionID,
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
var clientSessions map[SessionID]*ClientSession
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
@@ -1093,7 +1092,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
var err error
|
||||
if id == nil {
|
||||
clientSessions, err = c.listClientAllSessions(
|
||||
sessions, chanIDIndexBkt, filterFn, opts...,
|
||||
sessions, chanIDIndexBkt, opts...,
|
||||
)
|
||||
return err
|
||||
}
|
||||
@@ -1106,7 +1105,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
|
||||
clientSessions, err = c.listTowerSessions(
|
||||
*id, sessions, chanIDIndexBkt, towerToSessionIndex,
|
||||
filterFn, opts...,
|
||||
opts...,
|
||||
)
|
||||
return err
|
||||
}, func() {
|
||||
@@ -1121,8 +1120,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
|
||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||
func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := sessions.ForEach(func(k, _ []byte) error {
|
||||
@@ -1131,7 +1129,7 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := c.getClientSession(
|
||||
sessions, chanIDIndexBkt, k, filterFn, opts...,
|
||||
sessions, chanIDIndexBkt, k, opts...,
|
||||
)
|
||||
if errors.Is(err, ErrSessionFailedFilterFn) {
|
||||
return nil
|
||||
@@ -1153,8 +1151,8 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||
// listTowerSessions returns the set of all client sessions known to the db
|
||||
// that are associated with the given tower id.
|
||||
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
|
||||
towerToSessionIndex kvdb.RBucket, filterFn ClientSessionFilterFn,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
|
||||
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
|
||||
if towerIndexBkt == nil {
|
||||
@@ -1168,7 +1166,7 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := c.getClientSession(
|
||||
sessionsBkt, chanIDIndexBkt, k, filterFn, opts...,
|
||||
sessionsBkt, chanIDIndexBkt, k, opts...,
|
||||
)
|
||||
if errors.Is(err, ErrSessionFailedFilterFn) {
|
||||
return nil
|
||||
@@ -2140,6 +2138,13 @@ type ClientSessionListCfg struct {
|
||||
// PerCommittedUpdate will, if set, be called for each of the session's
|
||||
// committed (un-acked) updates.
|
||||
PerCommittedUpdate PerCommittedUpdateCB
|
||||
|
||||
// PreEvaluateFilterFn will be run after loading a session from the DB
|
||||
// and _before_ any of the other call-back functions in
|
||||
// ClientSessionListCfg. Therefore, if a session fails this filter
|
||||
// function, then it will not be passed to any of the other call backs
|
||||
// and won't be included in the return list.
|
||||
PreEvaluateFilterFn ClientSessionFilterFn
|
||||
}
|
||||
|
||||
// NewClientSessionCfg constructs a new ClientSessionListCfg.
|
||||
@@ -2174,12 +2179,22 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPreEvalFilterFn constructs a functional option that will set a call-back
|
||||
// function that will be called immediately after loading a session. If the
|
||||
// session fails this filter function, then it will not be passed to any of the
|
||||
// other evaluation call-back functions.
|
||||
func WithPreEvalFilterFn(fn ClientSessionFilterFn) ClientSessionListOption {
|
||||
return func(cfg *ClientSessionListCfg) {
|
||||
cfg.PreEvaluateFilterFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// getClientSession loads the full ClientSession associated with the serialized
|
||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||
// in addition to the ClientSession's body.
|
||||
func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
|
||||
idBytes []byte, filterFn ClientSessionFilterFn,
|
||||
opts ...ClientSessionListOption) (*ClientSession, error) {
|
||||
idBytes []byte, opts ...ClientSessionListOption) (*ClientSession,
|
||||
error) {
|
||||
|
||||
cfg := NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
@@ -2191,7 +2206,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if filterFn != nil && !filterFn(session) {
|
||||
if cfg.PreEvaluateFilterFn != nil && !cfg.PreEvaluateFilterFn(session) {
|
||||
return nil, ErrSessionFailedFilterFn
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user