diff --git a/docs/release-notes/release-notes-0.16.0.md b/docs/release-notes/release-notes-0.16.0.md index 1f97dfed6..2f8f59a0a 100644 --- a/docs/release-notes/release-notes-0.16.0.md +++ b/docs/release-notes/release-notes-0.16.0.md @@ -111,6 +111,10 @@ crash](https://github.com/lightningnetwork/lnd/pull/7019). closer coupling of Towers and Sessions and ensures that a session cannot be added if the tower it is referring to does not exist. +* [Remove `AckedUpdates` & `CommittedUpdates` from the `ClientSession` + struct](https://github.com/lightningnetwork/lnd/pull/6928) in order to + improve the performance of fetching a `ClientSession` from the DB. + * [Create a helper function to wait for peer to come online](https://github.com/lightningnetwork/lnd/pull/6931). diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 5864b6608..1d9d41849 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -265,12 +265,16 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context, return nil, err } - anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers() + opts, ackCounts, committedUpdateCounts := constructFunctionalOptions( + req.IncludeSessions, + ) + + anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers(opts...) if err != nil { return nil, err } - legacyTowers, err := c.cfg.Client.RegisteredTowers() + legacyTowers, err := c.cfg.Client.RegisteredTowers(opts...) if err != nil { return nil, err } @@ -286,7 +290,10 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context, rpcTowers := make([]*Tower, 0, len(towers)) for _, tower := range towers { - rpcTower := marshallTower(tower, req.IncludeSessions) + rpcTower := marshallTower( + tower, req.IncludeSessions, ackCounts, + committedUpdateCounts, + ) rpcTowers = append(rpcTowers, rpcTower) } @@ -306,16 +313,59 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context, return nil, err } + opts, ackCounts, committedUpdateCounts := constructFunctionalOptions( + req.IncludeSessions, + ) + var tower *wtclient.RegisteredTower - tower, err = c.cfg.Client.LookupTower(pubKey) + tower, err = c.cfg.Client.LookupTower(pubKey, opts...) if err == wtdb.ErrTowerNotFound { - tower, err = c.cfg.AnchorClient.LookupTower(pubKey) + tower, err = c.cfg.AnchorClient.LookupTower(pubKey, opts...) } if err != nil { return nil, err } - return marshallTower(tower, req.IncludeSessions), nil + return marshallTower( + tower, req.IncludeSessions, ackCounts, committedUpdateCounts, + ), nil +} + +// constructFunctionalOptions is a helper function that constructs a list of +// functional options to be used when fetching a tower from the DB. It also +// returns a map of acked-update counts and one for un-acked-update counts that +// will be populated once the db call has been made. +func constructFunctionalOptions(includeSessions bool) ( + []wtdb.ClientSessionListOption, map[wtdb.SessionID]uint16, + map[wtdb.SessionID]uint16) { + + var ( + opts []wtdb.ClientSessionListOption + ackCounts = make(map[wtdb.SessionID]uint16) + committedUpdateCounts = make(map[wtdb.SessionID]uint16) + ) + if !includeSessions { + return opts, ackCounts, committedUpdateCounts + } + + perAckedUpdate := func(s *wtdb.ClientSession, _ uint16, + _ wtdb.BackupID) { + + ackCounts[s.ID]++ + } + + perCommittedUpdate := func(s *wtdb.ClientSession, + _ *wtdb.CommittedUpdate) { + + committedUpdateCounts[s.ID]++ + } + + opts = []wtdb.ClientSessionListOption{ + wtdb.WithPerAckedUpdate(perAckedUpdate), + wtdb.WithPerCommittedUpdate(perCommittedUpdate), + } + + return opts, ackCounts, committedUpdateCounts } // Stats returns the in-memory statistics of the client since startup. @@ -387,7 +437,9 @@ func (c *WatchtowerClient) Policy(ctx context.Context, // marshallTower converts a client registered watchtower into its corresponding // RPC type. -func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower { +func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool, + ackCounts, pendingCounts map[wtdb.SessionID]uint16) *Tower { + rpcAddrs := make([]string, 0, len(tower.Addresses)) for _, addr := range tower.Addresses { rpcAddrs = append(rpcAddrs, addr.String()) @@ -399,8 +451,8 @@ func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower for _, session := range tower.Sessions { satPerVByte := session.Policy.SweepFeeRate.FeePerKVByte() / 1000 rpcSessions = append(rpcSessions, &TowerSession{ - NumBackups: uint32(len(session.AckedUpdates)), - NumPendingBackups: uint32(len(session.CommittedUpdates)), + NumBackups: uint32(ackCounts[session.ID]), + NumPendingBackups: uint32(pendingCounts[session.ID]), MaxBackups: uint32(session.Policy.MaxUpdates), SweepSatPerVbyte: uint32(satPerVByte), diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 3d23f0b82..35c0ef22b 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -83,10 +83,12 @@ type Client interface { // RegisteredTowers retrieves the list of watchtowers registered with // the client. - RegisteredTowers() ([]*RegisteredTower, error) + RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower, + error) // LookupTower retrieves a registered watchtower through its public key. - LookupTower(*btcec.PublicKey) (*RegisteredTower, error) + LookupTower(*btcec.PublicKey, + ...wtdb.ClientSessionListOption) (*RegisteredTower, error) // Stats returns the in-memory statistics of the client since startup. Stats() ClientStats @@ -287,12 +289,67 @@ func New(config *Config) (*TowerClient, error) { } plog := build.NewPrefixLog(prefix, log) - // Next, load all candidate towers and sessions from the database into - // the client. We will use any of these sessions if their policies match - // the current policy of the client, otherwise they will be ignored and - // new sessions will be requested. + // Load the sweep pkscripts that have been generated for all previously + // registered channels. + chanSummaries, err := cfg.DB.FetchChanSummaries() + if err != nil { + return nil, err + } + + c := &TowerClient{ + cfg: cfg, + log: plog, + pipeline: newTaskPipeline(plog), + chanCommitHeights: make(map[lnwire.ChannelID]uint64), + activeSessions: make(sessionQueueSet), + summaries: chanSummaries, + statTicker: time.NewTicker(DefaultStatInterval), + stats: new(ClientStats), + newTowers: make(chan *newTowerMsg), + staleTowers: make(chan *staleTowerMsg), + forceQuit: make(chan struct{}), + } + + // perUpdate is a callback function that will be used to inspect the + // full set of candidate client sessions loaded from disk, and to + // determine the highest known commit height for each channel. This + // allows the client to reject backups that it has already processed for + // its active policy. + perUpdate := func(policy wtpolicy.Policy, id wtdb.BackupID) { + // We only want to consider accepted updates that have been + // accepted under an identical policy to the client's current + // policy. + if policy != c.cfg.Policy { + return + } + + // Take the highest commit height found in the session's acked + // updates. + height, ok := c.chanCommitHeights[id.ChanID] + if !ok || id.CommitHeight > height { + c.chanCommitHeights[id.ChanID] = id.CommitHeight + } + } + + perAckedUpdate := func(s *wtdb.ClientSession, _ uint16, + id wtdb.BackupID) { + + perUpdate(s.Policy, id) + } + + perCommittedUpdate := func(s *wtdb.ClientSession, + u *wtdb.CommittedUpdate) { + + perUpdate(s.Policy, u.BackupID) + } + + // Load all candidate sessions and towers from the database into the + // client. We will use any of these sessions if their policies match the + // current policy of the client, otherwise they will be ignored and new + // sessions will be requested. isAnchorClient := cfg.Policy.IsAnchorChannel() activeSessionFilter := genActiveSessionFilter(isAnchorClient) + candidateTowers := newTowerListIterator() perActiveTower := func(tower *wtdb.Tower) { // If the tower has already been marked as active, then there is @@ -307,34 +364,19 @@ func New(config *Config) (*TowerClient, error) { // Add the tower to the set of candidate towers. candidateTowers.AddCandidate(tower) } + candidateSessions, err := getTowerAndSessionCandidates( cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, + wtdb.WithPerAckedUpdate(perAckedUpdate), + wtdb.WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { return nil, err } - // Load the sweep pkscripts that have been generated for all previously - // registered channels. - chanSummaries, err := cfg.DB.FetchChanSummaries() - if err != nil { - return nil, err - } + c.candidateTowers = candidateTowers + c.candidateSessions = candidateSessions - c := &TowerClient{ - cfg: cfg, - log: plog, - pipeline: newTaskPipeline(plog), - candidateTowers: candidateTowers, - candidateSessions: candidateSessions, - activeSessions: make(sessionQueueSet), - summaries: chanSummaries, - statTicker: time.NewTicker(DefaultStatInterval), - stats: new(ClientStats), - newTowers: make(chan *newTowerMsg), - staleTowers: make(chan *staleTowerMsg), - forceQuit: make(chan struct{}), - } c.negotiator = newSessionNegotiator(&NegotiatorConfig{ DB: cfg.DB, SecretKeyRing: cfg.SecretKeyRing, @@ -349,10 +391,6 @@ func New(config *Config) (*TowerClient, error) { Log: plog, }) - // Reconstruct the highest commit height processed for each channel - // under the client's current policy. - c.buildHighestCommitHeights() - return c, nil } @@ -363,7 +401,8 @@ func New(config *Config) (*TowerClient, error) { // tower. func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, sessionFilter func(*wtdb.ClientSession) bool, - perActiveTower func(tower *wtdb.Tower)) ( + perActiveTower func(tower *wtdb.Tower), + opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { towers, err := db.ListTowers() @@ -373,7 +412,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, tower := range towers { - sessions, err := db.ListClientSessions(&tower.ID) + sessions, err := db.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err } @@ -413,10 +452,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, // ClientSession's SessionPrivKey field is desired, otherwise, the existing // ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, - passesFilter func(*wtdb.ClientSession) bool) ( + passesFilter func(*wtdb.ClientSession) bool, + opts ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) { - sessions, err := db.ListClientSessions(forTower) + sessions, err := db.ListClientSessions(forTower, opts...) if err != nil { return nil, err } @@ -446,48 +486,10 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, return sessions, nil } -// buildHighestCommitHeights inspects the full set of candidate client sessions -// loaded from disk, and determines the highest known commit height for each -// channel. This allows the client to reject backups that it has already -// processed for it's active policy. -func (c *TowerClient) buildHighestCommitHeights() { - chanCommitHeights := make(map[lnwire.ChannelID]uint64) - for _, s := range c.candidateSessions { - // We only want to consider accepted updates that have been - // accepted under an identical policy to the client's current - // policy. - if s.Policy != c.cfg.Policy { - continue - } - - // Take the highest commit height found in the session's - // committed updates. - for _, committedUpdate := range s.CommittedUpdates { - bid := committedUpdate.BackupID - - height, ok := chanCommitHeights[bid.ChanID] - if !ok || bid.CommitHeight > height { - chanCommitHeights[bid.ChanID] = bid.CommitHeight - } - } - - // Take the heights commit height found in the session's acked - // updates. - for _, bid := range s.AckedUpdates { - height, ok := chanCommitHeights[bid.ChanID] - if !ok || bid.CommitHeight > height { - chanCommitHeights[bid.ChanID] = bid.CommitHeight - } - } - } - - c.chanCommitHeights = chanCommitHeights -} - // Start initializes the watchtower client by loading or negotiating an active // session and then begins processing backup tasks from the request pipeline. func (c *TowerClient) Start() error { - var err error + var returnErr error c.started.Do(func() { c.log.Infof("Watchtower client starting") @@ -496,19 +498,27 @@ func (c *TowerClient) Start() error { // sessions will be able to flush the committed updates after a // restart. for _, session := range c.candidateSessions { - if len(session.CommittedUpdates) > 0 { + committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID) + if err != nil { + returnErr = err + return + } + + if len(committedUpdates) > 0 { c.log.Infof("Starting session=%s to process "+ "%d committed backups", session.ID, - len(session.CommittedUpdates)) - c.initActiveQueue(session) + len(committedUpdates)) + + c.initActiveQueue(session, committedUpdates) } } // Now start the session negotiator, which will allow us to // request new session as soon as the backupDispatcher starts // up. - err = c.negotiator.Start() + err := c.negotiator.Start() if err != nil { + returnErr = err return } @@ -521,7 +531,7 @@ func (c *TowerClient) Start() error { c.log.Infof("Watchtower client started successfully") }) - return err + return returnErr } // Stop idempotently initiates a graceful shutdown of the watchtower client. @@ -697,7 +707,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, // active client's advertised policy will be ignored, but may be resumed if the // client is restarted with a matching policy. If no candidates were found, nil // is returned to signal that we need to request a new policy. -func (c *TowerClient) nextSessionQueue() *sessionQueue { +func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { // Select any candidate session at random, and remove it from the set of // candidate sessions. var candidateSession *wtdb.ClientSession @@ -719,13 +729,20 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue { // If none of the sessions could be used or none were found, we'll // return nil to signal that we need another session to be negotiated. if candidateSession == nil { - return nil + return nil, nil + } + + updates, err := c.cfg.DB.FetchSessionCommittedUpdates( + &candidateSession.ID, + ) + if err != nil { + return nil, err } // Initialize the session queue and spin it up so it can begin handling // updates. If the queue was already made active on startup, this will // simply return the existing session queue from the set. - return c.getOrInitActiveQueue(candidateSession) + return c.getOrInitActiveQueue(candidateSession, updates), nil } // backupDispatcher processes events coming from the taskPipeline and is @@ -798,7 +815,13 @@ func (c *TowerClient) backupDispatcher() { // We've exhausted the prior session, we'll pop another // from the remaining sessions and continue processing // backup tasks. - c.sessionQueue = c.nextSessionQueue() + var err error + c.sessionQueue, err = c.nextSessionQueue() + if err != nil { + c.log.Errorf("error fetching next session "+ + "queue: %v", err) + } + if c.sessionQueue != nil { c.log.Debugf("Loaded next candidate session "+ "queue id=%s", c.sessionQueue.ID()) @@ -1046,7 +1069,9 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error // newSessionQueue creates a sessionQueue from a ClientSession loaded from the // database and supplying it with the resources needed by the client. -func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue { +func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession, + updates []wtdb.CommittedUpdate) *sessionQueue { + return newSessionQueue(&sessionQueueConfig{ ClientSession: s, ChainHash: c.cfg.ChainHash, @@ -1058,28 +1083,32 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue { MinBackoff: c.cfg.MinBackoff, MaxBackoff: c.cfg.MaxBackoff, Log: c.log, - }) + }, updates) } // getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the // passed ClientSession. If it exists, the active sessionQueue is returned. // Otherwise a new sessionQueue is initialized and added to the set. -func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue { +func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession, + updates []wtdb.CommittedUpdate) *sessionQueue { + if sq, ok := c.activeSessions[s.ID]; ok { return sq } - return c.initActiveQueue(s) + return c.initActiveQueue(s, updates) } // initActiveQueue creates a new sessionQueue from the passed ClientSession, // adds the sessionQueue to the activeSessions set, and starts the sessionQueue // so that it can deliver any committed updates or begin accepting newly // assigned tasks. -func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue { +func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession, + updates []wtdb.CommittedUpdate) *sessionQueue { + // Initialize the session queue, providing it with all of the resources // it requires from the client instance. - sq := c.newSessionQueue(s) + sq := c.newSessionQueue(s, updates) // Add the session queue as an active session so that we remember to // stop it on shutdown. @@ -1233,13 +1262,15 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // RegisteredTowers retrieves the list of watchtowers registered with the // client. -func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) { +func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( + []*RegisteredTower, error) { + // Retrieve all of our towers along with all of our sessions. towers, err := c.cfg.DB.ListTowers() if err != nil { return nil, err } - clientSessions, err := c.cfg.DB.ListClientSessions(nil) + clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...) if err != nil { return nil, err } @@ -1272,13 +1303,15 @@ func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) { } // LookupTower retrieves a registered watchtower through its public key. -func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey) (*RegisteredTower, error) { +func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey, + opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) { + tower, err := c.cfg.DB.LoadTower(pubKey) if err != nil { return nil, err } - towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID) + towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err } diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index dbf2faf71..c67d7eac3 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -62,9 +62,14 @@ type DB interface { // still be able to accept state updates. An optional tower ID can be // used to filter out any client sessions in the response that do not // correspond to this tower. - ListClientSessions(*wtdb.TowerID) ( + ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) + // FetchSessionCommittedUpdates retrieves the current set of un-acked + // updates of the given session. + FetchSessionCommittedUpdates(id *wtdb.SessionID) ( + []wtdb.CommittedUpdate, error) + // FetchChanSummaries loads a mapping from all registered channels to // their channel summaries. FetchChanSummaries() (wtdb.ChannelSummaries, error) diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index 8b2a9ad5d..adffdca06 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -109,7 +109,9 @@ type sessionQueue struct { } // newSessionQueue intiializes a fresh sessionQueue. -func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { +func newSessionQueue(cfg *sessionQueueConfig, + updates []wtdb.CommittedUpdate) *sessionQueue { + localInit := wtwire.NewInitMessage( lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), cfg.ChainHash, @@ -137,7 +139,7 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { // The database should return them in sorted order, and session queue's // sequence number will be equal to that of the last committed update. - for _, update := range sq.cfg.ClientSession.CommittedUpdates { + for _, update := range updates { sq.commitQueue.PushBack(update) } diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 3cb5a8c70..796226d94 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -420,8 +420,17 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { return ErrUninitializedDB } towerID := TowerIDFromBytes(towerIDBytes) + + committedUpdateCount := make(map[SessionID]uint16) + perCommittedUpdate := func(s *ClientSession, + _ *CommittedUpdate) { + + committedUpdateCount[s.ID]++ + } + towerSessions, err := listTowerSessions( towerID, sessions, towers, towersToSessionsIndex, + WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { return err @@ -447,7 +456,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { // have any pending updates to ensure we don't load them upon // restarts. for _, session := range towerSessions { - if len(session.CommittedUpdates) > 0 { + if committedUpdateCount[session.ID] > 0 { return ErrTowerUnackedUpdates } err := markSessionStatus( @@ -736,8 +745,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID, // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (c *ClientDB) ListClientSessions(id *TowerID) ( - map[SessionID]*ClientSession, error) { +func (c *ClientDB) ListClientSessions(id *TowerID, + opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { var clientSessions map[SessionID]*ClientSession err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -757,7 +766,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( // known to the db. if id == nil { clientSessions, err = listClientAllSessions( - sessions, towers, + sessions, towers, opts..., ) return err } @@ -769,7 +778,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( } clientSessions, err = listTowerSessions( - *id, sessions, towers, towerToSessionIndex, + *id, sessions, towers, towerToSessionIndex, opts..., ) return err }, func() { @@ -783,8 +792,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) ( } // listClientAllSessions returns the set of all client sessions known to the db. -func listClientAllSessions(sessions, - towers kvdb.RBucket) (map[SessionID]*ClientSession, error) { +func listClientAllSessions(sessions, towers kvdb.RBucket, + opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) err := sessions.ForEach(func(k, _ []byte) error { @@ -792,7 +801,7 @@ func listClientAllSessions(sessions, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessions, towers, k) + session, err := getClientSession(sessions, towers, k, opts...) if err != nil { return err } @@ -811,8 +820,8 @@ func listClientAllSessions(sessions, // listTowerSessions returns the set of all client sessions known to the db // that are associated with the given tower id. func listTowerSessions(id TowerID, sessionsBkt, towersBkt, - towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession, - error) { + towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) ( + map[SessionID]*ClientSession, error) { towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes()) if towerIndexBkt == nil { @@ -825,7 +834,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessionsBkt, towersBkt, k) + session, err := getClientSession( + sessionsBkt, towersBkt, k, opts..., + ) if err != nil { return err } @@ -840,6 +851,36 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt, return clientSessions, nil } +// FetchSessionCommittedUpdates retrieves the current set of un-acked updates +// of the given session. +func (c *ClientDB) FetchSessionCommittedUpdates(id *SessionID) ( + []CommittedUpdate, error) { + + var committedUpdates []CommittedUpdate + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + sessions := tx.ReadBucket(cSessionBkt) + if sessions == nil { + return ErrUninitializedDB + } + + sessionBkt := sessions.NestedReadBucket(id[:]) + if sessionBkt == nil { + return ErrClientSessionNotFound + } + + var err error + committedUpdates, err = getClientSessionCommits( + sessionBkt, nil, nil, + ) + return err + }, func() {}) + if err != nil { + return nil, err + } + + return committedUpdates, nil +} + // FetchChanSummaries loads a mapping from all registered channels to their // channel summaries. func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { @@ -1157,11 +1198,63 @@ func getClientSessionBody(sessions kvdb.RBucket, return &session, nil } +// PerAckedUpdateCB describes the signature of a callback function that can be +// called for each of a session's acked updates. +type PerAckedUpdateCB func(*ClientSession, uint16, BackupID) + +// PerCommittedUpdateCB describes the signature of a callback function that can +// be called for each of a session's committed updates (updates that the client +// has not yet received an ACK for). +type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate) + +// ClientSessionListOption describes the signature of a functional option that +// can be used when listing client sessions in order to provide any extra +// instruction to the query. +type ClientSessionListOption func(cfg *ClientSessionListCfg) + +// ClientSessionListCfg defines various query parameters that will be used when +// querying the DB for client sessions. +type ClientSessionListCfg struct { + // PerAckedUpdate will, if set, be called for each of the session's + // acked updates. + PerAckedUpdate PerAckedUpdateCB + + // PerCommittedUpdate will, if set, be called for each of the session's + // committed (un-acked) updates. + PerCommittedUpdate PerCommittedUpdateCB +} + +// NewClientSessionCfg constructs a new ClientSessionListCfg. +func NewClientSessionCfg() *ClientSessionListCfg { + return &ClientSessionListCfg{} +} + +// WithPerAckedUpdate constructs a functional option that will set a call-back +// function to be called for each of a client's acked updates. +func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerAckedUpdate = cb + } +} + +// WithPerCommittedUpdate constructs a functional option that will set a +// call-back function to be called for each of a client's un-acked updates. +func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerCommittedUpdate = cb + } +} + // getClientSession loads the full ClientSession associated with the serialized // session id. This method populates the CommittedUpdates, AckUpdates and Tower // in addition to the ClientSession's body. -func getClientSession(sessions, towers kvdb.RBucket, - idBytes []byte) (*ClientSession, error) { +func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, + opts ...ClientSessionListOption) (*ClientSession, error) { + + cfg := NewClientSessionCfg() + for _, o := range opts { + o(cfg) + } session, err := getClientSessionBody(sessions, idBytes) if err != nil { @@ -1173,35 +1266,37 @@ func getClientSession(sessions, towers kvdb.RBucket, if err != nil { return nil, err } - - // Fetch the committed updates for this session. - commitedUpdates, err := getClientSessionCommits(sessions, idBytes) - if err != nil { - return nil, err - } - - // Fetch the acked updates for this session. - ackedUpdates, err := getClientSessionAcks(sessions, idBytes) - if err != nil { - return nil, err - } - session.Tower = tower - session.CommittedUpdates = commitedUpdates - session.AckedUpdates = ackedUpdates + + // Can't fail because client session body has already been read. + sessionBkt := sessions.NestedReadBucket(idBytes) + + // Pass the session's committed (un-acked) updates through the call-back + // if one is provided. + err = filterClientSessionCommits( + sessionBkt, session, cfg.PerCommittedUpdate, + ) + if err != nil { + return nil, err + } + + // Pass the session's acked updates through the call-back if one is + // provided. + err = filterClientSessionAcks(sessionBkt, session, cfg.PerAckedUpdate) + if err != nil { + return nil, err + } return session, nil } // getClientSessionCommits retrieves all committed updates for the session -// identified by the serialized session id. -func getClientSessionCommits(sessions kvdb.RBucket, - idBytes []byte) ([]CommittedUpdate, error) { +// identified by the serialized session id. If a PerCommittedUpdateCB is +// provided, then it will be called for each of the session's committed updates. +func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, + cb PerCommittedUpdateCB) ([]CommittedUpdate, error) { - // Can't fail because client session body has already been read. - sessionBkt := sessions.NestedReadBucket(idBytes) - - // Initialize commitedUpdates so that we can return an initialized map + // Initialize committedUpdates so that we can return an initialized map // if no committed updates exist. committedUpdates := make([]CommittedUpdate, 0) @@ -1220,6 +1315,10 @@ func getClientSessionCommits(sessions kvdb.RBucket, committedUpdates = append(committedUpdates, committedUpdate) + if cb != nil { + cb(s, &committedUpdate) + } + return nil }) if err != nil { @@ -1229,21 +1328,19 @@ func getClientSessionCommits(sessions kvdb.RBucket, return committedUpdates, nil } -// getClientSessionAcks retrieves all acked updates for the session identified -// by the serialized session id. -func getClientSessionAcks(sessions kvdb.RBucket, - idBytes []byte) (map[uint16]BackupID, error) { +// filterClientSessionAcks retrieves all acked updates for the session +// identified by the serialized session id and passes them to the provided +// call back if one is provided. +func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession, + cb PerAckedUpdateCB) error { - // Can't fail because client session body has already been read. - sessionBkt := sessions.NestedReadBucket(idBytes) - - // Initialize ackedUpdates so that we can return an initialized map if - // no acked updates exist. - ackedUpdates := make(map[uint16]BackupID) + if cb == nil { + return nil + } sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks) if sessionAcks == nil { - return ackedUpdates, nil + return nil } err := sessionAcks.ForEach(func(k, v []byte) error { @@ -1255,15 +1352,47 @@ func getClientSessionAcks(sessions kvdb.RBucket, return err } - ackedUpdates[seqNum] = backupID - + cb(s, seqNum, backupID) return nil }) if err != nil { - return nil, err + return err } - return ackedUpdates, nil + return nil +} + +// filterClientSessionCommits retrieves all committed updates for the session +// identified by the serialized session id and passes them to the given +// PerCommittedUpdateCB callback. +func filterClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, + cb PerCommittedUpdateCB) error { + + if cb == nil { + return nil + } + + sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits) + if sessionCommits == nil { + return nil + } + + err := sessionCommits.ForEach(func(k, v []byte) error { + var committedUpdate CommittedUpdate + err := committedUpdate.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + committedUpdate.SeqNum = byteOrder.Uint16(k) + + cb(s, &committedUpdate) + return nil + }) + if err != nil { + return err + } + + return nil } // putClientSessionBody stores the body of the ClientSession (everything but the diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index d4f1699c9..aa30cc713 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -48,12 +48,12 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, require.ErrorIs(h.t, err, expErr) } -func (h *clientDBHarness) listSessions( - id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { +func (h *clientDBHarness) listSessions(id *wtdb.TowerID, + opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession { h.t.Helper() - sessions, err := h.db.ListClientSessions(id) + sessions, err := h.db.ListClientSessions(id, opts...) require.NoError(h.t, err, "unable to list client sessions") return sessions @@ -207,6 +207,20 @@ func (h *clientDBHarness) newTower() *wtdb.Tower { }, nil) } +func (h *clientDBHarness) fetchSessionCommittedUpdates(id *wtdb.SessionID, + expErr error) []wtdb.CommittedUpdate { + + h.t.Helper() + + updates, err := h.db.FetchSessionCommittedUpdates(id) + if err != expErr { + h.t.Fatalf("expected fetch session committed updates error: "+ + "%v, got: %v", expErr, err) + } + + return updates +} + // testCreateClientSession asserts various conditions regarding the creation of // a new ClientSession. The test asserts: // - client sessions can only be created if a session key index is reserved. @@ -506,6 +520,9 @@ func testCommitUpdate(h *clientDBHarness) { // session, which should fail. update1 := randCommittedUpdate(h.t, 1) h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound) + h.fetchSessionCommittedUpdates( + &session.ID, wtdb.ErrClientSessionNotFound, + ) // Reserve a session key index and insert the session. session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType) @@ -520,11 +537,7 @@ func testCommitUpdate(h *clientDBHarness) { // Assert that the committed update appears in the client session's // CommittedUpdates map when loaded from disk and that there are no // AckedUpdates. - dbSession := h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ - *update1, - }) - checkAckedUpdates(h.t, dbSession, nil) + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil) // Try to commit the same update, which should succeed due to // idempotency (which is preserved when the breach hint is identical to @@ -534,11 +547,7 @@ func testCommitUpdate(h *clientDBHarness) { require.Equal(h.t, lastApplied, lastApplied2) // Assert that the loaded ClientSession is the same as before. - dbSession = h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ - *update1, - }) - checkAckedUpdates(h.t, dbSession, nil) + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil) // Generate another random update and try to commit it at the identical // sequence number. Since the breach hint has changed, this should fail. @@ -553,12 +562,10 @@ func testCommitUpdate(h *clientDBHarness) { // Check that both updates now appear as committed on the ClientSession // loaded from disk. - dbSession = h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{ *update1, *update2, - }) - checkAckedUpdates(h.t, dbSession, nil) + }, nil) // Finally, create one more random update and try to commit it at index // 4, which should be rejected since 3 is the next slot the database @@ -567,12 +574,20 @@ func testCommitUpdate(h *clientDBHarness) { h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate) // Assert that the ClientSession loaded from disk remains unchanged. - dbSession = h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{ *update1, *update2, - }) - checkAckedUpdates(h.t, dbSession, nil) + }, nil) +} + +func perAckedUpdate(updates map[uint16]wtdb.BackupID) func( + _ *wtdb.ClientSession, seq uint16, id wtdb.BackupID) { + + return func(_ *wtdb.ClientSession, seq uint16, + id wtdb.BackupID) { + + updates[seq] = id + } } // testAckUpdate asserts the behavior of AckUpdate. @@ -628,9 +643,7 @@ func testAckUpdate(h *clientDBHarness) { // Assert that the ClientSession loaded from disk has one update in it's // AckedUpdates map, and that the committed update has been removed. - dbSession := h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, nil) - checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ + h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{ 1: update1.BackupID, }) @@ -645,9 +658,7 @@ func testAckUpdate(h *clientDBHarness) { h.ackUpdate(&session.ID, 2, 2, nil) // Assert that both updates exist as AckedUpdates when loaded from disk. - dbSession = h.listSessions(nil)[session.ID] - checkCommittedUpdates(h.t, dbSession, nil) - checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ + h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{ 1: update1.BackupID, 2: update2.BackupID, }) @@ -663,9 +674,22 @@ func testAckUpdate(h *clientDBHarness) { h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied) } +func (h *clientDBHarness) assertUpdates(id wtdb.SessionID, + expectedPending []wtdb.CommittedUpdate, + expectedAcked map[uint16]wtdb.BackupID) { + + ackedUpdates := make(map[uint16]wtdb.BackupID) + _ = h.listSessions( + nil, wtdb.WithPerAckedUpdate(perAckedUpdate(ackedUpdates)), + ) + committedUpates := h.fetchSessionCommittedUpdates(&id, nil) + checkCommittedUpdates(h.t, committedUpates, expectedPending) + checkAckedUpdates(h.t, ackedUpdates, expectedAcked) +} + // checkCommittedUpdates asserts that the CommittedUpdates on session match the // expUpdates provided. -func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, +func checkCommittedUpdates(t *testing.T, actualUpdates, expUpdates []wtdb.CommittedUpdate) { t.Helper() @@ -677,12 +701,12 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make([]wtdb.CommittedUpdate, 0) } - require.Equal(t, expUpdates, session.CommittedUpdates) + require.Equal(t, expUpdates, actualUpdates) } // checkAckedUpdates asserts that the AckedUpdates on a session match the // expUpdates provided. -func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, +func checkAckedUpdates(t *testing.T, actualUpdates, expUpdates map[uint16]wtdb.BackupID) { // We promote nil expUpdates to an initialized map since the database @@ -692,7 +716,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, expUpdates = make(map[uint16]wtdb.BackupID) } - require.Equal(t, expUpdates, session.AckedUpdates) + require.Equal(t, expUpdates, actualUpdates) } // TestClientDB asserts the behavior of a fresh client db, a reopened client db, diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index 1d3e01f5a..a4d5c5ecc 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -37,23 +37,6 @@ type ClientSession struct { ClientSessionBody - // CommittedUpdates is a sorted list of unacked updates. These updates - // can be resent after a restart if the updates failed to send or - // receive an acknowledgment. - // - // NOTE: This list is serialized in it's own bucket, separate from the - // body of the ClientSession. The representation on disk is a key value - // map from sequence number to CommittedUpdateBody to allow efficient - // insertion and retrieval. - CommittedUpdates []CommittedUpdate - - // AckedUpdates is a map from sequence number to backup id to record - // which revoked states were uploaded via this session. - // - // NOTE: This map is serialized in it's own bucket, separate from the - // body of the ClientSession. - AckedUpdates map[uint16]BackupID - // Tower holds the pubkey and address of the watchtower. // // NOTE: This value is not serialized. It is recovered by looking up the diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 2a3825e87..8a47bdf7f 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -23,11 +23,13 @@ type keyIndexKey struct { type ClientDB struct { nextTowerID uint64 // to be used atomically - mu sync.Mutex - summaries map[lnwire.ChannelID]wtdb.ClientChanSummary - activeSessions map[wtdb.SessionID]wtdb.ClientSession - towerIndex map[towerPK]wtdb.TowerID - towers map[wtdb.TowerID]*wtdb.Tower + mu sync.Mutex + summaries map[lnwire.ChannelID]wtdb.ClientChanSummary + activeSessions map[wtdb.SessionID]wtdb.ClientSession + ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID + committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate + towerIndex map[towerPK]wtdb.TowerID + towers map[wtdb.TowerID]*wtdb.Tower nextIndex uint32 indexes map[keyIndexKey]uint32 @@ -37,12 +39,14 @@ type ClientDB struct { // NewClientDB initializes a new mock ClientDB. func NewClientDB() *ClientDB { return &ClientDB{ - summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), - activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), - towerIndex: make(map[towerPK]wtdb.TowerID), - towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[keyIndexKey]uint32), - legacyIndexes: make(map[wtdb.TowerID]uint32), + summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), + activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), + ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID), + committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate), + towerIndex: make(map[towerPK]wtdb.TowerID), + towers: make(map[wtdb.TowerID]*wtdb.Tower), + indexes: make(map[keyIndexKey]uint32), + legacyIndexes: make(map[wtdb.TowerID]uint32), } } @@ -75,7 +79,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { } else { towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) tower = &wtdb.Tower{ - ID: wtdb.TowerID(towerID), + ID: towerID, IdentityKey: lnAddr.IdentityKey, Addresses: []net.Addr{lnAddr.Address}, } @@ -129,7 +133,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { } for id, session := range towerSessions { - if len(session.CommittedUpdates) > 0 { + if len(m.committedUpdates[session.ID]) > 0 { return wtdb.ErrTowerUnackedUpdates } session.Status = wtdb.CSessionInactive @@ -193,26 +197,33 @@ func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) { // MarkBackupIneligible records that particular commit height is ineligible for // backup. This allows the client to track which updates it should not attempt // to retry after startup. -func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error { +func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error { return nil } // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (m *ClientDB) ListClientSessions( - tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { +func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID, + opts ...wtdb.ClientSessionListOption) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { m.mu.Lock() defer m.mu.Unlock() - return m.listClientSessions(tower) + return m.listClientSessions(tower, opts...) } // listClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. -func (m *ClientDB) listClientSessions( - tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { +func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, + opts ...wtdb.ClientSessionListOption) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { + + cfg := wtdb.NewClientSessionCfg() + for _, o := range opts { + o(cfg) + } sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, session := range m.activeSessions { @@ -222,11 +233,40 @@ func (m *ClientDB) listClientSessions( } session.Tower = m.towers[session.TowerID] sessions[session.ID] = &session + + if cfg.PerAckedUpdate != nil { + for seq, id := range m.ackedUpdates[session.ID] { + cfg.PerAckedUpdate(&session, seq, id) + } + } + + if cfg.PerCommittedUpdate != nil { + for _, update := range m.committedUpdates[session.ID] { + update := update + cfg.PerCommittedUpdate(&session, &update) + } + } } return sessions, nil } +// FetchSessionCommittedUpdates retrieves the current set of un-acked updates +// of the given session. +func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) ( + []wtdb.CommittedUpdate, error) { + + m.mu.Lock() + defer m.mu.Unlock() + + updates, ok := m.committedUpdates[*id] + if !ok { + return nil, wtdb.ErrClientSessionNotFound + } + + return updates, nil +} + // CreateClientSession records a newly negotiated client session in the set of // active sessions. The session can be identified by its SessionID. func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { @@ -271,9 +311,9 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { Policy: session.Policy, RewardPkScript: cloneBytes(session.RewardPkScript), }, - CommittedUpdates: make([]wtdb.CommittedUpdate, 0), - AckedUpdates: make(map[uint16]wtdb.BackupID), } + m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID) + m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0) return nil } @@ -334,7 +374,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, } // Check if an update has already been committed for this state. - for _, dbUpdate := range session.CommittedUpdates { + for _, dbUpdate := range m.committedUpdates[session.ID] { if dbUpdate.SeqNum == update.SeqNum { // If the breach hint matches, we'll just return the // last applied value so the client can retransmit. @@ -353,7 +393,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, } // Save the update and increment the sequence number. - session.CommittedUpdates = append(session.CommittedUpdates, *update) + m.committedUpdates[session.ID] = append( + m.committedUpdates[session.ID], *update, + ) session.SeqNum++ m.activeSessions[*id] = session @@ -363,7 +405,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, // AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This // removes the update from the set of committed updates, and validates the // lastApplied value returned from the tower. -func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error { +func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, + lastApplied uint16) error { + m.mu.Lock() defer m.mu.Unlock() @@ -387,7 +431,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err // Retrieve the committed update, failing if none is found. We should // only receive acks for state updates that we send. - updates := session.CommittedUpdates + updates := m.committedUpdates[session.ID] for i, update := range updates { if update.SeqNum != seqNum { continue @@ -398,9 +442,9 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err // along with the next update. copy(updates[:i], updates[i+1:]) updates[len(updates)-1] = wtdb.CommittedUpdate{} - session.CommittedUpdates = updates[:len(updates)-1] + m.committedUpdates[session.ID] = updates[:len(updates)-1] - session.AckedUpdates[seqNum] = update.BackupID + m.ackedUpdates[*id][seqNum] = update.BackupID session.TowerLastApplied = lastApplied m.activeSessions[*id] = session