diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 5e445f10d..825096b06 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -618,7 +618,7 @@ func createChannelEdge(node1, node2 *models.LightningNode) ( chanID := uint64(prand.Int63()) outpoint := wire.OutPoint{ Hash: rev, - Index: 9, + Index: prand.Uint32(), } // Add the new edge to the database, this should proceed without any @@ -991,6 +991,97 @@ func newEdgePolicy(chanID uint64, updateTime int64) *models.ChannelEdgePolicy { } } +// TestForEachSourceNodeChannel tests that the ForEachSourceNodeChannel +// correctly iterates through the channels of the set source node. +func TestForEachSourceNodeChannel(t *testing.T) { + t.Parallel() + + graph, err := MakeTestGraph(t) + require.NoError(t, err, "unable to make test database") + + // Create a source node (A) and set it as such in the DB. + nodeA := createTestVertex(t) + require.NoError(t, graph.SetSourceNode(nodeA)) + + // Now, create a few more nodes (B, C, D) along with some channels + // between them. We'll create the following graph: + // + // A -- B -- D + // | + // C + // + // The graph includes a channel (B-D) that does not belong to the source + // node along with 2 channels (A-B and A-C) that do belong to the source + // node. For the A-B channel, we will let the source node set an + // outgoing policy but for the A-C channel, we will set only an incoming + // policy. + + nodeB := createTestVertex(t) + nodeC := createTestVertex(t) + nodeD := createTestVertex(t) + + abEdge, abPolicy1, abPolicy2 := createChannelEdge(nodeA, nodeB) + require.NoError(t, graph.AddChannelEdge(abEdge)) + acEdge, acPolicy1, acPolicy2 := createChannelEdge(nodeA, nodeC) + require.NoError(t, graph.AddChannelEdge(acEdge)) + bdEdge, _, _ := createChannelEdge(nodeB, nodeD) + require.NoError(t, graph.AddChannelEdge(bdEdge)) + + // Figure out which of the policies returned above are node A's so that + // we know which to persist. + // + // First, set the outgoing policy for the A-B channel. + abPolicyAOutgoing := abPolicy1 + if !bytes.Equal(abPolicy1.ToNode[:], nodeB.PubKeyBytes[:]) { + abPolicyAOutgoing = abPolicy2 + } + require.NoError(t, graph.UpdateEdgePolicy(abPolicyAOutgoing)) + + // Now, set the incoming policy for the A-C channel. + acPolicyAIncoming := acPolicy1 + if !bytes.Equal(acPolicy1.ToNode[:], nodeA.PubKeyBytes[:]) { + acPolicyAIncoming = acPolicy2 + } + require.NoError(t, graph.UpdateEdgePolicy(acPolicyAIncoming)) + + type sourceNodeChan struct { + otherNode route.Vertex + havePolicy bool + } + + // Put together our expected source node channels. + expectedSrcChans := map[wire.OutPoint]*sourceNodeChan{ + abEdge.ChannelPoint: { + otherNode: nodeB.PubKeyBytes, + havePolicy: true, + }, + acEdge.ChannelPoint: { + otherNode: nodeC.PubKeyBytes, + havePolicy: false, + }, + } + + // Now, we'll use the ForEachSourceNodeChannel and assert that it + // returns the expected data in the call-back. + err = graph.ForEachSourceNodeChannel(func(chanPoint wire.OutPoint, + havePolicy bool, otherNode *models.LightningNode) error { + + require.Contains(t, expectedSrcChans, chanPoint) + expected := expectedSrcChans[chanPoint] + + require.Equal( + t, expected.otherNode[:], otherNode.PubKeyBytes[:], + ) + require.Equal(t, expected.havePolicy, havePolicy) + + delete(expectedSrcChans, chanPoint) + + return nil + }) + require.NoError(t, err) + require.Empty(t, expectedSrcChans) +} + func TestGraphTraversal(t *testing.T) { t.Parallel() diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 50d507185..68beb7a5d 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -3109,6 +3109,44 @@ func (c *KVStore) ForEachNodeChannel(nodePub route.Vertex, return nodeTraversal(nil, nodePub[:], c.db, cb) } +// ForEachSourceNodeChannel iterates through all channels of the source node, +// executing the passed callback on each. The callback is provided with the +// channel's outpoint, whether we have a policy for the channel and the channel +// peer's node information. +func (c *KVStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint, + havePolicy bool, otherNode *models.LightningNode) error) error { + + return kvdb.View(c.db, func(tx kvdb.RTx) error { + nodes := tx.ReadBucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + node, err := c.sourceNode(nodes) + if err != nil { + return err + } + + return nodeTraversal( + tx, node.PubKeyBytes[:], c.db, func(tx kvdb.RTx, + info *models.ChannelEdgeInfo, + policy, _ *models.ChannelEdgePolicy) error { + + peer, err := c.FetchOtherNode( + tx, info, node.PubKeyBytes[:], + ) + if err != nil { + return err + } + + return cb( + info.ChannelPoint, policy != nil, peer, + ) + }, + ) + }, func() {}) +} + // ForEachNodeChannelTx iterates through all channels of the given node, // executing the passed callback with an edge info structure and the policies // of each end of the channel. The first edge policy is the outgoing edge *to* diff --git a/server.go b/server.go index 20ae1b984..039e7f4d7 100644 --- a/server.go +++ b/server.go @@ -3553,36 +3553,17 @@ func (s *server) establishPersistentConnections() error { // After checking our previous connections for addresses to connect to, // iterate through the nodes in our channel graph to find addresses // that have been added via NodeAnnouncement messages. - sourceNode, err := s.graphDB.SourceNode() - if err != nil { - return fmt.Errorf("failed to fetch source node: %w", err) - } - // TODO(roasbeef): instead iterate over link nodes and query graph for // each of the nodes. - selfPub := s.identityECDH.PubKey().SerializeCompressed() - err = s.graphDB.ForEachNodeChannel(sourceNode.PubKeyBytes, func( - tx kvdb.RTx, - chanInfo *models.ChannelEdgeInfo, - policy, _ *models.ChannelEdgePolicy) error { + err = s.graphDB.ForEachSourceNodeChannel(func(chanPoint wire.OutPoint, + havePolicy bool, channelPeer *models.LightningNode) error { // If the remote party has announced the channel to us, but we // haven't yet, then we won't have a policy. However, we don't // need this to connect to the peer, so we'll log it and move on. - if policy == nil { + if !havePolicy { srvrLog.Warnf("No channel policy found for "+ - "ChannelPoint(%v): ", chanInfo.ChannelPoint) - } - - // We'll now fetch the peer opposite from us within this - // channel so we can queue up a direct connection to them. - channelPeer, err := s.graphDB.FetchOtherNode( - tx, chanInfo, selfPub, - ) - if err != nil { - return fmt.Errorf("unable to fetch channel peer for "+ - "ChannelPoint(%v): %v", chanInfo.ChannelPoint, - err) + "ChannelPoint(%v): ", chanPoint) } pubStr := string(channelPeer.PubKeyBytes[:]) @@ -3642,8 +3623,8 @@ func (s *server) establishPersistentConnections() error { return nil }) if err != nil { - srvrLog.Errorf("Failed to iterate channels for node %x", - sourceNode.PubKeyBytes) + srvrLog.Errorf("Failed to iterate over source node channels: "+ + "%v", err) if !errors.Is(err, graphdb.ErrGraphNoEdgesFound) && !errors.Is(err, graphdb.ErrEdgeNotFound) {