autopilot: add ForEachNodesChannels method to ChannelGraph interface

This is in preparation for removing the ForEachChannel method from the
Node interface.
This commit is contained in:
Elle Mouton
2025-08-03 17:18:35 +02:00
parent 5727bfa688
commit ce7fe84da7
3 changed files with 118 additions and 4 deletions

View File

@@ -51,7 +51,9 @@ func ChannelGraphFromDatabase(db GraphSource) ChannelGraph {
// channeldb.LightningNode. The wrapper method implement the autopilot.Node
// interface.
type dbNode struct {
tx graphdb.NodeRTx
tx graphdb.NodeRTx
pub [33]byte
addrs []net.Addr
}
// A compile time assertion to ensure dbNode meets the autopilot.Node
@@ -64,7 +66,7 @@ var _ Node = (*dbNode)(nil)
//
// NOTE: Part of the autopilot.Node interface.
func (d *dbNode) PubKey() [33]byte {
return d.tx.Node().PubKeyBytes
return d.pub
}
// Addrs returns a slice of publicly reachable public TCP addresses that the
@@ -72,7 +74,7 @@ func (d *dbNode) PubKey() [33]byte {
//
// NOTE: Part of the autopilot.Node interface.
func (d *dbNode) Addrs() []net.Addr {
return d.tx.Node().Addresses
return d.addrs
}
// ForEachChannel is a higher-order function that will be used to iterate
@@ -125,13 +127,56 @@ func (d *databaseChannelGraph) ForEachNode(ctx context.Context,
}
node := &dbNode{
tx: nodeTx,
tx: nodeTx,
pub: nodeTx.Node().PubKeyBytes,
addrs: nodeTx.Node().Addresses,
}
return cb(ctx, node)
}, reset)
}
// ForEachNodesChannels iterates through all connected nodes, and for each node,
// all the channels that connect to it. The passed callback will be called with
// the context, the Node itself, and a slice of ChannelEdge that connect to the
// node.
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (d *databaseChannelGraph) ForEachNodesChannels(ctx context.Context,
cb func(context.Context, Node, []*ChannelEdge) error,
reset func()) error {
return d.db.ForEachNodeCached(
ctx, true, func(ctx context.Context, node route.Vertex,
addrs []net.Addr,
chans map[uint64]*graphdb.DirectedChannel) error {
// We'll skip over any node that doesn't have any
// advertised addresses. As we won't be able to reach
// them to actually open any channels.
if len(addrs) == 0 {
return nil
}
edges := make([]*ChannelEdge, 0, len(chans))
for _, channel := range chans {
edges = append(edges, &ChannelEdge{
ChanID: lnwire.NewShortChanIDFromInt(
channel.ChannelID,
),
Capacity: channel.Capacity,
Peer: channel.OtherNode,
})
}
return cb(ctx, &dbNode{
pub: node,
addrs: addrs,
}, edges)
}, reset,
)
}
// databaseChannelGraphCached wraps a channeldb.ChannelGraph instance with the
// necessary API to properly implement the autopilot.ChannelGraph interface.
type databaseChannelGraphCached struct {
@@ -227,6 +272,44 @@ func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context,
}, reset)
}
// ForEachNodesChannels iterates through all connected nodes, and for each node,
// all the channels that connect to it. The passed callback will be called with
// the context, the Node itself, and a slice of ChannelEdge that connect to the
// node.
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (dc *databaseChannelGraphCached) ForEachNodesChannels(ctx context.Context,
cb func(context.Context, Node, []*ChannelEdge) error,
reset func()) error {
return dc.db.ForEachNodeCached(ctx, false, func(ctx context.Context,
n route.Vertex, _ []net.Addr,
channels map[uint64]*graphdb.DirectedChannel) error {
edges := make([]*ChannelEdge, 0, len(channels))
for cid, channel := range channels {
edges = append(edges, &ChannelEdge{
ChanID: lnwire.NewShortChanIDFromInt(cid),
Capacity: channel.Capacity,
Peer: channel.OtherNode,
})
}
if len(channels) > 0 {
node := dbNodeCached{
node: n,
channels: channels,
}
if err := cb(ctx, node, edges); err != nil {
return err
}
}
return nil
}, reset)
}
// memNode is a purely in-memory implementation of the autopilot.Node
// interface.
type memNode struct {

View File

@@ -87,6 +87,14 @@ type ChannelGraph interface {
// callback returns an error, then execution should be terminated.
ForEachNode(context.Context, func(context.Context, Node) error,
func()) error
// ForEachNodesChannels iterates through all connected nodes, and for
// each node, all the channels that connect to it. The passed callback
// will be called with the context, the Node itself, and a slice of
// ChannelEdge that connect to the node.
ForEachNodesChannels(ctx context.Context,
cb func(context.Context, Node, []*ChannelEdge) error,
reset func()) error
}
// NodeScore is a tuple mapping a NodeID to a score indicating the preference

View File

@@ -606,6 +606,29 @@ func (m *memChannelGraph) ForEachNode(ctx context.Context,
return nil
}
// ForEachNodesChannels iterates through all connected nodes, and for each node,
// all the channels that connect to it. The passed callback will be called with
// the context, the Node itself, and a slice of ChannelEdge that connect to the
// node.
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (m *memChannelGraph) ForEachNodesChannels(ctx context.Context,
cb func(context.Context, Node, []*ChannelEdge) error, _ func()) error {
for _, node := range m.graph {
edges := make([]*ChannelEdge, 0, len(node.chans))
for i := range node.chans {
edges = append(edges, &node.chans[i])
}
if err := cb(ctx, node, edges); err != nil {
return err
}
}
return nil
}
// randChanID generates a new random channel ID.
func randChanID() lnwire.ShortChannelID {
id := atomic.AddUint64(&chanIDCounter, 1)