diff --git a/server.go b/server.go index 7758ebe8c..61b31ffa6 100644 --- a/server.go +++ b/server.go @@ -1512,7 +1512,16 @@ func newServer(cfg *Config, listenAddrs []net.Addr, ) } + fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID + s.towerClient, err = wtclient.New(&wtclient.Config{ + FetchClosedChannel: fetchClosedChannel, + SubscribeChannelEvents: func() (subscribe.Subscription, + error) { + + return s.channelNotifier. + SubscribeChannelEvents() + }, Signer: cc.Wallet.Cfg.Signer, NewAddress: newSweepPkScriptGen(cc.Wallet), SecretKeyRing: s.cc.KeyRing, @@ -1536,6 +1545,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, blob.Type(blob.FlagAnchorChannel) s.anchorTowerClient, err = wtclient.New(&wtclient.Config{ + FetchClosedChannel: fetchClosedChannel, + SubscribeChannelEvents: func() (subscribe.Subscription, + error) { + + return s.channelNotifier. + SubscribeChannelEvents() + }, Signer: cc.Wallet.Cfg.Signer, NewAddress: newSweepPkScriptGen(cc.Wallet), SecretKeyRing: s.cc.KeyRing, diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index fda072840..c0a4c2331 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -2,6 +2,7 @@ package wtclient import ( "bytes" + "errors" "fmt" "net" "sync" @@ -12,10 +13,12 @@ import ( "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" @@ -146,6 +149,16 @@ type Config struct { // transaction. Signer input.Signer + // SubscribeChannelEvents can be used to subscribe to channel event + // notifications. + SubscribeChannelEvents func() (subscribe.Subscription, error) + + // FetchClosedChannel can be used to fetch the info about a closed + // channel. If the channel is not found or not yet closed then + // channeldb.ErrClosedChannelNotFound will be returned. + FetchClosedChannel func(cid lnwire.ChannelID) ( + *channeldb.ChannelCloseSummary, error) + // NewAddress generates a new on-chain sweep pkscript. NewAddress func() ([]byte, error) @@ -269,6 +282,7 @@ type TowerClient struct { staleTowers chan *staleTowerMsg wg sync.WaitGroup + quit chan struct{} forceQuit chan struct{} } @@ -319,6 +333,7 @@ func New(config *Config) (*TowerClient, error) { newTowers: make(chan *newTowerMsg), staleTowers: make(chan *staleTowerMsg), forceQuit: make(chan struct{}), + quit: make(chan struct{}), } // perUpdate is a callback function that will be used to inspect the @@ -364,7 +379,7 @@ func New(config *Config) (*TowerClient, error) { return } - log.Infof("Using private watchtower %s, offering policy %s", + c.log.Infof("Using private watchtower %s, offering policy %s", tower, cfg.Policy) // Add the tower to the set of candidate towers. @@ -540,10 +555,45 @@ func (c *TowerClient) Start() error { } } + chanSub, err := c.cfg.SubscribeChannelEvents() + if err != nil { + returnErr = err + return + } + + // Iterate over the list of registered channels and check if + // any of them can be marked as closed. + for id := range c.summaries { + isClosed, closedHeight, err := c.isChannelClosed(id) + if err != nil { + returnErr = err + return + } + + if !isClosed { + continue + } + + _, err = c.cfg.DB.MarkChannelClosed(id, closedHeight) + if err != nil { + c.log.Errorf("could not mark channel(%s) as "+ + "closed: %v", id, err) + + continue + } + + // Since the channel has been marked as closed, we can + // also remove it from the channel summaries map. + delete(c.summaries, id) + } + + c.wg.Add(1) + go c.handleChannelCloses(chanSub) + // Now start the session negotiator, which will allow us to // request new session as soon as the backupDispatcher starts // up. - err := c.negotiator.Start() + err = c.negotiator.Start() if err != nil { returnErr = err return @@ -591,6 +641,7 @@ func (c *TowerClient) Stop() error { // dispatcher to exit. The backup queue will signal it's // completion to the dispatcher, which releases the wait group // after all tasks have been assigned to session queues. + close(c.quit) c.wg.Wait() // 4. Since all valid tasks have been assigned to session @@ -772,6 +823,82 @@ func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { return c.getOrInitActiveQueue(candidateSession, updates), nil } +// handleChannelCloses listens for channel close events and marks channels as +// closed in the DB. +// +// NOTE: This method MUST be run as a goroutine. +func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) { + defer c.wg.Done() + + c.log.Debugf("Starting channel close handler") + defer c.log.Debugf("Stopping channel close handler") + + for { + select { + case update, ok := <-chanSub.Updates(): + if !ok { + c.log.Debugf("Channel notifier has exited") + return + } + + // We only care about channel-close events. + event, ok := update.(channelnotifier.ClosedChannelEvent) + if !ok { + continue + } + + chanID := lnwire.NewChanIDFromOutPoint( + &event.CloseSummary.ChanPoint, + ) + + c.log.Debugf("Received ClosedChannelEvent for "+ + "channel: %s", chanID) + + err := c.handleClosedChannel( + chanID, event.CloseSummary.CloseHeight, + ) + if err != nil { + c.log.Errorf("Could not handle channel close "+ + "event for channel(%s): %v", chanID, + err) + } + + case <-c.forceQuit: + return + + case <-c.quit: + return + } + } +} + +// handleClosedChannel handles the closure of a single channel. It will mark the +// channel as closed in the DB. +func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, + closeHeight uint32) error { + + c.backupMu.Lock() + defer c.backupMu.Unlock() + + // We only care about channels registered with the tower client. + if _, ok := c.summaries[chanID]; !ok { + return nil + } + + c.log.Debugf("Marking channel(%s) as closed", chanID) + + _, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) + if err != nil { + return fmt.Errorf("could not mark channel(%s) as closed: %w", + chanID, err) + } + + delete(c.summaries, chanID) + delete(c.chanCommitHeights, chanID) + + return nil +} + // backupDispatcher processes events coming from the taskPipeline and is // responsible for detecting when the client needs to renegotiate a session to // fulfill continuing demand. The event loop exits after all tasks have been @@ -1145,6 +1272,22 @@ func (c *TowerClient) initActiveQueue(s *ClientSession, return sq } +// isChanClosed can be used to check if the channel with the given ID has been +// closed. If it has been, the block height in which its closing transaction was +// mined will also be returned. +func (c *TowerClient) isChannelClosed(id lnwire.ChannelID) (bool, uint32, + error) { + + chanSum, err := c.cfg.FetchClosedChannel(id) + if errors.Is(err, channeldb.ErrClosedChannelNotFound) { + return false, 0, nil + } else if err != nil { + return false, 0, err + } + + return true, chanSum.CloseHeight, nil +} + // AddTower adds a new watchtower reachable at the given address and considers // it for new sessions. If the watchtower already exists, then any new addresses // included will be considered when dialing it for session negotiations and diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 1490c6d10..29c4e7a53 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -1,6 +1,7 @@ package wtclient_test import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -16,11 +17,13 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtclient" @@ -393,8 +396,12 @@ type testHarness struct { server *wtserver.Server net *mockNet - mu sync.Mutex - channels map[lnwire.ChannelID]*mockChannel + channelEvents *mockSubscription + sendUpdatesOn bool + + mu sync.Mutex + channels map[lnwire.ChannelID]*mockChannel + closedChannels map[lnwire.ChannelID]uint32 quit chan struct{} } @@ -441,13 +448,50 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { mockNet := newMockNet() clientDB := wtmock.NewClientDB() - clientCfg := &wtclient.Config{ - Signer: signer, - Dial: mockNet.Dial, - DB: clientDB, - AuthDial: mockNet.AuthDial, - SecretKeyRing: wtmock.NewSecretKeyRing(), - Policy: cfg.policy, + h := &testHarness{ + t: t, + cfg: cfg, + signer: signer, + capacity: cfg.localBalance + cfg.remoteBalance, + clientDB: clientDB, + serverAddr: towerAddr, + serverDB: serverDB, + serverCfg: serverCfg, + net: mockNet, + channelEvents: newMockSubscription(t), + channels: make(map[lnwire.ChannelID]*mockChannel), + closedChannels: make(map[lnwire.ChannelID]uint32), + quit: make(chan struct{}), + } + t.Cleanup(func() { + close(h.quit) + }) + + fetchChannel := func(id lnwire.ChannelID) ( + *channeldb.ChannelCloseSummary, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + height, ok := h.closedChannels[id] + if !ok { + return nil, channeldb.ErrClosedChannelNotFound + } + + return &channeldb.ChannelCloseSummary{CloseHeight: height}, nil + } + + h.clientCfg = &wtclient.Config{ + Signer: signer, + SubscribeChannelEvents: func() (subscribe.Subscription, error) { + return h.channelEvents, nil + }, + FetchClosedChannel: fetchChannel, + Dial: mockNet.Dial, + DB: clientDB, + AuthDial: mockNet.AuthDial, + SecretKeyRing: wtmock.NewSecretKeyRing(), + Policy: cfg.policy, NewAddress: func() ([]byte, error) { return addrScript, nil }, @@ -458,24 +502,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { ForceQuitDelay: 10 * time.Second, } - h := &testHarness{ - t: t, - cfg: cfg, - signer: signer, - capacity: cfg.localBalance + cfg.remoteBalance, - clientDB: clientDB, - clientCfg: clientCfg, - serverAddr: towerAddr, - serverDB: serverDB, - serverCfg: serverCfg, - net: mockNet, - channels: make(map[lnwire.ChannelID]*mockChannel), - quit: make(chan struct{}), - } - t.Cleanup(func() { - close(h.quit) - }) - if !cfg.noServerStart { h.startServer() t.Cleanup(h.stopServer) @@ -576,6 +602,41 @@ func (h *testHarness) channel(id uint64) *mockChannel { return c } +// closeChannel marks a channel as closed. +// +// NOTE: The method fails if a channel for id does not exist. +func (h *testHarness) closeChannel(id uint64, height uint32) { + h.t.Helper() + + h.mu.Lock() + defer h.mu.Unlock() + + chanID := chanIDFromInt(id) + + _, ok := h.channels[chanID] + require.Truef(h.t, ok, "unable to fetch channel %d", id) + + h.closedChannels[chanID] = height + delete(h.channels, chanID) + + chanPointHash, err := chainhash.NewHash(chanID[:]) + require.NoError(h.t, err) + + if !h.sendUpdatesOn { + return + } + + h.channelEvents.sendUpdate(channelnotifier.ClosedChannelEvent{ + CloseSummary: &channeldb.ChannelCloseSummary{ + ChanPoint: wire.OutPoint{ + Hash: *chanPointHash, + Index: 0, + }, + CloseHeight: height, + }, + }) +} + // registerChannel registers the channel identified by id with the client. func (h *testHarness) registerChannel(id uint64) { h.t.Helper() @@ -624,7 +685,7 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { err := h.client.BackupState( &chanID, retribution, channeldb.SingleFunderBit, ) - require.ErrorIs(h.t, expErr, err) + require.ErrorIs(h.t, err, expErr) } // sendPayments instructs the channel identified by id to send amt to the remote @@ -770,11 +831,94 @@ func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { require.NoError(h.t, err) } +// relevantSessions returns a list of session IDs that have acked updates for +// the given channel ID. +func (h *testHarness) relevantSessions(chanID uint64) []wtdb.SessionID { + h.t.Helper() + + var ( + sessionIDs []wtdb.SessionID + cID = chanIDFromInt(chanID) + ) + + collectSessions := wtdb.WithPerNumAckedUpdates( + func(session *wtdb.ClientSession, id lnwire.ChannelID, + _ uint16) { + + if !bytes.Equal(id[:], cID[:]) { + return + } + + sessionIDs = append(sessionIDs, session.ID) + }, + ) + + _, err := h.clientDB.ListClientSessions(nil, nil, collectSessions) + require.NoError(h.t, err) + + return sessionIDs +} + +// isSessionClosable returns true if the given session has been marked as +// closable in the DB. +func (h *testHarness) isSessionClosable(id wtdb.SessionID) bool { + h.t.Helper() + + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + _, ok := cs[id] + + return ok +} + +// mockSubscription is a mock subscription client that blocks on sends into the +// updates channel. +type mockSubscription struct { + t *testing.T + updates chan interface{} + + // Embed the subscription interface in this mock so that we satisfy it. + subscribe.Subscription +} + +// newMockSubscription creates a mock subscription. +func newMockSubscription(t *testing.T) *mockSubscription { + t.Helper() + + return &mockSubscription{ + t: t, + updates: make(chan interface{}), + } +} + +// sendUpdate sends an update into our updates channel, mocking the dispatch of +// an update from a subscription server. This call will fail the test if the +// update is not consumed within our timeout. +func (m *mockSubscription) sendUpdate(update interface{}) { + select { + case m.updates <- update: + + case <-time.After(waitTime): + m.t.Fatalf("update: %v timeout", update) + } +} + +// Updates returns the updates channel for the mock. +func (m *mockSubscription) Updates() <-chan interface{} { + return m.updates +} + const ( localBalance = lnwire.MilliSatoshi(100000000) remoteBalance = lnwire.MilliSatoshi(200000000) ) +var defaultTxPolicy = wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, +} + type clientTest struct { name string cfg harnessCfg @@ -791,10 +935,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, noRegisterChan0: true, @@ -825,10 +966,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, }, @@ -860,10 +998,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -927,10 +1062,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, }, @@ -1006,10 +1138,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1062,10 +1191,7 @@ var clientTests = []clientTest{ localBalance: 100000001, // ensure (% amt != 0) remoteBalance: 200000001, // ensure (% amt != 0) policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 1000, }, }, @@ -1106,10 +1232,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1156,10 +1279,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noAckCreateSession: true, @@ -1212,10 +1332,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noAckCreateSession: true, @@ -1274,10 +1391,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 10, }, }, @@ -1333,10 +1447,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1381,10 +1492,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1489,10 +1597,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1557,10 +1662,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noServerStart: true, @@ -1654,6 +1756,142 @@ var clientTests = []clientTest{ }, waitTime) require.NoError(h.t, err) }, + }, { + name: "assert that sessions are correctly marked as closable", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: defaultTxPolicy, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const numUpdates = 5 + + // In this test we assert that a channel is correctly + // marked as closed and that sessions are also correctly + // marked as closable. + + // We start with the sendUpdatesOn parameter set to + // false so that we can test that channels are correctly + // evaluated at startup. + h.sendUpdatesOn = false + + // Advance channel 0 to create all states and back them + // all up. This will saturate the session with updates + // for channel 0 which means that the session should be + // considered closable when channel 0 is closed. + hints := h.advanceChannelN(0, numUpdates) + h.backupStates(0, 0, numUpdates, nil) + h.waitServerUpdates(hints, waitTime) + + // We expect only 1 session to have updates for this + // channel. + sessionIDs := h.relevantSessions(0) + require.Len(h.t, sessionIDs, 1) + + // Since channel 0 is still open, the session should not + // yet be closable. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Close the channel. + h.closeChannel(0, 1) + + // Since updates are currently not being sent, we expect + // the session to still not be marked as closable. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Restart the client. + h.client.ForceQuit() + h.startClient() + + // The session should now have been marked as closable. + err := wait.Predicate(func() bool { + return h.isSessionClosable(sessionIDs[0]) + }, waitTime) + require.NoError(h.t, err) + + // Now we set sendUpdatesOn to true and do the same with + // a new channel. A restart should now not be necessary + // anymore. + h.sendUpdatesOn = true + + h.makeChannel( + 1, h.cfg.localBalance, h.cfg.remoteBalance, + ) + h.registerChannel(1) + + hints = h.advanceChannelN(1, numUpdates) + h.backupStates(1, 0, numUpdates, nil) + h.waitServerUpdates(hints, waitTime) + + // Determine the ID of the session of interest. + sessionIDs = h.relevantSessions(1) + + // We expect only 1 session to have updates for this + // channel. + require.Len(h.t, sessionIDs, 1) + + // Assert that the session is not yet closable since + // the channel is still open. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Now close the channel. + h.closeChannel(1, 1) + + // Since the updates have been turned on, the session + // should now show up as closable. + err = wait.Predicate(func() bool { + return h.isSessionClosable(sessionIDs[0]) + }, waitTime) + require.NoError(h.t, err) + + // Now we test that a session must be exhausted with all + // channels closed before it is seen as closable. + h.makeChannel( + 2, h.cfg.localBalance, h.cfg.remoteBalance, + ) + h.registerChannel(2) + + // Fill up only half of the session updates. + hints = h.advanceChannelN(2, numUpdates) + h.backupStates(2, 0, numUpdates/2, nil) + h.waitServerUpdates(hints[:numUpdates/2], waitTime) + + // Determine the ID of the session of interest. + sessionIDs = h.relevantSessions(2) + + // We expect only 1 session to have updates for this + // channel. + require.Len(h.t, sessionIDs, 1) + + // Now close the channel. + h.closeChannel(2, 1) + + // The session should _not_ be closable due to it not + // being exhausted yet. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Create a new channel. + h.makeChannel( + 3, h.cfg.localBalance, h.cfg.remoteBalance, + ) + h.registerChannel(3) + + hints = h.advanceChannelN(3, numUpdates) + h.backupStates(3, 0, numUpdates, nil) + h.waitServerUpdates(hints, waitTime) + + // Close it. + h.closeChannel(3, 1) + + // Now the session should be closable. + err = wait.Predicate(func() bool { + return h.isSessionClosable(sessionIDs[0]) + }, waitTime) + require.NoError(h.t, err) + }, }, }