From 03ef2740a66034a7b58da8943aa1450e8d1f9eef Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 13 Aug 2025 13:07:11 +0200 Subject: [PATCH] graph/db+sqldb: use batch validation for node migration Restructue the `migrateNodes` function so that it does the validation of migrated nodes in batches. So instead of fetching each node individually after migrating it, we wait for a minimum batch size to be reached and then validate a batch of nodes together. This lets us make way fewer DB round trips. --- graph/db/sql_migration.go | 169 +++++++++++++++++++++++++---------- graph/db/sql_store.go | 1 + sqldb/sqlc/graph.sql.go | 47 ++++++++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 5 ++ 5 files changed, 176 insertions(+), 47 deletions(-) diff --git a/graph/db/sql_migration.go b/graph/db/sql_migration.go index 9b22fbf03..72947bd3f 100644 --- a/graph/db/sql_migration.go +++ b/graph/db/sql_migration.go @@ -107,8 +107,8 @@ func checkGraphExists(db kvdb.Backend) (bool, error) { } // migrateNodes migrates all nodes from the KV backend to the SQL database. -// This includes doing a sanity check after each migration to ensure that the -// migrated node matches the original node. +// It collects nodes in batches, inserts them individually, and then validates +// them in batches. func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig, kvBackend kvdb.Backend, sqlDB SQLQueries) error { @@ -125,6 +125,108 @@ func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig, } ) + // batch is a map that holds node objects that have been migrated to + // the native SQL store that have yet to be validated. The object's held + // by this map were derived from the KVDB store and so when they are + // validated, the map index (the SQL store node ID) will be used to + // fetch the corresponding node object in the SQL store, and it will + // then be compared against the original KVDB node object. + batch := make( + map[int64]*models.LightningNode, cfg.MaxBatchSize, + ) + + // validateBatch validates that the batch of nodes in the 'batch' map + // have been migrated successfully. + validateBatch := func() error { + if len(batch) == 0 { + return nil + } + + // Extract DB node IDs. + dbIDs := make([]int64, 0, len(batch)) + for dbID := range batch { + dbIDs = append(dbIDs, dbID) + } + + // Batch fetch all nodes from the database. + dbNodes, err := sqlDB.GetNodesByIDs(ctx, dbIDs) + if err != nil { + return fmt.Errorf("could not batch fetch nodes: %w", + err) + } + + // Make sure that the number of nodes fetched matches the number + // of nodes in the batch. + if len(dbNodes) != len(batch) { + return fmt.Errorf("expected to fetch %d nodes, "+ + "but got %d", len(batch), len(dbNodes)) + } + + // Now, batch fetch the normalised data for all the nodes in + // the batch. + batchData, err := batchLoadNodeData(ctx, cfg, sqlDB, dbIDs) + if err != nil { + return fmt.Errorf("unable to batch load node data: %w", + err) + } + + for _, dbNode := range dbNodes { + // Get the KVDB node info from the batch map. + node, ok := batch[dbNode.ID] + if !ok { + return fmt.Errorf("node with ID %d not found "+ + "in batch", dbNode.ID) + } + + // Build the migrated node from the DB node and the + // batch node data. + migNode, err := buildNodeWithBatchData( + dbNode, batchData, + ) + if err != nil { + return fmt.Errorf("could not build migrated "+ + "node from dbNode(db id: %d, node "+ + "pub: %x): %w", dbNode.ID, + node.PubKeyBytes, err) + } + + // Make sure that the node addresses are sorted before + // comparing them to ensure that the order of addresses + // does not affect the comparison. + slices.SortFunc( + node.Addresses, func(i, j net.Addr) int { + return cmp.Compare( + i.String(), j.String(), + ) + }, + ) + slices.SortFunc( + migNode.Addresses, func(i, j net.Addr) int { + return cmp.Compare( + i.String(), j.String(), + ) + }, + ) + + err = sqldb.CompareRecords( + node, migNode, + fmt.Sprintf("node %x", node.PubKeyBytes), + ) + if err != nil { + return fmt.Errorf("node mismatch after "+ + "migration for node %x: %w", + node.PubKeyBytes, err) + } + } + + // Clear the batch map for the next iteration. + batch = make( + map[int64]*models.LightningNode, cfg.MaxBatchSize, + ) + + return nil + } + // Loop through each node in the KV store and insert it into the SQL // database. err := forEachNode(kvBackend, func(_ kvdb.RTx, @@ -172,52 +274,16 @@ func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig, err) } - // Fetch it from the SQL store and compare it against the - // original node object to ensure the migration was successful. - dbNode, err := sqlDB.GetNodeByPubKey( - ctx, sqlc.GetNodeByPubKeyParams{ - PubKey: node.PubKeyBytes[:], - Version: int16(ProtocolV1), - }, - ) - if err != nil { - return fmt.Errorf("could not get node by pubkey (%x)"+ - "after migration: %w", pub, err) - } + // Add to validation batch. + batch[id] = node - // Sanity check: ensure the migrated node ID matches the one we - // just inserted. - if dbNode.ID != id { - return fmt.Errorf("node ID mismatch for node (%x) "+ - "after migration: expected %d, got %d", - pub, id, dbNode.ID) - } - - migratedNode, err := buildNode(ctx, cfg, sqlDB, dbNode) - if err != nil { - return fmt.Errorf("could not build migrated node "+ - "from dbNode(db id: %d, node pub: %x): %w", - dbNode.ID, pub, err) - } - - // Make sure that the node addresses are sorted before - // comparing them to ensure that the order of addresses does - // not affect the comparison. - slices.SortFunc(node.Addresses, func(i, j net.Addr) int { - return cmp.Compare(i.String(), j.String()) - }) - slices.SortFunc( - migratedNode.Addresses, func(i, j net.Addr) int { - return cmp.Compare(i.String(), j.String()) - }, - ) - - err = sqldb.CompareRecords( - node, migratedNode, fmt.Sprintf("node %x", pub), - ) - if err != nil { - return fmt.Errorf("node mismatch after migration "+ - "for node %x: %w", pub, err) + // Validate batch when full. + if len(batch) >= cfg.MaxBatchSize { + err := validateBatch() + if err != nil { + return fmt.Errorf("batch validation failed: %w", + err) + } } s.Do(func() { @@ -239,6 +305,15 @@ func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig, return fmt.Errorf("could not migrate nodes: %w", err) } + // Validate any remaining nodes in the batch. + if len(batch) > 0 { + err := validateBatch() + if err != nil { + return fmt.Errorf("final batch validation failed: %w", + err) + } + } + log.Infof("Migrated %d nodes from KV to SQL (skipped %d nodes due to "+ "invalid TLV streams)", count, skipped) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 40d5525d2..de6c150b3 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -55,6 +55,7 @@ type SQLQueries interface { */ UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error) GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error) + GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error) GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error) GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error) ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error) diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 6db5d32c8..5a81e5217 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -1729,6 +1729,53 @@ func (q *Queries) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyPa return id, err } +const getNodesByIDs = `-- name: GetNodesByIDs :many +SELECT id, version, pub_key, alias, last_update, color, signature +FROM graph_nodes +WHERE id IN (/*SLICE:ids*/?) +` + +func (q *Queries) GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error) { + query := getNodesByIDs + var queryParams []interface{} + if len(ids) > 0 { + for _, v := range ids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:ids*/?", makeQueryParams(len(queryParams), len(ids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphNode + for rows.Next() { + var i GraphNode + if err := rows.Scan( + &i.ID, + &i.Version, + &i.PubKey, + &i.Alias, + &i.LastUpdate, + &i.Color, + &i.Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many SELECT id, version, pub_key, alias, last_update, color, signature FROM graph_nodes diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 55fd74b75..7f1ad9abb 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -68,6 +68,7 @@ type Querier interface { GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]GraphNodeFeature, error) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error) + GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]GraphNode, error) GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error) GetPruneTip(ctx context.Context) (GraphPruneLog, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index ae276a5cb..55629a9ac 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -21,6 +21,11 @@ WHERE graph_nodes.last_update IS NULL OR EXCLUDED.last_update > graph_nodes.last_update RETURNING id; +-- name: GetNodesByIDs :many +SELECT * +FROM graph_nodes +WHERE id IN (sqlc.slice('ids')/*SLICE:ids*/); + -- name: GetNodeByPubKey :one SELECT * FROM graph_nodes