diff --git a/autopilot/agent.go b/autopilot/agent.go index d9c35a685..7b951a66d 100644 --- a/autopilot/agent.go +++ b/autopilot/agent.go @@ -2,6 +2,7 @@ package autopilot import ( "bytes" + "context" "fmt" "math/rand" "net" @@ -11,6 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) @@ -166,8 +168,9 @@ type Agent struct { pendingOpens map[NodeID]LocalChannel pendingMtx sync.Mutex - quit chan struct{} - wg sync.WaitGroup + quit chan struct{} + wg sync.WaitGroup + cancel fn.Option[context.CancelFunc] } // New creates a new instance of the Agent instantiated using the passed @@ -199,20 +202,23 @@ func New(cfg Config, initialState []LocalChannel) (*Agent, error) { // Start starts the agent along with any goroutines it needs to perform its // normal duties. -func (a *Agent) Start() error { +func (a *Agent) Start(ctx context.Context) error { var err error a.started.Do(func() { - err = a.start() + ctx, cancel := context.WithCancel(ctx) + a.cancel = fn.Some(cancel) + + err = a.start(ctx) }) return err } -func (a *Agent) start() error { +func (a *Agent) start(ctx context.Context) error { rand.Seed(time.Now().Unix()) log.Infof("Autopilot Agent starting") a.wg.Add(1) - go a.controller() + go a.controller(ctx) return nil } @@ -230,6 +236,7 @@ func (a *Agent) Stop() error { func (a *Agent) stop() error { log.Infof("Autopilot Agent stopping") + a.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(a.quit) a.wg.Wait() @@ -401,7 +408,7 @@ func mergeChanState(pendingChans map[NodeID]LocalChannel, // and external state changes as a result of decisions it makes w.r.t channel // allocation, or attributes affecting its control loop being updated by the // backing Lightning Node. -func (a *Agent) controller() { +func (a *Agent) controller(ctx context.Context) { defer a.wg.Done() // We'll start off by assigning our starting balance, and injecting @@ -502,6 +509,9 @@ func (a *Agent) controller() { // immediately. case <-a.quit: return + + case <-ctx.Done(): + return } a.pendingMtx.Lock() @@ -539,7 +549,7 @@ func (a *Agent) controller() { log.Infof("Triggering attachment directive dispatch, "+ "total_funds=%v", a.totalBalance) - err := a.openChans(availableFunds, numChans, totalChans) + err := a.openChans(ctx, availableFunds, numChans, totalChans) if err != nil { log.Errorf("Unable to open channels: %v", err) } @@ -548,8 +558,8 @@ func (a *Agent) controller() { // openChans queries the agent's heuristic for a set of channel candidates, and // attempts to open channels to them. -func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32, - totalChans []LocalChannel) error { +func (a *Agent) openChans(ctx context.Context, availableFunds btcutil.Amount, + numChans uint32, totalChans []LocalChannel) error { // As channel size we'll use the maximum channel size available. chanSize := a.cfg.Constraints.MaxChanSize() @@ -598,7 +608,9 @@ func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32, selfPubBytes := a.cfg.Self.SerializeCompressed() nodes := make(map[NodeID]struct{}) addresses := make(map[NodeID][]net.Addr) - if err := a.cfg.Graph.ForEachNode(func(node Node) error { + if err := a.cfg.Graph.ForEachNode(ctx, func(_ context.Context, + node Node) error { + nID := NodeID(node.PubKey()) // If we come across ourselves, them we'll continue in diff --git a/autopilot/agent_test.go b/autopilot/agent_test.go index 39e86906e..9b3c30bb5 100644 --- a/autopilot/agent_test.go +++ b/autopilot/agent_test.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "errors" "fmt" "net" @@ -220,7 +221,7 @@ func setup(t *testing.T, initialChans []LocalChannel) *testContext { // With the autopilot agent and all its dependencies we'll start the // primary controller goroutine. - if err := agent.Start(); err != nil { + if err := agent.Start(context.Background()); err != nil { t.Fatalf("unable to start agent: %v", err) } diff --git a/autopilot/betweenness_centrality.go b/autopilot/betweenness_centrality.go index db45bcf66..2a45fe1f1 100644 --- a/autopilot/betweenness_centrality.go +++ b/autopilot/betweenness_centrality.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "fmt" "sync" ) @@ -169,7 +170,9 @@ func betweennessCentrality(g *SimpleGraph, s int, centrality []float64) { // Refresh recalculates and stores centrality values. func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error { - cache, err := NewSimpleGraph(graph) + ctx := context.TODO() + + cache, err := NewSimpleGraph(ctx, graph) if err != nil { return err } diff --git a/autopilot/graph.go b/autopilot/graph.go index c8b54082a..d20c6316e 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "encoding/hex" "net" "sort" @@ -80,7 +81,9 @@ func (d *dbNode) Addrs() []net.Addr { // describes the active channel. // // NOTE: Part of the autopilot.Node interface. -func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { +func (d *dbNode) ForEachChannel(ctx context.Context, + cb func(context.Context, ChannelEdge) error) error { + return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep, _ *models.ChannelEdgePolicy) error { @@ -108,7 +111,7 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { }, } - return cb(edge) + return cb(ctx, edge) }) } @@ -117,7 +120,9 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { // error, then execution should be terminated. // // NOTE: Part of the autopilot.ChannelGraph interface. -func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { +func (d *databaseChannelGraph) ForEachNode(ctx context.Context, + cb func(context.Context, Node) error) error { + return d.db.ForEachNode(func(nodeTx graphdb.NodeRTx) error { // We'll skip over any node that doesn't have any advertised // addresses. As we won't be able to reach them to actually @@ -129,7 +134,8 @@ func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { node := &dbNode{ tx: nodeTx, } - return cb(node) + + return cb(ctx, node) }) } @@ -185,7 +191,9 @@ func (nc dbNodeCached) Addrs() []net.Addr { // describes the active channel. // // NOTE: Part of the autopilot.Node interface. -func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { +func (nc dbNodeCached) ForEachChannel(ctx context.Context, + cb func(context.Context, ChannelEdge) error) error { + for cid, channel := range nc.channels { edge := ChannelEdge{ ChanID: lnwire.NewShortChanIDFromInt(cid), @@ -195,7 +203,7 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { }, } - if err := cb(edge); err != nil { + if err := cb(ctx, edge); err != nil { return err } } @@ -208,7 +216,9 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { // error, then execution should be terminated. // // NOTE: Part of the autopilot.ChannelGraph interface. -func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error { +func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context, + cb func(context.Context, Node) error) error { + return dc.db.ForEachNodeCached(func(n route.Vertex, channels map[uint64]*graphdb.DirectedChannel) error { @@ -217,7 +227,8 @@ func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error { node: n, channels: channels, } - return cb(node) + + return cb(ctx, node) } return nil }) @@ -262,9 +273,11 @@ func (m memNode) Addrs() []net.Addr { // describes the active channel. // // NOTE: Part of the autopilot.Node interface. -func (m memNode) ForEachChannel(cb func(ChannelEdge) error) error { +func (m memNode) ForEachChannel(ctx context.Context, + cb func(context.Context, ChannelEdge) error) error { + for _, channel := range m.chans { - if err := cb(channel); err != nil { + if err := cb(ctx, channel); err != nil { return err } } diff --git a/autopilot/interface.go b/autopilot/interface.go index 35182a760..0991d9864 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "net" "github.com/btcsuite/btcd/btcec/v2" @@ -35,7 +36,8 @@ type Node interface { // iterate through all edges emanating from/to the target node. For // each active channel, this function should be called with the // populated ChannelEdge that describes the active channel. - ForEachChannel(func(ChannelEdge) error) error + ForEachChannel(context.Context, func(context.Context, + ChannelEdge) error) error } // LocalChannel is a simple struct which contains relevant details of a @@ -83,7 +85,7 @@ type ChannelGraph interface { // ForEachNode is a higher-order function that should be called once // for each connected node within the channel graph. If the passed // callback returns an error, then execution should be terminated. - ForEachNode(func(Node) error) error + ForEachNode(context.Context, func(context.Context, Node) error) error } // NodeScore is a tuple mapping a NodeID to a score indicating the preference diff --git a/autopilot/manager.go b/autopilot/manager.go index 0463f98d9..036bf3a31 100644 --- a/autopilot/manager.go +++ b/autopilot/manager.go @@ -1,11 +1,13 @@ package autopilot import ( + "context" "fmt" "sync" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -53,8 +55,9 @@ type Manager struct { // disabled. pilot *Agent - quit chan struct{} - wg sync.WaitGroup + quit chan struct{} + wg sync.WaitGroup + cancel fn.Option[context.CancelFunc] sync.Mutex } @@ -80,6 +83,7 @@ func (m *Manager) Stop() error { log.Errorf("Unable to stop pilot: %v", err) } + m.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(m.quit) m.wg.Wait() }) @@ -96,7 +100,7 @@ func (m *Manager) IsActive() bool { // StartAgent creates and starts an autopilot agent from the Manager's // config. -func (m *Manager) StartAgent() error { +func (m *Manager) StartAgent(ctx context.Context) error { m.Lock() defer m.Unlock() @@ -104,6 +108,8 @@ func (m *Manager) StartAgent() error { if m.pilot != nil { return nil } + ctx, cancel := context.WithCancel(ctx) + m.cancel = fn.Some(cancel) // Next, we'll fetch the current state of open channels from the // database to use as initial state for the auto-pilot agent. @@ -119,7 +125,7 @@ func (m *Manager) StartAgent() error { return err } - if err := pilot.Start(); err != nil { + if err := pilot.Start(ctx); err != nil { return err } @@ -163,6 +169,8 @@ func (m *Manager) StartAgent() error { return case <-m.quit: return + case <-ctx.Done(): + return } } @@ -233,6 +241,8 @@ func (m *Manager) StartAgent() error { return case <-m.quit: return + case <-ctx.Done(): + return } } }() diff --git a/autopilot/prefattach.go b/autopilot/prefattach.go index 4f4ff635f..4f55e87ea 100644 --- a/autopilot/prefattach.go +++ b/autopilot/prefattach.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" prand "math/rand" "time" @@ -82,14 +83,18 @@ func (p *PrefAttachment) NodeScores(g ChannelGraph, chans []LocalChannel, chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( map[NodeID]*NodeScore, error) { + ctx := context.TODO() + // We first run though the graph once in order to find the median // channel size. var ( allChans []btcutil.Amount seenChans = make(map[uint64]struct{}) ) - if err := g.ForEachNode(func(n Node) error { - err := n.ForEachChannel(func(e ChannelEdge) error { + if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error { + err := n.ForEachChannel(ctx, func(_ context.Context, + e ChannelEdge) error { + if _, ok := seenChans[e.ChanID.ToUint64()]; ok { return nil } @@ -114,15 +119,19 @@ func (p *PrefAttachment) NodeScores(g ChannelGraph, chans []LocalChannel, // the graph. var maxChans int nodeChanNum := make(map[NodeID]int) - if err := g.ForEachNode(func(n Node) error { + if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error { var nodeChans int - err := n.ForEachChannel(func(e ChannelEdge) error { + err := n.ForEachChannel(ctx, func(_ context.Context, + e ChannelEdge) error { + // Since connecting to nodes with a lot of small // channels actually worsens our connectivity in the // graph (we will potentially waste time trying to use // these useless channels in path finding), we decrease // the counter for such channels. - if e.Capacity < medianChanSize/minMedianChanSizeFraction { + if e.Capacity < + medianChanSize/minMedianChanSizeFraction { + nodeChans-- return nil } diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index f20c3a480..7dec5f49f 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -2,6 +2,7 @@ package autopilot import ( "bytes" + "context" "errors" prand "math/rand" "net" @@ -126,6 +127,7 @@ func TestPrefAttachmentSelectEmptyGraph(t *testing.T) { // and the funds are appropriately allocated across each peer. func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { t.Parallel() + ctx := context.Background() prand.Seed(time.Now().Unix()) @@ -156,10 +158,12 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { // Get the score for all nodes found in the graph at // this point. nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(func(n Node) error { - nodes[n.PubKey()] = struct{}{} - return nil - }) + err = graph.ForEachNode(ctx, + func(_ context.Context, n Node) error { + nodes[n.PubKey()] = struct{}{} + return nil + }, + ) require.NoError(t1, err) require.Len(t1, nodes, 3) @@ -207,6 +211,7 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { // allocate all funds to each vertex (up to the max channel size). func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { t.Parallel() + ctx := context.Background() prand.Seed(time.Now().Unix()) @@ -245,22 +250,25 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { numNodes := 0 twoChans := false nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(func(n Node) error { - numNodes++ - nodes[n.PubKey()] = struct{}{} - numChans := 0 - err := n.ForEachChannel(func(c ChannelEdge) error { - numChans++ + err = graph.ForEachNode( + ctx, func(ctx context.Context, n Node) error { + numNodes++ + nodes[n.PubKey()] = struct{}{} + numChans := 0 + err := n.ForEachChannel(ctx, + func(_ context.Context, c ChannelEdge) error { //nolint:ll + numChans++ + return nil + }, + ) + if err != nil { + return err + } + + twoChans = twoChans || (numChans == 2) + return nil }) - if err != nil { - return err - } - - twoChans = twoChans || (numChans == 2) - - return nil - }) require.NoError(t1, err) require.EqualValues(t1, 3, numNodes) @@ -313,6 +321,7 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { // of zero during scoring. func TestPrefAttachmentSelectSkipNodes(t *testing.T) { t.Parallel() + ctx := context.Background() prand.Seed(time.Now().Unix()) @@ -335,10 +344,13 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) { require.NoError(t1, err) nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(func(n Node) error { - nodes[n.PubKey()] = struct{}{} - return nil - }) + err = graph.ForEachNode( + ctx, func(_ context.Context, n Node) error { + nodes[n.PubKey()] = struct{}{} + + return nil + }, + ) require.NoError(t1, err) require.Len(t1, nodes, 2) @@ -583,9 +595,11 @@ func newMemChannelGraph() *memChannelGraph { // error, then execution should be terminated. // // NOTE: Part of the autopilot.ChannelGraph interface. -func (m *memChannelGraph) ForEachNode(cb func(Node) error) error { +func (m *memChannelGraph) ForEachNode(ctx context.Context, + cb func(context.Context, Node) error) error { + for _, node := range m.graph { - if err := cb(node); err != nil { + if err := cb(ctx, node); err != nil { return err } } diff --git a/autopilot/simple_graph.go b/autopilot/simple_graph.go index 4d294b3f2..f028db3c7 100644 --- a/autopilot/simple_graph.go +++ b/autopilot/simple_graph.go @@ -1,5 +1,7 @@ package autopilot +import "context" + // diameterCutoff is used to discard nodes in the diameter calculation. // It is the multiplier for the eccentricity of the highest-degree node, // serving as a cutoff to discard all nodes with a smaller hop distance. This @@ -20,7 +22,7 @@ type SimpleGraph struct { // NewSimpleGraph creates a simplified graph from the current channel graph. // Returns an error if the channel graph iteration fails due to underlying // failure. -func NewSimpleGraph(g ChannelGraph) (*SimpleGraph, error) { +func NewSimpleGraph(ctx context.Context, g ChannelGraph) (*SimpleGraph, error) { nodes := make(map[NodeID]int) adj := make(map[int][]int) nextIndex := 0 @@ -42,17 +44,22 @@ func NewSimpleGraph(g ChannelGraph) (*SimpleGraph, error) { return nodeIndex } - // Iterate over each node and each channel and update the adj and the node - // index. - err := g.ForEachNode(func(node Node) error { + // Iterate over each node and each channel and update the adj and the + // node index. + err := g.ForEachNode(ctx, func(ctx context.Context, node Node) error { u := getNodeIndex(node) - return node.ForEachChannel(func(edge ChannelEdge) error { - v := getNodeIndex(edge.Peer) + return node.ForEachChannel( + ctx, func(_ context.Context, + edge ChannelEdge) error { - adj[u] = append(adj[u], v) - return nil - }) + v := getNodeIndex(edge.Peer) + + adj[u] = append(adj[u], v) + + return nil + }, + ) }) if err != nil { return nil, err diff --git a/discovery/bootstrapper.go b/discovery/bootstrapper.go index d07a5f852..1a0f99735 100644 --- a/discovery/bootstrapper.go +++ b/discovery/bootstrapper.go @@ -161,6 +161,8 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, numAddrs uint32, ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) { + ctx := context.TODO() + // We'll merge the ignore map with our currently selected map in order // to ensure we don't return any duplicate nodes. for n := range ignore { @@ -183,7 +185,9 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, errFound = fmt.Errorf("found node") ) - err := c.chanGraph.ForEachNode(func(node autopilot.Node) error { + err := c.chanGraph.ForEachNode(ctx, func(_ context.Context, + node autopilot.Node) error { + nID := autopilot.NodeID(node.PubKey()) if _, ok := c.tried[nID]; ok { return nil diff --git a/lnd.go b/lnd.go index 3afa8c2fb..41bd3ca4b 100644 --- a/lnd.go +++ b/lnd.go @@ -788,7 +788,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // active, then we'll start the autopilot agent immediately. It will be // stopped together with the autopilot service. if cfg.Autopilot.Active { - if err := atplManager.StartAgent(); err != nil { + if err := atplManager.StartAgent(ctx); err != nil { return mkErr("unable to start autopilot agent", err) } } diff --git a/lnrpc/autopilotrpc/autopilot_server.go b/lnrpc/autopilotrpc/autopilot_server.go index 23d0ff5f3..3e3c6f8f7 100644 --- a/lnrpc/autopilotrpc/autopilot_server.go +++ b/lnrpc/autopilotrpc/autopilot_server.go @@ -205,7 +205,7 @@ func (s *Server) ModifyStatus(ctx context.Context, var err error if in.Enable { - err = s.manager.StartAgent() + err = s.manager.StartAgent(ctx) } else { err = s.manager.StopAgent() } diff --git a/rpcserver.go b/rpcserver.go index 6e288604f..d51cf1cab 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -7140,7 +7140,7 @@ func (r *rpcServer) GetNetworkInfo(ctx context.Context, // Graph diameter. channelGraph := autopilot.ChannelGraphFromCachedDatabase(graph) - simpleGraph, err := autopilot.NewSimpleGraph(channelGraph) + simpleGraph, err := autopilot.NewSimpleGraph(ctx, channelGraph) if err != nil { return nil, err }