watchtower: add PreEvaluateFilterFn callback

In this commit, a PreEvaluateFilterFn option is added to the
wtdb.ClientSessionListCfg and it is used instead of a separate
ClientSessionFilterFn parameter. This neatens quiet a few function
signatures.
This commit is contained in:
Elle Mouton
2023-03-20 17:06:48 +02:00
parent c4c1f1ac92
commit 7bc86ca42e
6 changed files with 61 additions and 57 deletions

View File

@@ -410,8 +410,9 @@ func New(config *Config) (*TowerClient, error) {
// current policy of the client, otherwise they will be ignored and new
// sessions will be requested.
candidateSessions, err := getTowerAndSessionCandidates(
cfg.DB, cfg.SecretKeyRing, c.genSessionFilter(true),
perActiveTower, wtdb.WithPerMaxHeight(perMaxHeight),
cfg.DB, cfg.SecretKeyRing, perActiveTower,
wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)),
wtdb.WithPerMaxHeight(perMaxHeight),
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
)
if err != nil {
@@ -444,7 +445,6 @@ func New(config *Config) (*TowerClient, error) {
// sessionFilter check then the perActiveTower call-back will be called on that
// tower.
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
sessionFilter wtdb.ClientSessionFilterFn,
perActiveTower func(tower *Tower),
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*ClientSession, error) {
@@ -461,18 +461,12 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
return nil, err
}
sessions, err := db.ListClientSessions(
&tower.ID, sessionFilter, opts...,
)
sessions, err := db.ListClientSessions(&tower.ID, opts...)
if err != nil {
return nil, err
}
for _, s := range sessions {
if !sessionFilter(s) {
continue
}
cs, err := NewClientSessionFromDBSession(
s, tower, keyRing,
)
@@ -498,13 +492,10 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
// ListClientSessions method should be used.
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
sessionFilter wtdb.ClientSessionFilterFn,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*ClientSession, error) {
dbSessions, err := db.ListClientSessions(
forTower, sessionFilter, opts...,
)
dbSessions, err := db.ListClientSessions(forTower, opts...)
if err != nil {
return nil, err
}
@@ -1635,7 +1626,7 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error {
// Include all of its corresponding sessions to our set of candidates.
sessions, err := getClientSessions(
c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID,
c.genSessionFilter(true),
wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)),
)
if err != nil {
return fmt.Errorf("unable to determine sessions for tower %x: "+
@@ -1721,7 +1712,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// Otherwise, the tower should no longer be used for future session
// negotiations and backups.
pubKey := msg.pubKey.SerializeCompressed()
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID, nil)
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID)
if err != nil {
return fmt.Errorf("unable to retrieve sessions for tower %x: "+
"%v", pubKey, err)
@@ -1752,9 +1743,10 @@ func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
if err != nil {
return nil, err
}
clientSessions, err := c.cfg.DB.ListClientSessions(
nil, c.genSessionFilter(false), opts...,
)
opts = append(opts, wtdb.WithPreEvalFilterFn(c.genSessionFilter(false)))
clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
if err != nil {
return nil, err
}
@@ -1795,9 +1787,9 @@ func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
return nil, err
}
towerSessions, err := c.cfg.DB.ListClientSessions(
&tower.ID, c.genSessionFilter(false), opts...,
)
opts = append(opts, wtdb.WithPreEvalFilterFn(c.genSessionFilter(false)))
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
if err != nil {
return nil, err
}

View File

@@ -870,7 +870,7 @@ func (h *testHarness) relevantSessions(chanID uint64) []wtdb.SessionID {
},
)
_, err := h.clientDB.ListClientSessions(nil, nil, collectSessions)
_, err := h.clientDB.ListClientSessions(nil, collectSessions)
require.NoError(h.t, err)
return sessionIDs
@@ -1969,7 +1969,7 @@ var clientTests = []clientTest{
// Also make a note of the total number of sessions the
// client has.
sessions, err := h.clientDB.ListClientSessions(nil, nil)
sessions, err := h.clientDB.ListClientSessions(nil)
require.NoError(h.t, err)
require.Len(h.t, sessions, 4)
@@ -1981,9 +1981,7 @@ var clientTests = []clientTest{
// marked as closable. The server should also no longer
// have these sessions in its DB.
err = wait.Predicate(func() bool {
sess, err := h.clientDB.ListClientSessions(
nil, nil,
)
sess, err := h.clientDB.ListClientSessions(nil)
require.NoError(h.t, err)
cs, err := h.clientDB.ListClosableSessions()

View File

@@ -60,8 +60,7 @@ type DB interface {
// ListClientSessions returns the set of all client sessions known to
// the db. An optional tower ID can be used to filter out any client
// sessions in the response that do not correspond to this tower.
ListClientSessions(*wtdb.TowerID, wtdb.ClientSessionFilterFn,
...wtdb.ClientSessionListOption) (
ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error)
// GetClientSession loads the ClientSession with the given ID from the

View File

@@ -517,7 +517,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
towerSessions, err := c.listTowerSessions(
towerID, sessions, chanIDIndexBkt,
towersToSessionsIndex, nil,
towersToSessionsIndex,
WithPerCommittedUpdate(perCommittedUpdate),
)
if err != nil {
@@ -1055,7 +1055,7 @@ func (c *ClientDB) GetClientSession(id SessionID,
}
session, err := c.getClientSession(
sessionsBkt, chanIDIndexBkt, id[:], nil, opts...,
sessionsBkt, chanIDIndexBkt, id[:], opts...,
)
if err != nil {
return err
@@ -1073,8 +1073,7 @@ func (c *ClientDB) GetClientSession(id SessionID,
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID,
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
@@ -1093,7 +1092,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
var err error
if id == nil {
clientSessions, err = c.listClientAllSessions(
sessions, chanIDIndexBkt, filterFn, opts...,
sessions, chanIDIndexBkt, opts...,
)
return err
}
@@ -1106,7 +1105,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
clientSessions, err = c.listTowerSessions(
*id, sessions, chanIDIndexBkt, towerToSessionIndex,
filterFn, opts...,
opts...,
)
return err
}, func() {
@@ -1121,8 +1120,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
// listClientAllSessions returns the set of all client sessions known to the db.
func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
err := sessions.ForEach(func(k, _ []byte) error {
@@ -1131,7 +1129,7 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
// committed updates and compute the highest known commit height
// for each channel.
session, err := c.getClientSession(
sessions, chanIDIndexBkt, k, filterFn, opts...,
sessions, chanIDIndexBkt, k, opts...,
)
if errors.Is(err, ErrSessionFailedFilterFn) {
return nil
@@ -1153,8 +1151,8 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
// listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id.
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
towerToSessionIndex kvdb.RBucket, filterFn ClientSessionFilterFn,
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) {
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
if towerIndexBkt == nil {
@@ -1168,7 +1166,7 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
// committed updates and compute the highest known commit height
// for each channel.
session, err := c.getClientSession(
sessionsBkt, chanIDIndexBkt, k, filterFn, opts...,
sessionsBkt, chanIDIndexBkt, k, opts...,
)
if errors.Is(err, ErrSessionFailedFilterFn) {
return nil
@@ -2140,6 +2138,13 @@ type ClientSessionListCfg struct {
// PerCommittedUpdate will, if set, be called for each of the session's
// committed (un-acked) updates.
PerCommittedUpdate PerCommittedUpdateCB
// PreEvaluateFilterFn will be run after loading a session from the DB
// and _before_ any of the other call-back functions in
// ClientSessionListCfg. Therefore, if a session fails this filter
// function, then it will not be passed to any of the other call backs
// and won't be included in the return list.
PreEvaluateFilterFn ClientSessionFilterFn
}
// NewClientSessionCfg constructs a new ClientSessionListCfg.
@@ -2174,12 +2179,22 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
}
}
// WithPreEvalFilterFn constructs a functional option that will set a call-back
// function that will be called immediately after loading a session. If the
// session fails this filter function, then it will not be passed to any of the
// other evaluation call-back functions.
func WithPreEvalFilterFn(fn ClientSessionFilterFn) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PreEvaluateFilterFn = fn
}
}
// getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body.
func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
idBytes []byte, filterFn ClientSessionFilterFn,
opts ...ClientSessionListOption) (*ClientSession, error) {
idBytes []byte, opts ...ClientSessionListOption) (*ClientSession,
error) {
cfg := NewClientSessionCfg()
for _, o := range opts {
@@ -2191,7 +2206,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
return nil, err
}
if filterFn != nil && !filterFn(session) {
if cfg.PreEvaluateFilterFn != nil && !cfg.PreEvaluateFilterFn(session) {
return nil, ErrSessionFailedFilterFn
}

View File

@@ -52,12 +52,11 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
}
func (h *clientDBHarness) listSessions(id *wtdb.TowerID,
filterFn wtdb.ClientSessionFilterFn,
opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession {
h.t.Helper()
sessions, err := h.db.ListClientSessions(id, filterFn, opts...)
sessions, err := h.db.ListClientSessions(id, opts...)
require.NoError(h.t, err, "unable to list client sessions")
return sessions
@@ -84,7 +83,7 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
require.ErrorIs(h.t, err, expErr)
require.NotZero(h.t, tower.ID, "tower id should never be 0")
for _, session := range h.listSessions(&tower.ID, nil) {
for _, session := range h.listSessions(&tower.ID) {
require.Equal(h.t, wtdb.CSessionActive, session.Status)
}
@@ -127,7 +126,7 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
return
}
for _, session := range h.listSessions(&tower.ID, nil) {
for _, session := range h.listSessions(&tower.ID) {
require.Equal(h.t, wtdb.CSessionInactive,
session.Status, "expected status for session "+
"%v to be %v, got %v", session.ID,
@@ -301,7 +300,7 @@ func testCreateClientSession(h *clientDBHarness) {
// First, assert that this session is not already present in the
// database.
_, ok := h.listSessions(nil, nil)[session.ID]
_, ok := h.listSessions(nil)[session.ID]
require.Falsef(h.t, ok, "session for id %x should not exist yet",
session.ID)
@@ -329,7 +328,7 @@ func testCreateClientSession(h *clientDBHarness) {
h.insertSession(session, nil)
// Verify that the session now exists in the database.
_, ok = h.listSessions(nil, nil)[session.ID]
_, ok = h.listSessions(nil)[session.ID]
require.Truef(h.t, ok, "session for id %x should exist now", session.ID)
// Attempt to insert the session again, which should fail due to the
@@ -377,7 +376,7 @@ func testFilterClientSessions(h *clientDBHarness) {
// We should see the expected sessions for each tower when filtering
// them.
for towerID, expectedSessions := range towerSessions {
sessions := h.listSessions(&towerID, nil)
sessions := h.listSessions(&towerID)
require.Len(h.t, sessions, len(expectedSessions))
for _, expectedSession := range expectedSessions {

View File

@@ -89,7 +89,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
tower = m.towers[towerID]
tower.AddAddress(lnAddr.Address)
towerSessions, err := m.listClientSessions(&towerID, nil)
towerSessions, err := m.listClientSessions(&towerID)
if err != nil {
return nil, err
}
@@ -141,7 +141,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return nil
}
towerSessions, err := m.listClientSessions(&tower.ID, nil)
towerSessions, err := m.listClientSessions(&tower.ID)
if err != nil {
return err
}
@@ -226,20 +226,19 @@ func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error {
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
filterFn wtdb.ClientSessionFilterFn,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.listClientSessions(tower, filterFn, opts...)
return m.listClientSessions(tower, opts...)
}
// listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
filterFn wtdb.ClientSessionFilterFn,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
@@ -255,7 +254,9 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
continue
}
if filterFn != nil && !filterFn(&session) {
if cfg.PreEvaluateFilterFn != nil &&
!cfg.PreEvaluateFilterFn(&session) {
continue
}