From 776f2a026c446906b3d82d7f5d70150f63946f7f Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 27 Nov 2023 12:06:00 +0200 Subject: [PATCH] wtdb: supply commited update count to PostEvaluateFilterFn In this commit, we adjust the PostEvaluateFilterFn to also take in a count representing the number of committed updates (ie, persisted un-acked updates) that the session has. This will be made use of in an upcoming commit. --- watchtower/wtclient/client.go | 4 ++-- watchtower/wtdb/client_db.go | 36 +++++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 24d974ced..3383de6ab 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -67,8 +67,8 @@ func (c *client) genSessionFilter( // ExhaustedSessionFilter constructs a wtdb.ClientSessionFilterFn filter // function that will filter out any sessions that have been exhausted. -func ExhaustedSessionFilter() wtdb.ClientSessionFilterFn { - return func(session *wtdb.ClientSession) bool { +func ExhaustedSessionFilter() wtdb.ClientSessWithNumCommittedUpdatesFilterFn { + return func(session *wtdb.ClientSession, _ uint16) bool { return session.SeqNum < session.Policy.MaxUpdates } } diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 635c6cfa8..3a8eb6725 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -2305,6 +2305,12 @@ func getClientSessionBody(sessions kvdb.RBucket, // that read sessions from the DB. type ClientSessionFilterFn func(*ClientSession) bool +// ClientSessWithNumCommittedUpdatesFilterFn describes the signature of a +// callback function that can be used to filter out a session based on the +// contents of ClientSession along with the number of un-acked committed updates +// that the session has. +type ClientSessWithNumCommittedUpdatesFilterFn func(*ClientSession, uint16) bool + // PerMaxHeightCB describes the signature of a callback function that can be // called for each channel that a session has updates for to communicate the // maximum commitment height that the session has backed up for the channel. @@ -2366,7 +2372,7 @@ type ClientSessionListCfg struct { // functions in ClientSessionListCfg. If a session fails this filter // function then all it means is that it won't be included in the list // of sessions to return. - PostEvaluateFilterFn ClientSessionFilterFn + PostEvaluateFilterFn ClientSessWithNumCommittedUpdatesFilterFn } // NewClientSessionCfg constructs a new ClientSessionListCfg. @@ -2427,7 +2433,9 @@ func WithPreEvalFilterFn(fn ClientSessionFilterFn) ClientSessionListOption { // run against the other ClientSessionListCfg call-backs) whereas the session // will only reach the PostEvalFilterFn call-back once it has already been // evaluated by all the other call-backs. -func WithPostEvalFilterFn(fn ClientSessionFilterFn) ClientSessionListOption { +func WithPostEvalFilterFn( + fn ClientSessWithNumCommittedUpdatesFilterFn) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { cfg.PostEvaluateFilterFn = fn } @@ -2459,7 +2467,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket, // Pass the session's committed (un-acked) updates through the call-back // if one is provided. - err = filterClientSessionCommits( + numCommittedUpdates, err := filterClientSessionCommits( sessionBkt, session, cfg.PerCommittedUpdate, ) if err != nil { @@ -2477,7 +2485,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket, } if cfg.PostEvaluateFilterFn != nil && - !cfg.PostEvaluateFilterFn(session) { + !cfg.PostEvaluateFilterFn(session, numCommittedUpdates) { return nil, ErrSessionFailedFilterFn } @@ -2586,18 +2594,21 @@ func (c *ClientDB) filterClientSessionAcks(sessionBkt, // identified by the serialized session id and passes them to the given // PerCommittedUpdateCB callback. func filterClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, - cb PerCommittedUpdateCB) error { - - if cb == nil { - return nil - } + cb PerCommittedUpdateCB) (uint16, error) { sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits) if sessionCommits == nil { - return nil + return 0, nil } + var numUpdates uint16 err := sessionCommits.ForEach(func(k, v []byte) error { + numUpdates++ + + if cb == nil { + return nil + } + var committedUpdate CommittedUpdate err := committedUpdate.Decode(bytes.NewReader(v)) if err != nil { @@ -2606,13 +2617,14 @@ func filterClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, committedUpdate.SeqNum = byteOrder.Uint16(k) cb(s, &committedUpdate) + return nil }) if err != nil { - return err + return 0, err } - return nil + return numUpdates, nil } // putClientSessionBody stores the body of the ClientSession (everything but the