diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index 91d147acc..0578bcc9d 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -128,8 +129,9 @@ type ChanStatusManager struct { // become inactive. statusSampleTicker *time.Ticker - wg sync.WaitGroup - quit chan struct{} + wg sync.WaitGroup + quit chan struct{} + cancel fn.Option[context.CancelFunc] } // NewChanStatusManager initializes a new ChanStatusManager using the given @@ -177,12 +179,16 @@ func (m *ChanStatusManager) Start() error { var err error m.started.Do(func() { log.Info("Channel Status Manager starting") - err = m.start() + + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = fn.Some(cancel) + + err = m.start(ctx) }) return err } -func (m *ChanStatusManager) start() error { +func (m *ChanStatusManager) start(ctx context.Context) error { channels, err := m.fetchChannels() if err != nil { return err @@ -190,7 +196,7 @@ func (m *ChanStatusManager) start() error { // Populate the initial states of all confirmed, public channels. for _, c := range channels { - _, err := m.getOrInitChanStatus(c.FundingOutpoint) + _, err := m.getOrInitChanStatus(ctx, c.FundingOutpoint) switch { // If we can't retrieve the edge info for this channel, it may @@ -219,7 +225,7 @@ func (m *ChanStatusManager) start() error { } m.wg.Add(1) - go m.statusManager() + go m.statusManager(ctx) return nil } @@ -230,6 +236,10 @@ func (m *ChanStatusManager) Stop() error { log.Info("Channel Status Manager shutting down...") defer log.Debug("Channel Status Manager shutdown complete") + m.cancel.WhenSome(func(cancel context.CancelFunc) { + cancel() + }) + close(m.quit) m.wg.Wait() }) @@ -333,7 +343,7 @@ func (m *ChanStatusManager) submitRequest(reqChan chan statusRequest, // should be scheduled or broadcast. // // NOTE: This method MUST be run as a goroutine. -func (m *ChanStatusManager) statusManager() { +func (m *ChanStatusManager) statusManager(ctx context.Context) { defer m.wg.Done() for { @@ -341,15 +351,20 @@ func (m *ChanStatusManager) statusManager() { // Process any requests to mark channel as enabled. case req := <-m.enableRequests: - req.errChan <- m.processEnableRequest(req.outpoint, req.manual) + req.errChan <- m.processEnableRequest( + ctx, req.outpoint, req.manual, + ) // Process any requests to mark channel as disabled. case req := <-m.disableRequests: - req.errChan <- m.processDisableRequest(req.outpoint, req.manual) + req.errChan <- m.processDisableRequest( + ctx, req.outpoint, req.manual, + ) - // Process any requests to restore automatic channel state management. + // Process any requests to restore automatic channel state + // management. case req := <-m.autoRequests: - req.errChan <- m.processAutoRequest(req.outpoint) + req.errChan <- m.processAutoRequest(ctx, req.outpoint) // Use long-polling to detect when channels become inactive. case <-m.statusSampleTicker.C: @@ -358,12 +373,12 @@ func (m *ChanStatusManager) statusManager() { // ChanStatusPendingDisabled. The channel will then be // disabled if no request to enable is received before // the ChanDisableTimeout expires. - m.markPendingInactiveChannels() + m.markPendingInactiveChannels(ctx) // Now, do another sweep to disable any channels that // were marked in a prior iteration as pending inactive // if the inactive chan timeout has elapsed. - m.disableInactiveChannels() + m.disableInactiveChannels(ctx) case <-m.quit: return @@ -383,10 +398,10 @@ func (m *ChanStatusManager) statusManager() { // // An update will be broadcast only if the channel is currently disabled, // otherwise no update will be sent on the network. -func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint, - manual bool) error { +func (m *ChanStatusManager) processEnableRequest(ctx context.Context, + outpoint wire.OutPoint, manual bool) error { - curState, err := m.getOrInitChanStatus(outpoint) + curState, err := m.getOrInitChanStatus(ctx, outpoint) if err != nil { return err } @@ -423,7 +438,7 @@ func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint, case ChanStatusDisabled: log.Infof("Announcing channel(%v) enabled", outpoint) - err := m.signAndSendNextUpdate(outpoint, false) + err := m.signAndSendNextUpdate(ctx, outpoint, false) if err != nil { return err } @@ -441,10 +456,10 @@ func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint, // // An update will only be sent if the channel has a status other than // ChanStatusEnabled, otherwise no update will be sent on the network. -func (m *ChanStatusManager) processDisableRequest(outpoint wire.OutPoint, - manual bool) error { +func (m *ChanStatusManager) processDisableRequest(ctx context.Context, + outpoint wire.OutPoint, manual bool) error { - curState, err := m.getOrInitChanStatus(outpoint) + curState, err := m.getOrInitChanStatus(ctx, outpoint) if err != nil { return err } @@ -454,7 +469,7 @@ func (m *ChanStatusManager) processDisableRequest(outpoint wire.OutPoint, log.Infof("Announcing channel(%v) disabled [requested]", outpoint) - err := m.signAndSendNextUpdate(outpoint, true) + err := m.signAndSendNextUpdate(ctx, outpoint, true) if err != nil { return err } @@ -483,8 +498,10 @@ func (m *ChanStatusManager) processDisableRequest(outpoint wire.OutPoint, // which automatic / background requests are ignored). // // No update will be sent on the network. -func (m *ChanStatusManager) processAutoRequest(outpoint wire.OutPoint) error { - curState, err := m.getOrInitChanStatus(outpoint) +func (m *ChanStatusManager) processAutoRequest(ctx context.Context, + outpoint wire.OutPoint) error { + + curState, err := m.getOrInitChanStatus(ctx, outpoint) if err != nil { return err } @@ -505,7 +522,7 @@ func (m *ChanStatusManager) processAutoRequest(outpoint wire.OutPoint) error { // request to enable is received before the scheduled disable is broadcast, or // the channel is successfully re-enabled and channel is returned to an active // state from the POV of the ChanStatusManager. -func (m *ChanStatusManager) markPendingInactiveChannels() { +func (m *ChanStatusManager) markPendingInactiveChannels(ctx context.Context) { channels, err := m.fetchChannels() if err != nil { log.Errorf("Unable to load active channels: %v", err) @@ -515,7 +532,7 @@ func (m *ChanStatusManager) markPendingInactiveChannels() { for _, c := range channels { // Determine the initial status of the active channel, and // populate the entry in the chanStates map. - curState, err := m.getOrInitChanStatus(c.FundingOutpoint) + curState, err := m.getOrInitChanStatus(ctx, c.FundingOutpoint) if err != nil { log.Errorf("Unable to retrieve chan status for "+ "Channel(%v): %v", c.FundingOutpoint, err) @@ -554,7 +571,7 @@ func (m *ChanStatusManager) markPendingInactiveChannels() { // disableInactiveChannels scans through the set of monitored channels, and // broadcast a disable update for any pending inactive channels whose // SendDisableTime has been superseded by the current time. -func (m *ChanStatusManager) disableInactiveChannels() { +func (m *ChanStatusManager) disableInactiveChannels(ctx context.Context) { // Now, disable any channels whose inactive chan timeout has elapsed. now := time.Now() for outpoint, state := range m.chanStates { @@ -573,7 +590,7 @@ func (m *ChanStatusManager) disableInactiveChannels() { "[detected]", outpoint) // Sign an update disabling the channel. - err := m.signAndSendNextUpdate(outpoint, true) + err := m.signAndSendNextUpdate(ctx, outpoint, true) if err != nil { log.Errorf("Unable to sign update disabling "+ "channel(%v): %v", outpoint, err) @@ -626,12 +643,14 @@ func (m *ChanStatusManager) fetchChannels() ([]*channeldb.OpenChannel, error) { // use the current time as the update's timestamp, or increment the old // timestamp by 1 to ensure the update can propagate. If signing is successful, // the new update will be sent out on the network. -func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, - disabled bool) error { +func (m *ChanStatusManager) signAndSendNextUpdate(ctx context.Context, + outpoint wire.OutPoint, disabled bool) error { // Retrieve the latest update for this channel. We'll use this // as our starting point to send the new update. - chanUpdate, private, err := m.fetchLastChanUpdateByOutPoint(outpoint) + chanUpdate, private, err := m.fetchLastChanUpdateByOutPoint( + ctx, outpoint, + ) if err != nil { return err } @@ -651,10 +670,8 @@ func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, // a channel, and crafts a new ChannelUpdate with this policy. Returns an error // in case our ChannelEdgePolicy is not found in the database. Also returns if // the channel is private by checking AuthProof for nil. -func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( - *lnwire.ChannelUpdate1, bool, error) { - - ctx := context.TODO() +func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(ctx context.Context, + op wire.OutPoint) (*lnwire.ChannelUpdate1, bool, error) { // Get the edge info and policies for this channel from the graph. info, edge1, edge2, err := m.cfg.Graph.FetchChannelEdgesByOutpoint( @@ -675,10 +692,10 @@ func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( // ChanStatusEnabled or ChanStatusDisabled, determined by inspecting the bits on // the most recent announcement. An error is returned if the latest update could // not be retrieved. -func (m *ChanStatusManager) loadInitialChanState( +func (m *ChanStatusManager) loadInitialChanState(ctx context.Context, outpoint *wire.OutPoint) (ChannelState, error) { - lastUpdate, _, err := m.fetchLastChanUpdateByOutPoint(*outpoint) + lastUpdate, _, err := m.fetchLastChanUpdateByOutPoint(ctx, *outpoint) if err != nil { return ChannelState{}, err } @@ -701,7 +718,7 @@ func (m *ChanStatusManager) loadInitialChanState( // outpoint. If the chanStates map already contains an entry for the outpoint, // the value in the map is returned. Otherwise, the outpoint's initial status is // computed and updated in the chanStates map before being returned. -func (m *ChanStatusManager) getOrInitChanStatus( +func (m *ChanStatusManager) getOrInitChanStatus(ctx context.Context, outpoint wire.OutPoint) (ChannelState, error) { // Return the current ChannelState from the chanStates map if it is @@ -712,7 +729,7 @@ func (m *ChanStatusManager) getOrInitChanStatus( // Otherwise, determine the initial state based on the last update we // sent for the outpoint. - initialState, err := m.loadInitialChanState(&outpoint) + initialState, err := m.loadInitialChanState(ctx, &outpoint) if err != nil { return ChannelState{}, err }