diff --git a/graph/db/sql_migration.go b/graph/db/sql_migration.go index d62fe74ac..467f5436e 100644 --- a/graph/db/sql_migration.go +++ b/graph/db/sql_migration.go @@ -67,7 +67,7 @@ func MigrateGraphToSQL(ctx context.Context, cfg *SQLStoreConfig, } // 5) Migrate the closed SCID index. - err = migrateClosedSCIDIndex(ctx, kvBackend, sqlDB) + err = migrateClosedSCIDIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB) if err != nil { return fmt.Errorf("could not migrate closed SCID index: %w", err) @@ -992,12 +992,11 @@ func forEachPruneLogEntry(db kvdb.Backend, cb func(height uint32, } // migrateClosedSCIDIndex migrates the closed SCID index from the KV backend to -// the SQL database. It iterates over each closed SCID in the KV store, inserts -// it into the SQL database, and then verifies that the SCID was inserted -// correctly by checking if the channel with the given SCID is seen as closed in -// the SQL database. -func migrateClosedSCIDIndex(ctx context.Context, kvBackend kvdb.Backend, - sqlDB SQLQueries) error { +// the SQL database. It collects SCIDs in batches, inserts them individually, +// and then validates them in batches using GetClosedChannelsSCIDs for better +// performance. +func migrateClosedSCIDIndex(ctx context.Context, cfg *sqldb.QueryConfig, + kvBackend kvdb.Backend, sqlDB SQLQueries) error { var ( count uint64 @@ -1008,6 +1007,43 @@ func migrateClosedSCIDIndex(ctx context.Context, kvBackend kvdb.Backend, Interval: 10 * time.Second, } ) + + batch := make([][]byte, 0, cfg.MaxBatchSize) + + // validateBatch validates a batch of closed SCIDs using batch query. + validateBatch := func() error { + if len(batch) == 0 { + return nil + } + + // Batch fetch all closed SCIDs from the database. + dbSCIDs, err := sqlDB.GetClosedChannelsSCIDs(ctx, batch) + if err != nil { + return fmt.Errorf("could not batch get closed "+ + "SCIDs: %w", err) + } + + // Create set of SCIDs that exist in the database for quick + // lookup. + dbSCIDSet := make(map[string]struct{}) + for _, scid := range dbSCIDs { + dbSCIDSet[string(scid)] = struct{}{} + } + + // Validate each SCID in the batch. + for _, expectedSCID := range batch { + if _, found := dbSCIDSet[string(expectedSCID)]; !found { + return fmt.Errorf("closed SCID %x not found "+ + "in database", expectedSCID) + } + } + + // Reset the batch for the next iteration. + batch = make([][]byte, 0, cfg.MaxBatchSize) + + return nil + } + migrateSingleClosedSCID := func(scid lnwire.ShortChannelID) error { count++ chunk++ @@ -1019,17 +1055,16 @@ func migrateClosedSCIDIndex(ctx context.Context, kvBackend kvdb.Backend, "with SCID %s: %w", scid, err) } - // Now, verify that the channel with the given SCID is - // seen as closed. - isClosed, err := sqlDB.IsClosedChannel(ctx, chanIDB) - if err != nil { - return fmt.Errorf("could not check if channel %s "+ - "is closed: %w", scid, err) - } + // Add to validation batch. + batch = append(batch, chanIDB) - if !isClosed { - return fmt.Errorf("channel %s should be closed, "+ - "but is not", scid) + // Validate batch when full. + if len(batch) >= cfg.MaxBatchSize { + err := validateBatch() + if err != nil { + return fmt.Errorf("batch validation failed: %w", + err) + } } s.Do(func() { @@ -1051,6 +1086,15 @@ func migrateClosedSCIDIndex(ctx context.Context, kvBackend kvdb.Backend, err) } + // Validate any remaining SCIDs in the batch. + if len(batch) > 0 { + err := validateBatch() + if err != nil { + return fmt.Errorf("final batch validation failed: %w", + err) + } + } + log.Infof("Migrated %d closed SCIDs from KV to SQL", count) return nil diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 1de50165b..fbfee848f 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -151,6 +151,7 @@ type SQLQueries interface { */ InsertClosedChannel(ctx context.Context, scid []byte) error IsClosedChannel(ctx context.Context, scid []byte) (bool, error) + GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error) } // BatchedSQLQueries is a version of SQLQueries that's capable of batched diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 5718f25c1..7b9e472a1 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -1602,6 +1602,45 @@ func (q *Queries) GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDs return items, nil } +const getClosedChannelsSCIDs = `-- name: GetClosedChannelsSCIDs :many +SELECT scid +FROM graph_closed_scids +WHERE scid IN (/*SLICE:scids*/?) +` + +func (q *Queries) GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error) { + query := getClosedChannelsSCIDs + var queryParams []interface{} + if len(scids) > 0 { + for _, v := range scids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:scids*/?", makeQueryParams(len(queryParams), len(scids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:scids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items [][]byte + for rows.Next() { + var scid []byte + if err := rows.Scan(&scid); err != nil { + return nil, err + } + items = append(items, scid) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many SELECT node_id, type, value FROM graph_node_extra_types diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index e7b4f3f52..b3c0857eb 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -48,6 +48,7 @@ type Querier interface { GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error) + GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]GraphNodeExtraType, error) // This method may return more than one invoice if filter using multiple fields diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index ed0eb8c50..f77e5044c 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -984,3 +984,8 @@ SELECT EXISTS ( FROM graph_closed_scids WHERE scid = $1 ); + +-- name: GetClosedChannelsSCIDs :many +SELECT scid +FROM graph_closed_scids +WHERE scid IN (sqlc.slice('scids')/*SLICE:scids*/);