From 5b064747447c1108b6b308ef213d6109dc4965a8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 13 Aug 2025 14:11:40 +0200 Subject: [PATCH] graph/db+sqldb: batch validation for zombie index migration Finally, we update the migrateZombieIndex function to use batch validation just like was done in the previous commits. Here, we additionally make sure to validate the entire zombie index entry and not just the SCID. --- graph/db/sql_migration.go | 122 +++++++++++++++++++++++++++++------ graph/db/sql_store.go | 1 + sqldb/sqlc/graph.sql.go | 51 +++++++++++++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 6 ++ 5 files changed, 160 insertions(+), 21 deletions(-) diff --git a/graph/db/sql_migration.go b/graph/db/sql_migration.go index 467f5436e..17a751f87 100644 --- a/graph/db/sql_migration.go +++ b/graph/db/sql_migration.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/sqldb" "github.com/lightningnetwork/lnd/sqldb/sqlc" "golang.org/x/time/rate" @@ -74,7 +75,8 @@ func MigrateGraphToSQL(ctx context.Context, cfg *SQLStoreConfig, } // 6) Migrate the zombie index. - if err := migrateZombieIndex(ctx, kvBackend, sqlDB); err != nil { + err = migrateZombieIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB) + if err != nil { return fmt.Errorf("could not migrate zombie index: %w", err) } @@ -1100,18 +1102,17 @@ func migrateClosedSCIDIndex(ctx context.Context, cfg *sqldb.QueryConfig, return nil } -// migrateZombieIndex migrates the zombie index from the KV backend to -// the SQL database. It iterates over each zombie channel in the KV store, -// inserts it into the SQL database, and then verifies that the channel is -// indeed marked as a zombie channel in the SQL database. +// migrateZombieIndex migrates the zombie index from the KV backend to the SQL +// database. It collects zombie channels in batches, inserts them individually, +// and validates them in batches. // // NOTE: before inserting an entry into the zombie index, the function checks // if the channel is already marked as closed in the SQL store. If it is, // the entry is skipped. This means that the resulting zombie index count in // the SQL store may well be less than the count of zombie channels in the KV // store. -func migrateZombieIndex(ctx context.Context, kvBackend kvdb.Backend, - sqlDB SQLQueries) error { +func migrateZombieIndex(ctx context.Context, cfg *sqldb.QueryConfig, + kvBackend kvdb.Backend, sqlDB SQLQueries) error { var ( count uint64 @@ -1122,6 +1123,79 @@ func migrateZombieIndex(ctx context.Context, kvBackend kvdb.Backend, Interval: 10 * time.Second, } ) + + type zombieEntry struct { + pub1 route.Vertex + pub2 route.Vertex + } + + batch := make(map[uint64]*zombieEntry, cfg.MaxBatchSize) + + // validateBatch validates a batch of zombie SCIDs using batch query. + validateBatch := func() error { + if len(batch) == 0 { + return nil + } + + scids := make([][]byte, 0, len(batch)) + for scid := range batch { + scids = append(scids, channelIDToBytes(scid)) + } + + // Batch fetch all zombie channels from the database. + rows, err := sqlDB.GetZombieChannelsSCIDs( + ctx, sqlc.GetZombieChannelsSCIDsParams{ + Version: int16(ProtocolV1), + Scids: scids, + }, + ) + if err != nil { + return fmt.Errorf("could not batch get zombie "+ + "SCIDs: %w", err) + } + + // Make sure that the number of rows returned matches + // the number of SCIDs we requested. + if len(rows) != len(scids) { + return fmt.Errorf("expected to fetch %d zombie "+ + "SCIDs, but got %d", len(scids), len(rows)) + } + + // Validate each row is in the batch. + for _, row := range rows { + scid := byteOrder.Uint64(row.Scid) + + kvdbZombie, ok := batch[scid] + if !ok { + return fmt.Errorf("zombie SCID %x not found "+ + "in batch", scid) + } + + err = sqldb.CompareRecords( + kvdbZombie.pub1[:], row.NodeKey1, + fmt.Sprintf("zombie pub key 1 (%s) for "+ + "channel %d", kvdbZombie.pub1, scid), + ) + if err != nil { + return err + } + + err = sqldb.CompareRecords( + kvdbZombie.pub2[:], row.NodeKey2, + fmt.Sprintf("zombie pub key 2 (%s) for "+ + "channel %d", kvdbZombie.pub2, scid), + ) + if err != nil { + return err + } + } + + // Reset the batch for the next iteration. + batch = make(map[uint64]*zombieEntry, cfg.MaxBatchSize) + + return nil + } + err := forEachZombieEntry(kvBackend, func(chanID uint64, pubKey1, pubKey2 [33]byte) error { @@ -1158,22 +1232,19 @@ func migrateZombieIndex(ctx context.Context, kvBackend kvdb.Backend, "channel %d: %w", chanID, err) } - // Finally, verify that the channel is indeed marked as a - // zombie channel. - isZombie, err := sqlDB.IsZombieChannel( - ctx, sqlc.IsZombieChannelParams{ - Version: int16(ProtocolV1), - Scid: chanIDB, - }, - ) - if err != nil { - return fmt.Errorf("could not check if "+ - "channel %d is zombie: %w", chanID, err) + // Add to validation batch only after successful insertion. + batch[chanID] = &zombieEntry{ + pub1: pubKey1, + pub2: pubKey2, } - if !isZombie { - return fmt.Errorf("channel %d should be "+ - "a zombie, but is not", chanID) + // Validate batch when full. + if len(batch) >= cfg.MaxBatchSize { + err := validateBatch() + if err != nil { + return fmt.Errorf("batch validation failed: %w", + err) + } } s.Do(func() { @@ -1192,6 +1263,15 @@ func migrateZombieIndex(ctx context.Context, kvBackend kvdb.Backend, return fmt.Errorf("could not migrate zombie index: %w", err) } + // Validate any remaining zombie SCIDs in the batch. + if len(batch) > 0 { + err := validateBatch() + if err != nil { + return fmt.Errorf("final batch validation failed: %w", + err) + } + } + log.Infof("Migrated %d zombie channels from KV to SQL", count) return nil diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index fbfee848f..956b8406d 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -133,6 +133,7 @@ type SQLQueries interface { */ UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error) + GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error) CountZombieChannels(ctx context.Context, version int16) (int64, error) DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error) IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error) diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 7b9e472a1..d27143ca8 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -2261,6 +2261,57 @@ func (q *Queries) GetZombieChannel(ctx context.Context, arg GetZombieChannelPara return i, err } +const getZombieChannelsSCIDs = `-- name: GetZombieChannelsSCIDs :many +SELECT scid, version, node_key_1, node_key_2 +FROM graph_zombie_channels +WHERE version = $1 + AND scid IN (/*SLICE:scids*/?) +` + +type GetZombieChannelsSCIDsParams struct { + Version int16 + Scids [][]byte +} + +func (q *Queries) GetZombieChannelsSCIDs(ctx context.Context, arg GetZombieChannelsSCIDsParams) ([]GraphZombieChannel, error) { + query := getZombieChannelsSCIDs + var queryParams []interface{} + queryParams = append(queryParams, arg.Version) + if len(arg.Scids) > 0 { + for _, v := range arg.Scids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:scids*/?", makeQueryParams(len(queryParams), len(arg.Scids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:scids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphZombieChannel + for rows.Next() { + var i GraphZombieChannel + if err := rows.Scan( + &i.Scid, + &i.Version, + &i.NodeKey1, + &i.NodeKey2, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const highestSCID = `-- name: HighestSCID :one SELECT scid FROM graph_channels diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index b3c0857eb..eaef3c777 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -84,6 +84,7 @@ type Querier interface { // and so the query for V2 may differ. GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error) GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (GraphZombieChannel, error) + GetZombieChannelsSCIDs(ctx context.Context, arg GetZombieChannelsSCIDsParams) ([]GraphZombieChannel, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index f77e5044c..04f814083 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -908,6 +908,12 @@ DO UPDATE SET node_key_1 = COALESCE(EXCLUDED.node_key_1, graph_zombie_channels.node_key_1), node_key_2 = COALESCE(EXCLUDED.node_key_2, graph_zombie_channels.node_key_2); +-- name: GetZombieChannelsSCIDs :many +SELECT * +FROM graph_zombie_channels +WHERE version = @version + AND scid IN (sqlc.slice('scids')/*SLICE:scids*/); + -- name: DeleteZombieChannel :execresult DELETE FROM graph_zombie_channels WHERE scid = $1