graph/db+sqldb: use batch validation for node migration

Restructue the `migrateNodes` function so that it does the validation of
migrated nodes in batches. So instead of fetching each node individually
after migrating it, we wait for a minimum batch size to be reached and
then validate a batch of nodes together. This lets us make way fewer DB
round trips.
This commit is contained in:
Elle Mouton
2025-08-13 13:07:11 +02:00
parent 218aa9eaa8
commit 03ef2740a6
5 changed files with 176 additions and 47 deletions

View File

@@ -107,8 +107,8 @@ func checkGraphExists(db kvdb.Backend) (bool, error) {
}
// migrateNodes migrates all nodes from the KV backend to the SQL database.
// This includes doing a sanity check after each migration to ensure that the
// migrated node matches the original node.
// It collects nodes in batches, inserts them individually, and then validates
// them in batches.
func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
kvBackend kvdb.Backend, sqlDB SQLQueries) error {
@@ -125,6 +125,108 @@ func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
}
)
// batch is a map that holds node objects that have been migrated to
// the native SQL store that have yet to be validated. The object's held
// by this map were derived from the KVDB store and so when they are
// validated, the map index (the SQL store node ID) will be used to
// fetch the corresponding node object in the SQL store, and it will
// then be compared against the original KVDB node object.
batch := make(
map[int64]*models.LightningNode, cfg.MaxBatchSize,
)
// validateBatch validates that the batch of nodes in the 'batch' map
// have been migrated successfully.
validateBatch := func() error {
if len(batch) == 0 {
return nil
}
// Extract DB node IDs.
dbIDs := make([]int64, 0, len(batch))
for dbID := range batch {
dbIDs = append(dbIDs, dbID)
}
// Batch fetch all nodes from the database.
dbNodes, err := sqlDB.GetNodesByIDs(ctx, dbIDs)
if err != nil {
return fmt.Errorf("could not batch fetch nodes: %w",
err)
}
// Make sure that the number of nodes fetched matches the number
// of nodes in the batch.
if len(dbNodes) != len(batch) {
return fmt.Errorf("expected to fetch %d nodes, "+
"but got %d", len(batch), len(dbNodes))
}
// Now, batch fetch the normalised data for all the nodes in
// the batch.
batchData, err := batchLoadNodeData(ctx, cfg, sqlDB, dbIDs)
if err != nil {
return fmt.Errorf("unable to batch load node data: %w",
err)
}
for _, dbNode := range dbNodes {
// Get the KVDB node info from the batch map.
node, ok := batch[dbNode.ID]
if !ok {
return fmt.Errorf("node with ID %d not found "+
"in batch", dbNode.ID)
}
// Build the migrated node from the DB node and the
// batch node data.
migNode, err := buildNodeWithBatchData(
dbNode, batchData,
)
if err != nil {
return fmt.Errorf("could not build migrated "+
"node from dbNode(db id: %d, node "+
"pub: %x): %w", dbNode.ID,
node.PubKeyBytes, err)
}
// Make sure that the node addresses are sorted before
// comparing them to ensure that the order of addresses
// does not affect the comparison.
slices.SortFunc(
node.Addresses, func(i, j net.Addr) int {
return cmp.Compare(
i.String(), j.String(),
)
},
)
slices.SortFunc(
migNode.Addresses, func(i, j net.Addr) int {
return cmp.Compare(
i.String(), j.String(),
)
},
)
err = sqldb.CompareRecords(
node, migNode,
fmt.Sprintf("node %x", node.PubKeyBytes),
)
if err != nil {
return fmt.Errorf("node mismatch after "+
"migration for node %x: %w",
node.PubKeyBytes, err)
}
}
// Clear the batch map for the next iteration.
batch = make(
map[int64]*models.LightningNode, cfg.MaxBatchSize,
)
return nil
}
// Loop through each node in the KV store and insert it into the SQL
// database.
err := forEachNode(kvBackend, func(_ kvdb.RTx,
@@ -172,52 +274,16 @@ func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
err)
}
// Fetch it from the SQL store and compare it against the
// original node object to ensure the migration was successful.
dbNode, err := sqlDB.GetNodeByPubKey(
ctx, sqlc.GetNodeByPubKeyParams{
PubKey: node.PubKeyBytes[:],
Version: int16(ProtocolV1),
},
)
// Add to validation batch.
batch[id] = node
// Validate batch when full.
if len(batch) >= cfg.MaxBatchSize {
err := validateBatch()
if err != nil {
return fmt.Errorf("could not get node by pubkey (%x)"+
"after migration: %w", pub, err)
return fmt.Errorf("batch validation failed: %w",
err)
}
// Sanity check: ensure the migrated node ID matches the one we
// just inserted.
if dbNode.ID != id {
return fmt.Errorf("node ID mismatch for node (%x) "+
"after migration: expected %d, got %d",
pub, id, dbNode.ID)
}
migratedNode, err := buildNode(ctx, cfg, sqlDB, dbNode)
if err != nil {
return fmt.Errorf("could not build migrated node "+
"from dbNode(db id: %d, node pub: %x): %w",
dbNode.ID, pub, err)
}
// Make sure that the node addresses are sorted before
// comparing them to ensure that the order of addresses does
// not affect the comparison.
slices.SortFunc(node.Addresses, func(i, j net.Addr) int {
return cmp.Compare(i.String(), j.String())
})
slices.SortFunc(
migratedNode.Addresses, func(i, j net.Addr) int {
return cmp.Compare(i.String(), j.String())
},
)
err = sqldb.CompareRecords(
node, migratedNode, fmt.Sprintf("node %x", pub),
)
if err != nil {
return fmt.Errorf("node mismatch after migration "+
"for node %x: %w", pub, err)
}
s.Do(func() {
@@ -239,6 +305,15 @@ func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
return fmt.Errorf("could not migrate nodes: %w", err)
}
// Validate any remaining nodes in the batch.
if len(batch) > 0 {
err := validateBatch()
if err != nil {
return fmt.Errorf("final batch validation failed: %w",
err)
}
}
log.Infof("Migrated %d nodes from KV to SQL (skipped %d nodes due to "+
"invalid TLV streams)", count, skipped)

View File

@@ -55,6 +55,7 @@ type SQLQueries interface {
*/
UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)

View File

@@ -1729,6 +1729,53 @@ func (q *Queries) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyPa
return id, err
}
const getNodesByIDs = `-- name: GetNodesByIDs :many
SELECT id, version, pub_key, alias, last_update, color, signature
FROM graph_nodes
WHERE id IN (/*SLICE:ids*/?)
`
func (q *Queries) GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error) {
query := getNodesByIDs
var queryParams []interface{}
if len(ids) > 0 {
for _, v := range ids {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:ids*/?", makeQueryParams(len(queryParams), len(ids)), 1)
} else {
query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1)
}
rows, err := q.db.QueryContext(ctx, query, queryParams...)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GraphNode
for rows.Next() {
var i GraphNode
if err := rows.Scan(
&i.ID,
&i.Version,
&i.PubKey,
&i.Alias,
&i.LastUpdate,
&i.Color,
&i.Signature,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many
SELECT id, version, pub_key, alias, last_update, color, signature
FROM graph_nodes

View File

@@ -68,6 +68,7 @@ type Querier interface {
GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]GraphNodeFeature, error)
GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error)
GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error)
GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error)
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]GraphNode, error)
GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
GetPruneTip(ctx context.Context) (GraphPruneLog, error)

View File

@@ -21,6 +21,11 @@ WHERE graph_nodes.last_update IS NULL
OR EXCLUDED.last_update > graph_nodes.last_update
RETURNING id;
-- name: GetNodesByIDs :many
SELECT *
FROM graph_nodes
WHERE id IN (sqlc.slice('ids')/*SLICE:ids*/);
-- name: GetNodeByPubKey :one
SELECT *
FROM graph_nodes