mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-30 07:35:07 +02:00
watchtower: add PreEvaluateFilterFn callback
In this commit, a PreEvaluateFilterFn option is added to the wtdb.ClientSessionListCfg and it is used instead of a separate ClientSessionFilterFn parameter. This neatens quiet a few function signatures.
This commit is contained in:
@@ -410,8 +410,9 @@ func New(config *Config) (*TowerClient, error) {
|
||||
// current policy of the client, otherwise they will be ignored and new
|
||||
// sessions will be requested.
|
||||
candidateSessions, err := getTowerAndSessionCandidates(
|
||||
cfg.DB, cfg.SecretKeyRing, c.genSessionFilter(true),
|
||||
perActiveTower, wtdb.WithPerMaxHeight(perMaxHeight),
|
||||
cfg.DB, cfg.SecretKeyRing, perActiveTower,
|
||||
wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)),
|
||||
wtdb.WithPerMaxHeight(perMaxHeight),
|
||||
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -444,7 +445,6 @@ func New(config *Config) (*TowerClient, error) {
|
||||
// sessionFilter check then the perActiveTower call-back will be called on that
|
||||
// tower.
|
||||
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
sessionFilter wtdb.ClientSessionFilterFn,
|
||||
perActiveTower func(tower *Tower),
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*ClientSession, error) {
|
||||
@@ -461,18 +461,12 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessions, err := db.ListClientSessions(
|
||||
&tower.ID, sessionFilter, opts...,
|
||||
)
|
||||
sessions, err := db.ListClientSessions(&tower.ID, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, s := range sessions {
|
||||
if !sessionFilter(s) {
|
||||
continue
|
||||
}
|
||||
|
||||
cs, err := NewClientSessionFromDBSession(
|
||||
s, tower, keyRing,
|
||||
)
|
||||
@@ -498,13 +492,10 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
||||
// ListClientSessions method should be used.
|
||||
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
sessionFilter wtdb.ClientSessionFilterFn,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*ClientSession, error) {
|
||||
|
||||
dbSessions, err := db.ListClientSessions(
|
||||
forTower, sessionFilter, opts...,
|
||||
)
|
||||
dbSessions, err := db.ListClientSessions(forTower, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1635,7 +1626,7 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error {
|
||||
// Include all of its corresponding sessions to our set of candidates.
|
||||
sessions, err := getClientSessions(
|
||||
c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID,
|
||||
c.genSessionFilter(true),
|
||||
wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to determine sessions for tower %x: "+
|
||||
@@ -1721,7 +1712,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
||||
// Otherwise, the tower should no longer be used for future session
|
||||
// negotiations and backups.
|
||||
pubKey := msg.pubKey.SerializeCompressed()
|
||||
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID, nil)
|
||||
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to retrieve sessions for tower %x: "+
|
||||
"%v", pubKey, err)
|
||||
@@ -1752,9 +1743,10 @@ func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientSessions, err := c.cfg.DB.ListClientSessions(
|
||||
nil, c.genSessionFilter(false), opts...,
|
||||
)
|
||||
|
||||
opts = append(opts, wtdb.WithPreEvalFilterFn(c.genSessionFilter(false)))
|
||||
|
||||
clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1795,9 +1787,9 @@ func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
towerSessions, err := c.cfg.DB.ListClientSessions(
|
||||
&tower.ID, c.genSessionFilter(false), opts...,
|
||||
)
|
||||
opts = append(opts, wtdb.WithPreEvalFilterFn(c.genSessionFilter(false)))
|
||||
|
||||
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -870,7 +870,7 @@ func (h *testHarness) relevantSessions(chanID uint64) []wtdb.SessionID {
|
||||
},
|
||||
)
|
||||
|
||||
_, err := h.clientDB.ListClientSessions(nil, nil, collectSessions)
|
||||
_, err := h.clientDB.ListClientSessions(nil, collectSessions)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
return sessionIDs
|
||||
@@ -1969,7 +1969,7 @@ var clientTests = []clientTest{
|
||||
|
||||
// Also make a note of the total number of sessions the
|
||||
// client has.
|
||||
sessions, err := h.clientDB.ListClientSessions(nil, nil)
|
||||
sessions, err := h.clientDB.ListClientSessions(nil)
|
||||
require.NoError(h.t, err)
|
||||
require.Len(h.t, sessions, 4)
|
||||
|
||||
@@ -1981,9 +1981,7 @@ var clientTests = []clientTest{
|
||||
// marked as closable. The server should also no longer
|
||||
// have these sessions in its DB.
|
||||
err = wait.Predicate(func() bool {
|
||||
sess, err := h.clientDB.ListClientSessions(
|
||||
nil, nil,
|
||||
)
|
||||
sess, err := h.clientDB.ListClientSessions(nil)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
cs, err := h.clientDB.ListClosableSessions()
|
||||
|
@@ -60,8 +60,7 @@ type DB interface {
|
||||
// ListClientSessions returns the set of all client sessions known to
|
||||
// the db. An optional tower ID can be used to filter out any client
|
||||
// sessions in the response that do not correspond to this tower.
|
||||
ListClientSessions(*wtdb.TowerID, wtdb.ClientSessionFilterFn,
|
||||
...wtdb.ClientSessionListOption) (
|
||||
ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||
|
||||
// GetClientSession loads the ClientSession with the given ID from the
|
||||
|
@@ -517,7 +517,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
|
||||
towerSessions, err := c.listTowerSessions(
|
||||
towerID, sessions, chanIDIndexBkt,
|
||||
towersToSessionsIndex, nil,
|
||||
towersToSessionsIndex,
|
||||
WithPerCommittedUpdate(perCommittedUpdate),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1055,7 +1055,7 @@ func (c *ClientDB) GetClientSession(id SessionID,
|
||||
}
|
||||
|
||||
session, err := c.getClientSession(
|
||||
sessionsBkt, chanIDIndexBkt, id[:], nil, opts...,
|
||||
sessionsBkt, chanIDIndexBkt, id[:], opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1073,8 +1073,7 @@ func (c *ClientDB) GetClientSession(id SessionID,
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
var clientSessions map[SessionID]*ClientSession
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
@@ -1093,7 +1092,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
var err error
|
||||
if id == nil {
|
||||
clientSessions, err = c.listClientAllSessions(
|
||||
sessions, chanIDIndexBkt, filterFn, opts...,
|
||||
sessions, chanIDIndexBkt, opts...,
|
||||
)
|
||||
return err
|
||||
}
|
||||
@@ -1106,7 +1105,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
|
||||
clientSessions, err = c.listTowerSessions(
|
||||
*id, sessions, chanIDIndexBkt, towerToSessionIndex,
|
||||
filterFn, opts...,
|
||||
opts...,
|
||||
)
|
||||
return err
|
||||
}, func() {
|
||||
@@ -1121,8 +1120,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
|
||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||
func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||
filterFn ClientSessionFilterFn, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := sessions.ForEach(func(k, _ []byte) error {
|
||||
@@ -1131,7 +1129,7 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := c.getClientSession(
|
||||
sessions, chanIDIndexBkt, k, filterFn, opts...,
|
||||
sessions, chanIDIndexBkt, k, opts...,
|
||||
)
|
||||
if errors.Is(err, ErrSessionFailedFilterFn) {
|
||||
return nil
|
||||
@@ -1153,8 +1151,8 @@ func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||
// listTowerSessions returns the set of all client sessions known to the db
|
||||
// that are associated with the given tower id.
|
||||
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
|
||||
towerToSessionIndex kvdb.RBucket, filterFn ClientSessionFilterFn,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
|
||||
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
|
||||
if towerIndexBkt == nil {
|
||||
@@ -1168,7 +1166,7 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := c.getClientSession(
|
||||
sessionsBkt, chanIDIndexBkt, k, filterFn, opts...,
|
||||
sessionsBkt, chanIDIndexBkt, k, opts...,
|
||||
)
|
||||
if errors.Is(err, ErrSessionFailedFilterFn) {
|
||||
return nil
|
||||
@@ -2140,6 +2138,13 @@ type ClientSessionListCfg struct {
|
||||
// PerCommittedUpdate will, if set, be called for each of the session's
|
||||
// committed (un-acked) updates.
|
||||
PerCommittedUpdate PerCommittedUpdateCB
|
||||
|
||||
// PreEvaluateFilterFn will be run after loading a session from the DB
|
||||
// and _before_ any of the other call-back functions in
|
||||
// ClientSessionListCfg. Therefore, if a session fails this filter
|
||||
// function, then it will not be passed to any of the other call backs
|
||||
// and won't be included in the return list.
|
||||
PreEvaluateFilterFn ClientSessionFilterFn
|
||||
}
|
||||
|
||||
// NewClientSessionCfg constructs a new ClientSessionListCfg.
|
||||
@@ -2174,12 +2179,22 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPreEvalFilterFn constructs a functional option that will set a call-back
|
||||
// function that will be called immediately after loading a session. If the
|
||||
// session fails this filter function, then it will not be passed to any of the
|
||||
// other evaluation call-back functions.
|
||||
func WithPreEvalFilterFn(fn ClientSessionFilterFn) ClientSessionListOption {
|
||||
return func(cfg *ClientSessionListCfg) {
|
||||
cfg.PreEvaluateFilterFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// getClientSession loads the full ClientSession associated with the serialized
|
||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||
// in addition to the ClientSession's body.
|
||||
func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
|
||||
idBytes []byte, filterFn ClientSessionFilterFn,
|
||||
opts ...ClientSessionListOption) (*ClientSession, error) {
|
||||
idBytes []byte, opts ...ClientSessionListOption) (*ClientSession,
|
||||
error) {
|
||||
|
||||
cfg := NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
@@ -2191,7 +2206,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if filterFn != nil && !filterFn(session) {
|
||||
if cfg.PreEvaluateFilterFn != nil && !cfg.PreEvaluateFilterFn(session) {
|
||||
return nil, ErrSessionFailedFilterFn
|
||||
}
|
||||
|
||||
|
@@ -52,12 +52,11 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) listSessions(id *wtdb.TowerID,
|
||||
filterFn wtdb.ClientSessionFilterFn,
|
||||
opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
sessions, err := h.db.ListClientSessions(id, filterFn, opts...)
|
||||
sessions, err := h.db.ListClientSessions(id, opts...)
|
||||
require.NoError(h.t, err, "unable to list client sessions")
|
||||
|
||||
return sessions
|
||||
@@ -84,7 +83,7 @@ func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
require.NotZero(h.t, tower.ID, "tower id should never be 0")
|
||||
|
||||
for _, session := range h.listSessions(&tower.ID, nil) {
|
||||
for _, session := range h.listSessions(&tower.ID) {
|
||||
require.Equal(h.t, wtdb.CSessionActive, session.Status)
|
||||
}
|
||||
|
||||
@@ -127,7 +126,7 @@ func (h *clientDBHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr,
|
||||
return
|
||||
}
|
||||
|
||||
for _, session := range h.listSessions(&tower.ID, nil) {
|
||||
for _, session := range h.listSessions(&tower.ID) {
|
||||
require.Equal(h.t, wtdb.CSessionInactive,
|
||||
session.Status, "expected status for session "+
|
||||
"%v to be %v, got %v", session.ID,
|
||||
@@ -301,7 +300,7 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
|
||||
// First, assert that this session is not already present in the
|
||||
// database.
|
||||
_, ok := h.listSessions(nil, nil)[session.ID]
|
||||
_, ok := h.listSessions(nil)[session.ID]
|
||||
require.Falsef(h.t, ok, "session for id %x should not exist yet",
|
||||
session.ID)
|
||||
|
||||
@@ -329,7 +328,7 @@ func testCreateClientSession(h *clientDBHarness) {
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Verify that the session now exists in the database.
|
||||
_, ok = h.listSessions(nil, nil)[session.ID]
|
||||
_, ok = h.listSessions(nil)[session.ID]
|
||||
require.Truef(h.t, ok, "session for id %x should exist now", session.ID)
|
||||
|
||||
// Attempt to insert the session again, which should fail due to the
|
||||
@@ -377,7 +376,7 @@ func testFilterClientSessions(h *clientDBHarness) {
|
||||
// We should see the expected sessions for each tower when filtering
|
||||
// them.
|
||||
for towerID, expectedSessions := range towerSessions {
|
||||
sessions := h.listSessions(&towerID, nil)
|
||||
sessions := h.listSessions(&towerID)
|
||||
require.Len(h.t, sessions, len(expectedSessions))
|
||||
|
||||
for _, expectedSession := range expectedSessions {
|
||||
|
@@ -89,7 +89,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||
tower = m.towers[towerID]
|
||||
tower.AddAddress(lnAddr.Address)
|
||||
|
||||
towerSessions, err := m.listClientSessions(&towerID, nil)
|
||||
towerSessions, err := m.listClientSessions(&towerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -141,7 +141,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
towerSessions, err := m.listClientSessions(&tower.ID, nil)
|
||||
towerSessions, err := m.listClientSessions(&tower.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -226,20 +226,19 @@ func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error {
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
|
||||
filterFn wtdb.ClientSessionFilterFn,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.listClientSessions(tower, filterFn, opts...)
|
||||
|
||||
return m.listClientSessions(tower, opts...)
|
||||
}
|
||||
|
||||
// listClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||
filterFn wtdb.ClientSessionFilterFn,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
@@ -255,7 +254,9 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||
continue
|
||||
}
|
||||
|
||||
if filterFn != nil && !filterFn(&session) {
|
||||
if cfg.PreEvaluateFilterFn != nil &&
|
||||
!cfg.PreEvaluateFilterFn(&session) {
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user