diff --git a/autopilot/graph.go b/autopilot/graph.go index d9913fb0f..68b964776 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -51,7 +51,9 @@ func ChannelGraphFromDatabase(db GraphSource) ChannelGraph { // channeldb.LightningNode. The wrapper method implement the autopilot.Node // interface. type dbNode struct { - tx graphdb.NodeRTx + tx graphdb.NodeRTx + pub [33]byte + addrs []net.Addr } // A compile time assertion to ensure dbNode meets the autopilot.Node @@ -64,7 +66,7 @@ var _ Node = (*dbNode)(nil) // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) PubKey() [33]byte { - return d.tx.Node().PubKeyBytes + return d.pub } // Addrs returns a slice of publicly reachable public TCP addresses that the @@ -72,7 +74,7 @@ func (d *dbNode) PubKey() [33]byte { // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) Addrs() []net.Addr { - return d.tx.Node().Addresses + return d.addrs } // ForEachChannel is a higher-order function that will be used to iterate @@ -125,13 +127,56 @@ func (d *databaseChannelGraph) ForEachNode(ctx context.Context, } node := &dbNode{ - tx: nodeTx, + tx: nodeTx, + pub: nodeTx.Node().PubKeyBytes, + addrs: nodeTx.Node().Addresses, } return cb(ctx, node) }, reset) } +// ForEachNodesChannels iterates through all connected nodes, and for each node, +// all the channels that connect to it. The passed callback will be called with +// the context, the Node itself, and a slice of ChannelEdge that connect to the +// node. +// +// NOTE: Part of the autopilot.ChannelGraph interface. +func (d *databaseChannelGraph) ForEachNodesChannels(ctx context.Context, + cb func(context.Context, Node, []*ChannelEdge) error, + reset func()) error { + + return d.db.ForEachNodeCached( + ctx, true, func(ctx context.Context, node route.Vertex, + addrs []net.Addr, + chans map[uint64]*graphdb.DirectedChannel) error { + + // We'll skip over any node that doesn't have any + // advertised addresses. As we won't be able to reach + // them to actually open any channels. + if len(addrs) == 0 { + return nil + } + + edges := make([]*ChannelEdge, 0, len(chans)) + for _, channel := range chans { + edges = append(edges, &ChannelEdge{ + ChanID: lnwire.NewShortChanIDFromInt( + channel.ChannelID, + ), + Capacity: channel.Capacity, + Peer: channel.OtherNode, + }) + } + + return cb(ctx, &dbNode{ + pub: node, + addrs: addrs, + }, edges) + }, reset, + ) +} + // databaseChannelGraphCached wraps a channeldb.ChannelGraph instance with the // necessary API to properly implement the autopilot.ChannelGraph interface. type databaseChannelGraphCached struct { @@ -227,6 +272,44 @@ func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context, }, reset) } +// ForEachNodesChannels iterates through all connected nodes, and for each node, +// all the channels that connect to it. The passed callback will be called with +// the context, the Node itself, and a slice of ChannelEdge that connect to the +// node. +// +// NOTE: Part of the autopilot.ChannelGraph interface. +func (dc *databaseChannelGraphCached) ForEachNodesChannels(ctx context.Context, + cb func(context.Context, Node, []*ChannelEdge) error, + reset func()) error { + + return dc.db.ForEachNodeCached(ctx, false, func(ctx context.Context, + n route.Vertex, _ []net.Addr, + channels map[uint64]*graphdb.DirectedChannel) error { + + edges := make([]*ChannelEdge, 0, len(channels)) + for cid, channel := range channels { + edges = append(edges, &ChannelEdge{ + ChanID: lnwire.NewShortChanIDFromInt(cid), + Capacity: channel.Capacity, + Peer: channel.OtherNode, + }) + } + + if len(channels) > 0 { + node := dbNodeCached{ + node: n, + channels: channels, + } + + if err := cb(ctx, node, edges); err != nil { + return err + } + } + + return nil + }, reset) +} + // memNode is a purely in-memory implementation of the autopilot.Node // interface. type memNode struct { diff --git a/autopilot/interface.go b/autopilot/interface.go index 7a543c0d8..f35f319c5 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -87,6 +87,14 @@ type ChannelGraph interface { // callback returns an error, then execution should be terminated. ForEachNode(context.Context, func(context.Context, Node) error, func()) error + + // ForEachNodesChannels iterates through all connected nodes, and for + // each node, all the channels that connect to it. The passed callback + // will be called with the context, the Node itself, and a slice of + // ChannelEdge that connect to the node. + ForEachNodesChannels(ctx context.Context, + cb func(context.Context, Node, []*ChannelEdge) error, + reset func()) error } // NodeScore is a tuple mapping a NodeID to a score indicating the preference diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index 58e5f8f12..a9f67f79a 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -606,6 +606,29 @@ func (m *memChannelGraph) ForEachNode(ctx context.Context, return nil } +// ForEachNodesChannels iterates through all connected nodes, and for each node, +// all the channels that connect to it. The passed callback will be called with +// the context, the Node itself, and a slice of ChannelEdge that connect to the +// node. +// +// NOTE: Part of the autopilot.ChannelGraph interface. +func (m *memChannelGraph) ForEachNodesChannels(ctx context.Context, + cb func(context.Context, Node, []*ChannelEdge) error, _ func()) error { + + for _, node := range m.graph { + edges := make([]*ChannelEdge, 0, len(node.chans)) + for i := range node.chans { + edges = append(edges, &node.chans[i]) + } + + if err := cb(ctx, node, edges); err != nil { + return err + } + } + + return nil +} + // randChanID generates a new random channel ID. func randChanID() lnwire.ShortChannelID { id := atomic.AddUint64(&chanIDCounter, 1)