sqldb: add support for custom in-code migrations

This commit introduces support for custom, in-code migrations, allowing
a specific Go function to be executed at a designated database version
during sqlc migrations. If the current database version surpasses the
specified version, the migration will be skipped.
This commit is contained in:
Andras Banki-Horvath
2024-11-22 18:29:54 +01:00
parent 9acd06d296
commit b789fb2db3
6 changed files with 565 additions and 20 deletions

View File

@@ -2,22 +2,104 @@ package sqldb
import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"strings"
"time"
"github.com/btcsuite/btclog/v2"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/source/httpfs"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
)
var (
// migrationConfig defines a list of migrations to be applied to the
// database. Each migration is assigned a version number, determining
// its execution order.
// The schema version, tracked by golang-migrate, ensures migrations are
// applied to the correct schema. For migrations involving only schema
// changes, the migration function can be left nil. For custom
// migrations an implemented migration function is required.
//
// NOTE: The migration function may have runtime dependencies, which
// must be injected during runtime.
migrationConfig = []MigrationConfig{
{
Name: "000001_invoices",
Version: 1,
SchemaVersion: 1,
},
{
Name: "000002_amp_invoices",
Version: 2,
SchemaVersion: 2,
},
{
Name: "000003_invoice_events",
Version: 3,
SchemaVersion: 3,
},
{
Name: "000004_invoice_expiry_fix",
Version: 4,
SchemaVersion: 4,
},
{
Name: "000005_migration_tracker",
Version: 5,
SchemaVersion: 5,
},
}
)
// MigrationConfig is a configuration struct that describes SQL migrations. Each
// migration is associated with a specific schema version and a global database
// version. Migrations are applied in the order of their global database
// version. If a migration includes a non-nil MigrationFn, it is executed after
// the SQL schema has been migrated to the corresponding schema version.
type MigrationConfig struct {
// Name is the name of the migration.
Name string
// Version represents the "global" database version for this migration.
// Unlike the schema version tracked by golang-migrate, it encompasses
// all migrations, including those managed by golang-migrate as well
// as custom in-code migrations.
Version int
// SchemaVersion represents the schema version tracked by golang-migrate
// at which the migration is applied.
SchemaVersion int
// MigrationFn is the function executed for custom migrations at the
// specified version. It is used to handle migrations that cannot be
// performed through SQL alone. If set to nil, no custom migration is
// applied.
MigrationFn func(tx *sqlc.Queries) error
}
// MigrationTarget is a functional option that can be passed to applyMigrations
// to specify a target version to migrate to.
type MigrationTarget func(mig *migrate.Migrate) error
// MigrationExecutor is an interface that abstracts the migration functionality.
type MigrationExecutor interface {
// ExecuteMigrations runs database migrations up to the specified target
// version or all migrations if no target is specified. A migration may
// include a schema change, a custom migration function, or both.
// Developers must ensure that migrations are defined in the correct
// order. Migration details are stored in the global variable
// migrationConfig.
ExecuteMigrations(target MigrationTarget) error
}
var (
// TargetLatest is a MigrationTarget that migrates to the latest
// version available.
@@ -34,6 +116,14 @@ var (
}
)
// GetMigrations returns a copy of the migration configuration.
func GetMigrations() []MigrationConfig {
migrations := make([]MigrationConfig, len(migrationConfig))
copy(migrations, migrationConfig)
return migrations
}
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
// used to log migrations.
type migrationLogger struct {
@@ -216,3 +306,117 @@ func (t *replacerFile) Close() error {
// instance, so there's nothing to do for us here.
return nil
}
// MigrationTxOptions is the implementation of the TxOptions interface for
// migration transactions.
type MigrationTxOptions struct {
}
// ReadOnly returns false to indicate that migration transactions are not read
// only.
func (m *MigrationTxOptions) ReadOnly() bool {
return false
}
// ApplyMigrations applies the provided migrations to the database in sequence.
// It ensures migrations are executed in the correct order, applying both custom
// migration functions and SQL migrations as needed.
func ApplyMigrations(ctx context.Context, db *BaseDB,
migrator MigrationExecutor, migrations []MigrationConfig) error {
// Ensure that the migrations are sorted by version.
for i := 0; i < len(migrations); i++ {
if migrations[i].Version != i+1 {
return fmt.Errorf("migration version %d is out of "+
"order. Expected %d", migrations[i].Version,
i+1)
}
}
// Construct a transaction executor to apply custom migrations.
executor := NewTransactionExecutor(db, func(tx *sql.Tx) *sqlc.Queries {
return db.WithTx(tx)
})
currentVersion := 0
version, err := db.GetDatabaseVersion(ctx)
if !errors.Is(err, sql.ErrNoRows) {
if err != nil {
return fmt.Errorf("error getting current database "+
"version: %w", err)
}
currentVersion = int(version)
}
for _, migration := range migrations {
if migration.Version <= currentVersion {
log.Infof("Skipping migration '%s' (version %d) as it "+
"has already been applied", migration.Name,
migration.Version)
continue
}
log.Infof("Migrating SQL schema to version %d",
migration.SchemaVersion)
// Execute SQL schema migrations up to the target version.
err = migrator.ExecuteMigrations(
TargetVersion(uint(migration.SchemaVersion)),
)
if err != nil {
return fmt.Errorf("error executing schema migrations "+
"to target version %d: %w",
migration.SchemaVersion, err)
}
var opts MigrationTxOptions
// Run the custom migration as a transaction to ensure
// atomicity. If successful, mark the migration as complete in
// the migration tracker table.
err = executor.ExecTx(ctx, &opts, func(tx *sqlc.Queries) error {
// Apply the migration function if one is provided.
if migration.MigrationFn != nil {
log.Infof("Applying custom migration '%v' "+
"(version %d) to schema version %d",
migration.Name, migration.Version,
migration.SchemaVersion)
err = migration.MigrationFn(tx)
if err != nil {
return fmt.Errorf("error applying "+
"migration '%v' (version %d) "+
"to schema version %d: %w",
migration.Name,
migration.Version,
migration.SchemaVersion, err)
}
log.Infof("Migration '%v' (version %d) "+
"applied ", migration.Name,
migration.Version)
}
// Mark the migration as complete by adding the version
// to the migration tracker table along with the current
// timestamp.
err = tx.SetMigration(ctx, sqlc.SetMigrationParams{
Version: int32(migration.Version),
MigrationTime: time.Now(),
})
if err != nil {
return fmt.Errorf("error setting migration "+
"version %d: %w", migration.Version,
err)
}
return nil
}, func() {})
if err != nil {
return err
}
}
return nil
}

View File

@@ -2,8 +2,15 @@ package sqldb
import (
"context"
"database/sql"
"fmt"
"path/filepath"
"testing"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
"github.com/stretchr/testify/require"
)
@@ -152,3 +159,289 @@ func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) {
require.NoError(t, err)
require.Equal(t, expected, invoices)
}
// TestCustomMigration tests that a custom in-code migrations are correctly
// executed during the migration process.
func TestCustomMigration(t *testing.T) {
var customMigrationLog []string
logMigration := func(name string) {
customMigrationLog = append(customMigrationLog, name)
}
// Some migrations to use for both the failure and success tests. Note
// that the migrations are not in order to test that they are executed
// in the correct order.
migrations := []MigrationConfig{
{
Name: "1",
Version: 1,
SchemaVersion: 1,
MigrationFn: func(*sqlc.Queries) error {
logMigration("1")
return nil
},
},
{
Name: "2",
Version: 2,
SchemaVersion: 1,
MigrationFn: func(*sqlc.Queries) error {
logMigration("2")
return nil
},
},
{
Name: "3",
Version: 3,
SchemaVersion: 2,
MigrationFn: func(*sqlc.Queries) error {
logMigration("3")
return nil
},
},
}
tests := []struct {
name string
migrations []MigrationConfig
expectedSuccess bool
expectedMigrationLog []string
expectedSchemaVersion int
expectedVersion int
}{
{
name: "success",
migrations: migrations,
expectedSuccess: true,
expectedMigrationLog: []string{"1", "2", "3"},
expectedSchemaVersion: 2,
expectedVersion: 3,
},
{
name: "unordered migrations",
migrations: append([]MigrationConfig{
{
Name: "4",
Version: 4,
SchemaVersion: 3,
MigrationFn: func(*sqlc.Queries) error {
logMigration("4")
return nil
},
},
}, migrations...),
expectedSuccess: false,
expectedMigrationLog: nil,
expectedSchemaVersion: 0,
},
{
name: "failure of migration 4",
migrations: append(migrations, MigrationConfig{
Name: "4",
Version: 4,
SchemaVersion: 3,
MigrationFn: func(*sqlc.Queries) error {
return fmt.Errorf("migration 4 failed")
},
}),
expectedSuccess: false,
expectedMigrationLog: []string{"1", "2", "3"},
// Since schema migration is a separate step we expect
// that migrating up to 3 succeeded.
expectedSchemaVersion: 3,
// We still remain on version 3 though.
expectedVersion: 3,
},
{
name: "success of migration 4",
migrations: append(migrations, MigrationConfig{
Name: "4",
Version: 4,
SchemaVersion: 3,
MigrationFn: func(*sqlc.Queries) error {
logMigration("4")
return nil
},
}),
expectedSuccess: true,
expectedMigrationLog: []string{"1", "2", "3", "4"},
expectedSchemaVersion: 3,
expectedVersion: 4,
},
}
ctxb := context.Background()
for _, test := range tests {
// checkSchemaVersion checks the database schema version against
// the expected version.
getSchemaVersion := func(t *testing.T,
driver database.Driver, dbName string) int {
sqlMigrate, err := migrate.NewWithInstance(
"migrations", nil, dbName, driver,
)
require.NoError(t, err)
version, _, err := sqlMigrate.Version()
if err != migrate.ErrNilVersion {
require.NoError(t, err)
}
return int(version)
}
t.Run("SQLite "+test.name, func(t *testing.T) {
customMigrationLog = nil
// First instantiate the database and run the migrations
// including the custom migrations.
t.Logf("Creating new SQLite DB for testing migrations")
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
var (
db *SqliteStore
err error
)
// Run the migration 3 times to test that the migrations
// are idempotent.
for i := 0; i < 3; i++ {
db, err = NewSqliteStore(&SqliteConfig{
SkipMigrations: false,
}, dbFileName, test.migrations)
if db != nil {
dbToCleanup := db.DB
t.Cleanup(func() {
require.NoError(
t, dbToCleanup.Close(),
)
})
}
if test.expectedSuccess {
require.NoError(t, err)
} else {
require.Error(t, err)
// Also repoen the DB without migrations
// so we can read versions.
db, err = NewSqliteStore(&SqliteConfig{
SkipMigrations: true,
}, dbFileName, nil)
require.NoError(t, err)
}
require.Equal(t,
test.expectedMigrationLog,
customMigrationLog,
)
// Create the migration executor to be able to
// query the current schema version.
driver, err := sqlite_migrate.WithInstance(
db.DB, &sqlite_migrate.Config{},
)
require.NoError(t, err)
require.Equal(
t, test.expectedSchemaVersion,
getSchemaVersion(t, driver, ""),
)
// Check the migraton version in the database.
version, err := db.GetDatabaseVersion(ctxb)
if test.expectedSchemaVersion != 0 {
require.NoError(t, err)
} else {
require.Equal(t, sql.ErrNoRows, err)
}
require.Equal(
t, test.expectedVersion, int(version),
)
}
})
t.Run("Postgres "+test.name, func(t *testing.T) {
customMigrationLog = nil
// First create a temporary Postgres database to run
// the migrations on.
fixture := NewTestPgFixture(
t, DefaultPostgresFixtureLifetime,
)
t.Cleanup(func() {
fixture.TearDown(t)
})
dbName := randomDBName(t)
// Next instantiate the database and run the migrations
// including the custom migrations.
t.Logf("Creating new Postgres DB '%s' for testing "+
"migrations", dbName)
_, err := fixture.db.ExecContext(
context.Background(), "CREATE DATABASE "+dbName,
)
require.NoError(t, err)
cfg := fixture.GetConfig(dbName)
var db *PostgresStore
// Run the migration 3 times to test that the migrations
// are idempotent.
for i := 0; i < 3; i++ {
cfg.SkipMigrations = false
db, err = NewPostgresStore(cfg, test.migrations)
if test.expectedSuccess {
require.NoError(t, err)
} else {
require.Error(t, err)
// Also repoen the DB without migrations
// so we can read versions.
cfg.SkipMigrations = true
db, err = NewPostgresStore(cfg, nil)
require.NoError(t, err)
}
require.Equal(t,
test.expectedMigrationLog,
customMigrationLog,
)
// Create the migration executor to be able to
// query the current version.
driver, err := pgx_migrate.WithInstance(
db.DB, &pgx_migrate.Config{},
)
require.NoError(t, err)
require.Equal(
t, test.expectedSchemaVersion,
getSchemaVersion(t, driver, ""),
)
// Check the migraton version in the database.
version, err := db.GetDatabaseVersion(ctxb)
if test.expectedSchemaVersion != 0 {
require.NoError(t, err)
} else {
require.Equal(t, sql.ErrNoRows, err)
}
require.Equal(
t, test.expectedVersion, int(version),
)
}
})
}
}

