From 24702ede14a2aa2f5066b61c9e242b5403c03b92 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 5 Dec 2023 13:12:38 +0200 Subject: [PATCH] wtclient: add TerminateSession method --- watchtower/wtclient/client.go | 96 ++++++++++++++++++--- watchtower/wtclient/client_test.go | 133 +++++++++++++++++++++++++++-- watchtower/wtclient/manager.go | 21 +++++ 3 files changed, 231 insertions(+), 19 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 04faa1231..e8fec5d4a 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -148,6 +148,19 @@ type deactivateTowerMsg struct { errChan chan error } +// terminateSessMsg is an internal message we'll use within the TowerClient to +// signal that a session should be terminated. +type terminateSessMsg struct { + // id is the session identifier. + id wtdb.SessionID + + // errChan is the channel through which we'll send a response back to + // the caller when handling their request. + // + // NOTE: This channel must be buffered. + errChan chan error +} + // clientCfg holds the configuration values required by a client. type clientCfg struct { *Config @@ -181,9 +194,10 @@ type client struct { statTicker *time.Ticker stats *clientStats - newTowers chan *newTowerMsg - staleTowers chan *staleTowerMsg - deactivateTowers chan *deactivateTowerMsg + newTowers chan *newTowerMsg + staleTowers chan *staleTowerMsg + deactivateTowers chan *deactivateTowerMsg + terminateSessions chan *terminateSessMsg wg sync.WaitGroup quit chan struct{} @@ -209,16 +223,17 @@ func newClient(cfg *clientCfg) (*client, error) { } c := &client{ - cfg: cfg, - log: plog, - pipeline: queue, - activeSessions: newSessionQueueSet(), - statTicker: time.NewTicker(DefaultStatInterval), - stats: new(clientStats), - newTowers: make(chan *newTowerMsg), - staleTowers: make(chan *staleTowerMsg), - deactivateTowers: make(chan *deactivateTowerMsg), - quit: make(chan struct{}), + cfg: cfg, + log: plog, + pipeline: queue, + activeSessions: newSessionQueueSet(), + statTicker: time.NewTicker(DefaultStatInterval), + stats: new(clientStats), + newTowers: make(chan *newTowerMsg), + staleTowers: make(chan *staleTowerMsg), + deactivateTowers: make(chan *deactivateTowerMsg), + terminateSessions: make(chan *terminateSessMsg), + quit: make(chan struct{}), } candidateTowers := newTowerListIterator() @@ -718,6 +733,10 @@ func (c *client) backupDispatcher() { case msg := <-c.deactivateTowers: msg.errChan <- c.handleDeactivateTower(msg) + // A request has come through to terminate a session. + case msg := <-c.terminateSessions: + msg.errChan <- c.handleTerminateSession(msg) + case <-c.quit: return } @@ -807,6 +826,10 @@ func (c *client) backupDispatcher() { case msg := <-c.deactivateTowers: msg.errChan <- c.handleDeactivateTower(msg) + // A request has come through to terminate a session. + case msg := <-c.terminateSessions: + msg.errChan <- c.handleTerminateSession(msg) + case <-c.quit: return } @@ -1074,6 +1097,53 @@ func (c *client) initActiveQueue(s *ClientSession, return sq } +// terminateSession sets the given session's status to CSessionTerminal meaning +// that it will not be used again. +func (c *client) terminateSession(id wtdb.SessionID) error { + errChan := make(chan error, 1) + + select { + case c.terminateSessions <- &terminateSessMsg{ + id: id, + errChan: errChan, + }: + case <-c.pipeline.quit: + return ErrClientExiting + } + + select { + case err := <-errChan: + return err + case <-c.pipeline.quit: + return ErrClientExiting + } +} + +// handleTerminateSession handles a request to terminate a session. It will +// first shut down the session if it is part of the active session set, then +// it will ensure that the active session queue is set reset if it is using the +// session in question. Finally, the session's status in the DB will be updated. +func (c *client) handleTerminateSession(msg *terminateSessMsg) error { + id := msg.id + + delete(c.candidateSessions, id) + + err := c.activeSessions.StopAndRemove(id, true) + if err != nil { + return fmt.Errorf("could not stop session %s: %w", id, err) + } + + // If our active session queue corresponds to the session being + // terminated, then we'll proceed to negotiate a new one. + if c.sessionQueue != nil { + if bytes.Equal(c.sessionQueue.ID()[:], id[:]) { + c.sessionQueue = nil + } + } + + return nil +} + // deactivateTower sends a tower deactivation request to the backupDispatcher // where it will be handled synchronously. The request should result in all the // sessions that we have with the given tower being shutdown and removed from diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 0e5352f88..38d9acd9f 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -1034,15 +1034,19 @@ func (s *serverHarness) waitForUpdates(hints []blob.BreachHint, // Closure to assert the server's matches are consistent with the hint // set. serverHasHints := func(matches []wtdb.Match) bool { - if len(hintSet) != len(matches) { + // De-dup the server matches since it might very well have + // multiple matches for a hint if that update was backed up on + // more than one session. + matchHints := make(map[blob.BreachHint]struct{}) + for _, match := range matches { + matchHints[match.Hint] = struct{}{} + } + + if len(hintSet) != len(matchHints) { return false } - for _, match := range matches { - _, ok := hintSet[match.Hint] - require.Truef(s.t, ok, "match %v in db is not in "+ - "hint set", match.Hint) - } + require.EqualValues(s.t, hintSet, matchHints) return true } @@ -2770,6 +2774,123 @@ var clientTests = []clientTest{ h.server.waitForUpdates(hints[numUpdates-1:], waitTime) }, }, + { + name: "terminate session", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: defaultTxPolicy, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 10 + chanIDInt = 0 + ) + + // Advance the channel with a few updates. + hints := h.advanceChannelN(chanIDInt, numUpdates) + + // Backup one of these updates and wait for it to + // arrive at the server. + h.backupStates(chanIDInt, 0, 1, nil) + h.server.waitForUpdates(hints[:1], waitTime) + + // Now, restart the server in a state where it will not + // ack updates. This will allow us to wait for an update + // to be un-acked and persisted. + h.server.restart(func(cfg *wtserver.Config) { + cfg.NoAckUpdates = true + }) + + // Backup another update. These should remain in the + // client as un-acked. + h.backupStates(chanIDInt, 1, 2, nil) + + // Wait for the update to be persisted. + fetchUnacked := h.clientDB.FetchSessionCommittedUpdates + var sessID wtdb.SessionID + err := wait.Predicate(func() bool { + sessions, err := h.clientDB.ListClientSessions( + nil, + ) + require.NoError(h.t, err) + + var updates []wtdb.CommittedUpdate + for id := range sessions { + sessID = id + updates, err = fetchUnacked(&id) + require.NoError(h.t, err) + + return len(updates) == 1 + } + + return false + }, waitTime) + require.NoError(h.t, err) + + // Now try to terminate the session by directly calling + // the DB terminate method. This is expected to fail + // since the session still has un-acked updates. + err = h.clientDB.TerminateSession(sessID) + require.ErrorIs( + h.t, err, wtdb.ErrSessionHasUnackedUpdates, + ) + + // If we try to terminate the session through the client + // interface though, it should succeed since the client + // will handle the un-acked updates of the session. + err = h.clientMgr.TerminateSession(sessID) + require.NoError(h.t, err) + + // Fetch the session from the DB and assert that it is + // in the terminal state and that it is not exhausted. + sess, err := h.clientDB.GetClientSession(sessID) + require.NoError(h.t, err) + + require.Equal(h.t, wtdb.CSessionTerminal, sess.Status) + require.NotEqual( + h.t, sess.Policy.MaxUpdates, sess.SeqNum, + ) + + // Restart the server and allow it to ack updates again. + h.server.restart(func(cfg *wtserver.Config) { + cfg.NoAckUpdates = false + }) + + // Wait for the update from before to appear on the + // server. The server will actually have this back-up + // stored twice now since it would have stored it for + // the first session even though it did not send an ACK + // for it. + h.server.waitForUpdates(hints[1:2], waitTime) + + // Now we want to assert that this update was definitely + // not sent on the terminated session but was instead + // sent in a new session. + var ( + updateCounts = make(map[wtdb.SessionID]uint16) + totalUpdates uint16 + ) + sessions, err := h.clientDB.ListClientSessions(nil, + wtdb.WithPerNumAckedUpdates( + func(s *wtdb.ClientSession, + _ lnwire.ChannelID, + num uint16) { + + updateCounts[s.ID] += num + totalUpdates += num + }, + ), + ) + require.NoError(h.t, err) + require.Len(h.t, sessions, 2) + require.EqualValues(h.t, 1, updateCounts[sessID]) + require.EqualValues(h.t, 2, totalUpdates) + }, + }, } // TestClient executes the client test suite, asserting the ability to backup diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go index 17a351c15..f1cae979f 100644 --- a/watchtower/wtclient/manager.go +++ b/watchtower/wtclient/manager.go @@ -43,6 +43,10 @@ type ClientManager interface { // be used while the tower is inactive. DeactivateTower(pubKey *btcec.PublicKey) error + // TerminateSession sets the given session's status to CSessionTerminal + // meaning that it will not be used again. + TerminateSession(id wtdb.SessionID) error + // Stats returns the in-memory statistics of the client since startup. Stats() ClientStats @@ -436,6 +440,23 @@ func (m *Manager) RemoveTower(key *btcec.PublicKey, addr net.Addr) error { return nil } +// TerminateSession sets the given session's status to CSessionTerminal meaning +// that it will not be used again. +func (m *Manager) TerminateSession(id wtdb.SessionID) error { + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + for _, client := range m.clients { + err := client.terminateSession(id) + if err != nil { + return err + } + } + + // Finally, mark the session as terminated in the DB. + return m.cfg.DB.TerminateSession(id) +} + // DeactivateTower sets the given tower's status to inactive so that it is not // considered for session negotiation. Its sessions will also not be used while // the tower is inactive.