From 2192bf41551cca5cb7951749fde014ae4e7d91d4 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 13 Nov 2024 10:12:27 +0200 Subject: [PATCH] lnd+chanbackup: thread contexts through Remove four context.TODO()s --- chanbackup/backup.go | 13 +++-- chanbackup/backup_test.go | 10 ++-- chanbackup/pubsub.go | 11 +++-- chanbackup/pubsub_test.go | 16 +++++-- chanbackup/recover.go | 16 ++++--- chanbackup/recover_test.go | 22 +++++---- channel_notifier.go | 5 +- chanrestore.go | 9 +++- discovery/bootstrapper.go | 5 +- lnd.go | 6 +-- pilot.go | 6 ++- rpcserver.go | 27 ++++++----- server.go | 98 +++++++++++++++++++++----------------- 13 files changed, 138 insertions(+), 106 deletions(-) diff --git a/chanbackup/backup.go b/chanbackup/backup.go index 8248e547c..8f5318513 100644 --- a/chanbackup/backup.go +++ b/chanbackup/backup.go @@ -25,11 +25,9 @@ type LiveChannelSource interface { // passed open channel. The backup includes all information required to restore // the channel, as well as addressing information so we can find the peer and // reconnect to them to initiate the protocol. -func assembleChanBackup(addrSource channeldb.AddrSource, +func assembleChanBackup(ctx context.Context, addrSource channeldb.AddrSource, openChan *channeldb.OpenChannel) (*Single, error) { - ctx := context.TODO() - log.Debugf("Crafting backup for ChannelPoint(%v)", openChan.FundingOutpoint) @@ -95,7 +93,8 @@ func buildCloseTxInputs( // FetchBackupForChan attempts to create a plaintext static channel backup for // the target channel identified by its channel point. If we're unable to find // the target channel, then an error will be returned. -func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, +func FetchBackupForChan(ctx context.Context, chanPoint wire.OutPoint, + chanSource LiveChannelSource, addrSource channeldb.AddrSource) (*Single, error) { // First, we'll query the channel source to see if the channel is known @@ -109,7 +108,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, // Once we have the target channel, we can assemble the backup using // the source to obtain any extra information that we may need. - staticChanBackup, err := assembleChanBackup(addrSource, targetChan) + staticChanBackup, err := assembleChanBackup(ctx, addrSource, targetChan) if err != nil { return nil, fmt.Errorf("unable to create chan backup: %w", err) } @@ -119,7 +118,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, // FetchStaticChanBackups will return a plaintext static channel back up for // all known active/open channels within the passed channel source. -func FetchStaticChanBackups(chanSource LiveChannelSource, +func FetchStaticChanBackups(ctx context.Context, chanSource LiveChannelSource, addrSource channeldb.AddrSource) ([]Single, error) { // First, we'll query the backup source for information concerning all @@ -134,7 +133,7 @@ func FetchStaticChanBackups(chanSource LiveChannelSource, // channel. staticChanBackups := make([]Single, 0, len(openChans)) for _, openChan := range openChans { - chanBackup, err := assembleChanBackup(addrSource, openChan) + chanBackup, err := assembleChanBackup(ctx, addrSource, openChan) if err != nil { return nil, err } diff --git a/chanbackup/backup_test.go b/chanbackup/backup_test.go index c0d80c2ef..2da50b7e6 100644 --- a/chanbackup/backup_test.go +++ b/chanbackup/backup_test.go @@ -82,6 +82,7 @@ func (m *mockChannelSource) AddrsForNode(_ context.Context, // can find addresses for and otherwise. func TestFetchBackupForChan(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll make two channels, only one of them will have all the // information we need to construct set of backups for them. @@ -121,7 +122,7 @@ func TestFetchBackupForChan(t *testing.T) { } for i, testCase := range testCases { _, err := FetchBackupForChan( - testCase.chanPoint, chanSource, chanSource, + ctx, testCase.chanPoint, chanSource, chanSource, ) switch { // If this is a valid test case, and we failed, then we'll @@ -142,6 +143,7 @@ func TestFetchBackupForChan(t *testing.T) { // channel source for all channels and construct a Single for each channel. func TestFetchStaticChanBackups(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll make the set of channels that we want to seed the // channel source with. Both channels will be fully populated in the @@ -161,7 +163,7 @@ func TestFetchStaticChanBackups(t *testing.T) { // With the channel source populated, we'll now attempt to create a set // of backups for all the channels. This should succeed, as all items // are populated within the channel source. - backups, err := FetchStaticChanBackups(chanSource, chanSource) + backups, err := FetchStaticChanBackups(ctx, chanSource, chanSource) require.NoError(t, err, "unable to create chan back ups") if len(backups) != numChans { @@ -176,7 +178,7 @@ func TestFetchStaticChanBackups(t *testing.T) { copy(n[:], randomChan2.IdentityPub.SerializeCompressed()) delete(chanSource.addrs, n) - _, err = FetchStaticChanBackups(chanSource, chanSource) + _, err = FetchStaticChanBackups(ctx, chanSource, chanSource) if err == nil { t.Fatalf("query with incomplete information should fail") } @@ -185,7 +187,7 @@ func TestFetchStaticChanBackups(t *testing.T) { // source at all, then we'll fail as well. chanSource = newMockChannelSource() chanSource.failQuery = true - _, err = FetchStaticChanBackups(chanSource, chanSource) + _, err = FetchStaticChanBackups(ctx, chanSource, chanSource) if err == nil { t.Fatalf("query should fail") } diff --git a/chanbackup/pubsub.go b/chanbackup/pubsub.go index 8fa1d5f34..2b5872f18 100644 --- a/chanbackup/pubsub.go +++ b/chanbackup/pubsub.go @@ -2,6 +2,7 @@ package chanbackup import ( "bytes" + "context" "fmt" "net" "os" @@ -81,7 +82,8 @@ type ChannelNotifier interface { // synchronization point to ensure that the chanbackup.SubSwapper does // not miss any channel open or close events in the period between when // it's created, and when it requests the channel subscription. - SubscribeChans(map[wire.OutPoint]struct{}) (*ChannelSubscription, error) + SubscribeChans(context.Context, map[wire.OutPoint]struct{}) ( + *ChannelSubscription, error) } // SubSwapper subscribes to new updates to the open channel state, and then @@ -119,8 +121,9 @@ type SubSwapper struct { // set of channels, and the required interfaces to be notified of new channel // updates, pack a multi backup, and swap the current best backup from its // storage location. -func NewSubSwapper(startingChans []Single, chanNotifier ChannelNotifier, - keyRing keychain.KeyRing, backupSwapper Swapper) (*SubSwapper, error) { +func NewSubSwapper(ctx context.Context, startingChans []Single, + chanNotifier ChannelNotifier, keyRing keychain.KeyRing, + backupSwapper Swapper) (*SubSwapper, error) { // First, we'll subscribe to the latest set of channel updates given // the set of channels we already know of. @@ -128,7 +131,7 @@ func NewSubSwapper(startingChans []Single, chanNotifier ChannelNotifier, for _, chanBackup := range startingChans { knownChans[chanBackup.FundingOutpoint] = struct{}{} } - chanEvents, err := chanNotifier.SubscribeChans(knownChans) + chanEvents, err := chanNotifier.SubscribeChans(ctx, knownChans) if err != nil { return nil, err } diff --git a/chanbackup/pubsub_test.go b/chanbackup/pubsub_test.go index 32694e5a7..c134b91fc 100644 --- a/chanbackup/pubsub_test.go +++ b/chanbackup/pubsub_test.go @@ -1,6 +1,7 @@ package chanbackup import ( + "context" "fmt" "testing" "time" @@ -62,8 +63,8 @@ func newMockChannelNotifier() *mockChannelNotifier { } } -func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) ( - *ChannelSubscription, error) { +func (m *mockChannelNotifier) SubscribeChans(_ context.Context, + _ map[wire.OutPoint]struct{}) (*ChannelSubscription, error) { if m.fail { return nil, fmt.Errorf("fail") @@ -80,6 +81,7 @@ func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) ( // channel subscription, then the entire sub-swapper will fail to start. func TestNewSubSwapperSubscribeFail(t *testing.T) { t.Parallel() + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} @@ -88,7 +90,7 @@ func TestNewSubSwapperSubscribeFail(t *testing.T) { fail: true, } - _, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper) + _, err := NewSubSwapper(ctx, nil, &chanNotifier, keyRing, &swapper) if err == nil { t.Fatalf("expected fail due to lack of subscription") } @@ -152,13 +154,16 @@ func assertExpectedBackupSwap(t *testing.T, swapper *mockSwapper, // multiple time is permitted. func TestSubSwapperIdempotentStartStop(t *testing.T) { t.Parallel() + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} var chanNotifier mockChannelNotifier swapper := newMockSwapper(keyRing) - subSwapper, err := NewSubSwapper(nil, &chanNotifier, keyRing, swapper) + subSwapper, err := NewSubSwapper( + ctx, nil, &chanNotifier, keyRing, swapper, + ) require.NoError(t, err, "unable to init subSwapper") if err := subSwapper.Start(); err != nil { @@ -181,6 +186,7 @@ func TestSubSwapperIdempotentStartStop(t *testing.T) { // the master multi file backup. func TestSubSwapperUpdater(t *testing.T) { t.Parallel() + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} chanNotifier := newMockChannelNotifier() @@ -224,7 +230,7 @@ func TestSubSwapperUpdater(t *testing.T) { // With our channel set created, we'll make a fresh sub swapper // instance to begin our test. subSwapper, err := NewSubSwapper( - initialChanSet, chanNotifier, keyRing, swapper, + ctx, initialChanSet, chanNotifier, keyRing, swapper, ) require.NoError(t, err, "unable to make swapper") if err := subSwapper.Start(); err != nil { diff --git a/chanbackup/recover.go b/chanbackup/recover.go index 033bd695f..daaad6248 100644 --- a/chanbackup/recover.go +++ b/chanbackup/recover.go @@ -1,6 +1,7 @@ package chanbackup import ( + "context" "net" "github.com/btcsuite/btcd/btcec/v2" @@ -29,7 +30,8 @@ type PeerConnector interface { // available addresses. Once this method returns with a non-nil error, // the connector should attempt to persistently connect to the target // peer in the background as a persistent attempt. - ConnectPeer(node *btcec.PublicKey, addrs []net.Addr) error + ConnectPeer(ctx context.Context, node *btcec.PublicKey, + addrs []net.Addr) error } // Recover attempts to recover the static channel state from a set of static @@ -41,7 +43,7 @@ type PeerConnector interface { // well, in order to expose the addressing information required to locate to // and connect to each peer in order to initiate the recovery protocol. // The number of channels that were successfully restored is returned. -func Recover(backups []Single, restorer ChannelRestorer, +func Recover(ctx context.Context, backups []Single, restorer ChannelRestorer, peerConnector PeerConnector) (int, error) { var numRestored int @@ -70,7 +72,7 @@ func Recover(backups []Single, restorer ChannelRestorer, backup.FundingOutpoint) err = peerConnector.ConnectPeer( - backup.RemoteNodePub, backup.Addresses, + ctx, backup.RemoteNodePub, backup.Addresses, ) if err != nil { return numRestored, err @@ -95,7 +97,7 @@ func Recover(backups []Single, restorer ChannelRestorer, // established, then the PeerConnector will continue to attempt to re-establish // a persistent connection in the background. The number of channels that were // successfully restored is returned. -func UnpackAndRecoverSingles(singles PackedSingles, +func UnpackAndRecoverSingles(ctx context.Context, singles PackedSingles, keyChain keychain.KeyRing, restorer ChannelRestorer, peerConnector PeerConnector) (int, error) { @@ -104,7 +106,7 @@ func UnpackAndRecoverSingles(singles PackedSingles, return 0, err } - return Recover(chanBackups, restorer, peerConnector) + return Recover(ctx, chanBackups, restorer, peerConnector) } // UnpackAndRecoverMulti is a one-shot method, that given a set of packed @@ -114,7 +116,7 @@ func UnpackAndRecoverSingles(singles PackedSingles, // established, then the PeerConnector will continue to attempt to re-establish // a persistent connection in the background. The number of channels that were // successfully restored is returned. -func UnpackAndRecoverMulti(packedMulti PackedMulti, +func UnpackAndRecoverMulti(ctx context.Context, packedMulti PackedMulti, keyChain keychain.KeyRing, restorer ChannelRestorer, peerConnector PeerConnector) (int, error) { @@ -123,5 +125,5 @@ func UnpackAndRecoverMulti(packedMulti PackedMulti, return 0, err } - return Recover(chanBackups.StaticBackups, restorer, peerConnector) + return Recover(ctx, chanBackups.StaticBackups, restorer, peerConnector) } diff --git a/chanbackup/recover_test.go b/chanbackup/recover_test.go index c8719cb3f..c3a7f45e1 100644 --- a/chanbackup/recover_test.go +++ b/chanbackup/recover_test.go @@ -2,6 +2,7 @@ package chanbackup import ( "bytes" + "context" "errors" "net" "testing" @@ -39,7 +40,7 @@ type mockPeerConnector struct { callCount int } -func (m *mockPeerConnector) ConnectPeer(_ *btcec.PublicKey, +func (m *mockPeerConnector) ConnectPeer(_ context.Context, _ *btcec.PublicKey, _ []net.Addr) error { if m.fail { @@ -55,6 +56,7 @@ func (m *mockPeerConnector) ConnectPeer(_ *btcec.PublicKey, // recover a set of packed singles. func TestUnpackAndRecoverSingles(t *testing.T) { t.Parallel() + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} @@ -87,7 +89,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // as well chanRestorer.fail = true _, err := UnpackAndRecoverSingles( - packedBackups, keyRing, &chanRestorer, &peerConnector, + ctx, packedBackups, keyRing, &chanRestorer, &peerConnector, ) require.ErrorIs(t, err, errRestoreFail) @@ -97,7 +99,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // well peerConnector.fail = true _, err = UnpackAndRecoverSingles( - packedBackups, keyRing, &chanRestorer, &peerConnector, + ctx, packedBackups, keyRing, &chanRestorer, &peerConnector, ) require.ErrorIs(t, err, errConnectFail) @@ -107,7 +109,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // Next, we'll ensure that if all the interfaces function as expected, // then the channels will properly be unpacked and restored. numRestored, err := UnpackAndRecoverSingles( - packedBackups, keyRing, &chanRestorer, &peerConnector, + ctx, packedBackups, keyRing, &chanRestorer, &peerConnector, ) require.NoError(t, err) require.EqualValues(t, numSingles, numRestored) @@ -124,7 +126,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // If we modify the keyRing, then unpacking should fail. keyRing.Fail = true _, err = UnpackAndRecoverSingles( - packedBackups, keyRing, &chanRestorer, &peerConnector, + ctx, packedBackups, keyRing, &chanRestorer, &peerConnector, ) require.ErrorContains(t, err, "fail") @@ -135,7 +137,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // recover a packed multi. func TestUnpackAndRecoverMulti(t *testing.T) { t.Parallel() - + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} // First, we'll create a number of single chan backups that we'll @@ -171,7 +173,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) { // as well chanRestorer.fail = true _, err = UnpackAndRecoverMulti( - packedMulti, keyRing, &chanRestorer, &peerConnector, + ctx, packedMulti, keyRing, &chanRestorer, &peerConnector, ) require.ErrorIs(t, err, errRestoreFail) @@ -181,7 +183,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) { // well peerConnector.fail = true _, err = UnpackAndRecoverMulti( - packedMulti, keyRing, &chanRestorer, &peerConnector, + ctx, packedMulti, keyRing, &chanRestorer, &peerConnector, ) require.ErrorIs(t, err, errConnectFail) @@ -191,7 +193,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) { // Next, we'll ensure that if all the interfaces function as expected, // then the channels will properly be unpacked and restored. numRestored, err := UnpackAndRecoverMulti( - packedMulti, keyRing, &chanRestorer, &peerConnector, + ctx, packedMulti, keyRing, &chanRestorer, &peerConnector, ) require.NoError(t, err) require.EqualValues(t, numSingles, numRestored) @@ -208,7 +210,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) { // If we modify the keyRing, then unpacking should fail. keyRing.Fail = true _, err = UnpackAndRecoverMulti( - packedMulti, keyRing, &chanRestorer, &peerConnector, + ctx, packedMulti, keyRing, &chanRestorer, &peerConnector, ) require.ErrorContains(t, err, "fail") diff --git a/channel_notifier.go b/channel_notifier.go index 62d02e5ac..8affd48f0 100644 --- a/channel_notifier.go +++ b/channel_notifier.go @@ -32,11 +32,10 @@ type channelNotifier struct { // the channel subscription. // // NOTE: This is part of the chanbackup.ChannelNotifier interface. -func (c *channelNotifier) SubscribeChans(startingChans map[wire.OutPoint]struct{}) ( +func (c *channelNotifier) SubscribeChans(ctx context.Context, + startingChans map[wire.OutPoint]struct{}) ( *chanbackup.ChannelSubscription, error) { - ctx := context.TODO() - ltndLog.Infof("Channel backup proxy channel notifier starting") // TODO(roasbeef): read existing set of chans and diff diff --git a/chanrestore.go b/chanrestore.go index 5b221c105..6daf3922c 100644 --- a/chanrestore.go +++ b/chanrestore.go @@ -1,6 +1,7 @@ package lnd import ( + "context" "fmt" "math" "net" @@ -309,7 +310,9 @@ var _ chanbackup.ChannelRestorer = (*chanDBRestorer)(nil) // as a persistent attempt. // // NOTE: Part of the chanbackup.PeerConnector interface. -func (s *server) ConnectPeer(nodePub *btcec.PublicKey, addrs []net.Addr) error { +func (s *server) ConnectPeer(ctx context.Context, nodePub *btcec.PublicKey, + addrs []net.Addr) error { + // Before we connect to the remote peer, we'll remove any connections // to ensure the new connection is created after this new link/channel // is known. @@ -333,7 +336,9 @@ func (s *server) ConnectPeer(nodePub *btcec.PublicKey, addrs []net.Addr) error { // Attempt to connect to the peer using this full address. If // we're unable to connect to them, then we'll try the next // address in place of it. - err := s.ConnectToPeer(netAddr, true, s.cfg.ConnectionTimeout) + err := s.ConnectToPeer( + ctx, netAddr, true, s.cfg.ConnectionTimeout, + ) // If we're already connected to this peer, then we don't // consider this an error, so we'll exit here. diff --git a/discovery/bootstrapper.go b/discovery/bootstrapper.go index 2e50798db..8eea900a9 100644 --- a/discovery/bootstrapper.go +++ b/discovery/bootstrapper.go @@ -52,11 +52,10 @@ type NetworkPeerBootstrapper interface { // bootstrapper will be queried successively until the target amount is met. If // the ignore map is populated, then the bootstrappers will be instructed to // skip those nodes. -func MultiSourceBootstrap(ignore map[autopilot.NodeID]struct{}, numAddrs uint32, +func MultiSourceBootstrap(ctx context.Context, + ignore map[autopilot.NodeID]struct{}, numAddrs uint32, bootstrappers ...NetworkPeerBootstrapper) ([]*lnwire.NetAddress, error) { - ctx := context.TODO() - // We'll randomly shuffle our bootstrappers before querying them in // order to avoid from querying the same bootstrapper method over and // over, as some of these might tend to provide better/worse results diff --git a/lnd.go b/lnd.go index 07cf463e4..55d5b6386 100644 --- a/lnd.go +++ b/lnd.go @@ -603,7 +603,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // Set up the core server which will listen for incoming peer // connections. server, err := newServer( - cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, + ctx, cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, activeChainControl.Cfg.WalletUnlockParams.ChansToRestore, multiAcceptor, torController, tlsManager, leaderElector, graphSource, implCfg, @@ -616,7 +616,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // used to manage the underlying autopilot agent, starting and stopping // it at will. atplCfg, err := initAutoPilot( - server, cfg.Autopilot, activeChainControl.MinHtlcIn, + ctx, server, cfg.Autopilot, activeChainControl.MinHtlcIn, cfg.ActiveNetParams, ) if err != nil { @@ -736,7 +736,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // case the startup of the subservers do not behave as expected. errChan := make(chan error) go func() { - errChan <- server.Start() + errChan <- server.Start(ctx) }() defer func() { diff --git a/pilot.go b/pilot.go index 11333a072..f91dd21a3 100644 --- a/pilot.go +++ b/pilot.go @@ -1,6 +1,7 @@ package lnd import ( + "context" "errors" "fmt" "net" @@ -136,7 +137,7 @@ var _ autopilot.ChannelController = (*chanController)(nil) // Agent instance based on the passed configuration structs. The agent and all // interfaces needed to drive it won't be launched before the Manager's // StartAgent method is called. -func initAutoPilot(svr *server, cfg *lncfg.AutoPilot, +func initAutoPilot(ctx context.Context, svr *server, cfg *lncfg.AutoPilot, minHTLCIn lnwire.MilliSatoshi, netParams chainreg.BitcoinNetParams) ( *autopilot.ManagerCfg, error) { @@ -224,7 +225,8 @@ func initAutoPilot(svr *server, cfg *lncfg.AutoPilot, } err := svr.ConnectToPeer( - lnAddr, false, svr.cfg.ConnectionTimeout, + ctx, lnAddr, false, + svr.cfg.ConnectionTimeout, ) if err != nil { // If we weren't able to connect to the diff --git a/rpcserver.go b/rpcserver.go index e3ea23f85..ab7d78d17 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -1822,9 +1822,8 @@ func (r *rpcServer) ConnectPeer(ctx context.Context, timeout) } - if err := r.server.ConnectToPeer( - peerAddr, in.Perm, timeout, - ); err != nil { + err = r.server.ConnectToPeer(ctx, peerAddr, in.Perm, timeout) + if err != nil { rpcsLog.Errorf("[connectpeer]: error connecting to peer: %v", err) return nil, err @@ -4539,7 +4538,7 @@ func (r *rpcServer) ListChannels(ctx context.Context, // our list depending on the type of channels requested to us. isActive := peerOnline && linkActive channel, err := createRPCOpenChannel( - r, dbChannel, isActive, in.PeerAliasLookup, + ctx, r, dbChannel, isActive, in.PeerAliasLookup, ) if err != nil { return nil, err @@ -4655,10 +4654,10 @@ func encodeCustomChanData(lnChan *channeldb.OpenChannel) ([]byte, error) { } // createRPCOpenChannel creates an *lnrpc.Channel from the *channeldb.Channel. -func createRPCOpenChannel(r *rpcServer, dbChannel *channeldb.OpenChannel, - isActive, peerAliasLookup bool) (*lnrpc.Channel, error) { +func createRPCOpenChannel(ctx context.Context, r *rpcServer, + dbChannel *channeldb.OpenChannel, isActive, peerAliasLookup bool) ( + *lnrpc.Channel, error) { - ctx := context.TODO() nodePub := dbChannel.IdentityPub nodeID := hex.EncodeToString(nodePub.SerializeCompressed()) chanPoint := dbChannel.FundingOutpoint @@ -5174,7 +5173,8 @@ func (r *rpcServer) SubscribeChannelEvents(req *lnrpc.ChannelEventSubscription, } case channelnotifier.OpenChannelEvent: channel, err := createRPCOpenChannel( - r, event.Channel, true, false, + updateStream.Context(), r, + event.Channel, true, false, ) if err != nil { return err @@ -7852,7 +7852,7 @@ func (r *rpcServer) ExportChannelBackup(ctx context.Context, // the database. If this channel has been closed, or the outpoint is // unknown, then we'll return an error unpackedBackup, err := chanbackup.FetchBackupForChan( - chanPoint, r.server.chanStateDB, r.server.addrSource, + ctx, chanPoint, r.server.chanStateDB, r.server.addrSource, ) if err != nil { return nil, err @@ -8032,7 +8032,7 @@ func (r *rpcServer) ExportAllChannelBackups(ctx context.Context, // First, we'll attempt to read back ups for ALL currently opened // channels from disk. allUnpackedBackups, err := chanbackup.FetchStaticChanBackups( - r.server.chanStateDB, r.server.addrSource, + ctx, r.server.chanStateDB, r.server.addrSource, ) if err != nil { return nil, fmt.Errorf("unable to fetch all static chan "+ @@ -8090,7 +8090,7 @@ func (r *rpcServer) RestoreChannelBackups(ctx context.Context, // out to any peers that we know of which were our prior // channel peers. numRestored, err = chanbackup.UnpackAndRecoverSingles( - chanbackup.PackedSingles(packedBackups), + ctx, chanbackup.PackedSingles(packedBackups), r.server.cc.KeyRing, chanRestorer, r.server, ) if err != nil { @@ -8107,7 +8107,7 @@ func (r *rpcServer) RestoreChannelBackups(ctx context.Context, // channel peers. packedMulti := chanbackup.PackedMulti(packedMultiBackup) numRestored, err = chanbackup.UnpackAndRecoverMulti( - packedMulti, r.server.cc.KeyRing, chanRestorer, + ctx, packedMulti, r.server.cc.KeyRing, chanRestorer, r.server, ) if err != nil { @@ -8167,7 +8167,8 @@ func (r *rpcServer) SubscribeChannelBackups(req *lnrpc.ChannelBackupSubscription // we'll obtains the current set of single channel // backups from disk. chanBackups, err := chanbackup.FetchStaticChanBackups( - r.server.chanStateDB, r.server.addrSource, + updateStream.Context(), r.server.chanStateDB, + r.server.addrSource, ) if err != nil { return fmt.Errorf("unable to fetch all "+ diff --git a/server.go b/server.go index a6416b4ce..431a3ee24 100644 --- a/server.go +++ b/server.go @@ -506,7 +506,7 @@ func noiseDial(idKey keychain.SingleKeyECDH, // newServer creates a new instance of the server which is to listen using the // passed listener address. -func newServer(cfg *Config, listenAddrs []net.Addr, +func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, dbs *DatabaseInstances, cc *chainreg.ChainControl, nodeKeyDesc *keychain.KeyDescriptor, chansToRestore walletunlocker.ChannelsToRecover, @@ -1637,13 +1637,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath) startingChans, err := chanbackup.FetchStaticChanBackups( - s.chanStateDB, s.addrSource, + ctx, s.chanStateDB, s.addrSource, ) if err != nil { return nil, err } s.chanSubSwapper, err = chanbackup.NewSubSwapper( - startingChans, chanNotifier, s.cc.KeyRing, backupFile, + ctx, startingChans, chanNotifier, s.cc.KeyRing, backupFile, ) if err != nil { return nil, err @@ -1805,14 +1805,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // maintaining persistent outbound connections and also accepting new // incoming connections cmgr, err := connmgr.New(&connmgr.Config{ - Listeners: listeners, - OnAccept: s.InboundPeerConnected, + Listeners: listeners, + OnAccept: func(conn net.Conn) { + s.InboundPeerConnected(ctx, conn) + }, RetryDuration: time.Second * 5, TargetOutbound: 100, Dial: noiseDial( nodeKeyECDH, s.cfg.net, s.cfg.ConnectionTimeout, ), - OnConnection: s.OutboundPeerConnected, + OnConnection: func(req *connmgr.ConnReq, conn net.Conn) { + s.OutboundPeerConnected(ctx, req, conn) + }, }) if err != nil { return nil, err @@ -2078,7 +2082,7 @@ func (c cleaner) run() { // NOTE: This function is safe for concurrent access. // //nolint:funlen -func (s *server) Start() error { +func (s *server) Start(ctx context.Context) error { var startErr error // If one sub system fails to start, the following code ensures that the @@ -2289,7 +2293,7 @@ func (s *server) Start() error { } if len(s.chansToRestore.PackedSingleChanBackups) != 0 { _, err := chanbackup.UnpackAndRecoverSingles( - s.chansToRestore.PackedSingleChanBackups, + ctx, s.chansToRestore.PackedSingleChanBackups, s.cc.KeyRing, chanRestorer, s, ) if err != nil { @@ -2300,7 +2304,7 @@ func (s *server) Start() error { } if len(s.chansToRestore.PackedMultiChanBackup) != 0 { _, err := chanbackup.UnpackAndRecoverMulti( - s.chansToRestore.PackedMultiChanBackup, + ctx, s.chansToRestore.PackedMultiChanBackup, s.cc.KeyRing, chanRestorer, s, ) if err != nil { @@ -2365,8 +2369,7 @@ func (s *server) Start() error { } err = s.ConnectToPeer( - peerAddr, true, - s.cfg.ConnectionTimeout, + ctx, peerAddr, true, s.cfg.ConnectionTimeout, ) if err != nil { startErr = fmt.Errorf("unable to connect to "+ @@ -2453,14 +2456,16 @@ func (s *server) Start() error { // dedicated goroutine to maintain a set of persistent // connections. if shouldPeerBootstrap(s.cfg) { - bootstrappers, err := initNetworkBootstrappers(s) + bootstrappers, err := initNetworkBootstrappers(ctx, s) if err != nil { startErr = err return } s.wg.Add(1) - go s.peerBootstrapper(defaultMinPeers, bootstrappers) + go s.peerBootstrapper( + ctx, defaultMinPeers, bootstrappers, + ) } else { srvrLog.Infof("Auto peer bootstrapping is disabled") } @@ -2482,6 +2487,7 @@ func (s *server) Start() error { // NOTE: This function is safe for concurrent access. func (s *server) Stop() error { s.stop.Do(func() { + ctx := context.Background() atomic.StoreInt32(&s.stopping, 1) close(s.quit) @@ -2551,7 +2557,7 @@ func (s *server) Stop() error { // Update channel.backup file. Make sure to do it before // stopping chanSubSwapper. singles, err := chanbackup.FetchStaticChanBackups( - s.chanStateDB, s.addrSource, + ctx, s.chanStateDB, s.addrSource, ) if err != nil { srvrLog.Warnf("failed to fetch channel states: %v", @@ -2816,8 +2822,9 @@ out: // initNetworkBootstrappers initializes a set of network peer bootstrappers // based on the server, and currently active bootstrap mechanisms as defined // within the current configuration. -func initNetworkBootstrappers(s *server) ([]discovery.NetworkPeerBootstrapper, error) { - ctx := context.TODO() +func initNetworkBootstrappers(ctx context.Context, + s *server) ([]discovery.NetworkPeerBootstrapper, error) { + srvrLog.Infof("Initializing peer network bootstrappers!") var bootStrappers []discovery.NetworkPeerBootstrapper @@ -2890,7 +2897,7 @@ func (s *server) createBootstrapIgnorePeers() map[autopilot.NodeID]struct{} { // invariant, we ensure that our node is connected to a diverse set of peers // and that nodes newly joining the network receive an up to date network view // as soon as possible. -func (s *server) peerBootstrapper(numTargetPeers uint32, +func (s *server) peerBootstrapper(ctx context.Context, numTargetPeers uint32, bootstrappers []discovery.NetworkPeerBootstrapper) { defer s.wg.Done() @@ -2900,7 +2907,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, // We'll start off by aggressively attempting connections to peers in // order to be a part of the network as soon as possible. - s.initialPeerBootstrap(ignoreList, numTargetPeers, bootstrappers) + s.initialPeerBootstrap(ctx, ignoreList, numTargetPeers, bootstrappers) // Once done, we'll attempt to maintain our target minimum number of // peers. @@ -2978,7 +2985,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, ignoreList = s.createBootstrapIgnorePeers() peerAddrs, err := discovery.MultiSourceBootstrap( - ignoreList, numNeeded*2, bootstrappers..., + ctx, ignoreList, numNeeded*2, bootstrappers..., ) if err != nil { srvrLog.Errorf("Unable to retrieve bootstrap "+ @@ -2996,7 +3003,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, // country diversity, etc errChan := make(chan error, 1) s.connectToPeer( - a, errChan, + ctx, a, errChan, s.cfg.ConnectionTimeout, ) select { @@ -3027,8 +3034,8 @@ const bootstrapBackOffCeiling = time.Minute * 5 // initialPeerBootstrap attempts to continuously connect to peers on startup // until the target number of peers has been reached. This ensures that nodes // receive an up to date network view as soon as possible. -func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, - numTargetPeers uint32, +func (s *server) initialPeerBootstrap(ctx context.Context, + ignore map[autopilot.NodeID]struct{}, numTargetPeers uint32, bootstrappers []discovery.NetworkPeerBootstrapper) { srvrLog.Debugf("Init bootstrap with targetPeers=%v, bootstrappers=%v, "+ @@ -3087,7 +3094,7 @@ func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, // in order to reach our target. peersNeeded := numTargetPeers - numActivePeers bootstrapAddrs, err := discovery.MultiSourceBootstrap( - ignore, peersNeeded, bootstrappers..., + ctx, ignore, peersNeeded, bootstrappers..., ) if err != nil { srvrLog.Errorf("Unable to retrieve initial bootstrap "+ @@ -3105,7 +3112,8 @@ func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, errChan := make(chan error, 1) go s.connectToPeer( - addr, errChan, s.cfg.ConnectionTimeout, + ctx, addr, errChan, + s.cfg.ConnectionTimeout, ) // We'll only allow this connection attempt to @@ -3783,7 +3791,7 @@ func shouldDropLocalConnection(local, remote *btcec.PublicKey) bool { // connection. // // NOTE: This function is safe for concurrent access. -func (s *server) InboundPeerConnected(conn net.Conn) { +func (s *server) InboundPeerConnected(ctx context.Context, conn net.Conn) { // Exit early if we have already been instructed to shutdown, this // prevents any delayed callbacks from accidentally registering peers. if s.Stopped() { @@ -3853,7 +3861,7 @@ func (s *server) InboundPeerConnected(conn net.Conn) { // We were unable to locate an existing connection with the // target peer, proceed to connect. s.cancelConnReqs(pubStr, nil) - s.peerConnected(conn, nil, true) + s.peerConnected(ctx, conn, nil, true) case nil: // We already have a connection with the incoming peer. If the @@ -3885,7 +3893,7 @@ func (s *server) InboundPeerConnected(conn net.Conn) { s.removePeer(connectedPeer) s.ignorePeerTermination[connectedPeer] = struct{}{} s.scheduledPeerConnection[pubStr] = func() { - s.peerConnected(conn, nil, true) + s.peerConnected(ctx, conn, nil, true) } } } @@ -3893,7 +3901,9 @@ func (s *server) InboundPeerConnected(conn net.Conn) { // OutboundPeerConnected initializes a new peer in response to a new outbound // connection. // NOTE: This function is safe for concurrent access. -func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) { +func (s *server) OutboundPeerConnected(ctx context.Context, + connReq *connmgr.ConnReq, conn net.Conn) { + // Exit early if we have already been instructed to shutdown, this // prevents any delayed callbacks from accidentally registering peers. if s.Stopped() { @@ -3991,7 +4001,7 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) case ErrPeerNotConnected: // We were unable to locate an existing connection with the // target peer, proceed to connect. - s.peerConnected(conn, connReq, false) + s.peerConnected(ctx, conn, connReq, false) case nil: // We already have a connection with the incoming peer. If the @@ -4025,7 +4035,7 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) s.removePeer(connectedPeer) s.ignorePeerTermination[connectedPeer] = struct{}{} s.scheduledPeerConnection[pubStr] = func() { - s.peerConnected(conn, connReq, false) + s.peerConnected(ctx, conn, connReq, false) } } } @@ -4103,8 +4113,8 @@ func (s *server) SubscribeCustomMessages() (*subscribe.Client, error) { // peer by adding it to the server's global list of all active peers, and // starting all the goroutines the peer needs to function properly. The inbound // boolean should be true if the peer initiated the connection to us. -func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, - inbound bool) { +func (s *server) peerConnected(ctx context.Context, conn net.Conn, + connReq *connmgr.ConnReq, inbound bool) { brontideConn := conn.(*brontide.Conn) addr := conn.RemoteAddr() @@ -4258,7 +4268,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, // includes sending and receiving Init messages, which would be a DOS // vector if we held the server's mutex throughout the procedure. s.wg.Add(1) - go s.peerInitializer(p) + go s.peerInitializer(ctx, p) } // addPeer adds the passed peer to the server's global state of all active @@ -4313,7 +4323,7 @@ func (s *server) addPeer(p *peer.Brontide) { // be signaled of the new peer once the method returns. // // NOTE: This MUST be launched as a goroutine. -func (s *server) peerInitializer(p *peer.Brontide) { +func (s *server) peerInitializer(ctx context.Context, p *peer.Brontide) { defer s.wg.Done() pubBytes := p.IdentityKey().SerializeCompressed() @@ -4337,7 +4347,7 @@ func (s *server) peerInitializer(p *peer.Brontide) { // the peer is ever added to the ignorePeerTermination map, indicating // that the server has already handled the removal of this peer. s.wg.Add(1) - go s.peerTerminationWatcher(p, ready) + go s.peerTerminationWatcher(ctx, p, ready) // Start the peer! If an error occurs, we Disconnect the peer, which // will unblock the peerTerminationWatcher. @@ -4382,7 +4392,9 @@ func (s *server) peerInitializer(p *peer.Brontide) { // successfully, otherwise the peer should be disconnected instead. // // NOTE: This MUST be launched as a goroutine. -func (s *server) peerTerminationWatcher(p *peer.Brontide, ready chan struct{}) { +func (s *server) peerTerminationWatcher(ctx context.Context, p *peer.Brontide, + ready chan struct{}) { + defer s.wg.Done() p.WaitForDisconnect(ready) @@ -4471,7 +4483,7 @@ func (s *server) peerTerminationWatcher(p *peer.Brontide, ready chan struct{}) { // We'll ensure that we locate all the peers advertised addresses for // reconnection purposes. - advertisedAddrs, err := s.fetchNodeAdvertisedAddrs(pubKey) + advertisedAddrs, err := s.fetchNodeAdvertisedAddrs(ctx, pubKey) switch { // We found advertised addresses, so use them. case err == nil: @@ -4720,7 +4732,7 @@ func (s *server) removePeer(p *peer.Brontide) { // connection is established, or the initial handshake process fails. // // NOTE: This function is safe for concurrent access. -func (s *server) ConnectToPeer(addr *lnwire.NetAddress, +func (s *server) ConnectToPeer(ctx context.Context, addr *lnwire.NetAddress, perm bool, timeout time.Duration) error { targetPub := string(addr.IdentityKey.SerializeCompressed()) @@ -4782,7 +4794,7 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, // the crypto negotiation breaks down, then return an error to the // caller. errChan := make(chan error, 1) - s.connectToPeer(addr, errChan, timeout) + s.connectToPeer(ctx, addr, errChan, timeout) select { case err := <-errChan: @@ -4795,7 +4807,7 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, // connectToPeer establishes a connection to a remote peer. errChan is used to // notify the caller if the connection attempt has failed. Otherwise, it will be // closed. -func (s *server) connectToPeer(addr *lnwire.NetAddress, +func (s *server) connectToPeer(ctx context.Context, addr *lnwire.NetAddress, errChan chan<- error, timeout time.Duration) { conn, err := brontide.Dial( @@ -4815,7 +4827,7 @@ func (s *server) connectToPeer(addr *lnwire.NetAddress, srvrLog.Tracef("Brontide dialer made local=%v, remote=%v", conn.LocalAddr(), conn.RemoteAddr()) - s.OutboundPeerConnected(nil, conn) + s.OutboundPeerConnected(ctx, nil, conn) } // DisconnectPeer sends the request to server to close the connection with peer @@ -4961,8 +4973,8 @@ func computeNextBackoff(currBackoff, maxBackoff time.Duration) time.Duration { var errNoAdvertisedAddr = errors.New("no advertised address found") // fetchNodeAdvertisedAddrs attempts to fetch the advertised addresses of a node. -func (s *server) fetchNodeAdvertisedAddrs(pub *btcec.PublicKey) ([]net.Addr, error) { - ctx := context.TODO() +func (s *server) fetchNodeAdvertisedAddrs(ctx context.Context, + pub *btcec.PublicKey) ([]net.Addr, error) { vertex, err := route.NewVertexFromBytes(pub.SerializeCompressed()) if err != nil {