autopilot: start threading contexts through

The `GraphSource` interface in the `autopilot` package is directly
implemented by the `graphdb.KVStore` and so we will eventually thread
contexts through to this interface. So in this commit, we start updating
the autopilot system to thread contexts through in preparation for
passing the context through to any calls made to the GraphSource.

Two context.TODOs are added here which will be addressed in follow up
commits.
This commit is contained in:
Elle Mouton
2025-04-09 06:25:29 +02:00
committed by ziggie
parent ca52834795
commit cd4a59071d
13 changed files with 146 additions and 71 deletions

View File

@@ -2,6 +2,7 @@ package autopilot
import (
"bytes"
"context"
"fmt"
"math/rand"
"net"
@@ -11,6 +12,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
@@ -166,8 +168,9 @@ type Agent struct {
pendingOpens map[NodeID]LocalChannel
pendingMtx sync.Mutex
quit chan struct{}
wg sync.WaitGroup
quit chan struct{}
wg sync.WaitGroup
cancel fn.Option[context.CancelFunc]
}
// New creates a new instance of the Agent instantiated using the passed
@@ -199,20 +202,23 @@ func New(cfg Config, initialState []LocalChannel) (*Agent, error) {
// Start starts the agent along with any goroutines it needs to perform its
// normal duties.
func (a *Agent) Start() error {
func (a *Agent) Start(ctx context.Context) error {
var err error
a.started.Do(func() {
err = a.start()
ctx, cancel := context.WithCancel(ctx)
a.cancel = fn.Some(cancel)
err = a.start(ctx)
})
return err
}
func (a *Agent) start() error {
func (a *Agent) start(ctx context.Context) error {
rand.Seed(time.Now().Unix())
log.Infof("Autopilot Agent starting")
a.wg.Add(1)
go a.controller()
go a.controller(ctx)
return nil
}
@@ -230,6 +236,7 @@ func (a *Agent) Stop() error {
func (a *Agent) stop() error {
log.Infof("Autopilot Agent stopping")
a.cancel.WhenSome(func(fn context.CancelFunc) { fn() })
close(a.quit)
a.wg.Wait()
@@ -401,7 +408,7 @@ func mergeChanState(pendingChans map[NodeID]LocalChannel,
// and external state changes as a result of decisions it makes w.r.t channel
// allocation, or attributes affecting its control loop being updated by the
// backing Lightning Node.
func (a *Agent) controller() {
func (a *Agent) controller(ctx context.Context) {
defer a.wg.Done()
// We'll start off by assigning our starting balance, and injecting
@@ -502,6 +509,9 @@ func (a *Agent) controller() {
// immediately.
case <-a.quit:
return
case <-ctx.Done():
return
}
a.pendingMtx.Lock()
@@ -539,7 +549,7 @@ func (a *Agent) controller() {
log.Infof("Triggering attachment directive dispatch, "+
"total_funds=%v", a.totalBalance)
err := a.openChans(availableFunds, numChans, totalChans)
err := a.openChans(ctx, availableFunds, numChans, totalChans)
if err != nil {
log.Errorf("Unable to open channels: %v", err)
}
@@ -548,8 +558,8 @@ func (a *Agent) controller() {
// openChans queries the agent's heuristic for a set of channel candidates, and
// attempts to open channels to them.
func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32,
totalChans []LocalChannel) error {
func (a *Agent) openChans(ctx context.Context, availableFunds btcutil.Amount,
numChans uint32, totalChans []LocalChannel) error {
// As channel size we'll use the maximum channel size available.
chanSize := a.cfg.Constraints.MaxChanSize()
@@ -598,7 +608,9 @@ func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32,
selfPubBytes := a.cfg.Self.SerializeCompressed()
nodes := make(map[NodeID]struct{})
addresses := make(map[NodeID][]net.Addr)
if err := a.cfg.Graph.ForEachNode(func(node Node) error {
if err := a.cfg.Graph.ForEachNode(ctx, func(_ context.Context,
node Node) error {
nID := NodeID(node.PubKey())
// If we come across ourselves, them we'll continue in

View File

@@ -1,6 +1,7 @@
package autopilot
import (
"context"
"errors"
"fmt"
"net"
@@ -220,7 +221,7 @@ func setup(t *testing.T, initialChans []LocalChannel) *testContext {
// With the autopilot agent and all its dependencies we'll start the
// primary controller goroutine.
if err := agent.Start(); err != nil {
if err := agent.Start(context.Background()); err != nil {
t.Fatalf("unable to start agent: %v", err)
}

View File

@@ -1,6 +1,7 @@
package autopilot
import (
"context"
"fmt"
"sync"
)
@@ -169,7 +170,9 @@ func betweennessCentrality(g *SimpleGraph, s int, centrality []float64) {
// Refresh recalculates and stores centrality values.
func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error {
cache, err := NewSimpleGraph(graph)
ctx := context.TODO()
cache, err := NewSimpleGraph(ctx, graph)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package autopilot
import (
"context"
"encoding/hex"
"net"
"sort"
@@ -80,7 +81,9 @@ func (d *dbNode) Addrs() []net.Addr {
// describes the active channel.
//
// NOTE: Part of the autopilot.Node interface.
func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
func (d *dbNode) ForEachChannel(ctx context.Context,
cb func(context.Context, ChannelEdge) error) error {
return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep,
_ *models.ChannelEdgePolicy) error {
@@ -108,7 +111,7 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
},
}
return cb(edge)
return cb(ctx, edge)
})
}
@@ -117,7 +120,9 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
// error, then execution should be terminated.
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error {
func (d *databaseChannelGraph) ForEachNode(ctx context.Context,
cb func(context.Context, Node) error) error {
return d.db.ForEachNode(func(nodeTx graphdb.NodeRTx) error {
// We'll skip over any node that doesn't have any advertised
// addresses. As we won't be able to reach them to actually
@@ -129,7 +134,8 @@ func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error {
node := &dbNode{
tx: nodeTx,
}
return cb(node)
return cb(ctx, node)
})
}
@@ -185,7 +191,9 @@ func (nc dbNodeCached) Addrs() []net.Addr {
// describes the active channel.
//
// NOTE: Part of the autopilot.Node interface.
func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error {
func (nc dbNodeCached) ForEachChannel(ctx context.Context,
cb func(context.Context, ChannelEdge) error) error {
for cid, channel := range nc.channels {
edge := ChannelEdge{
ChanID: lnwire.NewShortChanIDFromInt(cid),
@@ -195,7 +203,7 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error {
},
}
if err := cb(edge); err != nil {
if err := cb(ctx, edge); err != nil {
return err
}
}
@@ -208,7 +216,9 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error {
// error, then execution should be terminated.
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error {
func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context,
cb func(context.Context, Node) error) error {
return dc.db.ForEachNodeCached(func(n route.Vertex,
channels map[uint64]*graphdb.DirectedChannel) error {
@@ -217,7 +227,8 @@ func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error {
node: n,
channels: channels,
}
return cb(node)
return cb(ctx, node)
}
return nil
})
@@ -262,9 +273,11 @@ func (m memNode) Addrs() []net.Addr {
// describes the active channel.
//
// NOTE: Part of the autopilot.Node interface.
func (m memNode) ForEachChannel(cb func(ChannelEdge) error) error {
func (m memNode) ForEachChannel(ctx context.Context,
cb func(context.Context, ChannelEdge) error) error {
for _, channel := range m.chans {
if err := cb(channel); err != nil {
if err := cb(ctx, channel); err != nil {
return err
}
}

View File

@@ -1,6 +1,7 @@
package autopilot
import (
"context"
"net"
"github.com/btcsuite/btcd/btcec/v2"
@@ -35,7 +36,8 @@ type Node interface {
// iterate through all edges emanating from/to the target node. For
// each active channel, this function should be called with the
// populated ChannelEdge that describes the active channel.
ForEachChannel(func(ChannelEdge) error) error
ForEachChannel(context.Context, func(context.Context,
ChannelEdge) error) error
}
// LocalChannel is a simple struct which contains relevant details of a
@@ -83,7 +85,7 @@ type ChannelGraph interface {
// ForEachNode is a higher-order function that should be called once
// for each connected node within the channel graph. If the passed
// callback returns an error, then execution should be terminated.
ForEachNode(func(Node) error) error
ForEachNode(context.Context, func(context.Context, Node) error) error
}
// NodeScore is a tuple mapping a NodeID to a score indicating the preference

View File

@@ -1,11 +1,13 @@
package autopilot
import (
"context"
"fmt"
"sync"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
@@ -53,8 +55,9 @@ type Manager struct {
// disabled.
pilot *Agent
quit chan struct{}
wg sync.WaitGroup
quit chan struct{}
wg sync.WaitGroup
cancel fn.Option[context.CancelFunc]
sync.Mutex
}
@@ -80,6 +83,7 @@ func (m *Manager) Stop() error {
log.Errorf("Unable to stop pilot: %v", err)
}
m.cancel.WhenSome(func(fn context.CancelFunc) { fn() })
close(m.quit)
m.wg.Wait()
})
@@ -96,7 +100,7 @@ func (m *Manager) IsActive() bool {
// StartAgent creates and starts an autopilot agent from the Manager's
// config.
func (m *Manager) StartAgent() error {
func (m *Manager) StartAgent(ctx context.Context) error {
m.Lock()
defer m.Unlock()
@@ -104,6 +108,8 @@ func (m *Manager) StartAgent() error {
if m.pilot != nil {
return nil
}
ctx, cancel := context.WithCancel(ctx)
m.cancel = fn.Some(cancel)
// Next, we'll fetch the current state of open channels from the
// database to use as initial state for the auto-pilot agent.
@@ -119,7 +125,7 @@ func (m *Manager) StartAgent() error {
return err
}
if err := pilot.Start(); err != nil {
if err := pilot.Start(ctx); err != nil {
return err
}
@@ -163,6 +169,8 @@ func (m *Manager) StartAgent() error {
return
case <-m.quit:
return
case <-ctx.Done():
return
}
}
@@ -233,6 +241,8 @@ func (m *Manager) StartAgent() error {
return
case <-m.quit:
return
case <-ctx.Done():
return
}
}
}()

View File

@@ -1,6 +1,7 @@
package autopilot
import (
"context"
prand "math/rand"
"time"
@@ -82,14 +83,18 @@ func (p *PrefAttachment) NodeScores(g ChannelGraph, chans []LocalChannel,
chanSize btcutil.Amount, nodes map[NodeID]struct{}) (
map[NodeID]*NodeScore, error) {
ctx := context.TODO()
// We first run though the graph once in order to find the median
// channel size.
var (
allChans []btcutil.Amount
seenChans = make(map[uint64]struct{})
)
if err := g.ForEachNode(func(n Node) error {
err := n.ForEachChannel(func(e ChannelEdge) error {
if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error {
err := n.ForEachChannel(ctx, func(_ context.Context,
e ChannelEdge) error {
if _, ok := seenChans[e.ChanID.ToUint64()]; ok {
return nil
}
@@ -114,15 +119,19 @@ func (p *PrefAttachment) NodeScores(g ChannelGraph, chans []LocalChannel,
// the graph.
var maxChans int
nodeChanNum := make(map[NodeID]int)
if err := g.ForEachNode(func(n Node) error {
if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error {
var nodeChans int
err := n.ForEachChannel(func(e ChannelEdge) error {
err := n.ForEachChannel(ctx, func(_ context.Context,
e ChannelEdge) error {
// Since connecting to nodes with a lot of small
// channels actually worsens our connectivity in the
// graph (we will potentially waste time trying to use
// these useless channels in path finding), we decrease
// the counter for such channels.
if e.Capacity < medianChanSize/minMedianChanSizeFraction {
if e.Capacity <
medianChanSize/minMedianChanSizeFraction {
nodeChans--
return nil
}

View File

@@ -2,6 +2,7 @@ package autopilot
import (
"bytes"
"context"
"errors"
prand "math/rand"
"net"
@@ -126,6 +127,7 @@ func TestPrefAttachmentSelectEmptyGraph(t *testing.T) {
// and the funds are appropriately allocated across each peer.
func TestPrefAttachmentSelectTwoVertexes(t *testing.T) {
t.Parallel()
ctx := context.Background()
prand.Seed(time.Now().Unix())
@@ -156,10 +158,12 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) {
// Get the score for all nodes found in the graph at
// this point.
nodes := make(map[NodeID]struct{})
err = graph.ForEachNode(func(n Node) error {
nodes[n.PubKey()] = struct{}{}
return nil
})
err = graph.ForEachNode(ctx,
func(_ context.Context, n Node) error {
nodes[n.PubKey()] = struct{}{}
return nil
},
)
require.NoError(t1, err)
require.Len(t1, nodes, 3)
@@ -207,6 +211,7 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) {
// allocate all funds to each vertex (up to the max channel size).
func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) {
t.Parallel()
ctx := context.Background()
prand.Seed(time.Now().Unix())
@@ -245,22 +250,25 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) {
numNodes := 0
twoChans := false
nodes := make(map[NodeID]struct{})
err = graph.ForEachNode(func(n Node) error {
numNodes++
nodes[n.PubKey()] = struct{}{}
numChans := 0
err := n.ForEachChannel(func(c ChannelEdge) error {
numChans++
err = graph.ForEachNode(
ctx, func(ctx context.Context, n Node) error {
numNodes++
nodes[n.PubKey()] = struct{}{}
numChans := 0
err := n.ForEachChannel(ctx,
func(_ context.Context, c ChannelEdge) error { //nolint:ll
numChans++
return nil
},
)
if err != nil {
return err
}
twoChans = twoChans || (numChans == 2)
return nil
})
if err != nil {
return err
}
twoChans = twoChans || (numChans == 2)
return nil
})
require.NoError(t1, err)
require.EqualValues(t1, 3, numNodes)
@@ -313,6 +321,7 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) {
// of zero during scoring.
func TestPrefAttachmentSelectSkipNodes(t *testing.T) {
t.Parallel()
ctx := context.Background()
prand.Seed(time.Now().Unix())
@@ -335,10 +344,13 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) {
require.NoError(t1, err)
nodes := make(map[NodeID]struct{})
err = graph.ForEachNode(func(n Node) error {
nodes[n.PubKey()] = struct{}{}
return nil
})
err = graph.ForEachNode(
ctx, func(_ context.Context, n Node) error {
nodes[n.PubKey()] = struct{}{}
return nil
},
)
require.NoError(t1, err)
require.Len(t1, nodes, 2)
@@ -583,9 +595,11 @@ func newMemChannelGraph() *memChannelGraph {
// error, then execution should be terminated.
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (m *memChannelGraph) ForEachNode(cb func(Node) error) error {
func (m *memChannelGraph) ForEachNode(ctx context.Context,
cb func(context.Context, Node) error) error {
for _, node := range m.graph {
if err := cb(node); err != nil {
if err := cb(ctx, node); err != nil {
return err
}
}

View File

@@ -1,5 +1,7 @@
package autopilot
import "context"
// diameterCutoff is used to discard nodes in the diameter calculation.
// It is the multiplier for the eccentricity of the highest-degree node,
// serving as a cutoff to discard all nodes with a smaller hop distance. This
@@ -20,7 +22,7 @@ type SimpleGraph struct {
// NewSimpleGraph creates a simplified graph from the current channel graph.
// Returns an error if the channel graph iteration fails due to underlying
// failure.
func NewSimpleGraph(g ChannelGraph) (*SimpleGraph, error) {
func NewSimpleGraph(ctx context.Context, g ChannelGraph) (*SimpleGraph, error) {
nodes := make(map[NodeID]int)
adj := make(map[int][]int)
nextIndex := 0
@@ -42,17 +44,22 @@ func NewSimpleGraph(g ChannelGraph) (*SimpleGraph, error) {
return nodeIndex
}
// Iterate over each node and each channel and update the adj and the node
// index.
err := g.ForEachNode(func(node Node) error {
// Iterate over each node and each channel and update the adj and the
// node index.
err := g.ForEachNode(ctx, func(ctx context.Context, node Node) error {
u := getNodeIndex(node)
return node.ForEachChannel(func(edge ChannelEdge) error {
v := getNodeIndex(edge.Peer)
return node.ForEachChannel(
ctx, func(_ context.Context,
edge ChannelEdge) error {
adj[u] = append(adj[u], v)
return nil
})
v := getNodeIndex(edge.Peer)
adj[u] = append(adj[u], v)
return nil
},
)
})
if err != nil {
return nil, err