mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-12-07 19:32:02 +01:00
graph/db+sqldb: use batch validation for node migration
Restructue the `migrateNodes` function so that it does the validation of migrated nodes in batches. So instead of fetching each node individually after migrating it, we wait for a minimum batch size to be reached and then validate a batch of nodes together. This lets us make way fewer DB round trips.
This commit is contained in:
@@ -107,8 +107,8 @@ func checkGraphExists(db kvdb.Backend) (bool, error) {
|
||||
}
|
||||
|
||||
// migrateNodes migrates all nodes from the KV backend to the SQL database.
|
||||
// This includes doing a sanity check after each migration to ensure that the
|
||||
// migrated node matches the original node.
|
||||
// It collects nodes in batches, inserts them individually, and then validates
|
||||
// them in batches.
|
||||
func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
|
||||
kvBackend kvdb.Backend, sqlDB SQLQueries) error {
|
||||
|
||||
@@ -125,6 +125,108 @@ func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
|
||||
}
|
||||
)
|
||||
|
||||
// batch is a map that holds node objects that have been migrated to
|
||||
// the native SQL store that have yet to be validated. The object's held
|
||||
// by this map were derived from the KVDB store and so when they are
|
||||
// validated, the map index (the SQL store node ID) will be used to
|
||||
// fetch the corresponding node object in the SQL store, and it will
|
||||
// then be compared against the original KVDB node object.
|
||||
batch := make(
|
||||
map[int64]*models.LightningNode, cfg.MaxBatchSize,
|
||||
)
|
||||
|
||||
// validateBatch validates that the batch of nodes in the 'batch' map
|
||||
// have been migrated successfully.
|
||||
validateBatch := func() error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract DB node IDs.
|
||||
dbIDs := make([]int64, 0, len(batch))
|
||||
for dbID := range batch {
|
||||
dbIDs = append(dbIDs, dbID)
|
||||
}
|
||||
|
||||
// Batch fetch all nodes from the database.
|
||||
dbNodes, err := sqlDB.GetNodesByIDs(ctx, dbIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not batch fetch nodes: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
// Make sure that the number of nodes fetched matches the number
|
||||
// of nodes in the batch.
|
||||
if len(dbNodes) != len(batch) {
|
||||
return fmt.Errorf("expected to fetch %d nodes, "+
|
||||
"but got %d", len(batch), len(dbNodes))
|
||||
}
|
||||
|
||||
// Now, batch fetch the normalised data for all the nodes in
|
||||
// the batch.
|
||||
batchData, err := batchLoadNodeData(ctx, cfg, sqlDB, dbIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to batch load node data: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
for _, dbNode := range dbNodes {
|
||||
// Get the KVDB node info from the batch map.
|
||||
node, ok := batch[dbNode.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf("node with ID %d not found "+
|
||||
"in batch", dbNode.ID)
|
||||
}
|
||||
|
||||
// Build the migrated node from the DB node and the
|
||||
// batch node data.
|
||||
migNode, err := buildNodeWithBatchData(
|
||||
dbNode, batchData,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not build migrated "+
|
||||
"node from dbNode(db id: %d, node "+
|
||||
"pub: %x): %w", dbNode.ID,
|
||||
node.PubKeyBytes, err)
|
||||
}
|
||||
|
||||
// Make sure that the node addresses are sorted before
|
||||
// comparing them to ensure that the order of addresses
|
||||
// does not affect the comparison.
|
||||
slices.SortFunc(
|
||||
node.Addresses, func(i, j net.Addr) int {
|
||||
return cmp.Compare(
|
||||
i.String(), j.String(),
|
||||
)
|
||||
},
|
||||
)
|
||||
slices.SortFunc(
|
||||
migNode.Addresses, func(i, j net.Addr) int {
|
||||
return cmp.Compare(
|
||||
i.String(), j.String(),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
err = sqldb.CompareRecords(
|
||||
node, migNode,
|
||||
fmt.Sprintf("node %x", node.PubKeyBytes),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("node mismatch after "+
|
||||
"migration for node %x: %w",
|
||||
node.PubKeyBytes, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the batch map for the next iteration.
|
||||
batch = make(
|
||||
map[int64]*models.LightningNode, cfg.MaxBatchSize,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop through each node in the KV store and insert it into the SQL
|
||||
// database.
|
||||
err := forEachNode(kvBackend, func(_ kvdb.RTx,
|
||||
@@ -172,52 +274,16 @@ func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
|
||||
err)
|
||||
}
|
||||
|
||||
// Fetch it from the SQL store and compare it against the
|
||||
// original node object to ensure the migration was successful.
|
||||
dbNode, err := sqlDB.GetNodeByPubKey(
|
||||
ctx, sqlc.GetNodeByPubKeyParams{
|
||||
PubKey: node.PubKeyBytes[:],
|
||||
Version: int16(ProtocolV1),
|
||||
},
|
||||
)
|
||||
// Add to validation batch.
|
||||
batch[id] = node
|
||||
|
||||
// Validate batch when full.
|
||||
if len(batch) >= cfg.MaxBatchSize {
|
||||
err := validateBatch()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get node by pubkey (%x)"+
|
||||
"after migration: %w", pub, err)
|
||||
return fmt.Errorf("batch validation failed: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
// Sanity check: ensure the migrated node ID matches the one we
|
||||
// just inserted.
|
||||
if dbNode.ID != id {
|
||||
return fmt.Errorf("node ID mismatch for node (%x) "+
|
||||
"after migration: expected %d, got %d",
|
||||
pub, id, dbNode.ID)
|
||||
}
|
||||
|
||||
migratedNode, err := buildNode(ctx, cfg, sqlDB, dbNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not build migrated node "+
|
||||
"from dbNode(db id: %d, node pub: %x): %w",
|
||||
dbNode.ID, pub, err)
|
||||
}
|
||||
|
||||
// Make sure that the node addresses are sorted before
|
||||
// comparing them to ensure that the order of addresses does
|
||||
// not affect the comparison.
|
||||
slices.SortFunc(node.Addresses, func(i, j net.Addr) int {
|
||||
return cmp.Compare(i.String(), j.String())
|
||||
})
|
||||
slices.SortFunc(
|
||||
migratedNode.Addresses, func(i, j net.Addr) int {
|
||||
return cmp.Compare(i.String(), j.String())
|
||||
},
|
||||
)
|
||||
|
||||
err = sqldb.CompareRecords(
|
||||
node, migratedNode, fmt.Sprintf("node %x", pub),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("node mismatch after migration "+
|
||||
"for node %x: %w", pub, err)
|
||||
}
|
||||
|
||||
s.Do(func() {
|
||||
@@ -239,6 +305,15 @@ func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
|
||||
return fmt.Errorf("could not migrate nodes: %w", err)
|
||||
}
|
||||
|
||||
// Validate any remaining nodes in the batch.
|
||||
if len(batch) > 0 {
|
||||
err := validateBatch()
|
||||
if err != nil {
|
||||
return fmt.Errorf("final batch validation failed: %w",
|
||||
err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Migrated %d nodes from KV to SQL (skipped %d nodes due to "+
|
||||
"invalid TLV streams)", count, skipped)
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ type SQLQueries interface {
|
||||
*/
|
||||
UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
|
||||
GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
|
||||
GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
|
||||
GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
|
||||
GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
|
||||
ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
|
||||
|
||||
@@ -1729,6 +1729,53 @@ func (q *Queries) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyPa
|
||||
return id, err
|
||||
}
|
||||
|
||||
const getNodesByIDs = `-- name: GetNodesByIDs :many
|
||||
SELECT id, version, pub_key, alias, last_update, color, signature
|
||||
FROM graph_nodes
|
||||
WHERE id IN (/*SLICE:ids*/?)
|
||||
`
|
||||
|
||||
func (q *Queries) GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error) {
|
||||
query := getNodesByIDs
|
||||
var queryParams []interface{}
|
||||
if len(ids) > 0 {
|
||||
for _, v := range ids {
|
||||
queryParams = append(queryParams, v)
|
||||
}
|
||||
query = strings.Replace(query, "/*SLICE:ids*/?", makeQueryParams(len(queryParams), len(ids)), 1)
|
||||
} else {
|
||||
query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1)
|
||||
}
|
||||
rows, err := q.db.QueryContext(ctx, query, queryParams...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GraphNode
|
||||
for rows.Next() {
|
||||
var i GraphNode
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Version,
|
||||
&i.PubKey,
|
||||
&i.Alias,
|
||||
&i.LastUpdate,
|
||||
&i.Color,
|
||||
&i.Signature,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many
|
||||
SELECT id, version, pub_key, alias, last_update, color, signature
|
||||
FROM graph_nodes
|
||||
|
||||
@@ -68,6 +68,7 @@ type Querier interface {
|
||||
GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]GraphNodeFeature, error)
|
||||
GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error)
|
||||
GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error)
|
||||
GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error)
|
||||
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]GraphNode, error)
|
||||
GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
|
||||
GetPruneTip(ctx context.Context) (GraphPruneLog, error)
|
||||
|
||||
@@ -21,6 +21,11 @@ WHERE graph_nodes.last_update IS NULL
|
||||
OR EXCLUDED.last_update > graph_nodes.last_update
|
||||
RETURNING id;
|
||||
|
||||
-- name: GetNodesByIDs :many
|
||||
SELECT *
|
||||
FROM graph_nodes
|
||||
WHERE id IN (sqlc.slice('ids')/*SLICE:ids*/);
|
||||
|
||||
-- name: GetNodeByPubKey :one
|
||||
SELECT *
|
||||
FROM graph_nodes
|
||||
|
||||
Reference in New Issue
Block a user