graph/db+sqldb: use batch validation for closed SCID migration

As was done in the previous commits for nodes & channels, we update the
migrateClosedSCIDIndex function here so that it validates migrated
entries in batches rather than one-by-one.
This commit is contained in:
Elle Mouton
2025-08-13 13:56:50 +02:00
parent 8554f17b3f
commit a490e03479
5 changed files with 107 additions and 17 deletions

View File

@@ -67,7 +67,7 @@ func MigrateGraphToSQL(ctx context.Context, cfg *SQLStoreConfig,
} }
// 5) Migrate the closed SCID index. // 5) Migrate the closed SCID index.
err = migrateClosedSCIDIndex(ctx, kvBackend, sqlDB) err = migrateClosedSCIDIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB)
if err != nil { if err != nil {
return fmt.Errorf("could not migrate closed SCID index: %w", return fmt.Errorf("could not migrate closed SCID index: %w",
err) err)
@@ -992,12 +992,11 @@ func forEachPruneLogEntry(db kvdb.Backend, cb func(height uint32,
} }
// migrateClosedSCIDIndex migrates the closed SCID index from the KV backend to // migrateClosedSCIDIndex migrates the closed SCID index from the KV backend to
// the SQL database. It iterates over each closed SCID in the KV store, inserts // the SQL database. It collects SCIDs in batches, inserts them individually,
// it into the SQL database, and then verifies that the SCID was inserted // and then validates them in batches using GetClosedChannelsSCIDs for better
// correctly by checking if the channel with the given SCID is seen as closed in // performance.
// the SQL database. func migrateClosedSCIDIndex(ctx context.Context, cfg *sqldb.QueryConfig,
func migrateClosedSCIDIndex(ctx context.Context, kvBackend kvdb.Backend, kvBackend kvdb.Backend, sqlDB SQLQueries) error {
sqlDB SQLQueries) error {
var ( var (
count uint64 count uint64
@@ -1008,6 +1007,43 @@ func migrateClosedSCIDIndex(ctx context.Context, kvBackend kvdb.Backend,
Interval: 10 * time.Second, Interval: 10 * time.Second,
} }
) )
batch := make([][]byte, 0, cfg.MaxBatchSize)
// validateBatch validates a batch of closed SCIDs using batch query.
validateBatch := func() error {
if len(batch) == 0 {
return nil
}
// Batch fetch all closed SCIDs from the database.
dbSCIDs, err := sqlDB.GetClosedChannelsSCIDs(ctx, batch)
if err != nil {
return fmt.Errorf("could not batch get closed "+
"SCIDs: %w", err)
}
// Create set of SCIDs that exist in the database for quick
// lookup.
dbSCIDSet := make(map[string]struct{})
for _, scid := range dbSCIDs {
dbSCIDSet[string(scid)] = struct{}{}
}
// Validate each SCID in the batch.
for _, expectedSCID := range batch {
if _, found := dbSCIDSet[string(expectedSCID)]; !found {
return fmt.Errorf("closed SCID %x not found "+
"in database", expectedSCID)
}
}
// Reset the batch for the next iteration.
batch = make([][]byte, 0, cfg.MaxBatchSize)
return nil
}
migrateSingleClosedSCID := func(scid lnwire.ShortChannelID) error { migrateSingleClosedSCID := func(scid lnwire.ShortChannelID) error {
count++ count++
chunk++ chunk++
@@ -1019,17 +1055,16 @@ func migrateClosedSCIDIndex(ctx context.Context, kvBackend kvdb.Backend,
"with SCID %s: %w", scid, err) "with SCID %s: %w", scid, err)
} }
// Now, verify that the channel with the given SCID is // Add to validation batch.
// seen as closed. batch = append(batch, chanIDB)
isClosed, err := sqlDB.IsClosedChannel(ctx, chanIDB)
if err != nil {
return fmt.Errorf("could not check if channel %s "+
"is closed: %w", scid, err)
}
if !isClosed { // Validate batch when full.
return fmt.Errorf("channel %s should be closed, "+ if len(batch) >= cfg.MaxBatchSize {
"but is not", scid) err := validateBatch()
if err != nil {
return fmt.Errorf("batch validation failed: %w",
err)
}
} }
s.Do(func() { s.Do(func() {
@@ -1051,6 +1086,15 @@ func migrateClosedSCIDIndex(ctx context.Context, kvBackend kvdb.Backend,
err) err)
} }
// Validate any remaining SCIDs in the batch.
if len(batch) > 0 {
err := validateBatch()
if err != nil {
return fmt.Errorf("final batch validation failed: %w",
err)
}
}
log.Infof("Migrated %d closed SCIDs from KV to SQL", count) log.Infof("Migrated %d closed SCIDs from KV to SQL", count)
return nil return nil

View File

@@ -151,6 +151,7 @@ type SQLQueries interface {
*/ */
InsertClosedChannel(ctx context.Context, scid []byte) error InsertClosedChannel(ctx context.Context, scid []byte) error
IsClosedChannel(ctx context.Context, scid []byte) (bool, error) IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
} }
// BatchedSQLQueries is a version of SQLQueries that's capable of batched // BatchedSQLQueries is a version of SQLQueries that's capable of batched

View File

@@ -1602,6 +1602,45 @@ func (q *Queries) GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDs
return items, nil return items, nil
} }
const getClosedChannelsSCIDs = `-- name: GetClosedChannelsSCIDs :many
SELECT scid
FROM graph_closed_scids
WHERE scid IN (/*SLICE:scids*/?)
`
func (q *Queries) GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error) {
query := getClosedChannelsSCIDs
var queryParams []interface{}
if len(scids) > 0 {
for _, v := range scids {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:scids*/?", makeQueryParams(len(queryParams), len(scids)), 1)
} else {
query = strings.Replace(query, "/*SLICE:scids*/?", "NULL", 1)
}
rows, err := q.db.QueryContext(ctx, query, queryParams...)
if err != nil {
return nil, err
}
defer rows.Close()
var items [][]byte
for rows.Next() {
var scid []byte
if err := rows.Scan(&scid); err != nil {
return nil, err
}
items = append(items, scid)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many
SELECT node_id, type, value SELECT node_id, type, value
FROM graph_node_extra_types FROM graph_node_extra_types

View File

@@ -48,6 +48,7 @@ type Querier interface {
GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error)
GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error)
GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error) GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error)
GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
GetDatabaseVersion(ctx context.Context) (int32, error) GetDatabaseVersion(ctx context.Context) (int32, error)
GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]GraphNodeExtraType, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]GraphNodeExtraType, error)
// This method may return more than one invoice if filter using multiple fields // This method may return more than one invoice if filter using multiple fields

View File

@@ -984,3 +984,8 @@ SELECT EXISTS (
FROM graph_closed_scids FROM graph_closed_scids
WHERE scid = $1 WHERE scid = $1
); );
-- name: GetClosedChannelsSCIDs :many
SELECT scid
FROM graph_closed_scids
WHERE scid IN (sqlc.slice('scids')/*SLICE:scids*/);