diff --git a/autopilot/graph.go b/autopilot/graph.go index 68b964776..a2a2e02ff 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -10,7 +10,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" graphdb "github.com/lightningnetwork/lnd/graph/db" - "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -51,7 +50,6 @@ func ChannelGraphFromDatabase(db GraphSource) ChannelGraph { // channeldb.LightningNode. The wrapper method implement the autopilot.Node // interface. type dbNode struct { - tx graphdb.NodeRTx pub [33]byte addrs []net.Addr } @@ -77,39 +75,6 @@ func (d *dbNode) Addrs() []net.Addr { return d.addrs } -// ForEachChannel is a higher-order function that will be used to iterate -// through all edges emanating from/to the target node. For each active -// channel, this function should be called with the populated ChannelEdge that -// describes the active channel. -// -// NOTE: Part of the autopilot.Node interface. -func (d *dbNode) ForEachChannel(ctx context.Context, - cb func(context.Context, ChannelEdge) error) error { - - return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep, - _ *models.ChannelEdgePolicy) error { - - // Skip channels for which no outgoing edge policy is available. - // - // TODO(joostjager): Ideally the case where channels have a nil - // policy should be supported, as autopilot is not looking at - // the policies. For now, it is not easily possible to get a - // reference to the other end LightningNode object without - // retrieving the policy. - if ep == nil { - return nil - } - - edge := ChannelEdge{ - ChanID: lnwire.NewShortChanIDFromInt(ep.ChannelID), - Capacity: ei.Capacity, - Peer: ep.ToNode, - } - - return cb(ctx, edge) - }) -} - // ForEachNode is a higher-order function that should be called once for each // connected node within the channel graph. If the passed callback returns an // error, then execution should be terminated. @@ -127,7 +92,6 @@ func (d *databaseChannelGraph) ForEachNode(ctx context.Context, } node := &dbNode{ - tx: nodeTx, pub: nodeTx.Node().PubKeyBytes, addrs: nodeTx.Node().Addresses, } @@ -223,30 +187,6 @@ func (nc dbNodeCached) Addrs() []net.Addr { return []net.Addr{} } -// ForEachChannel is a higher-order function that will be used to iterate -// through all edges emanating from/to the target node. For each active -// channel, this function should be called with the populated ChannelEdge that -// describes the active channel. -// -// NOTE: Part of the autopilot.Node interface. -func (nc dbNodeCached) ForEachChannel(ctx context.Context, - cb func(context.Context, ChannelEdge) error) error { - - for cid, channel := range nc.channels { - edge := ChannelEdge{ - ChanID: lnwire.NewShortChanIDFromInt(cid), - Capacity: channel.Capacity, - Peer: channel.OtherNode, - } - - if err := cb(ctx, edge); err != nil { - return err - } - } - - return nil -} - // ForEachNode is a higher-order function that should be called once for each // connected node within the channel graph. If the passed callback returns an // error, then execution should be terminated. @@ -343,24 +283,6 @@ func (m memNode) Addrs() []net.Addr { return m.addrs } -// ForEachChannel is a higher-order function that will be used to iterate -// through all edges emanating from/to the target node. For each active -// channel, this function should be called with the populated ChannelEdge that -// describes the active channel. -// -// NOTE: Part of the autopilot.Node interface. -func (m memNode) ForEachChannel(ctx context.Context, - cb func(context.Context, ChannelEdge) error) error { - - for _, channel := range m.chans { - if err := cb(ctx, channel); err != nil { - return err - } - } - - return nil -} - // Median returns the median value in the slice of Amounts. func Median(vals []btcutil.Amount) btcutil.Amount { sort.Slice(vals, func(i, j int) bool { diff --git a/autopilot/interface.go b/autopilot/interface.go index f35f319c5..1554b03ee 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -31,13 +31,6 @@ type Node interface { // Addrs returns a slice of publicly reachable public TCP addresses // that the peer is known to be listening on. Addrs() []net.Addr - - // ForEachChannel is a higher-order function that will be used to - // iterate through all edges emanating from/to the target node. For - // each active channel, this function should be called with the - // populated ChannelEdge that describes the active channel. - ForEachChannel(context.Context, func(context.Context, - ChannelEdge) error) error } // LocalChannel is a simple struct which contains relevant details of a diff --git a/autopilot/prefattach.go b/autopilot/prefattach.go index e9a9c53b4..267c13db3 100644 --- a/autopilot/prefattach.go +++ b/autopilot/prefattach.go @@ -89,26 +89,26 @@ func (p *PrefAttachment) NodeScores(ctx context.Context, g ChannelGraph, allChans []btcutil.Amount seenChans = make(map[uint64]struct{}) ) - if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error { - err := n.ForEachChannel(ctx, func(_ context.Context, - e ChannelEdge) error { + err := g.ForEachNodesChannels( + ctx, func(_ context.Context, node Node, + channels []*ChannelEdge) error { - if _, ok := seenChans[e.ChanID.ToUint64()]; ok { - return nil + for _, e := range channels { + if _, ok := seenChans[e.ChanID.ToUint64()]; ok { + continue + } + seenChans[e.ChanID.ToUint64()] = struct{}{} + allChans = append(allChans, e.Capacity) } - seenChans[e.ChanID.ToUint64()] = struct{}{} - allChans = append(allChans, e.Capacity) - return nil - }) - if err != nil { - return err - } - return nil - }, func() { - allChans = nil - clear(seenChans) - }); err != nil { + return nil + }, + func() { + allChans = nil + clear(seenChans) + }, + ) + if err != nil { return nil, err } @@ -120,55 +120,59 @@ func (p *PrefAttachment) NodeScores(ctx context.Context, g ChannelGraph, // the graph. var maxChans int nodeChanNum := make(map[NodeID]int) - if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error { - var nodeChans int - err := n.ForEachChannel(ctx, func(_ context.Context, - e ChannelEdge) error { + err = g.ForEachNodesChannels( + ctx, func(ctx context.Context, node Node, + edges []*ChannelEdge) error { - // Since connecting to nodes with a lot of small - // channels actually worsens our connectivity in the - // graph (we will potentially waste time trying to use - // these useless channels in path finding), we decrease - // the counter for such channels. - if e.Capacity < - medianChanSize/minMedianChanSizeFraction { + var nodeChans int + for _, e := range edges { + // Since connecting to nodes with a lot of small + // channels actually worsens our connectivity in + // the graph (we will potentially waste time + // trying to use these useless channels in path + // finding), we decrease the counter for such + // channels. + // + //nolint:ll + if e.Capacity < + medianChanSize/minMedianChanSizeFraction { - nodeChans-- + nodeChans-- + + continue + } + + // Larger channels we count. + nodeChans++ + } + + // We keep track of the highest-degree node we've seen, + // as this will be given the max score. + if nodeChans > maxChans { + maxChans = nodeChans + } + + // If this node is not among our nodes to score, we can + // return early. + nID := NodeID(node.PubKey()) + if _, ok := nodes[nID]; !ok { + log.Tracef("Node %x not among nodes to score, "+ + "ignoring", nID[:]) return nil } - // Larger channels we count. - nodeChans++ + // Otherwise we'll record the number of channels. + nodeChanNum[nID] = nodeChans + log.Tracef("Counted %v channels for node %x", nodeChans, + nID[:]) + return nil - }) - if err != nil { - return err - } - - // We keep track of the highest-degree node we've seen, as this - // will be given the max score. - if nodeChans > maxChans { - maxChans = nodeChans - } - - // If this node is not among our nodes to score, we can return - // early. - nID := NodeID(n.PubKey()) - if _, ok := nodes[nID]; !ok { - log.Tracef("Node %x not among nodes to score, "+ - "ignoring", nID[:]) - return nil - } - - // Otherwise we'll record the number of channels. - nodeChanNum[nID] = nodeChans - log.Tracef("Counted %v channels for node %x", nodeChans, nID[:]) - - return nil - }, func() { - maxChans = 0 - clear(nodeChanNum) - }); err != nil { + }, func() { + maxChans = 0 + clear(nodeChanNum) + }, + ) + if err != nil { return nil, err } diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index a9f67f79a..f0f351384 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -242,25 +242,23 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { numNodes := 0 twoChans := false nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode( - ctx, func(ctx context.Context, n Node) error { + err = graph.ForEachNodesChannels( + ctx, func(_ context.Context, node Node, + edges []*ChannelEdge) error { + numNodes++ - nodes[n.PubKey()] = struct{}{} + nodes[node.PubKey()] = struct{}{} numChans := 0 - err := n.ForEachChannel(ctx, - func(_ context.Context, c ChannelEdge) error { //nolint:ll - numChans++ - return nil - }, - ) - if err != nil { - return err + + for range edges { + numChans++ } twoChans = twoChans || (numChans == 2) return nil - }, func() { + }, + func() { numNodes = 0 twoChans = false clear(nodes) diff --git a/autopilot/simple_graph.go b/autopilot/simple_graph.go index 32c4559c8..44f514903 100644 --- a/autopilot/simple_graph.go +++ b/autopilot/simple_graph.go @@ -50,20 +50,17 @@ func NewSimpleGraph(ctx context.Context, g ChannelGraph) (*SimpleGraph, error) { // Iterate over each node and each channel and update the adj and the // node index. - err := g.ForEachNode(ctx, func(ctx context.Context, node Node) error { + err := g.ForEachNodesChannels(ctx, func(_ context.Context, + node Node, channels []*ChannelEdge) error { + u := getNodeIndex(node.PubKey()) - return node.ForEachChannel( - ctx, func(_ context.Context, - edge ChannelEdge) error { + for _, edge := range channels { + v := getNodeIndex(edge.Peer) + adj[u] = append(adj[u], v) + } - v := getNodeIndex(edge.Peer) - - adj[u] = append(adj[u], v) - - return nil - }, - ) + return nil }, func() { clear(adj) clear(nodes)