mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-01 18:50:09 +02:00
watchtower: remove CommittedUpdates from ClientSession
In this commit, the new ListClientSession functional options and new FetchSessionCommittedUpdates function are utilised in order to allow us to completely remove the CommittedUpdates member from the ClientSession struct.
This commit is contained in:
parent
fe3d9174ea
commit
75e5339217
@ -489,7 +489,7 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
// Start initializes the watchtower client by loading or negotiating an active
|
||||
// session and then begins processing backup tasks from the request pipeline.
|
||||
func (c *TowerClient) Start() error {
|
||||
var err error
|
||||
var returnErr error
|
||||
c.started.Do(func() {
|
||||
c.log.Infof("Watchtower client starting")
|
||||
|
||||
@ -498,19 +498,27 @@ func (c *TowerClient) Start() error {
|
||||
// sessions will be able to flush the committed updates after a
|
||||
// restart.
|
||||
for _, session := range c.candidateSessions {
|
||||
if len(session.CommittedUpdates) > 0 {
|
||||
committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID)
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
if len(committedUpdates) > 0 {
|
||||
c.log.Infof("Starting session=%s to process "+
|
||||
"%d committed backups", session.ID,
|
||||
len(session.CommittedUpdates))
|
||||
c.initActiveQueue(session)
|
||||
len(committedUpdates))
|
||||
|
||||
c.initActiveQueue(session, committedUpdates)
|
||||
}
|
||||
}
|
||||
|
||||
// Now start the session negotiator, which will allow us to
|
||||
// request new session as soon as the backupDispatcher starts
|
||||
// up.
|
||||
err = c.negotiator.Start()
|
||||
err := c.negotiator.Start()
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
@ -523,7 +531,7 @@ func (c *TowerClient) Start() error {
|
||||
|
||||
c.log.Infof("Watchtower client started successfully")
|
||||
})
|
||||
return err
|
||||
return returnErr
|
||||
}
|
||||
|
||||
// Stop idempotently initiates a graceful shutdown of the watchtower client.
|
||||
@ -699,7 +707,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||
// active client's advertised policy will be ignored, but may be resumed if the
|
||||
// client is restarted with a matching policy. If no candidates were found, nil
|
||||
// is returned to signal that we need to request a new policy.
|
||||
func (c *TowerClient) nextSessionQueue() *sessionQueue {
|
||||
func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
|
||||
// Select any candidate session at random, and remove it from the set of
|
||||
// candidate sessions.
|
||||
var candidateSession *wtdb.ClientSession
|
||||
@ -721,13 +729,20 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue {
|
||||
// If none of the sessions could be used or none were found, we'll
|
||||
// return nil to signal that we need another session to be negotiated.
|
||||
if candidateSession == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
updates, err := c.cfg.DB.FetchSessionCommittedUpdates(
|
||||
&candidateSession.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize the session queue and spin it up so it can begin handling
|
||||
// updates. If the queue was already made active on startup, this will
|
||||
// simply return the existing session queue from the set.
|
||||
return c.getOrInitActiveQueue(candidateSession)
|
||||
return c.getOrInitActiveQueue(candidateSession, updates), nil
|
||||
}
|
||||
|
||||
// backupDispatcher processes events coming from the taskPipeline and is
|
||||
@ -800,7 +815,13 @@ func (c *TowerClient) backupDispatcher() {
|
||||
// We've exhausted the prior session, we'll pop another
|
||||
// from the remaining sessions and continue processing
|
||||
// backup tasks.
|
||||
c.sessionQueue = c.nextSessionQueue()
|
||||
var err error
|
||||
c.sessionQueue, err = c.nextSessionQueue()
|
||||
if err != nil {
|
||||
c.log.Errorf("error fetching next session "+
|
||||
"queue: %v", err)
|
||||
}
|
||||
|
||||
if c.sessionQueue != nil {
|
||||
c.log.Debugf("Loaded next candidate session "+
|
||||
"queue id=%s", c.sessionQueue.ID())
|
||||
@ -1048,7 +1069,9 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error
|
||||
|
||||
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the
|
||||
// database and supplying it with the resources needed by the client.
|
||||
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession,
|
||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||
|
||||
return newSessionQueue(&sessionQueueConfig{
|
||||
ClientSession: s,
|
||||
ChainHash: c.cfg.ChainHash,
|
||||
@ -1060,28 +1083,32 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
MinBackoff: c.cfg.MinBackoff,
|
||||
MaxBackoff: c.cfg.MaxBackoff,
|
||||
Log: c.log,
|
||||
})
|
||||
}, updates)
|
||||
}
|
||||
|
||||
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
|
||||
// passed ClientSession. If it exists, the active sessionQueue is returned.
|
||||
// Otherwise a new sessionQueue is initialized and added to the set.
|
||||
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession,
|
||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||
|
||||
if sq, ok := c.activeSessions[s.ID]; ok {
|
||||
return sq
|
||||
}
|
||||
|
||||
return c.initActiveQueue(s)
|
||||
return c.initActiveQueue(s, updates)
|
||||
}
|
||||
|
||||
// initActiveQueue creates a new sessionQueue from the passed ClientSession,
|
||||
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue
|
||||
// so that it can deliver any committed updates or begin accepting newly
|
||||
// assigned tasks.
|
||||
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession,
|
||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||
|
||||
// Initialize the session queue, providing it with all of the resources
|
||||
// it requires from the client instance.
|
||||
sq := c.newSessionQueue(s)
|
||||
sq := c.newSessionQueue(s, updates)
|
||||
|
||||
// Add the session queue as an active session so that we remember to
|
||||
// stop it on shutdown.
|
||||
|
@ -109,7 +109,9 @@ type sessionQueue struct {
|
||||
}
|
||||
|
||||
// newSessionQueue intiializes a fresh sessionQueue.
|
||||
func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
|
||||
func newSessionQueue(cfg *sessionQueueConfig,
|
||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||
|
||||
localInit := wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
|
||||
cfg.ChainHash,
|
||||
@ -137,7 +139,7 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
|
||||
|
||||
// The database should return them in sorted order, and session queue's
|
||||
// sequence number will be equal to that of the last committed update.
|
||||
for _, update := range sq.cfg.ClientSession.CommittedUpdates {
|
||||
for _, update := range updates {
|
||||
sq.commitQueue.PushBack(update)
|
||||
}
|
||||
|
||||
|
@ -420,8 +420,17 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
towerID := TowerIDFromBytes(towerIDBytes)
|
||||
|
||||
committedUpdateCount := make(map[SessionID]uint16)
|
||||
perCommittedUpdate := func(s *ClientSession,
|
||||
_ *CommittedUpdate) {
|
||||
|
||||
committedUpdateCount[s.ID]++
|
||||
}
|
||||
|
||||
towerSessions, err := listTowerSessions(
|
||||
towerID, sessions, towers, towersToSessionsIndex,
|
||||
WithPerCommittedUpdate(perCommittedUpdate),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -447,7 +456,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
// have any pending updates to ensure we don't load them upon
|
||||
// restarts.
|
||||
for _, session := range towerSessions {
|
||||
if len(session.CommittedUpdates) > 0 {
|
||||
if committedUpdateCount[session.ID] > 0 {
|
||||
return ErrTowerUnackedUpdates
|
||||
}
|
||||
err := markSessionStatus(
|
||||
@ -1257,12 +1266,14 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session.Tower = tower
|
||||
|
||||
// Can't fail because client session body has already been read.
|
||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
||||
|
||||
// Fetch the committed updates for this session.
|
||||
commitedUpdates, err := getClientSessionCommits(
|
||||
// Pass the session's committed (un-acked) updates through the call-back
|
||||
// if one is provided.
|
||||
err = filterClientSessionCommits(
|
||||
sessionBkt, session, cfg.PerCommittedUpdate,
|
||||
)
|
||||
if err != nil {
|
||||
@ -1276,9 +1287,6 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.Tower = tower
|
||||
session.CommittedUpdates = commitedUpdates
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
@ -1354,6 +1362,39 @@ func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterClientSessionCommits retrieves all committed updates for the session
|
||||
// identified by the serialized session id and passes them to the given
|
||||
// PerCommittedUpdateCB callback.
|
||||
func filterClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||
cb PerCommittedUpdateCB) error {
|
||||
|
||||
if cb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits)
|
||||
if sessionCommits == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := sessionCommits.ForEach(func(k, v []byte) error {
|
||||
var committedUpdate CommittedUpdate
|
||||
err := committedUpdate.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
committedUpdate.SeqNum = byteOrder.Uint16(k)
|
||||
|
||||
cb(s, &committedUpdate)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// putClientSessionBody stores the body of the ClientSession (everything but the
|
||||
// CommittedUpdates and AckedUpdates).
|
||||
func putClientSessionBody(sessions kvdb.RwBucket,
|
||||
|
@ -37,16 +37,6 @@ type ClientSession struct {
|
||||
|
||||
ClientSessionBody
|
||||
|
||||
// CommittedUpdates is a sorted list of unacked updates. These updates
|
||||
// can be resent after a restart if the updates failed to send or
|
||||
// receive an acknowledgment.
|
||||
//
|
||||
// NOTE: This list is serialized in it's own bucket, separate from the
|
||||
// body of the ClientSession. The representation on disk is a key value
|
||||
// map from sequence number to CommittedUpdateBody to allow efficient
|
||||
// insertion and retrieval.
|
||||
CommittedUpdates []CommittedUpdate
|
||||
|
||||
// Tower holds the pubkey and address of the watchtower.
|
||||
//
|
||||
// NOTE: This value is not serialized. It is recovered by looking up the
|
||||
|
@ -23,12 +23,13 @@ type keyIndexKey struct {
|
||||
type ClientDB struct {
|
||||
nextTowerID uint64 // to be used atomically
|
||||
|
||||
mu sync.Mutex
|
||||
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
||||
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
||||
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
|
||||
towerIndex map[towerPK]wtdb.TowerID
|
||||
towers map[wtdb.TowerID]*wtdb.Tower
|
||||
mu sync.Mutex
|
||||
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
||||
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
||||
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
|
||||
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
|
||||
towerIndex map[towerPK]wtdb.TowerID
|
||||
towers map[wtdb.TowerID]*wtdb.Tower
|
||||
|
||||
nextIndex uint32
|
||||
indexes map[keyIndexKey]uint32
|
||||
@ -38,13 +39,14 @@ type ClientDB struct {
|
||||
// NewClientDB initializes a new mock ClientDB.
|
||||
func NewClientDB() *ClientDB {
|
||||
return &ClientDB{
|
||||
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
|
||||
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
|
||||
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
|
||||
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
|
||||
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
|
||||
committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
@ -131,7 +133,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
}
|
||||
|
||||
for id, session := range towerSessions {
|
||||
if len(session.CommittedUpdates) > 0 {
|
||||
if len(m.committedUpdates[session.ID]) > 0 {
|
||||
return wtdb.ErrTowerUnackedUpdates
|
||||
}
|
||||
session.Status = wtdb.CSessionInactive
|
||||
@ -237,6 +239,13 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||
cfg.PerAckedUpdate(&session, seq, id)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.PerCommittedUpdate != nil {
|
||||
for _, update := range m.committedUpdates[session.ID] {
|
||||
update := update
|
||||
cfg.PerCommittedUpdate(&session, &update)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
@ -250,12 +259,12 @@ func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
sess, ok := m.activeSessions[*id]
|
||||
updates, ok := m.committedUpdates[*id]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
return sess.CommittedUpdates, nil
|
||||
return updates, nil
|
||||
}
|
||||
|
||||
// CreateClientSession records a newly negotiated client session in the set of
|
||||
@ -302,9 +311,9 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
Policy: session.Policy,
|
||||
RewardPkScript: cloneBytes(session.RewardPkScript),
|
||||
},
|
||||
CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
|
||||
}
|
||||
m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID)
|
||||
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -365,7 +374,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||
}
|
||||
|
||||
// Check if an update has already been committed for this state.
|
||||
for _, dbUpdate := range session.CommittedUpdates {
|
||||
for _, dbUpdate := range m.committedUpdates[session.ID] {
|
||||
if dbUpdate.SeqNum == update.SeqNum {
|
||||
// If the breach hint matches, we'll just return the
|
||||
// last applied value so the client can retransmit.
|
||||
@ -384,7 +393,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||
}
|
||||
|
||||
// Save the update and increment the sequence number.
|
||||
session.CommittedUpdates = append(session.CommittedUpdates, *update)
|
||||
m.committedUpdates[session.ID] = append(
|
||||
m.committedUpdates[session.ID], *update,
|
||||
)
|
||||
session.SeqNum++
|
||||
m.activeSessions[*id] = session
|
||||
|
||||
@ -394,7 +405,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
|
||||
// removes the update from the set of committed updates, and validates the
|
||||
// lastApplied value returned from the tower.
|
||||
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error {
|
||||
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
|
||||
lastApplied uint16) error {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@ -418,7 +431,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
|
||||
|
||||
// Retrieve the committed update, failing if none is found. We should
|
||||
// only receive acks for state updates that we send.
|
||||
updates := session.CommittedUpdates
|
||||
updates := m.committedUpdates[session.ID]
|
||||
for i, update := range updates {
|
||||
if update.SeqNum != seqNum {
|
||||
continue
|
||||
@ -429,7 +442,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
|
||||
// along with the next update.
|
||||
copy(updates[:i], updates[i+1:])
|
||||
updates[len(updates)-1] = wtdb.CommittedUpdate{}
|
||||
session.CommittedUpdates = updates[:len(updates)-1]
|
||||
m.committedUpdates[session.ID] = updates[:len(updates)-1]
|
||||
|
||||
m.ackedUpdates[*id][seqNum] = update.BackupID
|
||||
session.TowerLastApplied = lastApplied
|
||||
|
Loading…
x
Reference in New Issue
Block a user