autopilot: remove the ForEachChannel method from Node interface

And instead make use of the new ForEachNodesChannels method which
uses a much more efficient method for iterating through nodes&channels.
This commit is contained in:
Elle Mouton
2025-08-03 17:22:56 +02:00
parent ce7fe84da7
commit 699e335954
5 changed files with 82 additions and 168 deletions

View File

@@ -10,7 +10,6 @@ import (
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcutil"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
@@ -51,7 +50,6 @@ func ChannelGraphFromDatabase(db GraphSource) ChannelGraph {
// channeldb.LightningNode. The wrapper method implement the autopilot.Node
// interface.
type dbNode struct {
tx graphdb.NodeRTx
pub [33]byte
addrs []net.Addr
}
@@ -77,39 +75,6 @@ func (d *dbNode) Addrs() []net.Addr {
return d.addrs
}
// ForEachChannel is a higher-order function that will be used to iterate
// through all edges emanating from/to the target node. For each active
// channel, this function should be called with the populated ChannelEdge that
// describes the active channel.
//
// NOTE: Part of the autopilot.Node interface.
func (d *dbNode) ForEachChannel(ctx context.Context,
cb func(context.Context, ChannelEdge) error) error {
return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep,
_ *models.ChannelEdgePolicy) error {
// Skip channels for which no outgoing edge policy is available.
//
// TODO(joostjager): Ideally the case where channels have a nil
// policy should be supported, as autopilot is not looking at
// the policies. For now, it is not easily possible to get a
// reference to the other end LightningNode object without
// retrieving the policy.
if ep == nil {
return nil
}
edge := ChannelEdge{
ChanID: lnwire.NewShortChanIDFromInt(ep.ChannelID),
Capacity: ei.Capacity,
Peer: ep.ToNode,
}
return cb(ctx, edge)
})
}
// ForEachNode is a higher-order function that should be called once for each
// connected node within the channel graph. If the passed callback returns an
// error, then execution should be terminated.
@@ -127,7 +92,6 @@ func (d *databaseChannelGraph) ForEachNode(ctx context.Context,
}
node := &dbNode{
tx: nodeTx,
pub: nodeTx.Node().PubKeyBytes,
addrs: nodeTx.Node().Addresses,
}
@@ -223,30 +187,6 @@ func (nc dbNodeCached) Addrs() []net.Addr {
return []net.Addr{}
}
// ForEachChannel is a higher-order function that will be used to iterate
// through all edges emanating from/to the target node. For each active
// channel, this function should be called with the populated ChannelEdge that
// describes the active channel.
//
// NOTE: Part of the autopilot.Node interface.
func (nc dbNodeCached) ForEachChannel(ctx context.Context,
cb func(context.Context, ChannelEdge) error) error {
for cid, channel := range nc.channels {
edge := ChannelEdge{
ChanID: lnwire.NewShortChanIDFromInt(cid),
Capacity: channel.Capacity,
Peer: channel.OtherNode,
}
if err := cb(ctx, edge); err != nil {
return err
}
}
return nil
}
// ForEachNode is a higher-order function that should be called once for each
// connected node within the channel graph. If the passed callback returns an
// error, then execution should be terminated.
@@ -343,24 +283,6 @@ func (m memNode) Addrs() []net.Addr {
return m.addrs
}
// ForEachChannel is a higher-order function that will be used to iterate
// through all edges emanating from/to the target node. For each active
// channel, this function should be called with the populated ChannelEdge that
// describes the active channel.
//
// NOTE: Part of the autopilot.Node interface.
func (m memNode) ForEachChannel(ctx context.Context,
cb func(context.Context, ChannelEdge) error) error {
for _, channel := range m.chans {
if err := cb(ctx, channel); err != nil {
return err
}
}
return nil
}
// Median returns the median value in the slice of Amounts.
func Median(vals []btcutil.Amount) btcutil.Amount {
sort.Slice(vals, func(i, j int) bool {

View File

@@ -31,13 +31,6 @@ type Node interface {
// Addrs returns a slice of publicly reachable public TCP addresses
// that the peer is known to be listening on.
Addrs() []net.Addr
// ForEachChannel is a higher-order function that will be used to
// iterate through all edges emanating from/to the target node. For
// each active channel, this function should be called with the
// populated ChannelEdge that describes the active channel.
ForEachChannel(context.Context, func(context.Context,
ChannelEdge) error) error
}
// LocalChannel is a simple struct which contains relevant details of a

View File

@@ -89,26 +89,26 @@ func (p *PrefAttachment) NodeScores(ctx context.Context, g ChannelGraph,
allChans []btcutil.Amount
seenChans = make(map[uint64]struct{})
)
if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error {
err := n.ForEachChannel(ctx, func(_ context.Context,
e ChannelEdge) error {
err := g.ForEachNodesChannels(
ctx, func(_ context.Context, node Node,
channels []*ChannelEdge) error {
if _, ok := seenChans[e.ChanID.ToUint64()]; ok {
return nil
for _, e := range channels {
if _, ok := seenChans[e.ChanID.ToUint64()]; ok {
continue
}
seenChans[e.ChanID.ToUint64()] = struct{}{}
allChans = append(allChans, e.Capacity)
}
seenChans[e.ChanID.ToUint64()] = struct{}{}
allChans = append(allChans, e.Capacity)
return nil
})
if err != nil {
return err
}
return nil
}, func() {
allChans = nil
clear(seenChans)
}); err != nil {
return nil
},
func() {
allChans = nil
clear(seenChans)
},
)
if err != nil {
return nil, err
}
@@ -120,55 +120,59 @@ func (p *PrefAttachment) NodeScores(ctx context.Context, g ChannelGraph,
// the graph.
var maxChans int
nodeChanNum := make(map[NodeID]int)
if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error {
var nodeChans int
err := n.ForEachChannel(ctx, func(_ context.Context,
e ChannelEdge) error {
err = g.ForEachNodesChannels(
ctx, func(ctx context.Context, node Node,
edges []*ChannelEdge) error {
// Since connecting to nodes with a lot of small
// channels actually worsens our connectivity in the
// graph (we will potentially waste time trying to use
// these useless channels in path finding), we decrease
// the counter for such channels.
if e.Capacity <
medianChanSize/minMedianChanSizeFraction {
var nodeChans int
for _, e := range edges {
// Since connecting to nodes with a lot of small
// channels actually worsens our connectivity in
// the graph (we will potentially waste time
// trying to use these useless channels in path
// finding), we decrease the counter for such
// channels.
//
//nolint:ll
if e.Capacity <
medianChanSize/minMedianChanSizeFraction {
nodeChans--
nodeChans--
continue
}
// Larger channels we count.
nodeChans++
}
// We keep track of the highest-degree node we've seen,
// as this will be given the max score.
if nodeChans > maxChans {
maxChans = nodeChans
}
// If this node is not among our nodes to score, we can
// return early.
nID := NodeID(node.PubKey())
if _, ok := nodes[nID]; !ok {
log.Tracef("Node %x not among nodes to score, "+
"ignoring", nID[:])
return nil
}
// Larger channels we count.
nodeChans++
// Otherwise we'll record the number of channels.
nodeChanNum[nID] = nodeChans
log.Tracef("Counted %v channels for node %x", nodeChans,
nID[:])
return nil
})
if err != nil {
return err
}
// We keep track of the highest-degree node we've seen, as this
// will be given the max score.
if nodeChans > maxChans {
maxChans = nodeChans
}
// If this node is not among our nodes to score, we can return
// early.
nID := NodeID(n.PubKey())
if _, ok := nodes[nID]; !ok {
log.Tracef("Node %x not among nodes to score, "+
"ignoring", nID[:])
return nil
}
// Otherwise we'll record the number of channels.
nodeChanNum[nID] = nodeChans
log.Tracef("Counted %v channels for node %x", nodeChans, nID[:])
return nil
}, func() {
maxChans = 0
clear(nodeChanNum)
}); err != nil {
}, func() {
maxChans = 0
clear(nodeChanNum)
},
)
if err != nil {
return nil, err
}

View File

@@ -242,25 +242,23 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) {
numNodes := 0
twoChans := false
nodes := make(map[NodeID]struct{})
err = graph.ForEachNode(
ctx, func(ctx context.Context, n Node) error {
err = graph.ForEachNodesChannels(
ctx, func(_ context.Context, node Node,
edges []*ChannelEdge) error {
numNodes++
nodes[n.PubKey()] = struct{}{}
nodes[node.PubKey()] = struct{}{}
numChans := 0
err := n.ForEachChannel(ctx,
func(_ context.Context, c ChannelEdge) error { //nolint:ll
numChans++
return nil
},
)
if err != nil {
return err
for range edges {
numChans++
}
twoChans = twoChans || (numChans == 2)
return nil
}, func() {
},
func() {
numNodes = 0
twoChans = false
clear(nodes)

View File

@@ -50,20 +50,17 @@ func NewSimpleGraph(ctx context.Context, g ChannelGraph) (*SimpleGraph, error) {
// Iterate over each node and each channel and update the adj and the
// node index.
err := g.ForEachNode(ctx, func(ctx context.Context, node Node) error {
err := g.ForEachNodesChannels(ctx, func(_ context.Context,
node Node, channels []*ChannelEdge) error {
u := getNodeIndex(node.PubKey())
return node.ForEachChannel(
ctx, func(_ context.Context,
edge ChannelEdge) error {
for _, edge := range channels {
v := getNodeIndex(edge.Peer)
adj[u] = append(adj[u], v)
}
v := getNodeIndex(edge.Peer)
adj[u] = append(adj[u], v)
return nil
},
)
return nil
}, func() {
clear(adj)
clear(nodes)