diff --git a/autopilot/graph.go b/autopilot/graph.go index fe74ad1fe..f1506b5d7 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -121,7 +121,7 @@ func (d *dbNode) ForEachChannel(ctx context.Context, // // NOTE: Part of the autopilot.ChannelGraph interface. func (d *databaseChannelGraph) ForEachNode(ctx context.Context, - cb func(context.Context, Node) error, _ func()) error { + cb func(context.Context, Node) error, reset func()) error { return d.db.ForEachNode(ctx, func(nodeTx graphdb.NodeRTx) error { // We'll skip over any node that doesn't have any advertised @@ -136,7 +136,7 @@ func (d *databaseChannelGraph) ForEachNode(ctx context.Context, } return cb(ctx, node) - }) + }, reset) } // databaseChannelGraphCached wraps a channeldb.ChannelGraph instance with the diff --git a/autopilot/interface.go b/autopilot/interface.go index 69eb4e6f1..a4efa2eec 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -229,7 +229,7 @@ type GraphSource interface { // the callback returns an error, then the transaction is aborted and // the iteration stops early. Any operations performed on the NodeTx // passed to the call-back are executed under the same read transaction. - ForEachNode(context.Context, func(graphdb.NodeRTx) error) error + ForEachNode(context.Context, func(graphdb.NodeRTx) error, func()) error // ForEachNodeCached is similar to ForEachNode, but it utilizes the // channel graph cache if one is available. It is less consistent than diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 20b097541..3a7f9b9de 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1454,7 +1454,7 @@ func TestGraphTraversalCacheable(t *testing.T) { nodeMap[tx.Node().PubKeyBytes] = struct{}{} return nil - }) + }, func() {}) require.NoError(t, err) require.Len(t, nodeMap, numNodes) @@ -1583,7 +1583,7 @@ func fillTestGraph(t testing.TB, graph *ChannelGraph, numNodes, err := graph.ForEachNode(ctx, func(tx NodeRTx) error { delete(nodeIndex, tx.Node().Alias) return nil - }) + }, func() {}) require.NoError(t, err) require.Len(t, nodeIndex, 0) @@ -1693,7 +1693,7 @@ func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { numNodes++ return nil - }) + }, func() {}) if err != nil { _, _, line, _ := runtime.Caller(1) t.Fatalf("line %v: unable to scan nodes: %v", line, err) diff --git a/graph/db/interfaces.go b/graph/db/interfaces.go index e47935c11..616aabc9f 100644 --- a/graph/db/interfaces.go +++ b/graph/db/interfaces.go @@ -102,7 +102,8 @@ type V1Store interface { //nolint:interfacebloat // passed to the call-back are executed under the same read transaction // and so, methods on the NodeTx object _MUST_ only be called from // within the call-back. - ForEachNode(ctx context.Context, cb func(tx NodeRTx) error) error + ForEachNode(ctx context.Context, cb func(tx NodeRTx) error, + reset func()) error // ForEachNodeCacheable iterates through all the stored vertices/nodes // in the graph, executing the passed callback with each node diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 45625437c..eff37ff21 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -802,9 +802,7 @@ func (c *KVStore) DisabledChannelIDs() ([]uint64, error) { // executed under the same read transaction and so, methods on the NodeTx object // _MUST_ only be called from within the call-back. func (c *KVStore) ForEachNode(_ context.Context, - cb func(tx NodeRTx) error) error { - - reset := func() {} + cb func(tx NodeRTx) error, reset func()) error { return forEachNode(c.db, func(tx kvdb.RTx, node *models.LightningNode) error { diff --git a/graph/db/sql_migration_test.go b/graph/db/sql_migration_test.go index 336a4d1c4..9f32603fd 100644 --- a/graph/db/sql_migration_test.go +++ b/graph/db/sql_migration_test.go @@ -401,6 +401,8 @@ func fetchAllNodes(t *testing.T, store V1Store) []*models.LightningNode { nodes = append(nodes, node) return nil + }, func() { + nodes = nil }) require.NoError(t, err) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index ad7bfc0fe..aef2f2464 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -785,9 +785,7 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context, // // NOTE: part of the V1Store interface. func (s *SQLStore) ForEachNode(ctx context.Context, - cb func(tx NodeRTx) error) error { - - reset := func() {} + cb func(tx NodeRTx) error, reset func()) error { var lastID int64 = 0 handleNode := func(db SQLQueries, dbNode sqlc.Node) error { diff --git a/itest/lnd_graph_migration_test.go b/itest/lnd_graph_migration_test.go index 412d23f57..d7417abe4 100644 --- a/itest/lnd_graph_migration_test.go +++ b/itest/lnd_graph_migration_test.go @@ -82,6 +82,9 @@ func testGraphMigration(ht *lntest.HarnessTest) { return nil }, ) + }, func() { + clear(edges) + numNodes = 0 }) require.NoError(ht, err) require.Equal(ht, expNumNodes, numNodes) diff --git a/rpcserver.go b/rpcserver.go index 39169a05b..8c134cbcd 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6745,6 +6745,8 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, resp.Nodes = append(resp.Nodes, lnNode) return nil + }, func() { + resp.Nodes = nil }) if err != nil { return nil, err