diff --git a/channeldb/graph.go b/channeldb/graph.go index cd746adc7..118f1882d 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -458,6 +458,72 @@ func (c *ChannelGraph) FetchNodeFeatures( } } +// ForEachNodeCached is similar to ForEachNode, but it utilizes the channel +// graph cache instead. Note that this doesn't return all the information the +// regular ForEachNode method does. +// +// NOTE: The callback contents MUST not be modified. +func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, + chans map[uint64]*DirectedChannel) error) error { + + if c.graphCache != nil { + return c.graphCache.ForEachNode(cb) + } + + // Otherwise call back to a version that uses the database directly. + // We'll iterate over each node, then the set of channels for each + // node, and construct a similar callback functiopn signature as the + // main funcotin expects. + return c.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { + channels := make(map[uint64]*DirectedChannel) + + err := node.ForEachChannel(tx, func(tx kvdb.RTx, + e *ChannelEdgeInfo, p1 *ChannelEdgePolicy, + p2 *ChannelEdgePolicy) error { + + toNodeCallback := func() route.Vertex { + return node.PubKeyBytes + } + toNodeFeatures, err := c.FetchNodeFeatures( + node.PubKeyBytes, + ) + if err != nil { + return err + } + + var cachedInPolicy *CachedEdgePolicy + if p2 != nil { + cachedInPolicy := NewCachedPolicy(p2) + cachedInPolicy.ToNodePubKey = toNodeCallback + cachedInPolicy.ToNodeFeatures = toNodeFeatures + } + + directedChannel := &DirectedChannel{ + ChannelID: e.ChannelID, + IsNode1: node.PubKeyBytes == e.NodeKey1Bytes, + OtherNode: e.NodeKey2Bytes, + Capacity: e.Capacity, + OutPolicySet: p1 != nil, + InPolicy: cachedInPolicy, + } + + if node.PubKeyBytes == e.NodeKey2Bytes { + directedChannel.OtherNode = e.NodeKey1Bytes + } + + channels[e.ChannelID] = directedChannel + + return nil + + }) + if err != nil { + return err + } + + return cb(node.PubKeyBytes, channels) + }) +} + // DisabledChannelIDs returns the channel ids of disabled channels. // A channel is disabled when two of the associated ChanelEdgePolicies // have their disabled bit on. diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index 891004bfc..3618b6d43 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -447,6 +447,30 @@ func (c *GraphCache) ForEachChannel(node route.Vertex, return nil } +// ForEachNode iterates over the adjacency list of the graph, executing the +// call back for each node and the set of channels that emanate from the given +// node. +// +// NOTE: This method should be considered _read only_, the channels or nodes +// passed in MUST NOT be modified. +func (c *GraphCache) ForEachNode(cb func(node route.Vertex, + channels map[uint64]*DirectedChannel) error) error { + + c.mtx.RLock() + defer c.mtx.RUnlock() + + for node, channels := range c.nodeChannels { + // We don't make a copy here since this is a read-only RPC + // call. We also don't need the node features either for this + // call. + if err := cb(node, channels); err != nil { + return err + } + } + + return nil +} + // GetFeatures returns the features of the node with the given ID. If no // features are known for the node, an empty feature vector is returned. func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector { diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index f8d2b9a5f..b408ec36d 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -120,6 +120,24 @@ func TestGraphCacheAddNode(t *testing.T) { require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet) assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy) + + // Now that we've inserted two nodes into the graph, check that + // we'll recover the same set of channels during ForEachNode. + nodes := make(map[route.Vertex]struct{}) + chans := make(map[uint64]struct{}) + _ = cache.ForEachNode(func(node route.Vertex, + edges map[uint64]*DirectedChannel) error { + + nodes[node] = struct{}{} + for chanID := range edges { + chans[chanID] = struct{}{} + } + + return nil + }) + + require.Len(t, nodes, 2) + require.Len(t, chans, 1) } runTest(pubKey1, pubKey2) diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index eb81026c1..36ee20095 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -1103,6 +1103,34 @@ func TestGraphTraversal(t *testing.T) { const numChannels = 5 chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels) + // Make an index of the node list for easy look up below. + nodeIndex := make(map[route.Vertex]struct{}) + for _, node := range nodeList { + nodeIndex[node.PubKeyBytes] = struct{}{} + } + + // If we turn the channel graph cache _off_, then iterate through the + // set of channels (to force the fall back), we should find all the + // channel as well as the nodes included. + graph.graphCache = nil + err = graph.ForEachNodeCached(func(node route.Vertex, + chans map[uint64]*DirectedChannel) error { + + if _, ok := nodeIndex[node]; !ok { + return fmt.Errorf("node %x not found in graph", node) + } + + for chanID := range chans { + if _, ok := chanIndex[chanID]; !ok { + return fmt.Errorf("chan %v not found in "+ + "graph", chanID) + } + } + + return nil + }) + require.NoError(t, err) + // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have // properly been reached.