From b789fb2db3a5df30f2e9682c4b5358564b7d7491 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Fri, 22 Nov 2024 18:29:54 +0100 Subject: [PATCH] sqldb: add support for custom in-code migrations This commit introduces support for custom, in-code migrations, allowing a specific Go function to be executed at a designated database version during sqlc migrations. If the current database version surpasses the specified version, the migration will be skipped. --- lncfg/db.go | 3 +- sqldb/migrations.go | 204 ++++++++++++++++++++++++++ sqldb/migrations_test.go | 293 ++++++++++++++++++++++++++++++++++++++ sqldb/postgres.go | 43 ++++-- sqldb/postgres_fixture.go | 4 +- sqldb/sqlite.go | 38 ++++- 6 files changed, 565 insertions(+), 20 deletions(-) diff --git a/lncfg/db.go b/lncfg/db.go index 3d45bb78b..a6598e66d 100644 --- a/lncfg/db.go +++ b/lncfg/db.go @@ -452,7 +452,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, var nativeSQLStore *sqldb.BaseDB if db.UseNativeSQL { nativePostgresStore, err := sqldb.NewPostgresStore( - db.Postgres, + db.Postgres, sqldb.GetMigrations(), ) if err != nil { return nil, fmt.Errorf("error opening "+ @@ -576,6 +576,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, nativeSQLiteStore, err := sqldb.NewSqliteStore( db.Sqlite, path.Join(chanDBPath, SqliteNativeDBName), + sqldb.GetMigrations(), ) if err != nil { return nil, fmt.Errorf("error opening "+ diff --git a/sqldb/migrations.go b/sqldb/migrations.go index 9d394ceed..83634d0e5 100644 --- a/sqldb/migrations.go +++ b/sqldb/migrations.go @@ -2,22 +2,104 @@ package sqldb import ( "bytes" + "context" + "database/sql" "errors" + "fmt" "io" "io/fs" "net/http" "strings" + "time" "github.com/btcsuite/btclog/v2" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" "github.com/golang-migrate/migrate/v4/source/httpfs" + "github.com/lightningnetwork/lnd/sqldb/sqlc" ) +var ( + // migrationConfig defines a list of migrations to be applied to the + // database. Each migration is assigned a version number, determining + // its execution order. + // The schema version, tracked by golang-migrate, ensures migrations are + // applied to the correct schema. For migrations involving only schema + // changes, the migration function can be left nil. For custom + // migrations an implemented migration function is required. + // + // NOTE: The migration function may have runtime dependencies, which + // must be injected during runtime. + migrationConfig = []MigrationConfig{ + { + Name: "000001_invoices", + Version: 1, + SchemaVersion: 1, + }, + { + Name: "000002_amp_invoices", + Version: 2, + SchemaVersion: 2, + }, + { + Name: "000003_invoice_events", + Version: 3, + SchemaVersion: 3, + }, + { + Name: "000004_invoice_expiry_fix", + Version: 4, + SchemaVersion: 4, + }, + { + Name: "000005_migration_tracker", + Version: 5, + SchemaVersion: 5, + }, + } +) + +// MigrationConfig is a configuration struct that describes SQL migrations. Each +// migration is associated with a specific schema version and a global database +// version. Migrations are applied in the order of their global database +// version. If a migration includes a non-nil MigrationFn, it is executed after +// the SQL schema has been migrated to the corresponding schema version. +type MigrationConfig struct { + // Name is the name of the migration. + Name string + + // Version represents the "global" database version for this migration. + // Unlike the schema version tracked by golang-migrate, it encompasses + // all migrations, including those managed by golang-migrate as well + // as custom in-code migrations. + Version int + + // SchemaVersion represents the schema version tracked by golang-migrate + // at which the migration is applied. + SchemaVersion int + + // MigrationFn is the function executed for custom migrations at the + // specified version. It is used to handle migrations that cannot be + // performed through SQL alone. If set to nil, no custom migration is + // applied. + MigrationFn func(tx *sqlc.Queries) error +} + // MigrationTarget is a functional option that can be passed to applyMigrations // to specify a target version to migrate to. type MigrationTarget func(mig *migrate.Migrate) error +// MigrationExecutor is an interface that abstracts the migration functionality. +type MigrationExecutor interface { + // ExecuteMigrations runs database migrations up to the specified target + // version or all migrations if no target is specified. A migration may + // include a schema change, a custom migration function, or both. + // Developers must ensure that migrations are defined in the correct + // order. Migration details are stored in the global variable + // migrationConfig. + ExecuteMigrations(target MigrationTarget) error +} + var ( // TargetLatest is a MigrationTarget that migrates to the latest // version available. @@ -34,6 +116,14 @@ var ( } ) +// GetMigrations returns a copy of the migration configuration. +func GetMigrations() []MigrationConfig { + migrations := make([]MigrationConfig, len(migrationConfig)) + copy(migrations, migrationConfig) + + return migrations +} + // migrationLogger is a logger that wraps the passed btclog.Logger so it can be // used to log migrations. type migrationLogger struct { @@ -216,3 +306,117 @@ func (t *replacerFile) Close() error { // instance, so there's nothing to do for us here. return nil } + +// MigrationTxOptions is the implementation of the TxOptions interface for +// migration transactions. +type MigrationTxOptions struct { +} + +// ReadOnly returns false to indicate that migration transactions are not read +// only. +func (m *MigrationTxOptions) ReadOnly() bool { + return false +} + +// ApplyMigrations applies the provided migrations to the database in sequence. +// It ensures migrations are executed in the correct order, applying both custom +// migration functions and SQL migrations as needed. +func ApplyMigrations(ctx context.Context, db *BaseDB, + migrator MigrationExecutor, migrations []MigrationConfig) error { + + // Ensure that the migrations are sorted by version. + for i := 0; i < len(migrations); i++ { + if migrations[i].Version != i+1 { + return fmt.Errorf("migration version %d is out of "+ + "order. Expected %d", migrations[i].Version, + i+1) + } + } + // Construct a transaction executor to apply custom migrations. + executor := NewTransactionExecutor(db, func(tx *sql.Tx) *sqlc.Queries { + return db.WithTx(tx) + }) + + currentVersion := 0 + version, err := db.GetDatabaseVersion(ctx) + if !errors.Is(err, sql.ErrNoRows) { + if err != nil { + return fmt.Errorf("error getting current database "+ + "version: %w", err) + } + + currentVersion = int(version) + } + + for _, migration := range migrations { + if migration.Version <= currentVersion { + log.Infof("Skipping migration '%s' (version %d) as it "+ + "has already been applied", migration.Name, + migration.Version) + + continue + } + + log.Infof("Migrating SQL schema to version %d", + migration.SchemaVersion) + + // Execute SQL schema migrations up to the target version. + err = migrator.ExecuteMigrations( + TargetVersion(uint(migration.SchemaVersion)), + ) + if err != nil { + return fmt.Errorf("error executing schema migrations "+ + "to target version %d: %w", + migration.SchemaVersion, err) + } + + var opts MigrationTxOptions + + // Run the custom migration as a transaction to ensure + // atomicity. If successful, mark the migration as complete in + // the migration tracker table. + err = executor.ExecTx(ctx, &opts, func(tx *sqlc.Queries) error { + // Apply the migration function if one is provided. + if migration.MigrationFn != nil { + log.Infof("Applying custom migration '%v' "+ + "(version %d) to schema version %d", + migration.Name, migration.Version, + migration.SchemaVersion) + + err = migration.MigrationFn(tx) + if err != nil { + return fmt.Errorf("error applying "+ + "migration '%v' (version %d) "+ + "to schema version %d: %w", + migration.Name, + migration.Version, + migration.SchemaVersion, err) + } + + log.Infof("Migration '%v' (version %d) "+ + "applied ", migration.Name, + migration.Version) + } + + // Mark the migration as complete by adding the version + // to the migration tracker table along with the current + // timestamp. + err = tx.SetMigration(ctx, sqlc.SetMigrationParams{ + Version: int32(migration.Version), + MigrationTime: time.Now(), + }) + if err != nil { + return fmt.Errorf("error setting migration "+ + "version %d: %w", migration.Version, + err) + } + + return nil + }, func() {}) + if err != nil { + return err + } + } + + return nil +} diff --git a/sqldb/migrations_test.go b/sqldb/migrations_test.go index cd55e92cb..284ba8e99 100644 --- a/sqldb/migrations_test.go +++ b/sqldb/migrations_test.go @@ -2,8 +2,15 @@ package sqldb import ( "context" + "database/sql" + "fmt" + "path/filepath" "testing" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5" + sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite" "github.com/lightningnetwork/lnd/sqldb/sqlc" "github.com/stretchr/testify/require" ) @@ -152,3 +159,289 @@ func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) { require.NoError(t, err) require.Equal(t, expected, invoices) } + +// TestCustomMigration tests that a custom in-code migrations are correctly +// executed during the migration process. +func TestCustomMigration(t *testing.T) { + var customMigrationLog []string + + logMigration := func(name string) { + customMigrationLog = append(customMigrationLog, name) + } + + // Some migrations to use for both the failure and success tests. Note + // that the migrations are not in order to test that they are executed + // in the correct order. + migrations := []MigrationConfig{ + { + Name: "1", + Version: 1, + SchemaVersion: 1, + MigrationFn: func(*sqlc.Queries) error { + logMigration("1") + + return nil + }, + }, + { + Name: "2", + Version: 2, + SchemaVersion: 1, + MigrationFn: func(*sqlc.Queries) error { + logMigration("2") + + return nil + }, + }, + { + Name: "3", + Version: 3, + SchemaVersion: 2, + MigrationFn: func(*sqlc.Queries) error { + logMigration("3") + + return nil + }, + }, + } + + tests := []struct { + name string + migrations []MigrationConfig + expectedSuccess bool + expectedMigrationLog []string + expectedSchemaVersion int + expectedVersion int + }{ + { + name: "success", + migrations: migrations, + expectedSuccess: true, + expectedMigrationLog: []string{"1", "2", "3"}, + expectedSchemaVersion: 2, + expectedVersion: 3, + }, + { + name: "unordered migrations", + migrations: append([]MigrationConfig{ + { + Name: "4", + Version: 4, + SchemaVersion: 3, + MigrationFn: func(*sqlc.Queries) error { + logMigration("4") + + return nil + }, + }, + }, migrations...), + expectedSuccess: false, + expectedMigrationLog: nil, + expectedSchemaVersion: 0, + }, + { + name: "failure of migration 4", + migrations: append(migrations, MigrationConfig{ + Name: "4", + Version: 4, + SchemaVersion: 3, + MigrationFn: func(*sqlc.Queries) error { + return fmt.Errorf("migration 4 failed") + }, + }), + expectedSuccess: false, + expectedMigrationLog: []string{"1", "2", "3"}, + // Since schema migration is a separate step we expect + // that migrating up to 3 succeeded. + expectedSchemaVersion: 3, + // We still remain on version 3 though. + expectedVersion: 3, + }, + { + name: "success of migration 4", + migrations: append(migrations, MigrationConfig{ + Name: "4", + Version: 4, + SchemaVersion: 3, + MigrationFn: func(*sqlc.Queries) error { + logMigration("4") + + return nil + }, + }), + expectedSuccess: true, + expectedMigrationLog: []string{"1", "2", "3", "4"}, + expectedSchemaVersion: 3, + expectedVersion: 4, + }, + } + + ctxb := context.Background() + for _, test := range tests { + // checkSchemaVersion checks the database schema version against + // the expected version. + getSchemaVersion := func(t *testing.T, + driver database.Driver, dbName string) int { + + sqlMigrate, err := migrate.NewWithInstance( + "migrations", nil, dbName, driver, + ) + require.NoError(t, err) + + version, _, err := sqlMigrate.Version() + if err != migrate.ErrNilVersion { + require.NoError(t, err) + } + + return int(version) + } + + t.Run("SQLite "+test.name, func(t *testing.T) { + customMigrationLog = nil + + // First instantiate the database and run the migrations + // including the custom migrations. + t.Logf("Creating new SQLite DB for testing migrations") + + dbFileName := filepath.Join(t.TempDir(), "tmp.db") + var ( + db *SqliteStore + err error + ) + + // Run the migration 3 times to test that the migrations + // are idempotent. + for i := 0; i < 3; i++ { + db, err = NewSqliteStore(&SqliteConfig{ + SkipMigrations: false, + }, dbFileName, test.migrations) + if db != nil { + dbToCleanup := db.DB + t.Cleanup(func() { + require.NoError( + t, dbToCleanup.Close(), + ) + }) + } + + if test.expectedSuccess { + require.NoError(t, err) + } else { + require.Error(t, err) + + // Also repoen the DB without migrations + // so we can read versions. + db, err = NewSqliteStore(&SqliteConfig{ + SkipMigrations: true, + }, dbFileName, nil) + require.NoError(t, err) + } + + require.Equal(t, + test.expectedMigrationLog, + customMigrationLog, + ) + + // Create the migration executor to be able to + // query the current schema version. + driver, err := sqlite_migrate.WithInstance( + db.DB, &sqlite_migrate.Config{}, + ) + require.NoError(t, err) + + require.Equal( + t, test.expectedSchemaVersion, + getSchemaVersion(t, driver, ""), + ) + + // Check the migraton version in the database. + version, err := db.GetDatabaseVersion(ctxb) + if test.expectedSchemaVersion != 0 { + require.NoError(t, err) + } else { + require.Equal(t, sql.ErrNoRows, err) + } + + require.Equal( + t, test.expectedVersion, int(version), + ) + } + }) + + t.Run("Postgres "+test.name, func(t *testing.T) { + customMigrationLog = nil + + // First create a temporary Postgres database to run + // the migrations on. + fixture := NewTestPgFixture( + t, DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + fixture.TearDown(t) + }) + + dbName := randomDBName(t) + + // Next instantiate the database and run the migrations + // including the custom migrations. + t.Logf("Creating new Postgres DB '%s' for testing "+ + "migrations", dbName) + + _, err := fixture.db.ExecContext( + context.Background(), "CREATE DATABASE "+dbName, + ) + require.NoError(t, err) + + cfg := fixture.GetConfig(dbName) + var db *PostgresStore + + // Run the migration 3 times to test that the migrations + // are idempotent. + for i := 0; i < 3; i++ { + cfg.SkipMigrations = false + db, err = NewPostgresStore(cfg, test.migrations) + + if test.expectedSuccess { + require.NoError(t, err) + } else { + require.Error(t, err) + + // Also repoen the DB without migrations + // so we can read versions. + cfg.SkipMigrations = true + db, err = NewPostgresStore(cfg, nil) + require.NoError(t, err) + } + + require.Equal(t, + test.expectedMigrationLog, + customMigrationLog, + ) + + // Create the migration executor to be able to + // query the current version. + driver, err := pgx_migrate.WithInstance( + db.DB, &pgx_migrate.Config{}, + ) + require.NoError(t, err) + + require.Equal( + t, test.expectedSchemaVersion, + getSchemaVersion(t, driver, ""), + ) + + // Check the migraton version in the database. + version, err := db.GetDatabaseVersion(ctxb) + if test.expectedSchemaVersion != 0 { + require.NoError(t, err) + } else { + require.Equal(t, sql.ErrNoRows, err) + } + + require.Equal( + t, test.expectedVersion, int(version), + ) + } + }) + } +} diff --git a/sqldb/postgres.go b/sqldb/postgres.go index c85539157..4884943f0 100644 --- a/sqldb/postgres.go +++ b/sqldb/postgres.go @@ -1,6 +1,7 @@ package sqldb import ( + "context" "database/sql" "fmt" "net/url" @@ -32,6 +33,9 @@ var ( "BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY", "TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE", } + + // Make sure PostgresStore implements the MigrationExecutor interface. + _ MigrationExecutor = (*PostgresStore)(nil) ) // replacePasswordInDSN takes a DSN string and returns it with the password @@ -85,43 +89,62 @@ type PostgresStore struct { // NewPostgresStore creates a new store that is backed by a Postgres database // backend. -func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { +func NewPostgresStore(cfg *PostgresConfig, migrations []MigrationConfig) ( + *PostgresStore, error) { + sanitizedDSN, err := replacePasswordInDSN(cfg.Dsn) if err != nil { return nil, err } log.Infof("Using SQL database '%s'", sanitizedDSN) - rawDB, err := sql.Open("pgx", cfg.Dsn) + db, err := sql.Open("pgx", cfg.Dsn) if err != nil { return nil, err } + // Ensure the migration tracker table exists before running migrations. + // This table tracks migration progress and ensures compatibility with + // SQLC query generation. If the table is already created by an SQLC + // migration, this operation becomes a no-op. + migrationTrackerSQL := ` + CREATE TABLE IF NOT EXISTS migration_tracker ( + version INTEGER UNIQUE NOT NULL, + migration_time TIMESTAMP NOT NULL + );` + + _, err = db.Exec(migrationTrackerSQL) + if err != nil { + return nil, fmt.Errorf("error creating migration tracker: %w", + err) + } + maxConns := defaultMaxConns if cfg.MaxConnections > 0 { maxConns = cfg.MaxConnections } - rawDB.SetMaxOpenConns(maxConns) - rawDB.SetMaxIdleConns(maxConns) - rawDB.SetConnMaxLifetime(connIdleLifetime) + db.SetMaxOpenConns(maxConns) + db.SetMaxIdleConns(maxConns) + db.SetConnMaxLifetime(connIdleLifetime) - queries := sqlc.New(rawDB) + queries := sqlc.New(db) s := &PostgresStore{ cfg: cfg, BaseDB: &BaseDB{ - DB: rawDB, + DB: db, Queries: queries, }, } // Execute migrations unless configured to skip them. if !cfg.SkipMigrations { - err := s.ExecuteMigrations(TargetLatest) + err := ApplyMigrations( + context.Background(), s.BaseDB, s, migrations, + ) if err != nil { - return nil, fmt.Errorf("error executing migrations: %w", - err) + return nil, err } } diff --git a/sqldb/postgres_fixture.go b/sqldb/postgres_fixture.go index da5769c42..284cd0c8c 100644 --- a/sqldb/postgres_fixture.go +++ b/sqldb/postgres_fixture.go @@ -148,7 +148,7 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore { require.NoError(t, err) cfg := fixture.GetConfig(dbName) - store, err := NewPostgresStore(cfg) + store, err := NewPostgresStore(cfg, GetMigrations()) require.NoError(t, err) return store @@ -172,7 +172,7 @@ func NewTestPostgresDBWithVersion(t *testing.T, fixture *TestPgFixture, storeCfg := fixture.GetConfig(dbName) storeCfg.SkipMigrations = true - store, err := NewPostgresStore(storeCfg) + store, err := NewPostgresStore(storeCfg, GetMigrations()) require.NoError(t, err) err = store.ExecuteMigrations(TargetVersion(version)) diff --git a/sqldb/sqlite.go b/sqldb/sqlite.go index 99e55d6ea..bf192eb0f 100644 --- a/sqldb/sqlite.go +++ b/sqldb/sqlite.go @@ -3,6 +3,7 @@ package sqldb import ( + "context" "database/sql" "fmt" "net/url" @@ -34,6 +35,9 @@ var ( sqliteSchemaReplacements = map[string]string{ "BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY", } + + // Make sure SqliteStore implements the MigrationExecutor interface. + _ MigrationExecutor = (*SqliteStore)(nil) ) // SqliteStore is a database store implementation that uses a sqlite backend. @@ -45,7 +49,9 @@ type SqliteStore struct { // NewSqliteStore attempts to open a new sqlite database based on the passed // config. -func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) { +func NewSqliteStore(cfg *SqliteConfig, dbPath string, + migrations []MigrationConfig) (*SqliteStore, error) { + // The set of pragma options are accepted using query options. For now // we only want to ensure that foreign key constraints are properly // enforced. @@ -102,6 +108,23 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) { return nil, err } + // Create the migration tracker table before starting migrations to + // ensure it can be used to track migration progress. Note that a + // corresponding SQLC migration also creates this table, making this + // operation a no-op in that context. Its purpose is to ensure + // compatibility with SQLC query generation. + migrationTrackerSQL := ` + CREATE TABLE IF NOT EXISTS migration_tracker ( + version INTEGER UNIQUE NOT NULL, + migration_time TIMESTAMP NOT NULL + );` + + _, err = db.Exec(migrationTrackerSQL) + if err != nil { + return nil, fmt.Errorf("error creating migration tracker: %w", + err) + } + db.SetMaxOpenConns(defaultMaxConns) db.SetMaxIdleConns(defaultMaxConns) db.SetConnMaxLifetime(connIdleLifetime) @@ -117,10 +140,11 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) { // Execute migrations unless configured to skip them. if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(TargetLatest); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - + err := ApplyMigrations( + context.Background(), s.BaseDB, s, migrations, + ) + if err != nil { + return nil, err } } @@ -157,7 +181,7 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore { dbFileName := filepath.Join(t.TempDir(), "tmp.db") sqlDB, err := NewSqliteStore(&SqliteConfig{ SkipMigrations: false, - }, dbFileName) + }, dbFileName, GetMigrations()) require.NoError(t, err) t.Cleanup(func() { @@ -180,7 +204,7 @@ func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore { dbFileName := filepath.Join(t.TempDir(), "tmp.db") sqlDB, err := NewSqliteStore(&SqliteConfig{ SkipMigrations: true, - }, dbFileName) + }, dbFileName, nil) require.NoError(t, err) err = sqlDB.ExecuteMigrations(TargetVersion(version))