watchtower: remove CommittedUpdates from ClientSession

In this commit, the new ListClientSession functional options and new
FetchSessionCommittedUpdates function are utilised in order to allow us
to completely remove the CommittedUpdates member from the ClientSession
struct.
This commit is contained in:
Elle Mouton 2022-09-30 12:18:08 +02:00
parent fe3d9174ea
commit 75e5339217
No known key found for this signature in database
GPG Key ID: D7D916376026F177
5 changed files with 129 additions and 56 deletions

View File

@ -489,7 +489,7 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
// Start initializes the watchtower client by loading or negotiating an active // Start initializes the watchtower client by loading or negotiating an active
// session and then begins processing backup tasks from the request pipeline. // session and then begins processing backup tasks from the request pipeline.
func (c *TowerClient) Start() error { func (c *TowerClient) Start() error {
var err error var returnErr error
c.started.Do(func() { c.started.Do(func() {
c.log.Infof("Watchtower client starting") c.log.Infof("Watchtower client starting")
@ -498,19 +498,27 @@ func (c *TowerClient) Start() error {
// sessions will be able to flush the committed updates after a // sessions will be able to flush the committed updates after a
// restart. // restart.
for _, session := range c.candidateSessions { for _, session := range c.candidateSessions {
if len(session.CommittedUpdates) > 0 { committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID)
if err != nil {
returnErr = err
return
}
if len(committedUpdates) > 0 {
c.log.Infof("Starting session=%s to process "+ c.log.Infof("Starting session=%s to process "+
"%d committed backups", session.ID, "%d committed backups", session.ID,
len(session.CommittedUpdates)) len(committedUpdates))
c.initActiveQueue(session)
c.initActiveQueue(session, committedUpdates)
} }
} }
// Now start the session negotiator, which will allow us to // Now start the session negotiator, which will allow us to
// request new session as soon as the backupDispatcher starts // request new session as soon as the backupDispatcher starts
// up. // up.
err = c.negotiator.Start() err := c.negotiator.Start()
if err != nil { if err != nil {
returnErr = err
return return
} }
@ -523,7 +531,7 @@ func (c *TowerClient) Start() error {
c.log.Infof("Watchtower client started successfully") c.log.Infof("Watchtower client started successfully")
}) })
return err return returnErr
} }
// Stop idempotently initiates a graceful shutdown of the watchtower client. // Stop idempotently initiates a graceful shutdown of the watchtower client.
@ -699,7 +707,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
// active client's advertised policy will be ignored, but may be resumed if the // active client's advertised policy will be ignored, but may be resumed if the
// client is restarted with a matching policy. If no candidates were found, nil // client is restarted with a matching policy. If no candidates were found, nil
// is returned to signal that we need to request a new policy. // is returned to signal that we need to request a new policy.
func (c *TowerClient) nextSessionQueue() *sessionQueue { func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
// Select any candidate session at random, and remove it from the set of // Select any candidate session at random, and remove it from the set of
// candidate sessions. // candidate sessions.
var candidateSession *wtdb.ClientSession var candidateSession *wtdb.ClientSession
@ -721,13 +729,20 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue {
// If none of the sessions could be used or none were found, we'll // If none of the sessions could be used or none were found, we'll
// return nil to signal that we need another session to be negotiated. // return nil to signal that we need another session to be negotiated.
if candidateSession == nil { if candidateSession == nil {
return nil return nil, nil
}
updates, err := c.cfg.DB.FetchSessionCommittedUpdates(
&candidateSession.ID,
)
if err != nil {
return nil, err
} }
// Initialize the session queue and spin it up so it can begin handling // Initialize the session queue and spin it up so it can begin handling
// updates. If the queue was already made active on startup, this will // updates. If the queue was already made active on startup, this will
// simply return the existing session queue from the set. // simply return the existing session queue from the set.
return c.getOrInitActiveQueue(candidateSession) return c.getOrInitActiveQueue(candidateSession, updates), nil
} }
// backupDispatcher processes events coming from the taskPipeline and is // backupDispatcher processes events coming from the taskPipeline and is
@ -800,7 +815,13 @@ func (c *TowerClient) backupDispatcher() {
// We've exhausted the prior session, we'll pop another // We've exhausted the prior session, we'll pop another
// from the remaining sessions and continue processing // from the remaining sessions and continue processing
// backup tasks. // backup tasks.
c.sessionQueue = c.nextSessionQueue() var err error
c.sessionQueue, err = c.nextSessionQueue()
if err != nil {
c.log.Errorf("error fetching next session "+
"queue: %v", err)
}
if c.sessionQueue != nil { if c.sessionQueue != nil {
c.log.Debugf("Loaded next candidate session "+ c.log.Debugf("Loaded next candidate session "+
"queue id=%s", c.sessionQueue.ID()) "queue id=%s", c.sessionQueue.ID())
@ -1048,7 +1069,9 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the // newSessionQueue creates a sessionQueue from a ClientSession loaded from the
// database and supplying it with the resources needed by the client. // database and supplying it with the resources needed by the client.
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue { func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
return newSessionQueue(&sessionQueueConfig{ return newSessionQueue(&sessionQueueConfig{
ClientSession: s, ClientSession: s,
ChainHash: c.cfg.ChainHash, ChainHash: c.cfg.ChainHash,
@ -1060,28 +1083,32 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
MinBackoff: c.cfg.MinBackoff, MinBackoff: c.cfg.MinBackoff,
MaxBackoff: c.cfg.MaxBackoff, MaxBackoff: c.cfg.MaxBackoff,
Log: c.log, Log: c.log,
}) }, updates)
} }
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the // getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
// passed ClientSession. If it exists, the active sessionQueue is returned. // passed ClientSession. If it exists, the active sessionQueue is returned.
// Otherwise a new sessionQueue is initialized and added to the set. // Otherwise a new sessionQueue is initialized and added to the set.
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue { func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
if sq, ok := c.activeSessions[s.ID]; ok { if sq, ok := c.activeSessions[s.ID]; ok {
return sq return sq
} }
return c.initActiveQueue(s) return c.initActiveQueue(s, updates)
} }
// initActiveQueue creates a new sessionQueue from the passed ClientSession, // initActiveQueue creates a new sessionQueue from the passed ClientSession,
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue // adds the sessionQueue to the activeSessions set, and starts the sessionQueue
// so that it can deliver any committed updates or begin accepting newly // so that it can deliver any committed updates or begin accepting newly
// assigned tasks. // assigned tasks.
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue { func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
// Initialize the session queue, providing it with all of the resources // Initialize the session queue, providing it with all of the resources
// it requires from the client instance. // it requires from the client instance.
sq := c.newSessionQueue(s) sq := c.newSessionQueue(s, updates)
// Add the session queue as an active session so that we remember to // Add the session queue as an active session so that we remember to
// stop it on shutdown. // stop it on shutdown.

View File

@ -109,7 +109,9 @@ type sessionQueue struct {
} }
// newSessionQueue intiializes a fresh sessionQueue. // newSessionQueue intiializes a fresh sessionQueue.
func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { func newSessionQueue(cfg *sessionQueueConfig,
updates []wtdb.CommittedUpdate) *sessionQueue {
localInit := wtwire.NewInitMessage( localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
cfg.ChainHash, cfg.ChainHash,
@ -137,7 +139,7 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
// The database should return them in sorted order, and session queue's // The database should return them in sorted order, and session queue's
// sequence number will be equal to that of the last committed update. // sequence number will be equal to that of the last committed update.
for _, update := range sq.cfg.ClientSession.CommittedUpdates { for _, update := range updates {
sq.commitQueue.PushBack(update) sq.commitQueue.PushBack(update)
} }

View File

@ -420,8 +420,17 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return ErrUninitializedDB return ErrUninitializedDB
} }
towerID := TowerIDFromBytes(towerIDBytes) towerID := TowerIDFromBytes(towerIDBytes)
committedUpdateCount := make(map[SessionID]uint16)
perCommittedUpdate := func(s *ClientSession,
_ *CommittedUpdate) {
committedUpdateCount[s.ID]++
}
towerSessions, err := listTowerSessions( towerSessions, err := listTowerSessions(
towerID, sessions, towers, towersToSessionsIndex, towerID, sessions, towers, towersToSessionsIndex,
WithPerCommittedUpdate(perCommittedUpdate),
) )
if err != nil { if err != nil {
return err return err
@ -447,7 +456,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
// have any pending updates to ensure we don't load them upon // have any pending updates to ensure we don't load them upon
// restarts. // restarts.
for _, session := range towerSessions { for _, session := range towerSessions {
if len(session.CommittedUpdates) > 0 { if committedUpdateCount[session.ID] > 0 {
return ErrTowerUnackedUpdates return ErrTowerUnackedUpdates
} }
err := markSessionStatus( err := markSessionStatus(
@ -1257,12 +1266,14 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
if err != nil { if err != nil {
return nil, err return nil, err
} }
session.Tower = tower
// Can't fail because client session body has already been read. // Can't fail because client session body has already been read.
sessionBkt := sessions.NestedReadBucket(idBytes) sessionBkt := sessions.NestedReadBucket(idBytes)
// Fetch the committed updates for this session. // Pass the session's committed (un-acked) updates through the call-back
commitedUpdates, err := getClientSessionCommits( // if one is provided.
err = filterClientSessionCommits(
sessionBkt, session, cfg.PerCommittedUpdate, sessionBkt, session, cfg.PerCommittedUpdate,
) )
if err != nil { if err != nil {
@ -1276,9 +1287,6 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
return nil, err return nil, err
} }
session.Tower = tower
session.CommittedUpdates = commitedUpdates
return session, nil return session, nil
} }
@ -1354,6 +1362,39 @@ func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
return nil return nil
} }
// filterClientSessionCommits retrieves all committed updates for the session
// identified by the serialized session id and passes them to the given
// PerCommittedUpdateCB callback.
func filterClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerCommittedUpdateCB) error {
if cb == nil {
return nil
}
sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits)
if sessionCommits == nil {
return nil
}
err := sessionCommits.ForEach(func(k, v []byte) error {
var committedUpdate CommittedUpdate
err := committedUpdate.Decode(bytes.NewReader(v))
if err != nil {
return err
}
committedUpdate.SeqNum = byteOrder.Uint16(k)
cb(s, &committedUpdate)
return nil
})
if err != nil {
return err
}
return nil
}
// putClientSessionBody stores the body of the ClientSession (everything but the // putClientSessionBody stores the body of the ClientSession (everything but the
// CommittedUpdates and AckedUpdates). // CommittedUpdates and AckedUpdates).
func putClientSessionBody(sessions kvdb.RwBucket, func putClientSessionBody(sessions kvdb.RwBucket,

View File

@ -37,16 +37,6 @@ type ClientSession struct {
ClientSessionBody ClientSessionBody
// CommittedUpdates is a sorted list of unacked updates. These updates
// can be resent after a restart if the updates failed to send or
// receive an acknowledgment.
//
// NOTE: This list is serialized in it's own bucket, separate from the
// body of the ClientSession. The representation on disk is a key value
// map from sequence number to CommittedUpdateBody to allow efficient
// insertion and retrieval.
CommittedUpdates []CommittedUpdate
// Tower holds the pubkey and address of the watchtower. // Tower holds the pubkey and address of the watchtower.
// //
// NOTE: This value is not serialized. It is recovered by looking up the // NOTE: This value is not serialized. It is recovered by looking up the

View File

@ -27,6 +27,7 @@ type ClientDB struct {
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]wtdb.ClientSession activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
towerIndex map[towerPK]wtdb.TowerID towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower towers map[wtdb.TowerID]*wtdb.Tower
@ -41,6 +42,7 @@ func NewClientDB() *ClientDB {
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID), ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate),
towerIndex: make(map[towerPK]wtdb.TowerID), towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower), towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32), indexes: make(map[keyIndexKey]uint32),
@ -131,7 +133,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
} }
for id, session := range towerSessions { for id, session := range towerSessions {
if len(session.CommittedUpdates) > 0 { if len(m.committedUpdates[session.ID]) > 0 {
return wtdb.ErrTowerUnackedUpdates return wtdb.ErrTowerUnackedUpdates
} }
session.Status = wtdb.CSessionInactive session.Status = wtdb.CSessionInactive
@ -237,6 +239,13 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
cfg.PerAckedUpdate(&session, seq, id) cfg.PerAckedUpdate(&session, seq, id)
} }
} }
if cfg.PerCommittedUpdate != nil {
for _, update := range m.committedUpdates[session.ID] {
update := update
cfg.PerCommittedUpdate(&session, &update)
}
}
} }
return sessions, nil return sessions, nil
@ -250,12 +259,12 @@ func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
sess, ok := m.activeSessions[*id] updates, ok := m.committedUpdates[*id]
if !ok { if !ok {
return nil, wtdb.ErrClientSessionNotFound return nil, wtdb.ErrClientSessionNotFound
} }
return sess.CommittedUpdates, nil return updates, nil
} }
// CreateClientSession records a newly negotiated client session in the set of // CreateClientSession records a newly negotiated client session in the set of
@ -302,9 +311,9 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
Policy: session.Policy, Policy: session.Policy,
RewardPkScript: cloneBytes(session.RewardPkScript), RewardPkScript: cloneBytes(session.RewardPkScript),
}, },
CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
} }
m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID) m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID)
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
return nil return nil
} }
@ -365,7 +374,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
} }
// Check if an update has already been committed for this state. // Check if an update has already been committed for this state.
for _, dbUpdate := range session.CommittedUpdates { for _, dbUpdate := range m.committedUpdates[session.ID] {
if dbUpdate.SeqNum == update.SeqNum { if dbUpdate.SeqNum == update.SeqNum {
// If the breach hint matches, we'll just return the // If the breach hint matches, we'll just return the
// last applied value so the client can retransmit. // last applied value so the client can retransmit.
@ -384,7 +393,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
} }
// Save the update and increment the sequence number. // Save the update and increment the sequence number.
session.CommittedUpdates = append(session.CommittedUpdates, *update) m.committedUpdates[session.ID] = append(
m.committedUpdates[session.ID], *update,
)
session.SeqNum++ session.SeqNum++
m.activeSessions[*id] = session m.activeSessions[*id] = session
@ -394,7 +405,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This // AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
// removes the update from the set of committed updates, and validates the // removes the update from the set of committed updates, and validates the
// lastApplied value returned from the tower. // lastApplied value returned from the tower.
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error { func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
lastApplied uint16) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@ -418,7 +431,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
// Retrieve the committed update, failing if none is found. We should // Retrieve the committed update, failing if none is found. We should
// only receive acks for state updates that we send. // only receive acks for state updates that we send.
updates := session.CommittedUpdates updates := m.committedUpdates[session.ID]
for i, update := range updates { for i, update := range updates {
if update.SeqNum != seqNum { if update.SeqNum != seqNum {
continue continue
@ -429,7 +442,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
// along with the next update. // along with the next update.
copy(updates[:i], updates[i+1:]) copy(updates[:i], updates[i+1:])
updates[len(updates)-1] = wtdb.CommittedUpdate{} updates[len(updates)-1] = wtdb.CommittedUpdate{}
session.CommittedUpdates = updates[:len(updates)-1] m.committedUpdates[session.ID] = updates[:len(updates)-1]
m.ackedUpdates[*id][seqNum] = update.BackupID m.ackedUpdates[*id][seqNum] = update.BackupID
session.TowerLastApplied = lastApplied session.TowerLastApplied = lastApplied