watchtower: remove CommittedUpdates from ClientSession

In this commit, the new ListClientSession functional options and new
FetchSessionCommittedUpdates function are utilised in order to allow us
to completely remove the CommittedUpdates member from the ClientSession
struct.
This commit is contained in:
Elle Mouton
2022-09-30 12:18:08 +02:00
parent fe3d9174ea
commit 75e5339217
5 changed files with 129 additions and 56 deletions

View File

@@ -23,12 +23,13 @@ type keyIndexKey struct {
type ClientDB struct {
nextTowerID uint64 // to be used atomically
mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower
mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower
nextIndex uint32
indexes map[keyIndexKey]uint32
@@ -38,13 +39,14 @@ type ClientDB struct {
// NewClientDB initializes a new mock ClientDB.
func NewClientDB() *ClientDB {
return &ClientDB{
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32),
legacyIndexes: make(map[wtdb.TowerID]uint32),
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate),
towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32),
legacyIndexes: make(map[wtdb.TowerID]uint32),
}
}
@@ -131,7 +133,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
}
for id, session := range towerSessions {
if len(session.CommittedUpdates) > 0 {
if len(m.committedUpdates[session.ID]) > 0 {
return wtdb.ErrTowerUnackedUpdates
}
session.Status = wtdb.CSessionInactive
@@ -237,6 +239,13 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
cfg.PerAckedUpdate(&session, seq, id)
}
}
if cfg.PerCommittedUpdate != nil {
for _, update := range m.committedUpdates[session.ID] {
update := update
cfg.PerCommittedUpdate(&session, &update)
}
}
}
return sessions, nil
@@ -250,12 +259,12 @@ func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
m.mu.Lock()
defer m.mu.Unlock()
sess, ok := m.activeSessions[*id]
updates, ok := m.committedUpdates[*id]
if !ok {
return nil, wtdb.ErrClientSessionNotFound
}
return sess.CommittedUpdates, nil
return updates, nil
}
// CreateClientSession records a newly negotiated client session in the set of
@@ -302,9 +311,9 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
Policy: session.Policy,
RewardPkScript: cloneBytes(session.RewardPkScript),
},
CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
}
m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID)
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
return nil
}
@@ -365,7 +374,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
}
// Check if an update has already been committed for this state.
for _, dbUpdate := range session.CommittedUpdates {
for _, dbUpdate := range m.committedUpdates[session.ID] {
if dbUpdate.SeqNum == update.SeqNum {
// If the breach hint matches, we'll just return the
// last applied value so the client can retransmit.
@@ -384,7 +393,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
}
// Save the update and increment the sequence number.
session.CommittedUpdates = append(session.CommittedUpdates, *update)
m.committedUpdates[session.ID] = append(
m.committedUpdates[session.ID], *update,
)
session.SeqNum++
m.activeSessions[*id] = session
@@ -394,7 +405,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
// removes the update from the set of committed updates, and validates the
// lastApplied value returned from the tower.
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error {
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
lastApplied uint16) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -418,7 +431,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
// Retrieve the committed update, failing if none is found. We should
// only receive acks for state updates that we send.
updates := session.CommittedUpdates
updates := m.committedUpdates[session.ID]
for i, update := range updates {
if update.SeqNum != seqNum {
continue
@@ -429,7 +442,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
// along with the next update.
copy(updates[:i], updates[i+1:])
updates[len(updates)-1] = wtdb.CommittedUpdate{}
session.CommittedUpdates = updates[:len(updates)-1]
m.committedUpdates[session.ID] = updates[:len(updates)-1]
m.ackedUpdates[*id][seqNum] = update.BackupID
session.TowerLastApplied = lastApplied