diff --git a/autopilot/graph.go b/autopilot/graph.go index b14d48231..0d98bebd2 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -10,7 +10,6 @@ import ( "github.com/btcsuite/btcd/btcutil" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -51,11 +50,7 @@ func ChannelGraphFromDatabase(db *graphdb.ChannelGraph) ChannelGraph { // channeldb.LightningNode. The wrapper method implement the autopilot.Node // interface. type dbNode struct { - db *graphdb.ChannelGraph - - tx kvdb.RTx - - node *models.LightningNode + tx graphdb.NodeRTx } // A compile time assertion to ensure dbNode meets the autopilot.Node @@ -68,7 +63,7 @@ var _ Node = (*dbNode)(nil) // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) PubKey() [33]byte { - return d.node.PubKeyBytes + return d.tx.Node().PubKeyBytes } // Addrs returns a slice of publicly reachable public TCP addresses that the @@ -76,7 +71,7 @@ func (d *dbNode) PubKey() [33]byte { // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) Addrs() []net.Addr { - return d.node.Addresses + return d.tx.Node().Addresses } // ForEachChannel is a higher-order function that will be used to iterate @@ -86,43 +81,35 @@ func (d *dbNode) Addrs() []net.Addr { // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { - return d.db.ForEachNodeChannelTx(d.tx, d.node.PubKeyBytes, - func(tx kvdb.RTx, ei *models.ChannelEdgeInfo, ep, - _ *models.ChannelEdgePolicy) error { + return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep, + _ *models.ChannelEdgePolicy) error { - // Skip channels for which no outgoing edge policy is - // available. - // - // TODO(joostjager): Ideally the case where channels - // have a nil policy should be supported, as autopilot - // is not looking at the policies. For now, it is not - // easily possible to get a reference to the other end - // LightningNode object without retrieving the policy. - if ep == nil { - return nil - } + // Skip channels for which no outgoing edge policy is available. + // + // TODO(joostjager): Ideally the case where channels have a nil + // policy should be supported, as autopilot is not looking at + // the policies. For now, it is not easily possible to get a + // reference to the other end LightningNode object without + // retrieving the policy. + if ep == nil { + return nil + } - node, err := d.db.FetchLightningNodeTx( - tx, ep.ToNode, - ) - if err != nil { - return err - } + node, err := d.tx.FetchNode(ep.ToNode) + if err != nil { + return err + } - edge := ChannelEdge{ - ChanID: lnwire.NewShortChanIDFromInt( - ep.ChannelID, - ), - Capacity: ei.Capacity, - Peer: &dbNode{ - tx: tx, - db: d.db, - node: node, - }, - } + edge := ChannelEdge{ + ChanID: lnwire.NewShortChanIDFromInt(ep.ChannelID), + Capacity: ei.Capacity, + Peer: &dbNode{ + tx: node, + }, + } - return cb(edge) - }) + return cb(edge) + }) } // ForEachNode is a higher-order function that should be called once for each @@ -131,20 +118,16 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { // // NOTE: Part of the autopilot.ChannelGraph interface. func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { - return d.db.ForEachNode(func(tx kvdb.RTx, - n *models.LightningNode) error { - + return d.db.ForEachNode(func(nodeTx graphdb.NodeRTx) error { // We'll skip over any node that doesn't have any advertised // addresses. As we won't be able to reach them to actually // open any channels. - if len(n.Addresses) == 0 { + if len(nodeTx.Node().Addresses) == 0 { return nil } node := &dbNode{ - db: d.db, - tx: tx, - node: n, + tx: nodeTx, } return cb(node) }) diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index 038fcbf35..d30174031 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -514,18 +514,18 @@ func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey, return &ChannelEdge{ ChanID: chanID, Capacity: capacity, - Peer: &dbNode{ - db: d.db, + Peer: &dbNode{tx: &testNodeTx{ + db: d, node: vertex1, - }, + }}, }, &ChannelEdge{ ChanID: chanID, Capacity: capacity, - Peer: &dbNode{ - db: d.db, + Peer: &dbNode{tx: &testNodeTx{ + db: d, node: vertex2, - }, + }}, }, nil } @@ -702,3 +702,37 @@ func (m *memChannelGraph) addRandNode() (*btcec.PublicKey, error) { return newPub, nil } + +type testNodeTx struct { + db *testDBGraph + node *models.LightningNode +} + +func (t *testNodeTx) Node() *models.LightningNode { + return t.node +} + +func (t *testNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { + + return t.db.db.ForEachNodeChannel(t.node.PubKeyBytes, func(_ kvdb.RTx, + edge *models.ChannelEdgeInfo, policy1, + policy2 *models.ChannelEdgePolicy) error { + + return f(edge, policy1, policy2) + }) +} + +func (t *testNodeTx) FetchNode(pub route.Vertex) (graphdb.NodeRTx, error) { + node, err := t.db.db.FetchLightningNode(pub) + if err != nil { + return nil, err + } + + return &testNodeTx{ + db: t.db, + node: node, + }, nil +} + +var _ graphdb.NodeRTx = (*testNodeTx)(nil) diff --git a/graph/db/graph.go b/graph/db/graph.go index fb1dd941f..06bd0257d 100644 --- a/graph/db/graph.go +++ b/graph/db/graph.go @@ -588,9 +588,9 @@ func (c *ChannelGraph) FetchNodeFeatures( } } -// ForEachNodeCached is similar to ForEachNode, but it utilizes the channel +// ForEachNodeCached is similar to forEachNode, but it utilizes the channel // graph cache instead. Note that this doesn't return all the information the -// regular ForEachNode method does. +// regular forEachNode method does. // // NOTE: The callback contents MUST not be modified. func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, @@ -604,7 +604,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, // We'll iterate over each node, then the set of channels for each // node, and construct a similar callback functiopn signature as the // main funcotin expects. - return c.ForEachNode(func(tx kvdb.RTx, + return c.forEachNode(func(tx kvdb.RTx, node *models.LightningNode) error { channels := make(map[uint64]*DirectedChannel) @@ -716,11 +716,25 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { // ForEachNode iterates through all the stored vertices/nodes in the graph, // executing the passed callback with each node encountered. If the callback // returns an error, then the transaction is aborted and the iteration stops +// early. Any operations performed on the NodeTx passed to the call-back are +// executed under the same read transaction and so, methods on the NodeTx object +// _MUST_ only be called from within the call-back. +func (c *ChannelGraph) ForEachNode(cb func(tx NodeRTx) error) error { + return c.forEachNode(func(tx kvdb.RTx, + node *models.LightningNode) error { + + return cb(newChanGraphNodeTx(tx, c, node)) + }) +} + +// forEachNode iterates through all the stored vertices/nodes in the graph, +// executing the passed callback with each node encountered. If the callback +// returns an error, then the transaction is aborted and the iteration stops // early. // // TODO(roasbeef): add iterator interface to allow for memory efficient graph // traversal when graph gets mega -func (c *ChannelGraph) ForEachNode( +func (c *ChannelGraph) forEachNode( cb func(kvdb.RTx, *models.LightningNode) error) error { traversal := func(tx kvdb.RTx) error { diff --git a/graph/db/graph_cache_test.go b/graph/db/graph_cache_test.go index 3f140c4c5..69abb2597 100644 --- a/graph/db/graph_cache_test.go +++ b/graph/db/graph_cache_test.go @@ -121,7 +121,7 @@ func TestGraphCacheAddNode(t *testing.T) { assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy) // Now that we've inserted two nodes into the graph, check that - // we'll recover the same set of channels during ForEachNode. + // we'll recover the same set of channels during forEachNode. nodes := make(map[route.Vertex]struct{}) chans := make(map[uint64]struct{}) _ = cache.ForEachNode(func(node route.Vertex, diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index d048cdafd..3b3454b6b 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1092,7 +1092,7 @@ func TestGraphTraversalCacheable(t *testing.T) { // Create a map of all nodes with the iteration we know works (because // it is tested in another test). nodeMap := make(map[route.Vertex]struct{}) - err = graph.ForEachNode( + err = graph.forEachNode( func(tx kvdb.RTx, n *models.LightningNode) error { nodeMap[n.PubKeyBytes] = struct{}{} @@ -1217,7 +1217,7 @@ func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, // Iterate over each node as returned by the graph, if all nodes are // reached, then the map created above should be empty. - err := graph.ForEachNode( + err := graph.forEachNode( func(_ kvdb.RTx, node *models.LightningNode) error { delete(nodeIndex, node.Alias) return nil @@ -1329,7 +1329,7 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { numNodes := 0 - err := graph.ForEachNode( + err := graph.forEachNode( func(_ kvdb.RTx, _ *models.LightningNode) error { numNodes++ return nil diff --git a/rpcserver.go b/rpcserver.go index 72e2fa4af..2e17eb75c 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6533,10 +6533,8 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // First iterate through all the known nodes (connected or unconnected // within the graph), collating their current state into the RPC // response. - err := graph.ForEachNode(func(_ kvdb.RTx, - node *models.LightningNode) error { - - lnNode := marshalNode(node) + err := graph.ForEachNode(func(nodeTx graphdb.NodeRTx) error { + lnNode := marshalNode(nodeTx.Node()) resp.Nodes = append(resp.Nodes, lnNode)