From 7bc86ca42ebf665b2e6b422a079cc58819246e85 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 20 Mar 2023 17:06:48 +0200 Subject: [PATCH] watchtower: add PreEvaluateFilterFn callback In this commit, a PreEvaluateFilterFn option is added to the wtdb.ClientSessionListCfg and it is used instead of a separate ClientSessionFilterFn parameter. This neatens quiet a few function signatures. --- watchtower/wtclient/client.go | 36 ++++++++++-------------- watchtower/wtclient/client_test.go | 8 ++---- watchtower/wtclient/interface.go | 3 +- watchtower/wtdb/client_db.go | 45 ++++++++++++++++++++---------- watchtower/wtdb/client_db_test.go | 13 ++++----- watchtower/wtmock/client_db.go | 13 +++++---- 6 files changed, 61 insertions(+), 57 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index e92b8b4cf..81c2d6b0c 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -410,8 +410,9 @@ func New(config *Config) (*TowerClient, error) { // current policy of the client, otherwise they will be ignored and new // sessions will be requested. candidateSessions, err := getTowerAndSessionCandidates( - cfg.DB, cfg.SecretKeyRing, c.genSessionFilter(true), - perActiveTower, wtdb.WithPerMaxHeight(perMaxHeight), + cfg.DB, cfg.SecretKeyRing, perActiveTower, + wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)), + wtdb.WithPerMaxHeight(perMaxHeight), wtdb.WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { @@ -444,7 +445,6 @@ func New(config *Config) (*TowerClient, error) { // sessionFilter check then the perActiveTower call-back will be called on that // tower. func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, - sessionFilter wtdb.ClientSessionFilterFn, perActiveTower func(tower *Tower), opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*ClientSession, error) { @@ -461,18 +461,12 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, return nil, err } - sessions, err := db.ListClientSessions( - &tower.ID, sessionFilter, opts..., - ) + sessions, err := db.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err } for _, s := range sessions { - if !sessionFilter(s) { - continue - } - cs, err := NewClientSessionFromDBSession( s, tower, keyRing, ) @@ -498,13 +492,10 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, // ClientSession's SessionPrivKey field is desired, otherwise, the existing // ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, - sessionFilter wtdb.ClientSessionFilterFn, opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*ClientSession, error) { - dbSessions, err := db.ListClientSessions( - forTower, sessionFilter, opts..., - ) + dbSessions, err := db.ListClientSessions(forTower, opts...) if err != nil { return nil, err } @@ -1635,7 +1626,7 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { // Include all of its corresponding sessions to our set of candidates. sessions, err := getClientSessions( c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID, - c.genSessionFilter(true), + wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)), ) if err != nil { return fmt.Errorf("unable to determine sessions for tower %x: "+ @@ -1721,7 +1712,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // Otherwise, the tower should no longer be used for future session // negotiations and backups. pubKey := msg.pubKey.SerializeCompressed() - sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID, nil) + sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID) if err != nil { return fmt.Errorf("unable to retrieve sessions for tower %x: "+ "%v", pubKey, err) @@ -1752,9 +1743,10 @@ func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( if err != nil { return nil, err } - clientSessions, err := c.cfg.DB.ListClientSessions( - nil, c.genSessionFilter(false), opts..., - ) + + opts = append(opts, wtdb.WithPreEvalFilterFn(c.genSessionFilter(false))) + + clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...) if err != nil { return nil, err } @@ -1795,9 +1787,9 @@ func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey, return nil, err } - towerSessions, err := c.cfg.DB.ListClientSessions( - &tower.ID, c.genSessionFilter(false), opts..., - ) + opts = append(opts, wtdb.WithPreEvalFilterFn(c.genSessionFilter(false))) + + towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err } diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 2657e691b..f870af527 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -870,7 +870,7 @@ func (h *testHarness) relevantSessions(chanID uint64) []wtdb.SessionID { }, ) - _, err := h.clientDB.ListClientSessions(nil, nil, collectSessions) + _, err := h.clientDB.ListClientSessions(nil, collectSessions) require.NoError(h.t, err) return sessionIDs @@ -1969,7 +1969,7 @@ var clientTests = []clientTest{ // Also make a note of the total number of sessions the // client has. - sessions, err := h.clientDB.ListClientSessions(nil, nil) + sessions, err := h.clientDB.ListClientSessions(nil) require.NoError(h.t, err) require.Len(h.t, sessions, 4) @@ -1981,9 +1981,7 @@ var clientTests = []clientTest{ // marked as closable. The server should also no longer // have these sessions in its DB. err = wait.Predicate(func() bool { - sess, err := h.clientDB.ListClientSessions( - nil, nil, - ) + sess, err := h.clientDB.ListClientSessions(nil) require.NoError(h.t, err) cs, err := h.clientDB.ListClosableSessions() diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 4eebef4e5..8a6d9b320 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -60,8 +60,7 @@ type DB interface { // ListClientSessions returns the set of all client sessions known to // the db. An optional tower ID can be used to filter out any client // sessions in the response that do not correspond to this tower. - ListClientSessions(*wtdb.TowerID, wtdb.ClientSessionFilterFn, - ...wtdb.ClientSessionListOption) ( + ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) // GetClientSession loads the ClientSession with the given ID from the diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 53c643b6f..2e64e6122 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -517,7 +517,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { towerSessions, err := c.listTowerSessions( towerID, sessions, chanIDIndexBkt, - towersToSessionsIndex, nil, + towersToSessionsIndex, WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { @@ -1055,7 +1055,7 @@ func (c *ClientDB) GetClientSession(id SessionID, } session, err := c.getClientSession( - sessionsBkt, chanIDIndexBkt, id[:], nil, opts..., + sessionsBkt, chanIDIndexBkt, id[:], opts..., ) if err != nil { return err @@ -1073,8 +1073,7 @@ func (c *ClientDB) GetClientSession(id SessionID, // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. func (c *ClientDB) ListClientSessions(id *TowerID, - filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) ( - map[SessionID]*ClientSession, error) { + opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { var clientSessions map[SessionID]*ClientSession err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -1093,7 +1092,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, var err error if id == nil { clientSessions, err = c.listClientAllSessions( - sessions, chanIDIndexBkt, filterFn, opts..., + sessions, chanIDIndexBkt, opts..., ) return err } @@ -1106,7 +1105,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, clientSessions, err = c.listTowerSessions( *id, sessions, chanIDIndexBkt, towerToSessionIndex, - filterFn, opts..., + opts..., ) return err }, func() { @@ -1121,8 +1120,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, // listClientAllSessions returns the set of all client sessions known to the db. func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket, - filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) ( - map[SessionID]*ClientSession, error) { + opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) err := sessions.ForEach(func(k, _ []byte) error { @@ -1131,7 +1129,7 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket, // committed updates and compute the highest known commit height // for each channel. session, err := c.getClientSession( - sessions, chanIDIndexBkt, k, filterFn, opts..., + sessions, chanIDIndexBkt, k, opts..., ) if errors.Is(err, ErrSessionFailedFilterFn) { return nil @@ -1153,8 +1151,8 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket, // listTowerSessions returns the set of all client sessions known to the db // that are associated with the given tower id. func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt, - towerToSessionIndex kvdb.RBucket, filterFn ClientSessionFilterFn, - opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { + towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) ( + map[SessionID]*ClientSession, error) { towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes()) if towerIndexBkt == nil { @@ -1168,7 +1166,7 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt, // committed updates and compute the highest known commit height // for each channel. session, err := c.getClientSession( - sessionsBkt, chanIDIndexBkt, k, filterFn, opts..., + sessionsBkt, chanIDIndexBkt, k, opts..., ) if errors.Is(err, ErrSessionFailedFilterFn) { return nil @@ -2140,6 +2138,13 @@ type ClientSessionListCfg struct { // PerCommittedUpdate will, if set, be called for each of the session's // committed (un-acked) updates. PerCommittedUpdate PerCommittedUpdateCB + + // PreEvaluateFilterFn will be run after loading a session from the DB + // and _before_ any of the other call-back functions in + // ClientSessionListCfg. Therefore, if a session fails this filter + // function, then it will not be passed to any of the other call backs + // and won't be included in the return list. + PreEvaluateFilterFn ClientSessionFilterFn } // NewClientSessionCfg constructs a new ClientSessionListCfg. @@ -2174,12 +2179,22 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { } } +// WithPreEvalFilterFn constructs a functional option that will set a call-back +// function that will be called immediately after loading a session. If the +// session fails this filter function, then it will not be passed to any of the +// other evaluation call-back functions. +func WithPreEvalFilterFn(fn ClientSessionFilterFn) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PreEvaluateFilterFn = fn + } +} + // getClientSession loads the full ClientSession associated with the serialized // session id. This method populates the CommittedUpdates, AckUpdates and Tower // in addition to the ClientSession's body. func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket, - idBytes []byte, filterFn ClientSessionFilterFn, - opts ...ClientSessionListOption) (*ClientSession, error) { + idBytes []byte, opts ...ClientSessionListOption) (*ClientSession, + error) { cfg := NewClientSessionCfg() for _, o := range opts { @@ -2191,7 +2206,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket, return nil, err } - if filterFn != nil && !filterFn(session) { + if cfg.PreEvaluateFilterFn != nil && !cfg.PreEvaluateFilterFn(session) { return nil, ErrSessionFailedFilterFn } diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index b3d241175..26d02bcf3 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -52,12 +52,11 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, } func (h *clientDBHarness) listSessions(id *wtdb.TowerID, - filterFn wtdb.ClientSessionFilterFn, opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession { h.t.Helper() - sessions, err := h.db.ListClientSessions(id, filterFn, opts...) + sessions, err := h.db.ListClientSessions(id, opts...) require.NoError(h.t, err, "unable to list client sessions") return sessions @@ -84,7 +83,7 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress, require.ErrorIs(h.t, err, expErr) require.NotZero(h.t, tower.ID, "tower id should never be 0") - for _, session := range h.listSessions(&tower.ID, nil) { + for _, session := range h.listSessions(&tower.ID) { require.Equal(h.t, wtdb.CSessionActive, session.Status) } @@ -127,7 +126,7 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr, return } - for _, session := range h.listSessions(&tower.ID, nil) { + for _, session := range h.listSessions(&tower.ID) { require.Equal(h.t, wtdb.CSessionInactive, session.Status, "expected status for session "+ "%v to be %v, got %v", session.ID, @@ -301,7 +300,7 @@ func testCreateClientSession(h *clientDBHarness) { // First, assert that this session is not already present in the // database. - _, ok := h.listSessions(nil, nil)[session.ID] + _, ok := h.listSessions(nil)[session.ID] require.Falsef(h.t, ok, "session for id %x should not exist yet", session.ID) @@ -329,7 +328,7 @@ func testCreateClientSession(h *clientDBHarness) { h.insertSession(session, nil) // Verify that the session now exists in the database. - _, ok = h.listSessions(nil, nil)[session.ID] + _, ok = h.listSessions(nil)[session.ID] require.Truef(h.t, ok, "session for id %x should exist now", session.ID) // Attempt to insert the session again, which should fail due to the @@ -377,7 +376,7 @@ func testFilterClientSessions(h *clientDBHarness) { // We should see the expected sessions for each tower when filtering // them. for towerID, expectedSessions := range towerSessions { - sessions := h.listSessions(&towerID, nil) + sessions := h.listSessions(&towerID) require.Len(h.t, sessions, len(expectedSessions)) for _, expectedSession := range expectedSessions { diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index e004fcdaf..22b21a563 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -89,7 +89,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { tower = m.towers[towerID] tower.AddAddress(lnAddr.Address) - towerSessions, err := m.listClientSessions(&towerID, nil) + towerSessions, err := m.listClientSessions(&towerID) if err != nil { return nil, err } @@ -141,7 +141,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { return nil } - towerSessions, err := m.listClientSessions(&tower.ID, nil) + towerSessions, err := m.listClientSessions(&tower.ID) if err != nil { return err } @@ -226,20 +226,19 @@ func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error { // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID, - filterFn wtdb.ClientSessionFilterFn, opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { m.mu.Lock() defer m.mu.Unlock() - return m.listClientSessions(tower, filterFn, opts...) + + return m.listClientSessions(tower, opts...) } // listClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, - filterFn wtdb.ClientSessionFilterFn, opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { @@ -255,7 +254,9 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, continue } - if filterFn != nil && !filterFn(&session) { + if cfg.PreEvaluateFilterFn != nil && + !cfg.PreEvaluateFilterFn(&session) { + continue }