From 3ac3b6a90df2c0b0cde6c151a0f8381716aae0db Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 30 Sep 2022 10:47:10 +0200 Subject: [PATCH 1/6] watchtower: refactor getClientSession helper funcs Small refactor to some of the tower client db helper functions in order to simplify upcoming commits. --- watchtower/wtdb/client_db.go | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 3cb5a8c70..f862b186c 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -1174,14 +1174,17 @@ func getClientSession(sessions, towers kvdb.RBucket, return nil, err } + // Can't fail because client session body has already been read. + sessionBkt := sessions.NestedReadBucket(idBytes) + // Fetch the committed updates for this session. - commitedUpdates, err := getClientSessionCommits(sessions, idBytes) + commitedUpdates, err := getClientSessionCommits(sessionBkt) if err != nil { return nil, err } // Fetch the acked updates for this session. - ackedUpdates, err := getClientSessionAcks(sessions, idBytes) + ackedUpdates, err := getClientSessionAcks(sessionBkt) if err != nil { return nil, err } @@ -1195,11 +1198,8 @@ func getClientSession(sessions, towers kvdb.RBucket, // getClientSessionCommits retrieves all committed updates for the session // identified by the serialized session id. -func getClientSessionCommits(sessions kvdb.RBucket, - idBytes []byte) ([]CommittedUpdate, error) { - - // Can't fail because client session body has already been read. - sessionBkt := sessions.NestedReadBucket(idBytes) +func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate, + error) { // Initialize commitedUpdates so that we can return an initialized map // if no committed updates exist. @@ -1231,11 +1231,8 @@ func getClientSessionCommits(sessions kvdb.RBucket, // getClientSessionAcks retrieves all acked updates for the session identified // by the serialized session id. -func getClientSessionAcks(sessions kvdb.RBucket, - idBytes []byte) (map[uint16]BackupID, error) { - - // Can't fail because client session body has already been read. - sessionBkt := sessions.NestedReadBucket(idBytes) +func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID, + error) { // Initialize ackedUpdates so that we can return an initialized map if // no acked updates exist. From 40e0ebf4171bfe018ff81eeb9956b634d82f73e9 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 13 Oct 2022 13:46:52 +0200 Subject: [PATCH 2/6] watchtower: add ListClientSessions functional options This commit adds functional options to the ListClientSessions call that can be used to perform a variety of extra operations during the DB query. These functional options are not yet used in this commit. --- watchtower/wtclient/client.go | 28 +++++--- watchtower/wtclient/interface.go | 2 +- watchtower/wtdb/client_db.go | 107 +++++++++++++++++++++++++------ watchtower/wtmock/client_db.go | 12 ++-- 4 files changed, 113 insertions(+), 36 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 3d23f0b82..0003c2b10 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -83,10 +83,12 @@ type Client interface { // RegisteredTowers retrieves the list of watchtowers registered with // the client. - RegisteredTowers() ([]*RegisteredTower, error) + RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower, + error) // LookupTower retrieves a registered watchtower through its public key. - LookupTower(*btcec.PublicKey) (*RegisteredTower, error) + LookupTower(*btcec.PublicKey, + ...wtdb.ClientSessionListOption) (*RegisteredTower, error) // Stats returns the in-memory statistics of the client since startup. Stats() ClientStats @@ -363,7 +365,8 @@ func New(config *Config) (*TowerClient, error) { // tower. func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, sessionFilter func(*wtdb.ClientSession) bool, - perActiveTower func(tower *wtdb.Tower)) ( + perActiveTower func(tower *wtdb.Tower), + opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { towers, err := db.ListTowers() @@ -373,7 +376,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, tower := range towers { - sessions, err := db.ListClientSessions(&tower.ID) + sessions, err := db.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err } @@ -413,10 +416,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, // ClientSession's SessionPrivKey field is desired, otherwise, the existing // ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, - passesFilter func(*wtdb.ClientSession) bool) ( + passesFilter func(*wtdb.ClientSession) bool, + opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { - sessions, err := db.ListClientSessions(forTower) + sessions, err := db.ListClientSessions(forTower, opts...) if err != nil { return nil, err } @@ -1233,13 +1237,15 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // RegisteredTowers retrieves the list of watchtowers registered with the // client. -func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) { +func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( + []*RegisteredTower, error) { + // Retrieve all of our towers along with all of our sessions. towers, err := c.cfg.DB.ListTowers() if err != nil { return nil, err } - clientSessions, err := c.cfg.DB.ListClientSessions(nil) + clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...) if err != nil { return nil, err } @@ -1272,13 +1278,15 @@ func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) { } // LookupTower retrieves a registered watchtower through its public key. -func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey) (*RegisteredTower, error) { +func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey, + opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) { + tower, err := c.cfg.DB.LoadTower(pubKey) if err != nil { return nil, err } - towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID) + towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err } diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index dbf2faf71..eb4a450a2 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -62,7 +62,7 @@ type DB interface { // still be able to accept state updates. An optional tower ID can be // used to filter out any client sessions in the response that do not // correspond to this tower. - ListClientSessions(*wtdb.TowerID) ( + ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) // FetchChanSummaries loads a mapping from all registered channels to diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index f862b186c..9e33a1df3 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -736,8 +736,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID, // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (c *ClientDB) ListClientSessions(id *TowerID) ( - map[SessionID]*ClientSession, error) { +func (c *ClientDB) ListClientSessions(id *TowerID, + opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { var clientSessions map[SessionID]*ClientSession err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -757,7 +757,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( // known to the db. if id == nil { clientSessions, err = listClientAllSessions( - sessions, towers, + sessions, towers, opts..., ) return err } @@ -769,7 +769,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( } clientSessions, err = listTowerSessions( - *id, sessions, towers, towerToSessionIndex, + *id, sessions, towers, towerToSessionIndex, opts..., ) return err }, func() { @@ -783,8 +783,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( } // listClientAllSessions returns the set of all client sessions known to the db. -func listClientAllSessions(sessions, - towers kvdb.RBucket) (map[SessionID]*ClientSession, error) { +func listClientAllSessions(sessions, towers kvdb.RBucket, + opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) err := sessions.ForEach(func(k, _ []byte) error { @@ -792,7 +792,7 @@ func listClientAllSessions(sessions, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessions, towers, k) + session, err := getClientSession(sessions, towers, k, opts...) if err != nil { return err } @@ -811,8 +811,8 @@ func listClientAllSessions(sessions, // listTowerSessions returns the set of all client sessions known to the db // that are associated with the given tower id. func listTowerSessions(id TowerID, sessionsBkt, towersBkt, - towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession, - error) { + towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) ( + map[SessionID]*ClientSession, error) { towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes()) if towerIndexBkt == nil { @@ -825,7 +825,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessionsBkt, towersBkt, k) + session, err := getClientSession( + sessionsBkt, towersBkt, k, opts..., + ) if err != nil { return err } @@ -1157,11 +1159,63 @@ func getClientSessionBody(sessions kvdb.RBucket, return &session, nil } +// PerAckedUpdateCB describes the signature of a callback function that can be +// called for each of a session's acked updates. +type PerAckedUpdateCB func(*ClientSession, uint16, BackupID) + +// PerCommittedUpdateCB describes the signature of a callback function that can +// be called for each of a session's committed updates (updates that the client +// has not yet received an ACK for). +type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate) + +// ClientSessionListOption describes the signature of a functional option that +// can be used when listing client sessions in order to provide any extra +// instruction to the query. +type ClientSessionListOption func(cfg *ClientSessionListCfg) + +// ClientSessionListCfg defines various query parameters that will be used when +// querying the DB for client sessions. +type ClientSessionListCfg struct { + // PerAckedUpdate will, if set, be called for each of the session's + // acked updates. + PerAckedUpdate PerAckedUpdateCB + + // PerCommittedUpdate will, if set, be called for each of the session's + // committed (un-acked) updates. + PerCommittedUpdate PerCommittedUpdateCB +} + +// NewClientSessionCfg constructs a new ClientSessionListCfg. +func NewClientSessionCfg() *ClientSessionListCfg { + return &ClientSessionListCfg{} +} + +// WithPerAckedUpdate constructs a functional option that will set a call-back +// function to be called for each of a client's acked updates. +func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerAckedUpdate = cb + } +} + +// WithPerCommittedUpdate constructs a functional option that will set a +// call-back function to be called for each of a client's un-acked updates. +func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerCommittedUpdate = cb + } +} + // getClientSession loads the full ClientSession associated with the serialized // session id. This method populates the CommittedUpdates, AckUpdates and Tower // in addition to the ClientSession's body. -func getClientSession(sessions, towers kvdb.RBucket, - idBytes []byte) (*ClientSession, error) { +func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, + opts ...ClientSessionListOption) (*ClientSession, error) { + + cfg := NewClientSessionCfg() + for _, o := range opts { + o(cfg) + } session, err := getClientSessionBody(sessions, idBytes) if err != nil { @@ -1178,13 +1232,17 @@ func getClientSession(sessions, towers kvdb.RBucket, sessionBkt := sessions.NestedReadBucket(idBytes) // Fetch the committed updates for this session. - commitedUpdates, err := getClientSessionCommits(sessionBkt) + commitedUpdates, err := getClientSessionCommits( + sessionBkt, session, cfg.PerCommittedUpdate, + ) if err != nil { return nil, err } // Fetch the acked updates for this session. - ackedUpdates, err := getClientSessionAcks(sessionBkt) + ackedUpdates, err := getClientSessionAcks( + sessionBkt, session, cfg.PerAckedUpdate, + ) if err != nil { return nil, err } @@ -1197,11 +1255,12 @@ func getClientSession(sessions, towers kvdb.RBucket, } // getClientSessionCommits retrieves all committed updates for the session -// identified by the serialized session id. -func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate, - error) { +// identified by the serialized session id. If a PerCommittedUpdateCB is +// provided, then it will be called for each of the session's committed updates. +func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, + cb PerCommittedUpdateCB) ([]CommittedUpdate, error) { - // Initialize commitedUpdates so that we can return an initialized map + // Initialize committedUpdates so that we can return an initialized map // if no committed updates exist. committedUpdates := make([]CommittedUpdate, 0) @@ -1220,6 +1279,10 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate, committedUpdates = append(committedUpdates, committedUpdate) + if cb != nil { + cb(s, &committedUpdate) + } + return nil }) if err != nil { @@ -1231,8 +1294,8 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket) ([]CommittedUpdate, // getClientSessionAcks retrieves all acked updates for the session identified // by the serialized session id. -func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID, - error) { +func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession, + cb PerAckedUpdateCB) (map[uint16]BackupID, error) { // Initialize ackedUpdates so that we can return an initialized map if // no acked updates exist. @@ -1254,6 +1317,10 @@ func getClientSessionAcks(sessionBkt kvdb.RBucket) (map[uint16]BackupID, ackedUpdates[seqNum] = backupID + if cb != nil { + cb(s, seqNum, backupID) + } + return nil }) if err != nil { diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 2a3825e87..f569991fc 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -200,19 +200,21 @@ func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight ui // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (m *ClientDB) ListClientSessions( - tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { +func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID, + opts ...wtdb.ClientSessionListOption) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { m.mu.Lock() defer m.mu.Unlock() - return m.listClientSessions(tower) + return m.listClientSessions(tower, opts...) } // listClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (m *ClientDB) listClientSessions( - tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { +func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, + _ ...wtdb.ClientSessionListOption) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, session := range m.activeSessions { From 15858cae1c7511dddfcf0278ccf19d872b52eeae Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 13 Oct 2022 14:24:15 +0200 Subject: [PATCH 3/6] watchtower+lnrpc: remove AckedUpdates from ClientSession struct In this commit, we start making use of the new ListClientSession functional options added in the previous commit. We use the functional options in order to calculate the max commit heights per channel on the construction of the tower client. We also use the options to count the total number of acked and committed updates. With this commit, we are also able to completely remove the AckedUpdates member of the ClientSession since it is no longer used anywhere in the code. --- lnrpc/wtclientrpc/wtclient.go | 70 +++++++++++++--- watchtower/wtclient/client.go | 130 +++++++++++++++--------------- watchtower/wtdb/client_db.go | 36 ++++----- watchtower/wtdb/client_db_test.go | 65 ++++++++------- watchtower/wtdb/client_session.go | 7 -- watchtower/wtmock/client_db.go | 23 ++++-- 6 files changed, 194 insertions(+), 137 deletions(-) diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 5864b6608..1d9d41849 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -265,12 +265,16 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context, return nil, err } - anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers() + opts, ackCounts, committedUpdateCounts := constructFunctionalOptions( + req.IncludeSessions, + ) + + anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers(opts...) if err != nil { return nil, err } - legacyTowers, err := c.cfg.Client.RegisteredTowers() + legacyTowers, err := c.cfg.Client.RegisteredTowers(opts...) if err != nil { return nil, err } @@ -286,7 +290,10 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context, rpcTowers := make([]*Tower, 0, len(towers)) for _, tower := range towers { - rpcTower := marshallTower(tower, req.IncludeSessions) + rpcTower := marshallTower( + tower, req.IncludeSessions, ackCounts, + committedUpdateCounts, + ) rpcTowers = append(rpcTowers, rpcTower) } @@ -306,16 +313,59 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context, return nil, err } + opts, ackCounts, committedUpdateCounts := constructFunctionalOptions( + req.IncludeSessions, + ) + var tower *wtclient.RegisteredTower - tower, err = c.cfg.Client.LookupTower(pubKey) + tower, err = c.cfg.Client.LookupTower(pubKey, opts...) if err == wtdb.ErrTowerNotFound { - tower, err = c.cfg.AnchorClient.LookupTower(pubKey) + tower, err = c.cfg.AnchorClient.LookupTower(pubKey, opts...) } if err != nil { return nil, err } - return marshallTower(tower, req.IncludeSessions), nil + return marshallTower( + tower, req.IncludeSessions, ackCounts, committedUpdateCounts, + ), nil +} + +// constructFunctionalOptions is a helper function that constructs a list of +// functional options to be used when fetching a tower from the DB. It also +// returns a map of acked-update counts and one for un-acked-update counts that +// will be populated once the db call has been made. +func constructFunctionalOptions(includeSessions bool) ( + []wtdb.ClientSessionListOption, map[wtdb.SessionID]uint16, + map[wtdb.SessionID]uint16) { + + var ( + opts []wtdb.ClientSessionListOption + ackCounts = make(map[wtdb.SessionID]uint16) + committedUpdateCounts = make(map[wtdb.SessionID]uint16) + ) + if !includeSessions { + return opts, ackCounts, committedUpdateCounts + } + + perAckedUpdate := func(s *wtdb.ClientSession, _ uint16, + _ wtdb.BackupID) { + + ackCounts[s.ID]++ + } + + perCommittedUpdate := func(s *wtdb.ClientSession, + _ *wtdb.CommittedUpdate) { + + committedUpdateCounts[s.ID]++ + } + + opts = []wtdb.ClientSessionListOption{ + wtdb.WithPerAckedUpdate(perAckedUpdate), + wtdb.WithPerCommittedUpdate(perCommittedUpdate), + } + + return opts, ackCounts, committedUpdateCounts } // Stats returns the in-memory statistics of the client since startup. @@ -387,7 +437,9 @@ func (c *WatchtowerClient) Policy(ctx context.Context, // marshallTower converts a client registered watchtower into its corresponding // RPC type. -func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower { +func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool, + ackCounts, pendingCounts map[wtdb.SessionID]uint16) *Tower { + rpcAddrs := make([]string, 0, len(tower.Addresses)) for _, addr := range tower.Addresses { rpcAddrs = append(rpcAddrs, addr.String()) @@ -399,8 +451,8 @@ func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower for _, session := range tower.Sessions { satPerVByte := session.Policy.SweepFeeRate.FeePerKVByte() / 1000 rpcSessions = append(rpcSessions, &TowerSession{ - NumBackups: uint32(len(session.AckedUpdates)), - NumPendingBackups: uint32(len(session.CommittedUpdates)), + NumBackups: uint32(ackCounts[session.ID]), + NumPendingBackups: uint32(pendingCounts[session.ID]), MaxBackups: uint32(session.Policy.MaxUpdates), SweepSatPerVbyte: uint32(satPerVByte), diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 0003c2b10..2e019c536 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -289,12 +289,67 @@ func New(config *Config) (*TowerClient, error) { } plog := build.NewPrefixLog(prefix, log) - // Next, load all candidate towers and sessions from the database into - // the client. We will use any of these sessions if their policies match - // the current policy of the client, otherwise they will be ignored and - // new sessions will be requested. + // Load the sweep pkscripts that have been generated for all previously + // registered channels. + chanSummaries, err := cfg.DB.FetchChanSummaries() + if err != nil { + return nil, err + } + + c := &TowerClient{ + cfg: cfg, + log: plog, + pipeline: newTaskPipeline(plog), + chanCommitHeights: make(map[lnwire.ChannelID]uint64), + activeSessions: make(sessionQueueSet), + summaries: chanSummaries, + statTicker: time.NewTicker(DefaultStatInterval), + stats: new(ClientStats), + newTowers: make(chan *newTowerMsg), + staleTowers: make(chan *staleTowerMsg), + forceQuit: make(chan struct{}), + } + + // perUpdate is a callback function that will be used to inspect the + // full set of candidate client sessions loaded from disk, and to + // determine the highest known commit height for each channel. This + // allows the client to reject backups that it has already processed for + // its active policy. + perUpdate := func(policy wtpolicy.Policy, id wtdb.BackupID) { + // We only want to consider accepted updates that have been + // accepted under an identical policy to the client's current + // policy. + if policy != c.cfg.Policy { + return + } + + // Take the highest commit height found in the session's acked + // updates. + height, ok := c.chanCommitHeights[id.ChanID] + if !ok || id.CommitHeight > height { + c.chanCommitHeights[id.ChanID] = id.CommitHeight + } + } + + perAckedUpdate := func(s *wtdb.ClientSession, _ uint16, + id wtdb.BackupID) { + + perUpdate(s.Policy, id) + } + + perCommittedUpdate := func(s *wtdb.ClientSession, + u *wtdb.CommittedUpdate) { + + perUpdate(s.Policy, u.BackupID) + } + + // Load all candidate sessions and towers from the database into the + // client. We will use any of these sessions if their policies match the + // current policy of the client, otherwise they will be ignored and new + // sessions will be requested. isAnchorClient := cfg.Policy.IsAnchorChannel() activeSessionFilter := genActiveSessionFilter(isAnchorClient) + candidateTowers := newTowerListIterator() perActiveTower := func(tower *wtdb.Tower) { // If the tower has already been marked as active, then there is @@ -309,34 +364,19 @@ func New(config *Config) (*TowerClient, error) { // Add the tower to the set of candidate towers. candidateTowers.AddCandidate(tower) } + candidateSessions, err := getTowerAndSessionCandidates( cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, + wtdb.WithPerAckedUpdate(perAckedUpdate), + wtdb.WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { return nil, err } - // Load the sweep pkscripts that have been generated for all previously - // registered channels. - chanSummaries, err := cfg.DB.FetchChanSummaries() - if err != nil { - return nil, err - } + c.candidateTowers = candidateTowers + c.candidateSessions = candidateSessions - c := &TowerClient{ - cfg: cfg, - log: plog, - pipeline: newTaskPipeline(plog), - candidateTowers: candidateTowers, - candidateSessions: candidateSessions, - activeSessions: make(sessionQueueSet), - summaries: chanSummaries, - statTicker: time.NewTicker(DefaultStatInterval), - stats: new(ClientStats), - newTowers: make(chan *newTowerMsg), - staleTowers: make(chan *staleTowerMsg), - forceQuit: make(chan struct{}), - } c.negotiator = newSessionNegotiator(&NegotiatorConfig{ DB: cfg.DB, SecretKeyRing: cfg.SecretKeyRing, @@ -351,10 +391,6 @@ func New(config *Config) (*TowerClient, error) { Log: plog, }) - // Reconstruct the highest commit height processed for each channel - // under the client's current policy. - c.buildHighestCommitHeights() - return c, nil } @@ -450,44 +486,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, return sessions, nil } -// buildHighestCommitHeights inspects the full set of candidate client sessions -// loaded from disk, and determines the highest known commit height for each -// channel. This allows the client to reject backups that it has already -// processed for it's active policy. -func (c *TowerClient) buildHighestCommitHeights() { - chanCommitHeights := make(map[lnwire.ChannelID]uint64) - for _, s := range c.candidateSessions { - // We only want to consider accepted updates that have been - // accepted under an identical policy to the client's current - // policy. - if s.Policy != c.cfg.Policy { - continue - } - - // Take the highest commit height found in the session's - // committed updates. - for _, committedUpdate := range s.CommittedUpdates { - bid := committedUpdate.BackupID - - height, ok := chanCommitHeights[bid.ChanID] - if !ok || bid.CommitHeight > height { - chanCommitHeights[bid.ChanID] = bid.CommitHeight - } - } - - // Take the heights commit height found in the session's acked - // updates. - for _, bid := range s.AckedUpdates { - height, ok := chanCommitHeights[bid.ChanID] - if !ok || bid.CommitHeight > height { - chanCommitHeights[bid.ChanID] = bid.CommitHeight - } - } - } - - c.chanCommitHeights = chanCommitHeights -} - // Start initializes the watchtower client by loading or negotiating an active // session and then begins processing backup tasks from the request pipeline. func (c *TowerClient) Start() error { diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 9e33a1df3..bb4db7444 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -1239,17 +1239,15 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, return nil, err } - // Fetch the acked updates for this session. - ackedUpdates, err := getClientSessionAcks( - sessionBkt, session, cfg.PerAckedUpdate, - ) + // Pass the session's acked updates through the call-back if one is + // provided. + err = filterClientSessionAcks(sessionBkt, session, cfg.PerAckedUpdate) if err != nil { return nil, err } session.Tower = tower session.CommittedUpdates = commitedUpdates - session.AckedUpdates = ackedUpdates return session, nil } @@ -1292,18 +1290,19 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, return committedUpdates, nil } -// getClientSessionAcks retrieves all acked updates for the session identified -// by the serialized session id. -func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession, - cb PerAckedUpdateCB) (map[uint16]BackupID, error) { +// filterClientSessionAcks retrieves all acked updates for the session +// identified by the serialized session id and passes them to the provided +// call back if one is provided. +func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession, + cb PerAckedUpdateCB) error { - // Initialize ackedUpdates so that we can return an initialized map if - // no acked updates exist. - ackedUpdates := make(map[uint16]BackupID) + if cb == nil { + return nil + } sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks) if sessionAcks == nil { - return ackedUpdates, nil + return nil } err := sessionAcks.ForEach(func(k, v []byte) error { @@ -1315,19 +1314,14 @@ func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession, return err } - ackedUpdates[seqNum] = backupID - - if cb != nil { - cb(s, seqNum, backupID) - } - + cb(s, seqNum, backupID) return nil }) if err != nil { - return nil, err + return err } - return ackedUpdates, nil + return nil } // putClientSessionBody stores the body of the ClientSession (everything but the diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index d4f1699c9..8bff940d1 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -48,12 +48,12 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, require.ErrorIs(h.t, err, expErr) } -func (h *clientDBHarness) listSessions( - id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { +func (h *clientDBHarness) listSessions(id *wtdb.TowerID, + opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession { h.t.Helper() - sessions, err := h.db.ListClientSessions(id) + sessions, err := h.db.ListClientSessions(id, opts...) require.NoError(h.t, err, "unable to list client sessions") return sessions @@ -520,11 +520,7 @@ func testCommitUpdate(h *clientDBHarness) { // Assert that the committed update appears in the client session's // CommittedUpdates map when loaded from disk and that there are no // AckedUpdates. - dbSession := h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ - *update1, - }) - checkAckedUpdates(h.t, dbSession, nil) + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil) // Try to commit the same update, which should succeed due to // idempotency (which is preserved when the breach hint is identical to @@ -534,11 +530,7 @@ func testCommitUpdate(h *clientDBHarness) { require.Equal(h.t, lastApplied, lastApplied2) // Assert that the loaded ClientSession is the same as before. - dbSession = h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ - *update1, - }) - checkAckedUpdates(h.t, dbSession, nil) + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil) // Generate another random update and try to commit it at the identical // sequence number. Since the breach hint has changed, this should fail. @@ -553,12 +545,10 @@ func testCommitUpdate(h *clientDBHarness) { // Check that both updates now appear as committed on the ClientSession // loaded from disk. - dbSession = h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{ *update1, *update2, - }) - checkAckedUpdates(h.t, dbSession, nil) + }, nil) // Finally, create one more random update and try to commit it at index // 4, which should be rejected since 3 is the next slot the database @@ -567,12 +557,20 @@ func testCommitUpdate(h *clientDBHarness) { h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate) // Assert that the ClientSession loaded from disk remains unchanged. - dbSession = h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{ *update1, *update2, - }) - checkAckedUpdates(h.t, dbSession, nil) + }, nil) +} + +func perAckedUpdate(updates map[uint16]wtdb.BackupID) func( + _ *wtdb.ClientSession, seq uint16, id wtdb.BackupID) { + + return func(_ *wtdb.ClientSession, seq uint16, + id wtdb.BackupID) { + + updates[seq] = id + } } // testAckUpdate asserts the behavior of AckUpdate. @@ -628,9 +626,7 @@ func testAckUpdate(h *clientDBHarness) { // Assert that the ClientSession loaded from disk has one update in it's // AckedUpdates map, and that the committed update has been removed. - dbSession := h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, nil) - checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ + h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{ 1: update1.BackupID, }) @@ -645,9 +641,7 @@ func testAckUpdate(h *clientDBHarness) { h.ackUpdate(&session.ID, 2, 2, nil) // Assert that both updates exist as AckedUpdates when loaded from disk. - dbSession = h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, nil) - checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ + h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{ 1: update1.BackupID, 2: update2.BackupID, }) @@ -663,6 +657,19 @@ func testAckUpdate(h *clientDBHarness) { h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied) } +func (h *clientDBHarness) assertUpdates(id wtdb.SessionID, + expectedPending []wtdb.CommittedUpdate, + expectedAcked map[uint16]wtdb.BackupID) { + + ackedUpdates := make(map[uint16]wtdb.BackupID) + _ = h.listSessions( + nil, wtdb.WithPerAckedUpdate(perAckedUpdate(ackedUpdates)), + ) + dbSession := h.listSessions(nil)[id] + checkCommittedUpdates(h.t, dbSession, expectedPending) + checkAckedUpdates(h.t, ackedUpdates, expectedAcked) +} + // checkCommittedUpdates asserts that the CommittedUpdates on session match the // expUpdates provided. func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, @@ -682,7 +689,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, // checkAckedUpdates asserts that the AckedUpdates on a session match the // expUpdates provided. -func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, +func checkAckedUpdates(t *testing.T, actualUpdates, expUpdates map[uint16]wtdb.BackupID) { // We promote nil expUpdates to an initialized map since the database @@ -692,7 +699,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make(map[uint16]wtdb.BackupID) } - require.Equal(t, expUpdates, session.AckedUpdates) + require.Equal(t, expUpdates, actualUpdates) } // TestClientDB asserts the behavior of a fresh client db, a reopened client db, diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index 1d3e01f5a..556b19937 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -47,13 +47,6 @@ type ClientSession struct { // insertion and retrieval. CommittedUpdates []CommittedUpdate - // AckedUpdates is a map from sequence number to backup id to record - // which revoked states were uploaded via this session. - // - // NOTE: This map is serialized in it's own bucket, separate from the - // body of the ClientSession. - AckedUpdates map[uint16]BackupID - // Tower holds the pubkey and address of the watchtower. // // NOTE: This value is not serialized. It is recovered by looking up the diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index f569991fc..ec79bc0e0 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -26,6 +26,7 @@ type ClientDB struct { mu sync.Mutex summaries map[lnwire.ChannelID]wtdb.ClientChanSummary activeSessions map[wtdb.SessionID]wtdb.ClientSession + ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID towerIndex map[towerPK]wtdb.TowerID towers map[wtdb.TowerID]*wtdb.Tower @@ -39,6 +40,7 @@ func NewClientDB() *ClientDB { return &ClientDB{ summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), + ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID), towerIndex: make(map[towerPK]wtdb.TowerID), towers: make(map[wtdb.TowerID]*wtdb.Tower), indexes: make(map[keyIndexKey]uint32), @@ -75,7 +77,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { } else { towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) tower = &wtdb.Tower{ - ID: wtdb.TowerID(towerID), + ID: towerID, IdentityKey: lnAddr.IdentityKey, Addresses: []net.Addr{lnAddr.Address}, } @@ -193,7 +195,7 @@ func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) { // MarkBackupIneligible records that particular commit height is ineligible for // backup. This allows the client to track which updates it should not attempt // to retry after startup. -func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error { +func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error { return nil } @@ -213,9 +215,14 @@ func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID, // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, - _ ...wtdb.ClientSessionListOption) ( + opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { + cfg := wtdb.NewClientSessionCfg() + for _, o := range opts { + o(cfg) + } + sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, session := range m.activeSessions { session := session @@ -224,6 +231,12 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, } session.Tower = m.towers[session.TowerID] sessions[session.ID] = &session + + if cfg.PerAckedUpdate != nil { + for seq, id := range m.ackedUpdates[session.ID] { + cfg.PerAckedUpdate(&session, seq, id) + } + } } return sessions, nil @@ -274,8 +287,8 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { RewardPkScript: cloneBytes(session.RewardPkScript), }, CommittedUpdates: make([]wtdb.CommittedUpdate, 0), - AckedUpdates: make(map[uint16]wtdb.BackupID), } + m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID) return nil } @@ -402,7 +415,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err updates[len(updates)-1] = wtdb.CommittedUpdate{} session.CommittedUpdates = updates[:len(updates)-1] - session.AckedUpdates[seqNum] = update.BackupID + m.ackedUpdates[*id][seqNum] = update.BackupID session.TowerLastApplied = lastApplied m.activeSessions[*id] = session From fe3d9174ea446ed55e7aca2fecbbcbb3ad9dba77 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 30 Sep 2022 11:47:54 +0200 Subject: [PATCH 4/6] watchtower: add FetchSessionCommittedUpdates func to DB In this commit, a new tower client db function is added that can be used to fetch all the committed updates for a given session ID. This is done in preparation for an upcoming commit where the CommittedUpdates will be removed from the ClientSession struct. --- watchtower/wtclient/interface.go | 5 +++++ watchtower/wtdb/client_db.go | 30 ++++++++++++++++++++++++++++++ watchtower/wtdb/client_db_test.go | 25 +++++++++++++++++++++---- watchtower/wtmock/client_db.go | 16 ++++++++++++++++ 4 files changed, 72 insertions(+), 4 deletions(-) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index eb4a450a2..c67d7eac3 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -65,6 +65,11 @@ type DB interface { ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) + // FetchSessionCommittedUpdates retrieves the current set of un-acked + // updates of the given session. + FetchSessionCommittedUpdates(id *wtdb.SessionID) ( + []wtdb.CommittedUpdate, error) + // FetchChanSummaries loads a mapping from all registered channels to // their channel summaries. FetchChanSummaries() (wtdb.ChannelSummaries, error) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index bb4db7444..c0c7d5118 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -842,6 +842,36 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt, return clientSessions, nil } +// FetchSessionCommittedUpdates retrieves the current set of un-acked updates +// of the given session. +func (c *ClientDB) FetchSessionCommittedUpdates(id *SessionID) ( + []CommittedUpdate, error) { + + var committedUpdates []CommittedUpdate + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + sessions := tx.ReadBucket(cSessionBkt) + if sessions == nil { + return ErrUninitializedDB + } + + sessionBkt := sessions.NestedReadBucket(id[:]) + if sessionBkt == nil { + return ErrClientSessionNotFound + } + + var err error + committedUpdates, err = getClientSessionCommits( + sessionBkt, nil, nil, + ) + return err + }, func() {}) + if err != nil { + return nil, err + } + + return committedUpdates, nil +} + // FetchChanSummaries loads a mapping from all registered channels to their // channel summaries. func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 8bff940d1..aa30cc713 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -207,6 +207,20 @@ func (h *clientDBHarness) newTower() *wtdb.Tower { }, nil) } +func (h *clientDBHarness) fetchSessionCommittedUpdates(id *wtdb.SessionID, + expErr error) []wtdb.CommittedUpdate { + + h.t.Helper() + + updates, err := h.db.FetchSessionCommittedUpdates(id) + if err != expErr { + h.t.Fatalf("expected fetch session committed updates error: "+ + "%v, got: %v", expErr, err) + } + + return updates +} + // testCreateClientSession asserts various conditions regarding the creation of // a new ClientSession. The test asserts: // - client sessions can only be created if a session key index is reserved. @@ -506,6 +520,9 @@ func testCommitUpdate(h *clientDBHarness) { // session, which should fail. update1 := randCommittedUpdate(h.t, 1) h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound) + h.fetchSessionCommittedUpdates( + &session.ID, wtdb.ErrClientSessionNotFound, + ) // Reserve a session key index and insert the session. session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType) @@ -665,14 +682,14 @@ func (h *clientDBHarness) assertUpdates(id wtdb.SessionID, _ = h.listSessions( nil, wtdb.WithPerAckedUpdate(perAckedUpdate(ackedUpdates)), ) - dbSession := h.listSessions(nil)[id] - checkCommittedUpdates(h.t, dbSession, expectedPending) + committedUpates := h.fetchSessionCommittedUpdates(&id, nil) + checkCommittedUpdates(h.t, committedUpates, expectedPending) checkAckedUpdates(h.t, ackedUpdates, expectedAcked) } // checkCommittedUpdates asserts that the CommittedUpdates on session match the // expUpdates provided. -func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, +func checkCommittedUpdates(t *testing.T, actualUpdates, expUpdates []wtdb.CommittedUpdate) { t.Helper() @@ -684,7 +701,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make([]wtdb.CommittedUpdate, 0) } - require.Equal(t, expUpdates, session.CommittedUpdates) + require.Equal(t, expUpdates, actualUpdates) } // checkAckedUpdates asserts that the AckedUpdates on a session match the diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index ec79bc0e0..f0df13fad 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -242,6 +242,22 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, return sessions, nil } +// FetchSessionCommittedUpdates retrieves the current set of un-acked updates +// of the given session. +func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) ( + []wtdb.CommittedUpdate, error) { + + m.mu.Lock() + defer m.mu.Unlock() + + sess, ok := m.activeSessions[*id] + if !ok { + return nil, wtdb.ErrClientSessionNotFound + } + + return sess.CommittedUpdates, nil +} + // CreateClientSession records a newly negotiated client session in the set of // active sessions. The session can be identified by its SessionID. func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { From 75e5339217c36007488cc1404ba89e935d6f7f60 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 30 Sep 2022 12:18:08 +0200 Subject: [PATCH 5/6] watchtower: remove CommittedUpdates from ClientSession In this commit, the new ListClientSession functional options and new FetchSessionCommittedUpdates function are utilised in order to allow us to completely remove the CommittedUpdates member from the ClientSession struct. --- watchtower/wtclient/client.go | 59 ++++++++++++++++++++-------- watchtower/wtclient/session_queue.go | 6 ++- watchtower/wtdb/client_db.go | 53 ++++++++++++++++++++++--- watchtower/wtdb/client_session.go | 10 ----- watchtower/wtmock/client_db.go | 57 ++++++++++++++++----------- 5 files changed, 129 insertions(+), 56 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 2e019c536..35c0ef22b 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -489,7 +489,7 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, // Start initializes the watchtower client by loading or negotiating an active // session and then begins processing backup tasks from the request pipeline. func (c *TowerClient) Start() error { - var err error + var returnErr error c.started.Do(func() { c.log.Infof("Watchtower client starting") @@ -498,19 +498,27 @@ func (c *TowerClient) Start() error { // sessions will be able to flush the committed updates after a // restart. for _, session := range c.candidateSessions { - if len(session.CommittedUpdates) > 0 { + committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID) + if err != nil { + returnErr = err + return + } + + if len(committedUpdates) > 0 { c.log.Infof("Starting session=%s to process "+ "%d committed backups", session.ID, - len(session.CommittedUpdates)) - c.initActiveQueue(session) + len(committedUpdates)) + + c.initActiveQueue(session, committedUpdates) } } // Now start the session negotiator, which will allow us to // request new session as soon as the backupDispatcher starts // up. - err = c.negotiator.Start() + err := c.negotiator.Start() if err != nil { + returnErr = err return } @@ -523,7 +531,7 @@ func (c *TowerClient) Start() error { c.log.Infof("Watchtower client started successfully") }) - return err + return returnErr } // Stop idempotently initiates a graceful shutdown of the watchtower client. @@ -699,7 +707,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, // active client's advertised policy will be ignored, but may be resumed if the // client is restarted with a matching policy. If no candidates were found, nil // is returned to signal that we need to request a new policy. -func (c *TowerClient) nextSessionQueue() *sessionQueue { +func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { // Select any candidate session at random, and remove it from the set of // candidate sessions. var candidateSession *wtdb.ClientSession @@ -721,13 +729,20 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue { // If none of the sessions could be used or none were found, we'll // return nil to signal that we need another session to be negotiated. if candidateSession == nil { - return nil + return nil, nil + } + + updates, err := c.cfg.DB.FetchSessionCommittedUpdates( + &candidateSession.ID, + ) + if err != nil { + return nil, err } // Initialize the session queue and spin it up so it can begin handling // updates. If the queue was already made active on startup, this will // simply return the existing session queue from the set. - return c.getOrInitActiveQueue(candidateSession) + return c.getOrInitActiveQueue(candidateSession, updates), nil } // backupDispatcher processes events coming from the taskPipeline and is @@ -800,7 +815,13 @@ func (c *TowerClient) backupDispatcher() { // We've exhausted the prior session, we'll pop another // from the remaining sessions and continue processing // backup tasks. - c.sessionQueue = c.nextSessionQueue() + var err error + c.sessionQueue, err = c.nextSessionQueue() + if err != nil { + c.log.Errorf("error fetching next session "+ + "queue: %v", err) + } + if c.sessionQueue != nil { c.log.Debugf("Loaded next candidate session "+ "queue id=%s", c.sessionQueue.ID()) @@ -1048,7 +1069,9 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error // newSessionQueue creates a sessionQueue from a ClientSession loaded from the // database and supplying it with the resources needed by the client. -func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue { +func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession, + updates []wtdb.CommittedUpdate) *sessionQueue { + return newSessionQueue(&sessionQueueConfig{ ClientSession: s, ChainHash: c.cfg.ChainHash, @@ -1060,28 +1083,32 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue { MinBackoff: c.cfg.MinBackoff, MaxBackoff: c.cfg.MaxBackoff, Log: c.log, - }) + }, updates) } // getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the // passed ClientSession. If it exists, the active sessionQueue is returned. // Otherwise a new sessionQueue is initialized and added to the set. -func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue { +func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession, + updates []wtdb.CommittedUpdate) *sessionQueue { + if sq, ok := c.activeSessions[s.ID]; ok { return sq } - return c.initActiveQueue(s) + return c.initActiveQueue(s, updates) } // initActiveQueue creates a new sessionQueue from the passed ClientSession, // adds the sessionQueue to the activeSessions set, and starts the sessionQueue // so that it can deliver any committed updates or begin accepting newly // assigned tasks. -func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue { +func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession, + updates []wtdb.CommittedUpdate) *sessionQueue { + // Initialize the session queue, providing it with all of the resources // it requires from the client instance. - sq := c.newSessionQueue(s) + sq := c.newSessionQueue(s, updates) // Add the session queue as an active session so that we remember to // stop it on shutdown. diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index 8b2a9ad5d..adffdca06 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -109,7 +109,9 @@ type sessionQueue struct { } // newSessionQueue intiializes a fresh sessionQueue. -func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { +func newSessionQueue(cfg *sessionQueueConfig, + updates []wtdb.CommittedUpdate) *sessionQueue { + localInit := wtwire.NewInitMessage( lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), cfg.ChainHash, @@ -137,7 +139,7 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { // The database should return them in sorted order, and session queue's // sequence number will be equal to that of the last committed update. - for _, update := range sq.cfg.ClientSession.CommittedUpdates { + for _, update := range updates { sq.commitQueue.PushBack(update) } diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index c0c7d5118..796226d94 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -420,8 +420,17 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { return ErrUninitializedDB } towerID := TowerIDFromBytes(towerIDBytes) + + committedUpdateCount := make(map[SessionID]uint16) + perCommittedUpdate := func(s *ClientSession, + _ *CommittedUpdate) { + + committedUpdateCount[s.ID]++ + } + towerSessions, err := listTowerSessions( towerID, sessions, towers, towersToSessionsIndex, + WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { return err @@ -447,7 +456,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { // have any pending updates to ensure we don't load them upon // restarts. for _, session := range towerSessions { - if len(session.CommittedUpdates) > 0 { + if committedUpdateCount[session.ID] > 0 { return ErrTowerUnackedUpdates } err := markSessionStatus( @@ -1257,12 +1266,14 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, if err != nil { return nil, err } + session.Tower = tower // Can't fail because client session body has already been read. sessionBkt := sessions.NestedReadBucket(idBytes) - // Fetch the committed updates for this session. - commitedUpdates, err := getClientSessionCommits( + // Pass the session's committed (un-acked) updates through the call-back + // if one is provided. + err = filterClientSessionCommits( sessionBkt, session, cfg.PerCommittedUpdate, ) if err != nil { @@ -1276,9 +1287,6 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, return nil, err } - session.Tower = tower - session.CommittedUpdates = commitedUpdates - return session, nil } @@ -1354,6 +1362,39 @@ func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession, return nil } +// filterClientSessionCommits retrieves all committed updates for the session +// identified by the serialized session id and passes them to the given +// PerCommittedUpdateCB callback. +func filterClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, + cb PerCommittedUpdateCB) error { + + if cb == nil { + return nil + } + + sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits) + if sessionCommits == nil { + return nil + } + + err := sessionCommits.ForEach(func(k, v []byte) error { + var committedUpdate CommittedUpdate + err := committedUpdate.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + committedUpdate.SeqNum = byteOrder.Uint16(k) + + cb(s, &committedUpdate) + return nil + }) + if err != nil { + return err + } + + return nil +} + // putClientSessionBody stores the body of the ClientSession (everything but the // CommittedUpdates and AckedUpdates). func putClientSessionBody(sessions kvdb.RwBucket, diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index 556b19937..a4d5c5ecc 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -37,16 +37,6 @@ type ClientSession struct { ClientSessionBody - // CommittedUpdates is a sorted list of unacked updates. These updates - // can be resent after a restart if the updates failed to send or - // receive an acknowledgment. - // - // NOTE: This list is serialized in it's own bucket, separate from the - // body of the ClientSession. The representation on disk is a key value - // map from sequence number to CommittedUpdateBody to allow efficient - // insertion and retrieval. - CommittedUpdates []CommittedUpdate - // Tower holds the pubkey and address of the watchtower. // // NOTE: This value is not serialized. It is recovered by looking up the diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index f0df13fad..8a47bdf7f 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -23,12 +23,13 @@ type keyIndexKey struct { type ClientDB struct { nextTowerID uint64 // to be used atomically - mu sync.Mutex - summaries map[lnwire.ChannelID]wtdb.ClientChanSummary - activeSessions map[wtdb.SessionID]wtdb.ClientSession - ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID - towerIndex map[towerPK]wtdb.TowerID - towers map[wtdb.TowerID]*wtdb.Tower + mu sync.Mutex + summaries map[lnwire.ChannelID]wtdb.ClientChanSummary + activeSessions map[wtdb.SessionID]wtdb.ClientSession + ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID + committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate + towerIndex map[towerPK]wtdb.TowerID + towers map[wtdb.TowerID]*wtdb.Tower nextIndex uint32 indexes map[keyIndexKey]uint32 @@ -38,13 +39,14 @@ type ClientDB struct { // NewClientDB initializes a new mock ClientDB. func NewClientDB() *ClientDB { return &ClientDB{ - summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), - activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), - ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID), - towerIndex: make(map[towerPK]wtdb.TowerID), - towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[keyIndexKey]uint32), - legacyIndexes: make(map[wtdb.TowerID]uint32), + summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), + activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), + ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID), + committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate), + towerIndex: make(map[towerPK]wtdb.TowerID), + towers: make(map[wtdb.TowerID]*wtdb.Tower), + indexes: make(map[keyIndexKey]uint32), + legacyIndexes: make(map[wtdb.TowerID]uint32), } } @@ -131,7 +133,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { } for id, session := range towerSessions { - if len(session.CommittedUpdates) > 0 { + if len(m.committedUpdates[session.ID]) > 0 { return wtdb.ErrTowerUnackedUpdates } session.Status = wtdb.CSessionInactive @@ -237,6 +239,13 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, cfg.PerAckedUpdate(&session, seq, id) } } + + if cfg.PerCommittedUpdate != nil { + for _, update := range m.committedUpdates[session.ID] { + update := update + cfg.PerCommittedUpdate(&session, &update) + } + } } return sessions, nil @@ -250,12 +259,12 @@ func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) ( m.mu.Lock() defer m.mu.Unlock() - sess, ok := m.activeSessions[*id] + updates, ok := m.committedUpdates[*id] if !ok { return nil, wtdb.ErrClientSessionNotFound } - return sess.CommittedUpdates, nil + return updates, nil } // CreateClientSession records a newly negotiated client session in the set of @@ -302,9 +311,9 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { Policy: session.Policy, RewardPkScript: cloneBytes(session.RewardPkScript), }, - CommittedUpdates: make([]wtdb.CommittedUpdate, 0), } m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID) + m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0) return nil } @@ -365,7 +374,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, } // Check if an update has already been committed for this state. - for _, dbUpdate := range session.CommittedUpdates { + for _, dbUpdate := range m.committedUpdates[session.ID] { if dbUpdate.SeqNum == update.SeqNum { // If the breach hint matches, we'll just return the // last applied value so the client can retransmit. @@ -384,7 +393,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, } // Save the update and increment the sequence number. - session.CommittedUpdates = append(session.CommittedUpdates, *update) + m.committedUpdates[session.ID] = append( + m.committedUpdates[session.ID], *update, + ) session.SeqNum++ m.activeSessions[*id] = session @@ -394,7 +405,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, // AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This // removes the update from the set of committed updates, and validates the // lastApplied value returned from the tower. -func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error { +func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, + lastApplied uint16) error { + m.mu.Lock() defer m.mu.Unlock() @@ -418,7 +431,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err // Retrieve the committed update, failing if none is found. We should // only receive acks for state updates that we send. - updates := session.CommittedUpdates + updates := m.committedUpdates[session.ID] for i, update := range updates { if update.SeqNum != seqNum { continue @@ -429,7 +442,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err // along with the next update. copy(updates[:i], updates[i+1:]) updates[len(updates)-1] = wtdb.CommittedUpdate{} - session.CommittedUpdates = updates[:len(updates)-1] + m.committedUpdates[session.ID] = updates[:len(updates)-1] m.ackedUpdates[*id][seqNum] = update.BackupID session.TowerLastApplied = lastApplied From e6eed5682722caccf62acc1f36b07bb9630b9043 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 13 Oct 2022 14:33:17 +0200 Subject: [PATCH 6/6] docs: update release notes for #6928 --- docs/release-notes/release-notes-0.16.0.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/release-notes/release-notes-0.16.0.md b/docs/release-notes/release-notes-0.16.0.md index 5bdcce5b3..7a48142be 100644 --- a/docs/release-notes/release-notes-0.16.0.md +++ b/docs/release-notes/release-notes-0.16.0.md @@ -111,6 +111,10 @@ crash](https://github.com/lightningnetwork/lnd/pull/7019). closer coupling of Towers and Sessions and ensures that a session cannot be added if the tower it is referring to does not exist. +* [Remove `AckedUpdates` & `CommittedUpdates` from the `ClientSession` + struct](https://github.com/lightningnetwork/lnd/pull/6928) in order to + improve the performance of fetching a `ClientSession` from the DB. + * [Create a helper function to wait for peer to come online](https://github.com/lightningnetwork/lnd/pull/6931).