discovery: thread contexts through sync manager

Here, we remove one context.TODO() by threading a context through to the
SyncManager.
This commit is contained in:
Elle Mouton
2025-04-07 10:22:24 +02:00
parent 5193a9f82c
commit 3101f2a66e
3 changed files with 23 additions and 16 deletions

View File

@ -678,7 +678,7 @@ func (d *AuthenticatedGossiper) start(ctx context.Context) error {
return err return err
} }
d.syncMgr.Start() d.syncMgr.Start(ctx)
d.banman.start() d.banman.start()

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
@ -200,8 +201,9 @@ type SyncManager struct {
// number of queries. // number of queries.
rateLimiter *rate.Limiter rateLimiter *rate.Limiter
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
cancel fn.Option[context.CancelFunc]
} }
// newSyncManager constructs a new SyncManager backed by the given config. // newSyncManager constructs a new SyncManager backed by the given config.
@ -246,10 +248,13 @@ func newSyncManager(cfg *SyncManagerCfg) *SyncManager {
} }
// Start starts the SyncManager in order to properly carry out its duties. // Start starts the SyncManager in order to properly carry out its duties.
func (m *SyncManager) Start() { func (m *SyncManager) Start(ctx context.Context) {
m.start.Do(func() { m.start.Do(func() {
ctx, cancel := context.WithCancel(ctx)
m.cancel = fn.Some(cancel)
m.wg.Add(1) m.wg.Add(1)
go m.syncerHandler() go m.syncerHandler(ctx)
}) })
} }
@ -259,6 +264,7 @@ func (m *SyncManager) Stop() {
log.Debugf("SyncManager is stopping") log.Debugf("SyncManager is stopping")
defer log.Debugf("SyncManager stopped") defer log.Debugf("SyncManager stopped")
m.cancel.WhenSome(func(fn context.CancelFunc) { fn() })
close(m.quit) close(m.quit)
m.wg.Wait() m.wg.Wait()
@ -282,7 +288,7 @@ func (m *SyncManager) Stop() {
// much of the public network as possible. // much of the public network as possible.
// //
// NOTE: This must be run as a goroutine. // NOTE: This must be run as a goroutine.
func (m *SyncManager) syncerHandler() { func (m *SyncManager) syncerHandler(ctx context.Context) {
defer m.wg.Done() defer m.wg.Done()
m.cfg.RotateTicker.Resume() m.cfg.RotateTicker.Resume()
@ -380,7 +386,7 @@ func (m *SyncManager) syncerHandler() {
} }
m.syncersMu.Unlock() m.syncersMu.Unlock()
s.Start(context.TODO()) s.Start(ctx)
// Once we create the GossipSyncer, we'll signal to the // Once we create the GossipSyncer, we'll signal to the
// caller that they can proceed since the SyncManager's // caller that they can proceed since the SyncManager's

View File

@ -2,6 +2,7 @@ package discovery
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"reflect" "reflect"
@ -82,7 +83,7 @@ func TestSyncManagerNumActiveSyncers(t *testing.T) {
} }
syncMgr := newPinnedTestSyncManager(numActiveSyncers, pinnedSyncers) syncMgr := newPinnedTestSyncManager(numActiveSyncers, pinnedSyncers)
syncMgr.Start() syncMgr.Start(context.Background())
defer syncMgr.Stop() defer syncMgr.Stop()
// First we'll start by adding the pinned syncers. These should // First we'll start by adding the pinned syncers. These should
@ -134,7 +135,7 @@ func TestSyncManagerNewActiveSyncerAfterDisconnect(t *testing.T) {
// We'll create our test sync manager to have two active syncers. // We'll create our test sync manager to have two active syncers.
syncMgr := newTestSyncManager(2) syncMgr := newTestSyncManager(2)
syncMgr.Start() syncMgr.Start(context.Background())
defer syncMgr.Stop() defer syncMgr.Stop()
// The first will be an active syncer that performs a historical sync // The first will be an active syncer that performs a historical sync
@ -187,7 +188,7 @@ func TestSyncManagerRotateActiveSyncerCandidate(t *testing.T) {
// We'll create our sync manager with three active syncers. // We'll create our sync manager with three active syncers.
syncMgr := newTestSyncManager(1) syncMgr := newTestSyncManager(1)
syncMgr.Start() syncMgr.Start(context.Background())
defer syncMgr.Stop() defer syncMgr.Stop()
// The first syncer registered always performs a historical sync. // The first syncer registered always performs a historical sync.
@ -235,7 +236,7 @@ func TestSyncManagerNoInitialHistoricalSync(t *testing.T) {
t.Parallel() t.Parallel()
syncMgr := newTestSyncManager(0) syncMgr := newTestSyncManager(0)
syncMgr.Start() syncMgr.Start(context.Background())
defer syncMgr.Stop() defer syncMgr.Stop()
// We should not expect any messages from the peer. // We should not expect any messages from the peer.
@ -269,7 +270,7 @@ func TestSyncManagerInitialHistoricalSync(t *testing.T) {
t.Fatal("expected graph to not be considered as synced") t.Fatal("expected graph to not be considered as synced")
} }
syncMgr.Start() syncMgr.Start(context.Background())
defer syncMgr.Stop() defer syncMgr.Stop()
// We should expect to see a QueryChannelRange message with a // We should expect to see a QueryChannelRange message with a
@ -338,7 +339,7 @@ func TestSyncManagerHistoricalSyncOnReconnect(t *testing.T) {
t.Parallel() t.Parallel()
syncMgr := newTestSyncManager(2) syncMgr := newTestSyncManager(2)
syncMgr.Start() syncMgr.Start(context.Background())
defer syncMgr.Stop() defer syncMgr.Stop()
// We should expect to see a QueryChannelRange message with a // We should expect to see a QueryChannelRange message with a
@ -372,7 +373,7 @@ func TestSyncManagerForceHistoricalSync(t *testing.T) {
t.Parallel() t.Parallel()
syncMgr := newTestSyncManager(1) syncMgr := newTestSyncManager(1)
syncMgr.Start() syncMgr.Start(context.Background())
defer syncMgr.Stop() defer syncMgr.Stop()
// We should expect to see a QueryChannelRange message with a // We should expect to see a QueryChannelRange message with a
@ -410,7 +411,7 @@ func TestSyncManagerGraphSyncedAfterHistoricalSyncReplacement(t *testing.T) {
t.Parallel() t.Parallel()
syncMgr := newTestSyncManager(1) syncMgr := newTestSyncManager(1)
syncMgr.Start() syncMgr.Start(context.Background())
defer syncMgr.Stop() defer syncMgr.Stop()
// We should expect to see a QueryChannelRange message with a // We should expect to see a QueryChannelRange message with a
@ -468,7 +469,7 @@ func TestSyncManagerWaitUntilInitialHistoricalSync(t *testing.T) {
// We'll start by creating our test sync manager which will hold up to // We'll start by creating our test sync manager which will hold up to
// 2 active syncers. // 2 active syncers.
syncMgr := newTestSyncManager(numActiveSyncers) syncMgr := newTestSyncManager(numActiveSyncers)
syncMgr.Start() syncMgr.Start(context.Background())
defer syncMgr.Stop() defer syncMgr.Stop()
// We'll go ahead and create our syncers. // We'll go ahead and create our syncers.