diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index ba79d9682..fd557a02c 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -731,12 +731,12 @@ func (t *testNodeTx) Node() *models.LightningNode { func (t *testNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { - return t.db.db.ForEachNodeChannel(t.node.PubKeyBytes, func( - edge *models.ChannelEdgeInfo, policy1, - policy2 *models.ChannelEdgePolicy) error { + return t.db.db.ForEachNodeChannel(context.Background(), + t.node.PubKeyBytes, func(edge *models.ChannelEdgeInfo, policy1, + policy2 *models.ChannelEdgePolicy) error { - return f(edge, policy1, policy2) - }) + return f(edge, policy1, policy2) + }) } func (t *testNodeTx) FetchNode(pub route.Vertex) (graphdb.NodeRTx, error) { diff --git a/graph/builder.go b/graph/builder.go index df4f3bfdf..d590fb896 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -1278,7 +1278,9 @@ func (b *Builder) FetchLightningNode(ctx context.Context, func (b *Builder) ForAllOutgoingChannels(cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { - return b.cfg.Graph.ForEachNodeChannel(b.cfg.SelfNode, + ctx := context.TODO() + + return b.cfg.Graph.ForEachNodeChannel(ctx, b.cfg.SelfNode, func(c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 33d576df5..b1206281d 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1332,6 +1332,7 @@ func TestForEachSourceNodeChannel(t *testing.T) { func TestGraphTraversal(t *testing.T) { t.Parallel() + ctx := context.Background() graph := MakeTestGraph(t) @@ -1387,7 +1388,7 @@ func TestGraphTraversal(t *testing.T) { // outgoing channels for a particular node. numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] - err = graph.ForEachNodeChannel(firstNode.PubKeyBytes, + err = graph.ForEachNodeChannel(ctx, firstNode.PubKeyBytes, func(_ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { @@ -3138,7 +3139,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { expectedOut bool) { calls := 0 - err := graph.ForEachNodeChannel(node.PubKeyBytes, + err := graph.ForEachNodeChannel(ctx, node.PubKeyBytes, func(_ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { @@ -4204,6 +4205,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { // allocations and the total memory consumed by the full graph traversal. func BenchmarkForEachChannel(b *testing.B) { graph := MakeTestGraph(b) + ctx := context.Background() const numNodes = 100 const numChannels = 4 @@ -4244,7 +4246,7 @@ func BenchmarkForEachChannel(b *testing.B) { return nil } - err := graph.ForEachNodeChannel(n, cb) + err := graph.ForEachNodeChannel(ctx, n, cb) require.NoError(b, err) } } diff --git a/graph/db/interfaces.go b/graph/db/interfaces.go index e04b5ff59..b43e3b004 100644 --- a/graph/db/interfaces.go +++ b/graph/db/interfaces.go @@ -83,7 +83,7 @@ type V1Store interface { //nolint:interfacebloat // to the caller. // // Unknown policies are passed into the callback as nil values. - ForEachNodeChannel(nodePub route.Vertex, + ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index d685695f4..8fb330e81 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -3232,7 +3232,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. -func (c *KVStore) ForEachNodeChannel(nodePub route.Vertex, +func (c *KVStore) ForEachNodeChannel(_ context.Context, nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 302e08b9e..bc9064252 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -958,12 +958,10 @@ func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex, // Unknown policies are passed into the callback as nil values. // // NOTE: part of the V1Store interface. -func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex, +func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { - var ctx = context.TODO() - return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { dbNode, err := db.GetNodeByPubKey( ctx, sqlc.GetNodeByPubKeyParams{ diff --git a/graph/interfaces.go b/graph/interfaces.go index be226a495..9d576e4bc 100644 --- a/graph/interfaces.go +++ b/graph/interfaces.go @@ -257,7 +257,7 @@ type DB interface { // to the caller. // // Unknown policies are passed into the callback as nil values. - ForEachNodeChannel(nodePub route.Vertex, + ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error diff --git a/rpcserver.go b/rpcserver.go index 87bda9633..5b39c47ca 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -7020,7 +7020,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, channels []*lnrpc.ChannelEdge ) - err = graph.ForEachNodeChannel(node.PubKeyBytes, + err = graph.ForEachNodeChannel(ctx, node.PubKeyBytes, func(edge *models.ChannelEdgeInfo, c1, c2 *models.ChannelEdgePolicy) error { @@ -7702,7 +7702,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, } var feeReports []*lnrpc.ChannelFeeReport - err = channelGraph.ForEachNodeChannel(selfNode.PubKeyBytes, + err = channelGraph.ForEachNodeChannel(ctx, selfNode.PubKeyBytes, func(chanInfo *models.ChannelEdgeInfo, edgePolicy, _ *models.ChannelEdgePolicy) error { diff --git a/server.go b/server.go index 7ca8e8ea8..729490e25 100644 --- a/server.go +++ b/server.go @@ -1255,7 +1255,9 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, ForAllOutgoingChannels: func(cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { - return s.graphDB.ForEachNodeChannel(selfVertex, + ctx := context.TODO() + + return s.graphDB.ForEachNodeChannel(ctx, selfVertex, func(c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error {