View File

@@ -1,6 +1,7 @@
package sqldb
import (
"context"
"database/sql"
"fmt"
"net/url"
@@ -32,6 +33,9 @@ var (
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
}
// Make sure PostgresStore implements the MigrationExecutor interface.
_ MigrationExecutor = (*PostgresStore)(nil)
)
// replacePasswordInDSN takes a DSN string and returns it with the password
@@ -85,43 +89,62 @@ type PostgresStore struct {
// NewPostgresStore creates a new store that is backed by a Postgres database
// backend.
func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
func NewPostgresStore(cfg *PostgresConfig, migrations []MigrationConfig) (
*PostgresStore, error) {
sanitizedDSN, err := replacePasswordInDSN(cfg.Dsn)
if err != nil {
return nil, err
}
log.Infof("Using SQL database '%s'", sanitizedDSN)
rawDB, err := sql.Open("pgx", cfg.Dsn)
db, err := sql.Open("pgx", cfg.Dsn)
if err != nil {
return nil, err
}
// Ensure the migration tracker table exists before running migrations.
// This table tracks migration progress and ensures compatibility with
// SQLC query generation. If the table is already created by an SQLC
// migration, this operation becomes a no-op.
migrationTrackerSQL := `
CREATE TABLE IF NOT EXISTS migration_tracker (
version INTEGER UNIQUE NOT NULL,
migration_time TIMESTAMP NOT NULL
);`
_, err = db.Exec(migrationTrackerSQL)
if err != nil {
return nil, fmt.Errorf("error creating migration tracker: %w",
err)
}
maxConns := defaultMaxConns
if cfg.MaxConnections > 0 {
maxConns = cfg.MaxConnections
}
rawDB.SetMaxOpenConns(maxConns)
rawDB.SetMaxIdleConns(maxConns)
rawDB.SetConnMaxLifetime(connIdleLifetime)
db.SetMaxOpenConns(maxConns)
db.SetMaxIdleConns(maxConns)
db.SetConnMaxLifetime(connIdleLifetime)
queries := sqlc.New(rawDB)
queries := sqlc.New(db)
s := &PostgresStore{
cfg: cfg,
BaseDB: &BaseDB{
DB: rawDB,
DB: db,
Queries: queries,
},
}
// Execute migrations unless configured to skip them.
if !cfg.SkipMigrations {
err := s.ExecuteMigrations(TargetLatest)
err := ApplyMigrations(
context.Background(), s.BaseDB, s, migrations,
)
if err != nil {
return nil, fmt.Errorf("error executing migrations: %w",
err)
return nil, err
}
}

View File

@@ -148,7 +148,7 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore {
require.NoError(t, err)
cfg := fixture.GetConfig(dbName)
store, err := NewPostgresStore(cfg)
store, err := NewPostgresStore(cfg, GetMigrations())
require.NoError(t, err)
return store
@@ -172,7 +172,7 @@ func NewTestPostgresDBWithVersion(t *testing.T, fixture *TestPgFixture,
storeCfg := fixture.GetConfig(dbName)
storeCfg.SkipMigrations = true
store, err := NewPostgresStore(storeCfg)
store, err := NewPostgresStore(storeCfg, GetMigrations())
require.NoError(t, err)
err = store.ExecuteMigrations(TargetVersion(version))

View File

@@ -3,6 +3,7 @@
package sqldb
import (
"context"
"database/sql"
"fmt"
"net/url"
@@ -34,6 +35,9 @@ var (
sqliteSchemaReplacements = map[string]string{
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
}
// Make sure SqliteStore implements the MigrationExecutor interface.
_ MigrationExecutor = (*SqliteStore)(nil)
)
// SqliteStore is a database store implementation that uses a sqlite backend.
@@ -45,7 +49,9 @@ type SqliteStore struct {
// NewSqliteStore attempts to open a new sqlite database based on the passed
// config.
func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
func NewSqliteStore(cfg *SqliteConfig, dbPath string,
migrations []MigrationConfig) (*SqliteStore, error) {
// The set of pragma options are accepted using query options. For now
// we only want to ensure that foreign key constraints are properly
// enforced.
@@ -102,6 +108,23 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
return nil, err
}
// Create the migration tracker table before starting migrations to
// ensure it can be used to track migration progress. Note that a
// corresponding SQLC migration also creates this table, making this
// operation a no-op in that context. Its purpose is to ensure
// compatibility with SQLC query generation.
migrationTrackerSQL := `
CREATE TABLE IF NOT EXISTS migration_tracker (
version INTEGER UNIQUE NOT NULL,
migration_time TIMESTAMP NOT NULL
);`
_, err = db.Exec(migrationTrackerSQL)
if err != nil {
return nil, fmt.Errorf("error creating migration tracker: %w",
err)
}
db.SetMaxOpenConns(defaultMaxConns)
db.SetMaxIdleConns(defaultMaxConns)
db.SetConnMaxLifetime(connIdleLifetime)
@@ -117,10 +140,11 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
// Execute migrations unless configured to skip them.
if !cfg.SkipMigrations {
if err := s.ExecuteMigrations(TargetLatest); err != nil {
return nil, fmt.Errorf("error executing migrations: "+
"%w", err)
err := ApplyMigrations(
context.Background(), s.BaseDB, s, migrations,
)
if err != nil {
return nil, err
}
}
@@ -157,7 +181,7 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore {
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
sqlDB, err := NewSqliteStore(&SqliteConfig{
SkipMigrations: false,
}, dbFileName)
}, dbFileName, GetMigrations())
require.NoError(t, err)
t.Cleanup(func() {
@@ -180,7 +204,7 @@ func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore {
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
sqlDB, err := NewSqliteStore(&SqliteConfig{
SkipMigrations: true,
}, dbFileName)
}, dbFileName, nil)
require.NoError(t, err)
err = sqlDB.ExecuteMigrations(TargetVersion(version))