graph/db: use sqldb helper for ForEachNode

A pure refactor commit which updates the ForEachNode method to make use
of the new sqldb.ExecuteCollectAndBatchWithSharedDataQuery helper.
This commit is contained in:
Elle Mouton
2025-07-31 12:43:10 +02:00
parent 905941067e
commit 1219cdb7f1

View File

@@ -806,49 +806,18 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
func (s *SQLStore) ForEachNode(ctx context.Context,
cb func(tx NodeRTx) error, reset func()) error {
var lastID int64
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
nodeCB := func(dbID int64, node *models.LightningNode) error {
err := cb(newSQLGraphNodeTx(
db, s.cfg.ChainHash, dbID, node,
))
if err != nil {
return fmt.Errorf("callback failed for "+
"node(id=%d): %w", dbID, err)
}
lastID = dbID
return forEachNodePaginated(
ctx, s.cfg.QueryCfg, db,
ProtocolV1,
func(ctx context.Context, dbNodeID int64,
node *models.LightningNode) error {
return nil
}
for {
nodes, err := db.ListNodesPaginated(
ctx, sqlc.ListNodesPaginatedParams{
Version: int16(ProtocolV1),
ID: lastID,
Limit: s.cfg.QueryCfg.MaxPageSize,
},
)
if err != nil {
return fmt.Errorf("unable to fetch nodes: %w",
err)
}
if len(nodes) == 0 {
break
}
err = forEachNodeInBatch(
ctx, s.cfg.QueryCfg, db, nodes, nodeCB,
)
if err != nil {
return fmt.Errorf("unable to iterate over "+
"nodes: %w", err)
}
}
return nil
return cb(newSQLGraphNodeTx(
db, s.cfg.ChainHash, dbNodeID, node,
))
},
)
}, reset)
}
@@ -5082,3 +5051,55 @@ func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
},
)
}
// forEachNodePaginated executes a paginated query to process each node in the
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
// and applies the provided processNode function to each node.
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
db SQLQueries, protocol ProtocolVersion,
processNode func(context.Context, int64,
*models.LightningNode) error) error {
pageQueryFunc := func(ctx context.Context, lastID int64,
limit int32) ([]sqlc.GraphNode, error) {
return db.ListNodesPaginated(
ctx, sqlc.ListNodesPaginatedParams{
Version: int16(protocol),
ID: lastID,
Limit: limit,
},
)
}
extractPageCursor := func(node sqlc.GraphNode) int64 {
return node.ID
}
collectFunc := func(node sqlc.GraphNode) (int64, error) {
return node.ID, nil
}
batchQueryFunc := func(ctx context.Context,
nodeIDs []int64) (*batchNodeData, error) {
return batchLoadNodeData(ctx, cfg, db, nodeIDs)
}
processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
batchData *batchNodeData) error {
node, err := buildNodeWithBatchData(&dbNode, batchData)
if err != nil {
return fmt.Errorf("unable to build "+
"node(id=%d): %w", dbNode.ID, err)
}
return processNode(ctx, dbNode.ID, node)
}
return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
collectFunc, batchQueryFunc, processItem,
)
}