watchtower: add PreEvaluateFilterFn callback

In this commit, a PreEvaluateFilterFn option is added to the
wtdb.ClientSessionListCfg and it is used instead of a separate
ClientSessionFilterFn parameter. This neatens quiet a few function
signatures.
This commit is contained in:
Elle Mouton
2023-03-20 17:06:48 +02:00
parent c4c1f1ac92
commit 7bc86ca42e
6 changed files with 61 additions and 57 deletions

View File

@@ -517,7 +517,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
towerSessions, err := c.listTowerSessions(
towerID, sessions, chanIDIndexBkt,
towersToSessionsIndex, nil,
towersToSessionsIndex,
WithPerCommittedUpdate(perCommittedUpdate),
)
if err != nil {
@@ -1055,7 +1055,7 @@ func (c *ClientDB) GetClientSession(id SessionID,
}
session, err := c.getClientSession(
sessionsBkt, chanIDIndexBkt, id[:], nil, opts...,
sessionsBkt, chanIDIndexBkt, id[:], opts...,
)
if err != nil {
return err
@@ -1073,8 +1073,7 @@ func (c *ClientDB) GetClientSession(id SessionID,
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID,
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
@@ -1093,7 +1092,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
var err error
if id == nil {
clientSessions, err = c.listClientAllSessions(
sessions, chanIDIndexBkt, filterFn, opts...,
sessions, chanIDIndexBkt, opts...,
)
return err
}
@@ -1106,7 +1105,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
clientSessions, err = c.listTowerSessions(
*id, sessions, chanIDIndexBkt, towerToSessionIndex,
filterFn, opts...,
opts...,
)
return err
}, func() {
@@ -1121,8 +1120,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
// listClientAllSessions returns the set of all client sessions known to the db.
func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
err := sessions.ForEach(func(k, _ []byte) error {
@@ -1131,7 +1129,7 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
// committed updates and compute the highest known commit height
// for each channel.
session, err := c.getClientSession(
sessions, chanIDIndexBkt, k, filterFn, opts...,
sessions, chanIDIndexBkt, k, opts...,
)
if errors.Is(err, ErrSessionFailedFilterFn) {
return nil
@@ -1153,8 +1151,8 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
// listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id.
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
towerToSessionIndex kvdb.RBucket, filterFn ClientSessionFilterFn,
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
if towerIndexBkt == nil {
@@ -1168,7 +1166,7 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
// committed updates and compute the highest known commit height
// for each channel.
session, err := c.getClientSession(
sessionsBkt, chanIDIndexBkt, k, filterFn, opts...,
sessionsBkt, chanIDIndexBkt, k, opts...,
)
if errors.Is(err, ErrSessionFailedFilterFn) {
return nil
@@ -2140,6 +2138,13 @@ type ClientSessionListCfg struct {
// PerCommittedUpdate will, if set, be called for each of the session's
// committed (un-acked) updates.
PerCommittedUpdate PerCommittedUpdateCB
// PreEvaluateFilterFn will be run after loading a session from the DB
// and _before_ any of the other call-back functions in
// ClientSessionListCfg. Therefore, if a session fails this filter
// function, then it will not be passed to any of the other call backs
// and won't be included in the return list.
PreEvaluateFilterFn ClientSessionFilterFn
}
// NewClientSessionCfg constructs a new ClientSessionListCfg.
@@ -2174,12 +2179,22 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
}
}
// WithPreEvalFilterFn constructs a functional option that will set a call-back
// function that will be called immediately after loading a session. If the
// session fails this filter function, then it will not be passed to any of the
// other evaluation call-back functions.
func WithPreEvalFilterFn(fn ClientSessionFilterFn) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PreEvaluateFilterFn = fn
}
}
// getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body.
func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
idBytes []byte, filterFn ClientSessionFilterFn,
opts ...ClientSessionListOption) (*ClientSession, error) {
idBytes []byte, opts ...ClientSessionListOption) (*ClientSession,
error) {
cfg := NewClientSessionCfg()
for _, o := range opts {
@@ -2191,7 +2206,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
return nil, err
}
if filterFn != nil && !filterFn(session) {
if cfg.PreEvaluateFilterFn != nil && !cfg.PreEvaluateFilterFn(session) {
return nil, ErrSessionFailedFilterFn
}