multi: Add atomic start/stop functions.

Make sure that each subsystem only starts and stop once. This makes
sure we don't close e.g. quit channels twice.
This commit is contained in:
ziggie
2024-05-08 20:22:21 +01:00
parent 67c5fa9478
commit 08b68bbaf7
4 changed files with 92 additions and 18 deletions

View File

@@ -12,7 +12,9 @@ package chanfitness
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
@@ -48,6 +50,9 @@ var (
// ChannelEventStore maintains a set of event logs for the node's channels to // ChannelEventStore maintains a set of event logs for the node's channels to
// provide insight into the performance and health of channels. // provide insight into the performance and health of channels.
type ChannelEventStore struct { type ChannelEventStore struct {
started atomic.Bool
stopped atomic.Bool
cfg *Config cfg *Config
// peers tracks all of our currently monitored peers and their channels. // peers tracks all of our currently monitored peers and their channels.
@@ -142,7 +147,11 @@ func NewChannelEventStore(config *Config) *ChannelEventStore {
// information from the store. If this function fails, it cancels its existing // information from the store. If this function fails, it cancels its existing
// subscriptions and returns an error. // subscriptions and returns an error.
func (c *ChannelEventStore) Start() error { func (c *ChannelEventStore) Start() error {
log.Info("ChannelEventStore starting") log.Info("ChannelEventStore starting...")
if c.started.Swap(true) {
return fmt.Errorf("ChannelEventStore started more than once")
}
// Create a subscription to channel events. // Create a subscription to channel events.
channelClient, err := c.cfg.SubscribeChannelEvents() channelClient, err := c.cfg.SubscribeChannelEvents()
@@ -198,13 +207,18 @@ func (c *ChannelEventStore) Start() error {
cancel: cancel, cancel: cancel,
}) })
log.Debug("ChannelEventStore started")
return nil return nil
} }
// Stop terminates all goroutines started by the event store. // Stop terminates all goroutines started by the event store.
func (c *ChannelEventStore) Stop() { func (c *ChannelEventStore) Stop() error {
log.Info("ChannelEventStore shutting down...") log.Info("ChannelEventStore shutting down...")
defer log.Debug("ChannelEventStore shutdown complete")
if c.stopped.Swap(true) {
return fmt.Errorf("ChannelEventStore stopped more than once")
}
// Stop the consume goroutine. // Stop the consume goroutine.
close(c.quit) close(c.quit)
@@ -213,6 +227,10 @@ func (c *ChannelEventStore) Stop() {
// Stop the ticker after the goroutine reading from it has exited, to // Stop the ticker after the goroutine reading from it has exited, to
// avoid a race. // avoid a race.
c.cfg.FlapCountTicker.Stop() c.cfg.FlapCountTicker.Stop()
log.Debugf("ChannelEventStore shutdown complete")
return nil
} }
// addChannel checks whether we are already tracking a channel's peer, creates a // addChannel checks whether we are already tracking a channel's peer, creates a

View File

@@ -4,6 +4,7 @@ import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
@@ -33,6 +34,9 @@ var (
// Settle - routes UpdateFulfillHTLC to the originating link. // Settle - routes UpdateFulfillHTLC to the originating link.
// Fail - routes UpdateFailHTLC to the originating link. // Fail - routes UpdateFailHTLC to the originating link.
type InterceptableSwitch struct { type InterceptableSwitch struct {
started atomic.Bool
stopped atomic.Bool
// htlcSwitch is the underline switch // htlcSwitch is the underline switch
htlcSwitch *Switch htlcSwitch *Switch
@@ -201,6 +205,12 @@ func (s *InterceptableSwitch) SetInterceptor(
} }
func (s *InterceptableSwitch) Start() error { func (s *InterceptableSwitch) Start() error {
log.Info("InterceptableSwitch starting...")
if s.started.Swap(true) {
return fmt.Errorf("InterceptableSwitch started more than once")
}
blockEpochStream, err := s.notifier.RegisterBlockEpochNtfn(nil) blockEpochStream, err := s.notifier.RegisterBlockEpochNtfn(nil)
if err != nil { if err != nil {
return err return err
@@ -217,15 +227,25 @@ func (s *InterceptableSwitch) Start() error {
} }
}() }()
log.Debug("InterceptableSwitch started")
return nil return nil
} }
func (s *InterceptableSwitch) Stop() error { func (s *InterceptableSwitch) Stop() error {
log.Info("InterceptableSwitch shutting down...")
if s.stopped.Swap(true) {
return fmt.Errorf("InterceptableSwitch stopped more than once")
}
close(s.quit) close(s.quit)
s.wg.Wait() s.wg.Wait()
s.blockEpochStream.Cancel() s.blockEpochStream.Cancel()
log.Debug("InterceptableSwitch shutdown complete")
return nil return nil
} }

View File

@@ -101,6 +101,9 @@ func (r *htlcReleaseEvent) Less(other queue.PriorityQueueItem) bool {
// created by the daemon. The registry is a thin wrapper around a map in order // created by the daemon. The registry is a thin wrapper around a map in order
// to ensure that all updates/reads are thread safe. // to ensure that all updates/reads are thread safe.
type InvoiceRegistry struct { type InvoiceRegistry struct {
started atomic.Bool
stopped atomic.Bool
sync.RWMutex sync.RWMutex
nextClientID uint32 // must be used atomically nextClientID uint32 // must be used atomically
@@ -213,33 +216,48 @@ func (i *InvoiceRegistry) scanInvoicesOnStart(ctx context.Context) error {
// Start starts the registry and all goroutines it needs to carry out its task. // Start starts the registry and all goroutines it needs to carry out its task.
func (i *InvoiceRegistry) Start() error { func (i *InvoiceRegistry) Start() error {
// Start InvoiceExpiryWatcher and prepopulate it with existing active var err error
// invoices.
err := i.expiryWatcher.Start(func(hash lntypes.Hash, force bool) error { log.Info("InvoiceRegistry starting...")
return i.cancelInvoiceImpl(context.Background(), hash, force)
}) if i.started.Swap(true) {
return fmt.Errorf("InvoiceRegistry started more than once")
}
// Start InvoiceExpiryWatcher and prepopulate it with existing
// active invoices.
err = i.expiryWatcher.Start(
func(hash lntypes.Hash, force bool) error {
return i.cancelInvoiceImpl(
context.Background(), hash, force,
)
})
if err != nil { if err != nil {
return err return err
} }
log.Info("InvoiceRegistry starting")
i.wg.Add(1) i.wg.Add(1)
go i.invoiceEventLoop() go i.invoiceEventLoop()
// Now scan all pending and removable invoices to the expiry watcher or // Now scan all pending and removable invoices to the expiry
// delete them. // watcher or delete them.
err = i.scanInvoicesOnStart(context.Background()) err = i.scanInvoicesOnStart(context.Background())
if err != nil { if err != nil {
_ = i.Stop() _ = i.Stop()
return err
} }
return nil log.Debug("InvoiceRegistry started")
return err
} }
// Stop signals the registry for a graceful shutdown. // Stop signals the registry for a graceful shutdown.
func (i *InvoiceRegistry) Stop() error { func (i *InvoiceRegistry) Stop() error {
log.Info("InvoiceRegistry shutting down...")
if i.stopped.Swap(true) {
return fmt.Errorf("InvoiceRegistry stopped more than once")
}
log.Info("InvoiceRegistry shutting down...") log.Info("InvoiceRegistry shutting down...")
defer log.Debug("InvoiceRegistry shutdown complete") defer log.Debug("InvoiceRegistry shutdown complete")
@@ -248,6 +266,9 @@ func (i *InvoiceRegistry) Stop() error {
close(i.quit) close(i.quit)
i.wg.Wait() i.wg.Wait()
log.Debug("InvoiceRegistry shutdown complete")
return nil return nil
} }

View File

@@ -261,6 +261,9 @@ type TxPublisherConfig struct {
// until the tx is confirmed or the fee rate reaches the maximum fee rate // until the tx is confirmed or the fee rate reaches the maximum fee rate
// specified by the caller. // specified by the caller.
type TxPublisher struct { type TxPublisher struct {
started atomic.Bool
stopped atomic.Bool
wg sync.WaitGroup wg sync.WaitGroup
// cfg specifies the configuration of the TxPublisher. // cfg specifies the configuration of the TxPublisher.
@@ -666,7 +669,10 @@ type monitorRecord struct {
// off the monitor loop. // off the monitor loop.
func (t *TxPublisher) Start() error { func (t *TxPublisher) Start() error {
log.Info("TxPublisher starting...") log.Info("TxPublisher starting...")
defer log.Debugf("TxPublisher started")
if t.started.Swap(true) {
return fmt.Errorf("TxPublisher started more than once")
}
blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil) blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil)
if err != nil { if err != nil {
@@ -676,17 +682,26 @@ func (t *TxPublisher) Start() error {
t.wg.Add(1) t.wg.Add(1)
go t.monitor(blockEvent) go t.monitor(blockEvent)
log.Debugf("TxPublisher started")
return nil return nil
} }
// Stop stops the publisher and waits for the monitor loop to exit. // Stop stops the publisher and waits for the monitor loop to exit.
func (t *TxPublisher) Stop() { func (t *TxPublisher) Stop() error {
log.Info("TxPublisher stopping...") log.Info("TxPublisher stopping...")
defer log.Debugf("TxPublisher stopped")
if t.stopped.Swap(true) {
return fmt.Errorf("TxPublisher stopped more than once")
}
close(t.quit) close(t.quit)
t.wg.Wait() t.wg.Wait()
log.Debug("TxPublisher stopped")
return nil
} }
// monitor is the main loop driven by new blocks. Whevenr a new block arrives, // monitor is the main loop driven by new blocks. Whevenr a new block arrives,