From 13fcb087944e4674999d10bdadfea847e0ec46d3 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 9 Apr 2025 06:25:29 +0200 Subject: [PATCH] autopilot: start threading contexts through The `GraphSource` interface in the `autopilot` package is directly implemented by the `graphdb.KVStore` and so we will eventually thread contexts through to this interface. So in this commit, we start updating the autopilot system to thread contexts through in preparation for passing the context through to any calls made to the GraphSource. Two context.TODOs are added here which will be addressed in follow up commits. --- autopilot/agent.go | 34 +++++++++----- autopilot/agent_test.go | 3 +- autopilot/betweenness_centrality.go | 5 ++- autopilot/graph.go | 33 +++++++++----- autopilot/interface.go | 6 ++- autopilot/manager.go | 18 ++++++-- autopilot/prefattach.go | 19 +++++--- autopilot/prefattach_test.go | 62 ++++++++++++++++---------- autopilot/simple_graph.go | 25 +++++++---- discovery/bootstrapper.go | 6 ++- lnd.go | 2 +- lnrpc/autopilotrpc/autopilot_server.go | 2 +- rpcserver.go | 2 +- 13 files changed, 146 insertions(+), 71 deletions(-) diff --git a/autopilot/agent.go b/autopilot/agent.go index d9c35a685..7b951a66d 100644 --- a/autopilot/agent.go +++ b/autopilot/agent.go @@ -2,6 +2,7 @@ package autopilot import ( "bytes" + "context" "fmt" "math/rand" "net" @@ -11,6 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) @@ -166,8 +168,9 @@ type Agent struct { pendingOpens map[NodeID]LocalChannel pendingMtx sync.Mutex - quit chan struct{} - wg sync.WaitGroup + quit chan struct{} + wg sync.WaitGroup + cancel fn.Option[context.CancelFunc] } // New creates a new instance of the Agent instantiated using the passed @@ -199,20 +202,23 @@ func New(cfg Config, initialState []LocalChannel) (*Agent, error) { // Start starts the agent along with any goroutines it needs to perform its // normal duties. -func (a *Agent) Start() error { +func (a *Agent) Start(ctx context.Context) error { var err error a.started.Do(func() { - err = a.start() + ctx, cancel := context.WithCancel(ctx) + a.cancel = fn.Some(cancel) + + err = a.start(ctx) }) return err } -func (a *Agent) start() error { +func (a *Agent) start(ctx context.Context) error { rand.Seed(time.Now().Unix()) log.Infof("Autopilot Agent starting") a.wg.Add(1) - go a.controller() + go a.controller(ctx) return nil } @@ -230,6 +236,7 @@ func (a *Agent) Stop() error { func (a *Agent) stop() error { log.Infof("Autopilot Agent stopping") + a.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(a.quit) a.wg.Wait() @@ -401,7 +408,7 @@ func mergeChanState(pendingChans map[NodeID]LocalChannel, // and external state changes as a result of decisions it makes w.r.t channel // allocation, or attributes affecting its control loop being updated by the // backing Lightning Node. -func (a *Agent) controller() { +func (a *Agent) controller(ctx context.Context) { defer a.wg.Done() // We'll start off by assigning our starting balance, and injecting @@ -502,6 +509,9 @@ func (a *Agent) controller() { // immediately. case <-a.quit: return + + case <-ctx.Done(): + return } a.pendingMtx.Lock() @@ -539,7 +549,7 @@ func (a *Agent) controller() { log.Infof("Triggering attachment directive dispatch, "+ "total_funds=%v", a.totalBalance) - err := a.openChans(availableFunds, numChans, totalChans) + err := a.openChans(ctx, availableFunds, numChans, totalChans) if err != nil { log.Errorf("Unable to open channels: %v", err) } @@ -548,8 +558,8 @@ func (a *Agent) controller() { // openChans queries the agent's heuristic for a set of channel candidates, and // attempts to open channels to them. -func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32, - totalChans []LocalChannel) error { +func (a *Agent) openChans(ctx context.Context, availableFunds btcutil.Amount, + numChans uint32, totalChans []LocalChannel) error { // As channel size we'll use the maximum channel size available. chanSize := a.cfg.Constraints.MaxChanSize() @@ -598,7 +608,9 @@ func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32, selfPubBytes := a.cfg.Self.SerializeCompressed() nodes := make(map[NodeID]struct{}) addresses := make(map[NodeID][]net.Addr) - if err := a.cfg.Graph.ForEachNode(func(node Node) error { + if err := a.cfg.Graph.ForEachNode(ctx, func(_ context.Context, + node Node) error { + nID := NodeID(node.PubKey()) // If we come across ourselves, them we'll continue in diff --git a/autopilot/agent_test.go b/autopilot/agent_test.go index 39e86906e..9b3c30bb5 100644 --- a/autopilot/agent_test.go +++ b/autopilot/agent_test.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "errors" "fmt" "net" @@ -220,7 +221,7 @@ func setup(t *testing.T, initialChans []LocalChannel) *testContext { // With the autopilot agent and all its dependencies we'll start the // primary controller goroutine. - if err := agent.Start(); err != nil { + if err := agent.Start(context.Background()); err != nil { t.Fatalf("unable to start agent: %v", err) } diff --git a/autopilot/betweenness_centrality.go b/autopilot/betweenness_centrality.go index db45bcf66..2a45fe1f1 100644 --- a/autopilot/betweenness_centrality.go +++ b/autopilot/betweenness_centrality.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "fmt" "sync" ) @@ -169,7 +170,9 @@ func betweennessCentrality(g *SimpleGraph, s int, centrality []float64) { // Refresh recalculates and stores centrality values. func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error { - cache, err := NewSimpleGraph(graph) + ctx := context.TODO() + + cache, err := NewSimpleGraph(ctx, graph) if err != nil { return err } diff --git a/autopilot/graph.go b/autopilot/graph.go index c8b54082a..d20c6316e 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "encoding/hex" "net" "sort" @@ -80,7 +81,9 @@ func (d *dbNode) Addrs() []net.Addr { // describes the active channel. // // NOTE: Part of the autopilot.Node interface. -func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { +func (d *dbNode) ForEachChannel(ctx context.Context, + cb func(context.Context, ChannelEdge) error) error { + return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep, _ *models.ChannelEdgePolicy) error { @@ -108,7 +111,7 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { }, } - return cb(edge) + return cb(ctx, edge) }) } @@ -117,7 +120,9 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { // error, then execution should be terminated. // // NOTE: Part of the autopilot.ChannelGraph interface. -func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { +func (d *databaseChannelGraph) ForEachNode(ctx context.Context, + cb func(context.Context, Node) error) error { + return d.db.ForEachNode(func(nodeTx graphdb.NodeRTx) error { // We'll skip over any node that doesn't have any advertised // addresses. As we won't be able to reach them to actually @@ -129,7 +134,8 @@ func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { node := &dbNode{ tx: nodeTx, } - return cb(node) + + return cb(ctx, node) }) } @@ -185,7 +191,9 @@ func (nc dbNodeCached) Addrs() []net.Addr { // describes the active channel. // // NOTE: Part of the autopilot.Node interface. -func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { +func (nc dbNodeCached) ForEachChannel(ctx context.Context, + cb func(context.Context, ChannelEdge) error) error { + for cid, channel := range nc.channels { edge := ChannelEdge{ ChanID: lnwire.NewShortChanIDFromInt(cid), @@ -195,7 +203,7 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { }, } - if err := cb(edge); err != nil { + if err := cb(ctx, edge); err != nil { return err } } @@ -208,7 +216,9 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { // error, then execution should be terminated. // // NOTE: Part of the autopilot.ChannelGraph interface. -func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error { +func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context, + cb func(context.Context, Node) error) error { + return dc.db.ForEachNodeCached(func(n route.Vertex, channels map[uint64]*graphdb.DirectedChannel) error { @@ -217,7 +227,8 @@ func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error { node: n, channels: channels, } - return cb(node) + + return cb(ctx, node) } return nil }) @@ -262,9 +273,11 @@ func (m memNode) Addrs() []net.Addr { // describes the active channel. // // NOTE: Part of the autopilot.Node interface. -func (m memNode) ForEachChannel(cb func(ChannelEdge) error) error { +func (m memNode) ForEachChannel(ctx context.Context, + cb func(context.Context, ChannelEdge) error) error { + for _, channel := range m.chans { - if err := cb(channel); err != nil { + if err := cb(ctx, channel); err != nil { return err } } diff --git a/autopilot/interface.go b/autopilot/interface.go index 35182a760..0991d9864 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "net" "github.com/btcsuite/btcd/btcec/v2" @@ -35,7 +36,8 @@ type Node interface { // iterate through all edges emanating from/to the target node. For // each active channel, this function should be called with the // populated ChannelEdge that describes the active channel. - ForEachChannel(func(ChannelEdge) error) error + ForEachChannel(context.Context, func(context.Context, + ChannelEdge) error) error } // LocalChannel is a simple struct which contains relevant details of a @@ -83,7 +85,7 @@ type ChannelGraph interface { // ForEachNode is a higher-order function that should be called once // for each connected node within the channel graph. If the passed // callback returns an error, then execution should be terminated. - ForEachNode(func(Node) error) error + ForEachNode(context.Context, func(context.Context, Node) error) error } // NodeScore is a tuple mapping a NodeID to a score indicating the preference diff --git a/autopilot/manager.go b/autopilot/manager.go index 0463f98d9..036bf3a31 100644 --- a/autopilot/manager.go +++ b/autopilot/manager.go @@ -1,11 +1,13 @@ package autopilot import ( + "context" "fmt" "sync" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -53,8 +55,9 @@ type Manager struct { // disabled. pilot *Agent - quit chan struct{} - wg sync.WaitGroup + quit chan struct{} + wg sync.WaitGroup + cancel fn.Option[context.CancelFunc] sync.Mutex } @@ -80,6 +83,7 @@ func (m *Manager) Stop() error { log.Errorf("Unable to stop pilot: %v", err) } + m.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(m.quit) m.wg.Wait() }) @@ -96,7 +100,7 @@ func (m *Manager) IsActive() bool { // StartAgent creates and starts an autopilot agent from the Manager's // config. -func (m *Manager) StartAgent() error { +func (m *Manager) StartAgent(ctx context.Context) error { m.Lock() defer m.Unlock() @@ -104,6 +108,8 @@ func (m *Manager) StartAgent() error { if m.pilot != nil { return nil } + ctx, cancel := context.WithCancel(ctx) + m.cancel = fn.Some(cancel) // Next, we'll fetch the current state of open channels from the // database to use as initial state for the auto-pilot agent. @@ -119,7 +125,7 @@ func (m *Manager) StartAgent() error { return err } - if err := pilot.Start(); err != nil { + if err := pilot.Start(ctx); err != nil { return err } @@ -163,6 +169,8 @@ func (m *Manager) StartAgent() error { return case <-m.quit: return + case <-ctx.Done(): + return } } @@ -233,6 +241,8 @@ func (m *Manager) StartAgent() error { return case <-m.quit: return + case <-ctx.Done(): + return } } }() diff --git a/autopilot/prefattach.go b/autopilot/prefattach.go index 4f4ff635f..4f55e87ea 100644 --- a/autopilot/prefattach.go +++ b/autopilot/prefattach.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" prand "math/rand" "time" @@ -82,14 +83,18 @@ func (p *PrefAttachment) NodeScores(g ChannelGraph, chans []LocalChannel, chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( map[NodeID]*NodeScore, error) { + ctx := context.TODO() + // We first run though the graph once in order to find the median // channel size. var ( allChans []btcutil.Amount seenChans = make(map[uint64]struct{}) ) - if err := g.ForEachNode(func(n Node) error { - err := n.ForEachChannel(func(e ChannelEdge) error { + if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error { + err := n.ForEachChannel(ctx, func(_ context.Context, + e ChannelEdge) error { + if _, ok := seenChans[e.ChanID.ToUint64()]; ok { return nil } @@ -114,15 +119,19 @@ func (p *PrefAttachment) NodeScores(g ChannelGraph, chans []LocalChannel, // the graph. var maxChans int nodeChanNum := make(map[NodeID]int) - if err := g.ForEachNode(func(n Node) error { + if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error { var nodeChans int - err := n.ForEachChannel(func(e ChannelEdge) error { + err := n.ForEachChannel(ctx, func(_ context.Context, + e ChannelEdge) error { + // Since connecting to nodes with a lot of small // channels actually worsens our connectivity in the // graph (we will potentially waste time trying to use // these useless channels in path finding), we decrease // the counter for such channels. - if e.Capacity < medianChanSize/minMedianChanSizeFraction { + if e.Capacity < + medianChanSize/minMedianChanSizeFraction { + nodeChans-- return nil } diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index f20c3a480..7dec5f49f 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -2,6 +2,7 @@ package autopilot import ( "bytes" + "context" "errors" prand "math/rand" "net" @@ -126,6 +127,7 @@ func TestPrefAttachmentSelectEmptyGraph(t *testing.T) { // and the funds are appropriately allocated across each peer. func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { t.Parallel() + ctx := context.Background() prand.Seed(time.Now().Unix()) @@ -156,10 +158,12 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { // Get the score for all nodes found in the graph at // this point. nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(func(n Node) error { - nodes[n.PubKey()] = struct{}{} - return nil - }) + err = graph.ForEachNode(ctx, + func(_ context.Context, n Node) error { + nodes[n.PubKey()] = struct{}{} + return nil + }, + ) require.NoError(t1, err) require.Len(t1, nodes, 3) @@ -207,6 +211,7 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { // allocate all funds to each vertex (up to the max channel size). func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { t.Parallel() + ctx := context.Background() prand.Seed(time.Now().Unix()) @@ -245,22 +250,25 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { numNodes := 0 twoChans := false nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(func(n Node) error { - numNodes++ - nodes[n.PubKey()] = struct{}{} - numChans := 0 - err := n.ForEachChannel(func(c ChannelEdge) error { - numChans++ + err = graph.ForEachNode( + ctx, func(ctx context.Context, n Node) error { + numNodes++ + nodes[n.PubKey()] = struct{}{} + numChans := 0 + err := n.ForEachChannel(ctx, + func(_ context.Context, c ChannelEdge) error { //nolint:ll + numChans++ + return nil + }, + ) + if err != nil { + return err + } + + twoChans = twoChans || (numChans == 2) + return nil }) - if err != nil { - return err - } - - twoChans = twoChans || (numChans == 2) - - return nil - }) require.NoError(t1, err) require.EqualValues(t1, 3, numNodes) @@ -313,6 +321,7 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { // of zero during scoring. func TestPrefAttachmentSelectSkipNodes(t *testing.T) { t.Parallel() + ctx := context.Background() prand.Seed(time.Now().Unix()) @@ -335,10 +344,13 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) { require.NoError(t1, err) nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(func(n Node) error { - nodes[n.PubKey()] = struct{}{} - return nil - }) + err = graph.ForEachNode( + ctx, func(_ context.Context, n Node) error { + nodes[n.PubKey()] = struct{}{} + + return nil + }, + ) require.NoError(t1, err) require.Len(t1, nodes, 2) @@ -583,9 +595,11 @@ func newMemChannelGraph() *memChannelGraph { // error, then execution should be terminated. // // NOTE: Part of the autopilot.ChannelGraph interface. -func (m *memChannelGraph) ForEachNode(cb func(Node) error) error { +func (m *memChannelGraph) ForEachNode(ctx context.Context, + cb func(context.Context, Node) error) error { + for _, node := range m.graph { - if err := cb(node); err != nil { + if err := cb(ctx, node); err != nil { return err } } diff --git a/autopilot/simple_graph.go b/autopilot/simple_graph.go index 4d294b3f2..f028db3c7 100644 --- a/autopilot/simple_graph.go +++ b/autopilot/simple_graph.go @@ -1,5 +1,7 @@ package autopilot +import "context" + // diameterCutoff is used to discard nodes in the diameter calculation. // It is the multiplier for the eccentricity of the highest-degree node, // serving as a cutoff to discard all nodes with a smaller hop distance. This @@ -20,7 +22,7 @@ type SimpleGraph struct { // NewSimpleGraph creates a simplified graph from the current channel graph. // Returns an error if the channel graph iteration fails due to underlying // failure. -func NewSimpleGraph(g ChannelGraph) (*SimpleGraph, error) { +func NewSimpleGraph(ctx context.Context, g ChannelGraph) (*SimpleGraph, error) { nodes := make(map[NodeID]int) adj := make(map[int][]int) nextIndex := 0 @@ -42,17 +44,22 @@ func NewSimpleGraph(g ChannelGraph) (*SimpleGraph, error) { return nodeIndex } - // Iterate over each node and each channel and update the adj and the node - // index. - err := g.ForEachNode(func(node Node) error { + // Iterate over each node and each channel and update the adj and the + // node index. + err := g.ForEachNode(ctx, func(ctx context.Context, node Node) error { u := getNodeIndex(node) - return node.ForEachChannel(func(edge ChannelEdge) error { - v := getNodeIndex(edge.Peer) + return node.ForEachChannel( + ctx, func(_ context.Context, + edge ChannelEdge) error { - adj[u] = append(adj[u], v) - return nil - }) + v := getNodeIndex(edge.Peer) + + adj[u] = append(adj[u], v) + + return nil + }, + ) }) if err != nil { return nil, err diff --git a/discovery/bootstrapper.go b/discovery/bootstrapper.go index d07a5f852..1a0f99735 100644 --- a/discovery/bootstrapper.go +++ b/discovery/bootstrapper.go @@ -161,6 +161,8 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, numAddrs uint32, ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) { + ctx := context.TODO() + // We'll merge the ignore map with our currently selected map in order // to ensure we don't return any duplicate nodes. for n := range ignore { @@ -183,7 +185,9 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, errFound = fmt.Errorf("found node") ) - err := c.chanGraph.ForEachNode(func(node autopilot.Node) error { + err := c.chanGraph.ForEachNode(ctx, func(_ context.Context, + node autopilot.Node) error { + nID := autopilot.NodeID(node.PubKey()) if _, ok := c.tried[nID]; ok { return nil diff --git a/lnd.go b/lnd.go index 3afa8c2fb..41bd3ca4b 100644 --- a/lnd.go +++ b/lnd.go @@ -788,7 +788,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // active, then we'll start the autopilot agent immediately. It will be // stopped together with the autopilot service. if cfg.Autopilot.Active { - if err := atplManager.StartAgent(); err != nil { + if err := atplManager.StartAgent(ctx); err != nil { return mkErr("unable to start autopilot agent", err) } } diff --git a/lnrpc/autopilotrpc/autopilot_server.go b/lnrpc/autopilotrpc/autopilot_server.go index 23d0ff5f3..3e3c6f8f7 100644 --- a/lnrpc/autopilotrpc/autopilot_server.go +++ b/lnrpc/autopilotrpc/autopilot_server.go @@ -205,7 +205,7 @@ func (s *Server) ModifyStatus(ctx context.Context, var err error if in.Enable { - err = s.manager.StartAgent() + err = s.manager.StartAgent(ctx) } else { err = s.manager.StopAgent() } diff --git a/rpcserver.go b/rpcserver.go index 6e288604f..d51cf1cab 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -7140,7 +7140,7 @@ func (r *rpcServer) GetNetworkInfo(ctx context.Context, // Graph diameter. channelGraph := autopilot.ChannelGraphFromCachedDatabase(graph) - simpleGraph, err := autopilot.NewSimpleGraph(channelGraph) + simpleGraph, err := autopilot.NewSimpleGraph(ctx, channelGraph) if err != nil { return nil, err }