diff --git a/graph/db/sql_migration.go b/graph/db/sql_migration.go index 467f5436e..17a751f87 100644 --- a/graph/db/sql_migration.go +++ b/graph/db/sql_migration.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/sqldb" "github.com/lightningnetwork/lnd/sqldb/sqlc" "golang.org/x/time/rate" @@ -74,7 +75,8 @@ func MigrateGraphToSQL(ctx context.Context, cfg *SQLStoreConfig, } // 6) Migrate the zombie index. - if err := migrateZombieIndex(ctx, kvBackend, sqlDB); err != nil { + err = migrateZombieIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB) + if err != nil { return fmt.Errorf("could not migrate zombie index: %w", err) } @@ -1100,18 +1102,17 @@ func migrateClosedSCIDIndex(ctx context.Context, cfg *sqldb.QueryConfig, return nil } -// migrateZombieIndex migrates the zombie index from the KV backend to -// the SQL database. It iterates over each zombie channel in the KV store, -// inserts it into the SQL database, and then verifies that the channel is -// indeed marked as a zombie channel in the SQL database. +// migrateZombieIndex migrates the zombie index from the KV backend to the SQL +// database. It collects zombie channels in batches, inserts them individually, +// and validates them in batches. // // NOTE: before inserting an entry into the zombie index, the function checks // if the channel is already marked as closed in the SQL store. If it is, // the entry is skipped. This means that the resulting zombie index count in // the SQL store may well be less than the count of zombie channels in the KV // store. -func migrateZombieIndex(ctx context.Context, kvBackend kvdb.Backend, - sqlDB SQLQueries) error { +func migrateZombieIndex(ctx context.Context, cfg *sqldb.QueryConfig, + kvBackend kvdb.Backend, sqlDB SQLQueries) error { var ( count uint64 @@ -1122,6 +1123,79 @@ func migrateZombieIndex(ctx context.Context, kvBackend kvdb.Backend, Interval: 10 * time.Second, } ) + + type zombieEntry struct { + pub1 route.Vertex + pub2 route.Vertex + } + + batch := make(map[uint64]*zombieEntry, cfg.MaxBatchSize) + + // validateBatch validates a batch of zombie SCIDs using batch query. + validateBatch := func() error { + if len(batch) == 0 { + return nil + } + + scids := make([][]byte, 0, len(batch)) + for scid := range batch { + scids = append(scids, channelIDToBytes(scid)) + } + + // Batch fetch all zombie channels from the database. + rows, err := sqlDB.GetZombieChannelsSCIDs( + ctx, sqlc.GetZombieChannelsSCIDsParams{ + Version: int16(ProtocolV1), + Scids: scids, + }, + ) + if err != nil { + return fmt.Errorf("could not batch get zombie "+ + "SCIDs: %w", err) + } + + // Make sure that the number of rows returned matches + // the number of SCIDs we requested. + if len(rows) != len(scids) { + return fmt.Errorf("expected to fetch %d zombie "+ + "SCIDs, but got %d", len(scids), len(rows)) + } + + // Validate each row is in the batch. + for _, row := range rows { + scid := byteOrder.Uint64(row.Scid) + + kvdbZombie, ok := batch[scid] + if !ok { + return fmt.Errorf("zombie SCID %x not found "+ + "in batch", scid) + } + + err = sqldb.CompareRecords( + kvdbZombie.pub1[:], row.NodeKey1, + fmt.Sprintf("zombie pub key 1 (%s) for "+ + "channel %d", kvdbZombie.pub1, scid), + ) + if err != nil { + return err + } + + err = sqldb.CompareRecords( + kvdbZombie.pub2[:], row.NodeKey2, + fmt.Sprintf("zombie pub key 2 (%s) for "+ + "channel %d", kvdbZombie.pub2, scid), + ) + if err != nil { + return err + } + } + + // Reset the batch for the next iteration. + batch = make(map[uint64]*zombieEntry, cfg.MaxBatchSize) + + return nil + } + err := forEachZombieEntry(kvBackend, func(chanID uint64, pubKey1, pubKey2 [33]byte) error { @@ -1158,22 +1232,19 @@ func migrateZombieIndex(ctx context.Context, kvBackend kvdb.Backend, "channel %d: %w", chanID, err) } - // Finally, verify that the channel is indeed marked as a - // zombie channel. - isZombie, err := sqlDB.IsZombieChannel( - ctx, sqlc.IsZombieChannelParams{ - Version: int16(ProtocolV1), - Scid: chanIDB, - }, - ) - if err != nil { - return fmt.Errorf("could not check if "+ - "channel %d is zombie: %w", chanID, err) + // Add to validation batch only after successful insertion. + batch[chanID] = &zombieEntry{ + pub1: pubKey1, + pub2: pubKey2, } - if !isZombie { - return fmt.Errorf("channel %d should be "+ - "a zombie, but is not", chanID) + // Validate batch when full. + if len(batch) >= cfg.MaxBatchSize { + err := validateBatch() + if err != nil { + return fmt.Errorf("batch validation failed: %w", + err) + } } s.Do(func() { @@ -1192,6 +1263,15 @@ func migrateZombieIndex(ctx context.Context, kvBackend kvdb.Backend, return fmt.Errorf("could not migrate zombie index: %w", err) } + // Validate any remaining zombie SCIDs in the batch. + if len(batch) > 0 { + err := validateBatch() + if err != nil { + return fmt.Errorf("final batch validation failed: %w", + err) + } + } + log.Infof("Migrated %d zombie channels from KV to SQL", count) return nil diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index fbfee848f..956b8406d 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -133,6 +133,7 @@ type SQLQueries interface { */ UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error) + GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error) CountZombieChannels(ctx context.Context, version int16) (int64, error) DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error) IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error) diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 7b9e472a1..d27143ca8 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -2261,6 +2261,57 @@ func (q *Queries) GetZombieChannel(ctx context.Context, arg GetZombieChannelPara return i, err } +const getZombieChannelsSCIDs = `-- name: GetZombieChannelsSCIDs :many +SELECT scid, version, node_key_1, node_key_2 +FROM graph_zombie_channels +WHERE version = $1 + AND scid IN (/*SLICE:scids*/?) +` + +type GetZombieChannelsSCIDsParams struct { + Version int16 + Scids [][]byte +} + +func (q *Queries) GetZombieChannelsSCIDs(ctx context.Context, arg GetZombieChannelsSCIDsParams) ([]GraphZombieChannel, error) { + query := getZombieChannelsSCIDs + var queryParams []interface{} + queryParams = append(queryParams, arg.Version) + if len(arg.Scids) > 0 { + for _, v := range arg.Scids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:scids*/?", makeQueryParams(len(queryParams), len(arg.Scids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:scids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphZombieChannel + for rows.Next() { + var i GraphZombieChannel + if err := rows.Scan( + &i.Scid, + &i.Version, + &i.NodeKey1, + &i.NodeKey2, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const highestSCID = `-- name: HighestSCID :one SELECT scid FROM graph_channels diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index b3c0857eb..eaef3c777 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -84,6 +84,7 @@ type Querier interface { // and so the query for V2 may differ. GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error) GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (GraphZombieChannel, error) + GetZombieChannelsSCIDs(ctx context.Context, arg GetZombieChannelsSCIDsParams) ([]GraphZombieChannel, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index f77e5044c..04f814083 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -908,6 +908,12 @@ DO UPDATE SET node_key_1 = COALESCE(EXCLUDED.node_key_1, graph_zombie_channels.node_key_1), node_key_2 = COALESCE(EXCLUDED.node_key_2, graph_zombie_channels.node_key_2); +-- name: GetZombieChannelsSCIDs :many +SELECT * +FROM graph_zombie_channels +WHERE version = @version + AND scid IN (sqlc.slice('scids')/*SLICE:scids*/); + -- name: DeleteZombieChannel :execresult DELETE FROM graph_zombie_channels WHERE scid = $1