watchtower: remove CommittedUpdates from ClientSession

In this commit, the new ListClientSession functional options and new
FetchSessionCommittedUpdates function are utilised in order to allow us
to completely remove the CommittedUpdates member from the ClientSession
struct.
This commit is contained in:
Elle Mouton 2022-09-30 12:18:08 +02:00
parent fe3d9174ea
commit 75e5339217
No known key found for this signature in database
GPG Key ID: D7D916376026F177
5 changed files with 129 additions and 56 deletions

View File

@ -489,7 +489,7 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
// Start initializes the watchtower client by loading or negotiating an active
// session and then begins processing backup tasks from the request pipeline.
func (c *TowerClient) Start() error {
var err error
var returnErr error
c.started.Do(func() {
c.log.Infof("Watchtower client starting")
@ -498,19 +498,27 @@ func (c *TowerClient) Start() error {
// sessions will be able to flush the committed updates after a
// restart.
for _, session := range c.candidateSessions {
if len(session.CommittedUpdates) > 0 {
committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID)
if err != nil {
returnErr = err
return
}
if len(committedUpdates) > 0 {
c.log.Infof("Starting session=%s to process "+
"%d committed backups", session.ID,
len(session.CommittedUpdates))
c.initActiveQueue(session)
len(committedUpdates))
c.initActiveQueue(session, committedUpdates)
}
}
// Now start the session negotiator, which will allow us to
// request new session as soon as the backupDispatcher starts
// up.
err = c.negotiator.Start()
err := c.negotiator.Start()
if err != nil {
returnErr = err
return
}
@ -523,7 +531,7 @@ func (c *TowerClient) Start() error {
c.log.Infof("Watchtower client started successfully")
})
return err
return returnErr
}
// Stop idempotently initiates a graceful shutdown of the watchtower client.
@ -699,7 +707,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
// active client's advertised policy will be ignored, but may be resumed if the
// client is restarted with a matching policy. If no candidates were found, nil
// is returned to signal that we need to request a new policy.
func (c *TowerClient) nextSessionQueue() *sessionQueue {
func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
// Select any candidate session at random, and remove it from the set of
// candidate sessions.
var candidateSession *wtdb.ClientSession
@ -721,13 +729,20 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue {
// If none of the sessions could be used or none were found, we'll
// return nil to signal that we need another session to be negotiated.
if candidateSession == nil {
return nil
return nil, nil
}
updates, err := c.cfg.DB.FetchSessionCommittedUpdates(
&candidateSession.ID,
)
if err != nil {
return nil, err
}
// Initialize the session queue and spin it up so it can begin handling
// updates. If the queue was already made active on startup, this will
// simply return the existing session queue from the set.
return c.getOrInitActiveQueue(candidateSession)
return c.getOrInitActiveQueue(candidateSession, updates), nil
}
// backupDispatcher processes events coming from the taskPipeline and is
@ -800,7 +815,13 @@ func (c *TowerClient) backupDispatcher() {
// We've exhausted the prior session, we'll pop another
// from the remaining sessions and continue processing
// backup tasks.
c.sessionQueue = c.nextSessionQueue()
var err error
c.sessionQueue, err = c.nextSessionQueue()
if err != nil {
c.log.Errorf("error fetching next session "+
"queue: %v", err)
}
if c.sessionQueue != nil {
c.log.Debugf("Loaded next candidate session "+
"queue id=%s", c.sessionQueue.ID())
@ -1048,7 +1069,9 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the
// database and supplying it with the resources needed by the client.
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
return newSessionQueue(&sessionQueueConfig{
ClientSession: s,
ChainHash: c.cfg.ChainHash,
@ -1060,28 +1083,32 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
MinBackoff: c.cfg.MinBackoff,
MaxBackoff: c.cfg.MaxBackoff,
Log: c.log,
})
}, updates)
}
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
// passed ClientSession. If it exists, the active sessionQueue is returned.
// Otherwise a new sessionQueue is initialized and added to the set.
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue {
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
if sq, ok := c.activeSessions[s.ID]; ok {
return sq
}
return c.initActiveQueue(s)
return c.initActiveQueue(s, updates)
}
// initActiveQueue creates a new sessionQueue from the passed ClientSession,
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue
// so that it can deliver any committed updates or begin accepting newly
// assigned tasks.
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue {
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
// Initialize the session queue, providing it with all of the resources
// it requires from the client instance.
sq := c.newSessionQueue(s)
sq := c.newSessionQueue(s, updates)
// Add the session queue as an active session so that we remember to
// stop it on shutdown.

View File

@ -109,7 +109,9 @@ type sessionQueue struct {
}
// newSessionQueue intiializes a fresh sessionQueue.
func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
func newSessionQueue(cfg *sessionQueueConfig,
updates []wtdb.CommittedUpdate) *sessionQueue {
localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
cfg.ChainHash,
@ -137,7 +139,7 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
// The database should return them in sorted order, and session queue's
// sequence number will be equal to that of the last committed update.
for _, update := range sq.cfg.ClientSession.CommittedUpdates {
for _, update := range updates {
sq.commitQueue.PushBack(update)
}

View File

@ -420,8 +420,17 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return ErrUninitializedDB
}
towerID := TowerIDFromBytes(towerIDBytes)
committedUpdateCount := make(map[SessionID]uint16)
perCommittedUpdate := func(s *ClientSession,
_ *CommittedUpdate) {
committedUpdateCount[s.ID]++
}
towerSessions, err := listTowerSessions(
towerID, sessions, towers, towersToSessionsIndex,
WithPerCommittedUpdate(perCommittedUpdate),
)
if err != nil {
return err
@ -447,7 +456,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
// have any pending updates to ensure we don't load them upon
// restarts.
for _, session := range towerSessions {
if len(session.CommittedUpdates) > 0 {
if committedUpdateCount[session.ID] > 0 {
return ErrTowerUnackedUpdates
}
err := markSessionStatus(
@ -1257,12 +1266,14 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
if err != nil {
return nil, err
}
session.Tower = tower
// Can't fail because client session body has already been read.
sessionBkt := sessions.NestedReadBucket(idBytes)
// Fetch the committed updates for this session.
commitedUpdates, err := getClientSessionCommits(
// Pass the session's committed (un-acked) updates through the call-back
// if one is provided.
err = filterClientSessionCommits(
sessionBkt, session, cfg.PerCommittedUpdate,
)
if err != nil {
@ -1276,9 +1287,6 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
return nil, err
}
session.Tower = tower
session.CommittedUpdates = commitedUpdates
return session, nil
}
@ -1354,6 +1362,39 @@ func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
return nil
}
// filterClientSessionCommits retrieves all committed updates for the session
// identified by the serialized session id and passes them to the given
// PerCommittedUpdateCB callback.
func filterClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerCommittedUpdateCB) error {
if cb == nil {
return nil
}
sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits)
if sessionCommits == nil {
return nil
}
err := sessionCommits.ForEach(func(k, v []byte) error {
var committedUpdate CommittedUpdate
err := committedUpdate.Decode(bytes.NewReader(v))
if err != nil {
return err
}
committedUpdate.SeqNum = byteOrder.Uint16(k)
cb(s, &committedUpdate)
return nil
})
if err != nil {
return err
}
return nil
}
// putClientSessionBody stores the body of the ClientSession (everything but the
// CommittedUpdates and AckedUpdates).
func putClientSessionBody(sessions kvdb.RwBucket,

View File

@ -37,16 +37,6 @@ type ClientSession struct {
ClientSessionBody
// CommittedUpdates is a sorted list of unacked updates. These updates
// can be resent after a restart if the updates failed to send or
// receive an acknowledgment.
//
// NOTE: This list is serialized in it's own bucket, separate from the
// body of the ClientSession. The representation on disk is a key value
// map from sequence number to CommittedUpdateBody to allow efficient
// insertion and retrieval.
CommittedUpdates []CommittedUpdate
// Tower holds the pubkey and address of the watchtower.
//
// NOTE: This value is not serialized. It is recovered by looking up the

View File

@ -23,12 +23,13 @@ type keyIndexKey struct {
type ClientDB struct {
nextTowerID uint64 // to be used atomically
mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower
mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower
nextIndex uint32
indexes map[keyIndexKey]uint32
@ -38,13 +39,14 @@ type ClientDB struct {
// NewClientDB initializes a new mock ClientDB.
func NewClientDB() *ClientDB {
return &ClientDB{
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32),
legacyIndexes: make(map[wtdb.TowerID]uint32),
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate),
towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32),
legacyIndexes: make(map[wtdb.TowerID]uint32),
}
}
@ -131,7 +133,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
}
for id, session := range towerSessions {
if len(session.CommittedUpdates) > 0 {
if len(m.committedUpdates[session.ID]) > 0 {
return wtdb.ErrTowerUnackedUpdates
}
session.Status = wtdb.CSessionInactive
@ -237,6 +239,13 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
cfg.PerAckedUpdate(&session, seq, id)
}
}
if cfg.PerCommittedUpdate != nil {
for _, update := range m.committedUpdates[session.ID] {
update := update
cfg.PerCommittedUpdate(&session, &update)
}
}
}
return sessions, nil
@ -250,12 +259,12 @@ func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
m.mu.Lock()
defer m.mu.Unlock()
sess, ok := m.activeSessions[*id]
updates, ok := m.committedUpdates[*id]
if !ok {
return nil, wtdb.ErrClientSessionNotFound
}
return sess.CommittedUpdates, nil
return updates, nil
}
// CreateClientSession records a newly negotiated client session in the set of
@ -302,9 +311,9 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
Policy: session.Policy,
RewardPkScript: cloneBytes(session.RewardPkScript),
},
CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
}
m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID)
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
return nil
}
@ -365,7 +374,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
}
// Check if an update has already been committed for this state.
for _, dbUpdate := range session.CommittedUpdates {
for _, dbUpdate := range m.committedUpdates[session.ID] {
if dbUpdate.SeqNum == update.SeqNum {
// If the breach hint matches, we'll just return the
// last applied value so the client can retransmit.
@ -384,7 +393,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
}
// Save the update and increment the sequence number.
session.CommittedUpdates = append(session.CommittedUpdates, *update)
m.committedUpdates[session.ID] = append(
m.committedUpdates[session.ID], *update,
)
session.SeqNum++
m.activeSessions[*id] = session
@ -394,7 +405,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
// removes the update from the set of committed updates, and validates the
// lastApplied value returned from the tower.
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error {
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
lastApplied uint16) error {
m.mu.Lock()
defer m.mu.Unlock()
@ -418,7 +431,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
// Retrieve the committed update, failing if none is found. We should
// only receive acks for state updates that we send.
updates := session.CommittedUpdates
updates := m.committedUpdates[session.ID]
for i, update := range updates {
if update.SeqNum != seqNum {
continue
@ -429,7 +442,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
// along with the next update.
copy(updates[:i], updates[i+1:])
updates[len(updates)-1] = wtdb.CommittedUpdate{}
session.CommittedUpdates = updates[:len(updates)-1]
m.committedUpdates[session.ID] = updates[:len(updates)-1]
m.ackedUpdates[*id][seqNum] = update.BackupID
session.TowerLastApplied = lastApplied