diff --git a/discovery/gossiper.go b/discovery/gossiper.go index c1f81df53..6232a92ce 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -678,7 +678,7 @@ func (d *AuthenticatedGossiper) start(ctx context.Context) error { return err } - d.syncMgr.Start() + d.syncMgr.Start(ctx) d.banman.start() diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index a38065d7e..c825d27fd 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -8,6 +8,7 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -200,8 +201,9 @@ type SyncManager struct { // number of queries. rateLimiter *rate.Limiter - wg sync.WaitGroup - quit chan struct{} + wg sync.WaitGroup + quit chan struct{} + cancel fn.Option[context.CancelFunc] } // newSyncManager constructs a new SyncManager backed by the given config. @@ -246,10 +248,13 @@ func newSyncManager(cfg *SyncManagerCfg) *SyncManager { } // Start starts the SyncManager in order to properly carry out its duties. -func (m *SyncManager) Start() { +func (m *SyncManager) Start(ctx context.Context) { m.start.Do(func() { + ctx, cancel := context.WithCancel(ctx) + m.cancel = fn.Some(cancel) + m.wg.Add(1) - go m.syncerHandler() + go m.syncerHandler(ctx) }) } @@ -259,6 +264,7 @@ func (m *SyncManager) Stop() { log.Debugf("SyncManager is stopping") defer log.Debugf("SyncManager stopped") + m.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(m.quit) m.wg.Wait() @@ -282,7 +288,7 @@ func (m *SyncManager) Stop() { // much of the public network as possible. // // NOTE: This must be run as a goroutine. -func (m *SyncManager) syncerHandler() { +func (m *SyncManager) syncerHandler(ctx context.Context) { defer m.wg.Done() m.cfg.RotateTicker.Resume() @@ -380,7 +386,7 @@ func (m *SyncManager) syncerHandler() { } m.syncersMu.Unlock() - s.Start(context.TODO()) + s.Start(ctx) // Once we create the GossipSyncer, we'll signal to the // caller that they can proceed since the SyncManager's diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index 4aff5b631..b8ef93197 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -2,6 +2,7 @@ package discovery import ( "bytes" + "context" "fmt" "io" "reflect" @@ -82,7 +83,7 @@ func TestSyncManagerNumActiveSyncers(t *testing.T) { } syncMgr := newPinnedTestSyncManager(numActiveSyncers, pinnedSyncers) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // First we'll start by adding the pinned syncers. These should @@ -134,7 +135,7 @@ func TestSyncManagerNewActiveSyncerAfterDisconnect(t *testing.T) { // We'll create our test sync manager to have two active syncers. syncMgr := newTestSyncManager(2) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // The first will be an active syncer that performs a historical sync @@ -187,7 +188,7 @@ func TestSyncManagerRotateActiveSyncerCandidate(t *testing.T) { // We'll create our sync manager with three active syncers. syncMgr := newTestSyncManager(1) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // The first syncer registered always performs a historical sync. @@ -235,7 +236,7 @@ func TestSyncManagerNoInitialHistoricalSync(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(0) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should not expect any messages from the peer. @@ -269,7 +270,7 @@ func TestSyncManagerInitialHistoricalSync(t *testing.T) { t.Fatal("expected graph to not be considered as synced") } - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -338,7 +339,7 @@ func TestSyncManagerHistoricalSyncOnReconnect(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(2) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -372,7 +373,7 @@ func TestSyncManagerForceHistoricalSync(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(1) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -410,7 +411,7 @@ func TestSyncManagerGraphSyncedAfterHistoricalSyncReplacement(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(1) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -468,7 +469,7 @@ func TestSyncManagerWaitUntilInitialHistoricalSync(t *testing.T) { // We'll start by creating our test sync manager which will hold up to // 2 active syncers. syncMgr := newTestSyncManager(numActiveSyncers) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We'll go ahead and create our syncers.