diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 1f45335fb..228877743 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -390,6 +390,10 @@ func constructFunctionalOptions(includeSessions, return opts, ackCounts, committedUpdateCounts } + perNumRogueUpdates := func(s *wtdb.ClientSession, numUpdates uint16) { + ackCounts[s.ID] += numUpdates + } + perNumAckedUpdates := func(s *wtdb.ClientSession, id lnwire.ChannelID, numUpdates uint16) { @@ -405,6 +409,7 @@ func constructFunctionalOptions(includeSessions, opts = []wtdb.ClientSessionListOption{ wtdb.WithPerNumAckedUpdates(perNumAckedUpdates), wtdb.WithPerCommittedUpdate(perCommittedUpdate), + wtdb.WithPerRogueUpdateCount(perNumRogueUpdates), } if excludeExhaustedSessions { diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index e7b9c1137..084f2dcfe 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -2285,6 +2285,11 @@ type PerMaxHeightCB func(*ClientSession, lnwire.ChannelID, uint64) // number of updates that the session has for the channel. type PerNumAckedUpdatesCB func(*ClientSession, lnwire.ChannelID, uint16) +// PerRogueUpdateCountCB describes the signature of a callback function that can +// be called for each session with the number of rogue updates that the session +// has. +type PerRogueUpdateCountCB func(*ClientSession, uint16) + // PerAckedUpdateCB describes the signature of a callback function that can be // called for each of a session's acked updates. type PerAckedUpdateCB func(*ClientSession, uint16, BackupID) @@ -2307,6 +2312,10 @@ type ClientSessionListCfg struct { // channel. PerNumAckedUpdates PerNumAckedUpdatesCB + // PerRogueUpdateCount will, if set, be called with the number of rogue + // updates that the session has backed up. + PerRogueUpdateCount PerRogueUpdateCountCB + // PerMaxHeight will, if set, be called for each of the session's // channels to communicate the highest commit height of updates stored // for that channel. @@ -2354,6 +2363,15 @@ func WithPerNumAckedUpdates(cb PerNumAckedUpdatesCB) ClientSessionListOption { } } +// WithPerRogueUpdateCount constructs a functional option that will set a +// call-back function to be called with the number of rogue updates that the +// session has backed up. +func WithPerRogueUpdateCount(cb PerRogueUpdateCountCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerRogueUpdateCount = cb + } +} + // WithPerCommittedUpdate constructs a functional option that will set a // call-back function to be called for each of a client's un-acked updates. func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { @@ -2422,7 +2440,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket, // provided. err = c.filterClientSessionAcks( sessionBkt, chanIDIndexBkt, session, cfg.PerMaxHeight, - cfg.PerNumAckedUpdates, + cfg.PerNumAckedUpdates, cfg.PerRogueUpdateCount, ) if err != nil { return nil, err @@ -2480,7 +2498,24 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, // call back if one is provided. func (c *ClientDB) filterClientSessionAcks(sessionBkt, chanIDIndexBkt kvdb.RBucket, s *ClientSession, perMaxCb PerMaxHeightCB, - perNumAckedUpdates PerNumAckedUpdatesCB) error { + perNumAckedUpdates PerNumAckedUpdatesCB, + perRogueUpdateCount PerRogueUpdateCountCB) error { + + if perRogueUpdateCount != nil { + var ( + count uint64 + err error + ) + rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount) + if len(rogueCountBytes) != 0 { + count, err = readBigSize(rogueCountBytes) + if err != nil { + return err + } + } + + perRogueUpdateCount(s, uint16(count)) + } if perMaxCb == nil && perNumAckedUpdates == nil { return nil