mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-01 10:11:11 +02:00
wtclient: add TerminateSession method
This commit is contained in:
@@ -148,6 +148,19 @@ type deactivateTowerMsg struct {
|
||||
errChan chan error
|
||||
}
|
||||
|
||||
// terminateSessMsg is an internal message we'll use within the TowerClient to
|
||||
// signal that a session should be terminated.
|
||||
type terminateSessMsg struct {
|
||||
// id is the session identifier.
|
||||
id wtdb.SessionID
|
||||
|
||||
// errChan is the channel through which we'll send a response back to
|
||||
// the caller when handling their request.
|
||||
//
|
||||
// NOTE: This channel must be buffered.
|
||||
errChan chan error
|
||||
}
|
||||
|
||||
// clientCfg holds the configuration values required by a client.
|
||||
type clientCfg struct {
|
||||
*Config
|
||||
@@ -181,9 +194,10 @@ type client struct {
|
||||
statTicker *time.Ticker
|
||||
stats *clientStats
|
||||
|
||||
newTowers chan *newTowerMsg
|
||||
staleTowers chan *staleTowerMsg
|
||||
deactivateTowers chan *deactivateTowerMsg
|
||||
newTowers chan *newTowerMsg
|
||||
staleTowers chan *staleTowerMsg
|
||||
deactivateTowers chan *deactivateTowerMsg
|
||||
terminateSessions chan *terminateSessMsg
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
@@ -209,16 +223,17 @@ func newClient(cfg *clientCfg) (*client, error) {
|
||||
}
|
||||
|
||||
c := &client{
|
||||
cfg: cfg,
|
||||
log: plog,
|
||||
pipeline: queue,
|
||||
activeSessions: newSessionQueueSet(),
|
||||
statTicker: time.NewTicker(DefaultStatInterval),
|
||||
stats: new(clientStats),
|
||||
newTowers: make(chan *newTowerMsg),
|
||||
staleTowers: make(chan *staleTowerMsg),
|
||||
deactivateTowers: make(chan *deactivateTowerMsg),
|
||||
quit: make(chan struct{}),
|
||||
cfg: cfg,
|
||||
log: plog,
|
||||
pipeline: queue,
|
||||
activeSessions: newSessionQueueSet(),
|
||||
statTicker: time.NewTicker(DefaultStatInterval),
|
||||
stats: new(clientStats),
|
||||
newTowers: make(chan *newTowerMsg),
|
||||
staleTowers: make(chan *staleTowerMsg),
|
||||
deactivateTowers: make(chan *deactivateTowerMsg),
|
||||
terminateSessions: make(chan *terminateSessMsg),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
candidateTowers := newTowerListIterator()
|
||||
@@ -718,6 +733,10 @@ func (c *client) backupDispatcher() {
|
||||
case msg := <-c.deactivateTowers:
|
||||
msg.errChan <- c.handleDeactivateTower(msg)
|
||||
|
||||
// A request has come through to terminate a session.
|
||||
case msg := <-c.terminateSessions:
|
||||
msg.errChan <- c.handleTerminateSession(msg)
|
||||
|
||||
case <-c.quit:
|
||||
return
|
||||
}
|
||||
@@ -807,6 +826,10 @@ func (c *client) backupDispatcher() {
|
||||
case msg := <-c.deactivateTowers:
|
||||
msg.errChan <- c.handleDeactivateTower(msg)
|
||||
|
||||
// A request has come through to terminate a session.
|
||||
case msg := <-c.terminateSessions:
|
||||
msg.errChan <- c.handleTerminateSession(msg)
|
||||
|
||||
case <-c.quit:
|
||||
return
|
||||
}
|
||||
@@ -1074,6 +1097,53 @@ func (c *client) initActiveQueue(s *ClientSession,
|
||||
return sq
|
||||
}
|
||||
|
||||
// terminateSession sets the given session's status to CSessionTerminal meaning
|
||||
// that it will not be used again.
|
||||
func (c *client) terminateSession(id wtdb.SessionID) error {
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
select {
|
||||
case c.terminateSessions <- &terminateSessMsg{
|
||||
id: id,
|
||||
errChan: errChan,
|
||||
}:
|
||||
case <-c.pipeline.quit:
|
||||
return ErrClientExiting
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case <-c.pipeline.quit:
|
||||
return ErrClientExiting
|
||||
}
|
||||
}
|
||||
|
||||
// handleTerminateSession handles a request to terminate a session. It will
|
||||
// first shut down the session if it is part of the active session set, then
|
||||
// it will ensure that the active session queue is set reset if it is using the
|
||||
// session in question. Finally, the session's status in the DB will be updated.
|
||||
func (c *client) handleTerminateSession(msg *terminateSessMsg) error {
|
||||
id := msg.id
|
||||
|
||||
delete(c.candidateSessions, id)
|
||||
|
||||
err := c.activeSessions.StopAndRemove(id, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not stop session %s: %w", id, err)
|
||||
}
|
||||
|
||||
// If our active session queue corresponds to the session being
|
||||
// terminated, then we'll proceed to negotiate a new one.
|
||||
if c.sessionQueue != nil {
|
||||
if bytes.Equal(c.sessionQueue.ID()[:], id[:]) {
|
||||
c.sessionQueue = nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deactivateTower sends a tower deactivation request to the backupDispatcher
|
||||
// where it will be handled synchronously. The request should result in all the
|
||||
// sessions that we have with the given tower being shutdown and removed from
|
||||
|
@@ -1034,15 +1034,19 @@ func (s *serverHarness) waitForUpdates(hints []blob.BreachHint,
|
||||
// Closure to assert the server's matches are consistent with the hint
|
||||
// set.
|
||||
serverHasHints := func(matches []wtdb.Match) bool {
|
||||
if len(hintSet) != len(matches) {
|
||||
// De-dup the server matches since it might very well have
|
||||
// multiple matches for a hint if that update was backed up on
|
||||
// more than one session.
|
||||
matchHints := make(map[blob.BreachHint]struct{})
|
||||
for _, match := range matches {
|
||||
matchHints[match.Hint] = struct{}{}
|
||||
}
|
||||
|
||||
if len(hintSet) != len(matchHints) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
_, ok := hintSet[match.Hint]
|
||||
require.Truef(s.t, ok, "match %v in db is not in "+
|
||||
"hint set", match.Hint)
|
||||
}
|
||||
require.EqualValues(s.t, hintSet, matchHints)
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -2770,6 +2774,123 @@ var clientTests = []clientTest{
|
||||
h.server.waitForUpdates(hints[numUpdates-1:], waitTime)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "terminate session",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
numUpdates = 10
|
||||
chanIDInt = 0
|
||||
)
|
||||
|
||||
// Advance the channel with a few updates.
|
||||
hints := h.advanceChannelN(chanIDInt, numUpdates)
|
||||
|
||||
// Backup one of these updates and wait for it to
|
||||
// arrive at the server.
|
||||
h.backupStates(chanIDInt, 0, 1, nil)
|
||||
h.server.waitForUpdates(hints[:1], waitTime)
|
||||
|
||||
// Now, restart the server in a state where it will not
|
||||
// ack updates. This will allow us to wait for an update
|
||||
// to be un-acked and persisted.
|
||||
h.server.restart(func(cfg *wtserver.Config) {
|
||||
cfg.NoAckUpdates = true
|
||||
})
|
||||
|
||||
// Backup another update. These should remain in the
|
||||
// client as un-acked.
|
||||
h.backupStates(chanIDInt, 1, 2, nil)
|
||||
|
||||
// Wait for the update to be persisted.
|
||||
fetchUnacked := h.clientDB.FetchSessionCommittedUpdates
|
||||
var sessID wtdb.SessionID
|
||||
err := wait.Predicate(func() bool {
|
||||
sessions, err := h.clientDB.ListClientSessions(
|
||||
nil,
|
||||
)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
var updates []wtdb.CommittedUpdate
|
||||
for id := range sessions {
|
||||
sessID = id
|
||||
updates, err = fetchUnacked(&id)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
return len(updates) == 1
|
||||
}
|
||||
|
||||
return false
|
||||
}, waitTime)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Now try to terminate the session by directly calling
|
||||
// the DB terminate method. This is expected to fail
|
||||
// since the session still has un-acked updates.
|
||||
err = h.clientDB.TerminateSession(sessID)
|
||||
require.ErrorIs(
|
||||
h.t, err, wtdb.ErrSessionHasUnackedUpdates,
|
||||
)
|
||||
|
||||
// If we try to terminate the session through the client
|
||||
// interface though, it should succeed since the client
|
||||
// will handle the un-acked updates of the session.
|
||||
err = h.clientMgr.TerminateSession(sessID)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
// Fetch the session from the DB and assert that it is
|
||||
// in the terminal state and that it is not exhausted.
|
||||
sess, err := h.clientDB.GetClientSession(sessID)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
require.Equal(h.t, wtdb.CSessionTerminal, sess.Status)
|
||||
require.NotEqual(
|
||||
h.t, sess.Policy.MaxUpdates, sess.SeqNum,
|
||||
)
|
||||
|
||||
// Restart the server and allow it to ack updates again.
|
||||
h.server.restart(func(cfg *wtserver.Config) {
|
||||
cfg.NoAckUpdates = false
|
||||
})
|
||||
|
||||
// Wait for the update from before to appear on the
|
||||
// server. The server will actually have this back-up
|
||||
// stored twice now since it would have stored it for
|
||||
// the first session even though it did not send an ACK
|
||||
// for it.
|
||||
h.server.waitForUpdates(hints[1:2], waitTime)
|
||||
|
||||
// Now we want to assert that this update was definitely
|
||||
// not sent on the terminated session but was instead
|
||||
// sent in a new session.
|
||||
var (
|
||||
updateCounts = make(map[wtdb.SessionID]uint16)
|
||||
totalUpdates uint16
|
||||
)
|
||||
sessions, err := h.clientDB.ListClientSessions(nil,
|
||||
wtdb.WithPerNumAckedUpdates(
|
||||
func(s *wtdb.ClientSession,
|
||||
_ lnwire.ChannelID,
|
||||
num uint16) {
|
||||
|
||||
updateCounts[s.ID] += num
|
||||
totalUpdates += num
|
||||
},
|
||||
),
|
||||
)
|
||||
require.NoError(h.t, err)
|
||||
require.Len(h.t, sessions, 2)
|
||||
require.EqualValues(h.t, 1, updateCounts[sessID])
|
||||
require.EqualValues(h.t, 2, totalUpdates)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestClient executes the client test suite, asserting the ability to backup
|
||||
|
@@ -43,6 +43,10 @@ type ClientManager interface {
|
||||
// be used while the tower is inactive.
|
||||
DeactivateTower(pubKey *btcec.PublicKey) error
|
||||
|
||||
// TerminateSession sets the given session's status to CSessionTerminal
|
||||
// meaning that it will not be used again.
|
||||
TerminateSession(id wtdb.SessionID) error
|
||||
|
||||
// Stats returns the in-memory statistics of the client since startup.
|
||||
Stats() ClientStats
|
||||
|
||||
@@ -436,6 +440,23 @@ func (m *Manager) RemoveTower(key *btcec.PublicKey, addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TerminateSession sets the given session's status to CSessionTerminal meaning
|
||||
// that it will not be used again.
|
||||
func (m *Manager) TerminateSession(id wtdb.SessionID) error {
|
||||
m.clientsMu.Lock()
|
||||
defer m.clientsMu.Unlock()
|
||||
|
||||
for _, client := range m.clients {
|
||||
err := client.terminateSession(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, mark the session as terminated in the DB.
|
||||
return m.cfg.DB.TerminateSession(id)
|
||||
}
|
||||
|
||||
// DeactivateTower sets the given tower's status to inactive so that it is not
|
||||
// considered for session negotiation. Its sessions will also not be used while
|
||||
// the tower is inactive.
|
||||
|
Reference in New Issue
Block a user