diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index c7366e155..886d1ddcc 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -806,49 +806,18 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context, func (s *SQLStore) ForEachNode(ctx context.Context, cb func(tx NodeRTx) error, reset func()) error { - var lastID int64 - return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - nodeCB := func(dbID int64, node *models.LightningNode) error { - err := cb(newSQLGraphNodeTx( - db, s.cfg.ChainHash, dbID, node, - )) - if err != nil { - return fmt.Errorf("callback failed for "+ - "node(id=%d): %w", dbID, err) - } - lastID = dbID + return forEachNodePaginated( + ctx, s.cfg.QueryCfg, db, + ProtocolV1, + func(ctx context.Context, dbNodeID int64, + node *models.LightningNode) error { - return nil - } - - for { - nodes, err := db.ListNodesPaginated( - ctx, sqlc.ListNodesPaginatedParams{ - Version: int16(ProtocolV1), - ID: lastID, - Limit: s.cfg.QueryCfg.MaxPageSize, - }, - ) - if err != nil { - return fmt.Errorf("unable to fetch nodes: %w", - err) - } - - if len(nodes) == 0 { - break - } - - err = forEachNodeInBatch( - ctx, s.cfg.QueryCfg, db, nodes, nodeCB, - ) - if err != nil { - return fmt.Errorf("unable to iterate over "+ - "nodes: %w", err) - } - } - - return nil + return cb(newSQLGraphNodeTx( + db, s.cfg.ChainHash, dbNodeID, node, + )) + }, + ) }, reset) } @@ -5082,3 +5051,55 @@ func batchLoadChannelPolicyExtrasHelper(ctx context.Context, }, ) } + +// forEachNodePaginated executes a paginated query to process each node in the +// graph. It uses the provided SQLQueries interface to fetch nodes in batches +// and applies the provided processNode function to each node. +func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig, + db SQLQueries, protocol ProtocolVersion, + processNode func(context.Context, int64, + *models.LightningNode) error) error { + + pageQueryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.GraphNode, error) { + + return db.ListNodesPaginated( + ctx, sqlc.ListNodesPaginatedParams{ + Version: int16(protocol), + ID: lastID, + Limit: limit, + }, + ) + } + + extractPageCursor := func(node sqlc.GraphNode) int64 { + return node.ID + } + + collectFunc := func(node sqlc.GraphNode) (int64, error) { + return node.ID, nil + } + + batchQueryFunc := func(ctx context.Context, + nodeIDs []int64) (*batchNodeData, error) { + + return batchLoadNodeData(ctx, cfg, db, nodeIDs) + } + + processItem := func(ctx context.Context, dbNode sqlc.GraphNode, + batchData *batchNodeData) error { + + node, err := buildNodeWithBatchData(&dbNode, batchData) + if err != nil { + return fmt.Errorf("unable to build "+ + "node(id=%d): %w", dbNode.ID, err) + } + + return processNode(ctx, dbNode.ID, node) + } + + return sqldb.ExecuteCollectAndBatchWithSharedDataQuery( + ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor, + collectFunc, batchQueryFunc, processItem, + ) +}