mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-30 23:53:41 +02:00
sqldb: add support for custom in-code migrations
This commit introduces support for custom, in-code migrations, allowing a specific Go function to be executed at a designated database version during sqlc migrations. If the current database version surpasses the specified version, the migration will be skipped.
This commit is contained in:
@@ -2,22 +2,104 @@ package sqldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btclog/v2"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
)
|
||||
|
||||
var (
|
||||
// migrationConfig defines a list of migrations to be applied to the
|
||||
// database. Each migration is assigned a version number, determining
|
||||
// its execution order.
|
||||
// The schema version, tracked by golang-migrate, ensures migrations are
|
||||
// applied to the correct schema. For migrations involving only schema
|
||||
// changes, the migration function can be left nil. For custom
|
||||
// migrations an implemented migration function is required.
|
||||
//
|
||||
// NOTE: The migration function may have runtime dependencies, which
|
||||
// must be injected during runtime.
|
||||
migrationConfig = []MigrationConfig{
|
||||
{
|
||||
Name: "000001_invoices",
|
||||
Version: 1,
|
||||
SchemaVersion: 1,
|
||||
},
|
||||
{
|
||||
Name: "000002_amp_invoices",
|
||||
Version: 2,
|
||||
SchemaVersion: 2,
|
||||
},
|
||||
{
|
||||
Name: "000003_invoice_events",
|
||||
Version: 3,
|
||||
SchemaVersion: 3,
|
||||
},
|
||||
{
|
||||
Name: "000004_invoice_expiry_fix",
|
||||
Version: 4,
|
||||
SchemaVersion: 4,
|
||||
},
|
||||
{
|
||||
Name: "000005_migration_tracker",
|
||||
Version: 5,
|
||||
SchemaVersion: 5,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// MigrationConfig is a configuration struct that describes SQL migrations. Each
|
||||
// migration is associated with a specific schema version and a global database
|
||||
// version. Migrations are applied in the order of their global database
|
||||
// version. If a migration includes a non-nil MigrationFn, it is executed after
|
||||
// the SQL schema has been migrated to the corresponding schema version.
|
||||
type MigrationConfig struct {
|
||||
// Name is the name of the migration.
|
||||
Name string
|
||||
|
||||
// Version represents the "global" database version for this migration.
|
||||
// Unlike the schema version tracked by golang-migrate, it encompasses
|
||||
// all migrations, including those managed by golang-migrate as well
|
||||
// as custom in-code migrations.
|
||||
Version int
|
||||
|
||||
// SchemaVersion represents the schema version tracked by golang-migrate
|
||||
// at which the migration is applied.
|
||||
SchemaVersion int
|
||||
|
||||
// MigrationFn is the function executed for custom migrations at the
|
||||
// specified version. It is used to handle migrations that cannot be
|
||||
// performed through SQL alone. If set to nil, no custom migration is
|
||||
// applied.
|
||||
MigrationFn func(tx *sqlc.Queries) error
|
||||
}
|
||||
|
||||
// MigrationTarget is a functional option that can be passed to applyMigrations
|
||||
// to specify a target version to migrate to.
|
||||
type MigrationTarget func(mig *migrate.Migrate) error
|
||||
|
||||
// MigrationExecutor is an interface that abstracts the migration functionality.
|
||||
type MigrationExecutor interface {
|
||||
// ExecuteMigrations runs database migrations up to the specified target
|
||||
// version or all migrations if no target is specified. A migration may
|
||||
// include a schema change, a custom migration function, or both.
|
||||
// Developers must ensure that migrations are defined in the correct
|
||||
// order. Migration details are stored in the global variable
|
||||
// migrationConfig.
|
||||
ExecuteMigrations(target MigrationTarget) error
|
||||
}
|
||||
|
||||
var (
|
||||
// TargetLatest is a MigrationTarget that migrates to the latest
|
||||
// version available.
|
||||
@@ -34,6 +116,14 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// GetMigrations returns a copy of the migration configuration.
|
||||
func GetMigrations() []MigrationConfig {
|
||||
migrations := make([]MigrationConfig, len(migrationConfig))
|
||||
copy(migrations, migrationConfig)
|
||||
|
||||
return migrations
|
||||
}
|
||||
|
||||
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
|
||||
// used to log migrations.
|
||||
type migrationLogger struct {
|
||||
@@ -216,3 +306,117 @@ func (t *replacerFile) Close() error {
|
||||
// instance, so there's nothing to do for us here.
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrationTxOptions is the implementation of the TxOptions interface for
|
||||
// migration transactions.
|
||||
type MigrationTxOptions struct {
|
||||
}
|
||||
|
||||
// ReadOnly returns false to indicate that migration transactions are not read
|
||||
// only.
|
||||
func (m *MigrationTxOptions) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ApplyMigrations applies the provided migrations to the database in sequence.
|
||||
// It ensures migrations are executed in the correct order, applying both custom
|
||||
// migration functions and SQL migrations as needed.
|
||||
func ApplyMigrations(ctx context.Context, db *BaseDB,
|
||||
migrator MigrationExecutor, migrations []MigrationConfig) error {
|
||||
|
||||
// Ensure that the migrations are sorted by version.
|
||||
for i := 0; i < len(migrations); i++ {
|
||||
if migrations[i].Version != i+1 {
|
||||
return fmt.Errorf("migration version %d is out of "+
|
||||
"order. Expected %d", migrations[i].Version,
|
||||
i+1)
|
||||
}
|
||||
}
|
||||
// Construct a transaction executor to apply custom migrations.
|
||||
executor := NewTransactionExecutor(db, func(tx *sql.Tx) *sqlc.Queries {
|
||||
return db.WithTx(tx)
|
||||
})
|
||||
|
||||
currentVersion := 0
|
||||
version, err := db.GetDatabaseVersion(ctx)
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting current database "+
|
||||
"version: %w", err)
|
||||
}
|
||||
|
||||
currentVersion = int(version)
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
if migration.Version <= currentVersion {
|
||||
log.Infof("Skipping migration '%s' (version %d) as it "+
|
||||
"has already been applied", migration.Name,
|
||||
migration.Version)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("Migrating SQL schema to version %d",
|
||||
migration.SchemaVersion)
|
||||
|
||||
// Execute SQL schema migrations up to the target version.
|
||||
err = migrator.ExecuteMigrations(
|
||||
TargetVersion(uint(migration.SchemaVersion)),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error executing schema migrations "+
|
||||
"to target version %d: %w",
|
||||
migration.SchemaVersion, err)
|
||||
}
|
||||
|
||||
var opts MigrationTxOptions
|
||||
|
||||
// Run the custom migration as a transaction to ensure
|
||||
// atomicity. If successful, mark the migration as complete in
|
||||
// the migration tracker table.
|
||||
err = executor.ExecTx(ctx, &opts, func(tx *sqlc.Queries) error {
|
||||
// Apply the migration function if one is provided.
|
||||
if migration.MigrationFn != nil {
|
||||
log.Infof("Applying custom migration '%v' "+
|
||||
"(version %d) to schema version %d",
|
||||
migration.Name, migration.Version,
|
||||
migration.SchemaVersion)
|
||||
|
||||
err = migration.MigrationFn(tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error applying "+
|
||||
"migration '%v' (version %d) "+
|
||||
"to schema version %d: %w",
|
||||
migration.Name,
|
||||
migration.Version,
|
||||
migration.SchemaVersion, err)
|
||||
}
|
||||
|
||||
log.Infof("Migration '%v' (version %d) "+
|
||||
"applied ", migration.Name,
|
||||
migration.Version)
|
||||
}
|
||||
|
||||
// Mark the migration as complete by adding the version
|
||||
// to the migration tracker table along with the current
|
||||
// timestamp.
|
||||
err = tx.SetMigration(ctx, sqlc.SetMigrationParams{
|
||||
Version: int32(migration.Version),
|
||||
MigrationTime: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting migration "+
|
||||
"version %d: %w", migration.Version,
|
||||
err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -2,8 +2,15 @@ package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
||||
sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite"
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -152,3 +159,289 @@ func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, invoices)
|
||||
}
|
||||
|
||||
// TestCustomMigration tests that a custom in-code migrations are correctly
|
||||
// executed during the migration process.
|
||||
func TestCustomMigration(t *testing.T) {
|
||||
var customMigrationLog []string
|
||||
|
||||
logMigration := func(name string) {
|
||||
customMigrationLog = append(customMigrationLog, name)
|
||||
}
|
||||
|
||||
// Some migrations to use for both the failure and success tests. Note
|
||||
// that the migrations are not in order to test that they are executed
|
||||
// in the correct order.
|
||||
migrations := []MigrationConfig{
|
||||
{
|
||||
Name: "1",
|
||||
Version: 1,
|
||||
SchemaVersion: 1,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("1")
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "2",
|
||||
Version: 2,
|
||||
SchemaVersion: 1,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("2")
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "3",
|
||||
Version: 3,
|
||||
SchemaVersion: 2,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("3")
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
migrations []MigrationConfig
|
||||
expectedSuccess bool
|
||||
expectedMigrationLog []string
|
||||
expectedSchemaVersion int
|
||||
expectedVersion int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
migrations: migrations,
|
||||
expectedSuccess: true,
|
||||
expectedMigrationLog: []string{"1", "2", "3"},
|
||||
expectedSchemaVersion: 2,
|
||||
expectedVersion: 3,
|
||||
},
|
||||
{
|
||||
name: "unordered migrations",
|
||||
migrations: append([]MigrationConfig{
|
||||
{
|
||||
Name: "4",
|
||||
Version: 4,
|
||||
SchemaVersion: 3,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("4")
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}, migrations...),
|
||||
expectedSuccess: false,
|
||||
expectedMigrationLog: nil,
|
||||
expectedSchemaVersion: 0,
|
||||
},
|
||||
{
|
||||
name: "failure of migration 4",
|
||||
migrations: append(migrations, MigrationConfig{
|
||||
Name: "4",
|
||||
Version: 4,
|
||||
SchemaVersion: 3,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
return fmt.Errorf("migration 4 failed")
|
||||
},
|
||||
}),
|
||||
expectedSuccess: false,
|
||||
expectedMigrationLog: []string{"1", "2", "3"},
|
||||
// Since schema migration is a separate step we expect
|
||||
// that migrating up to 3 succeeded.
|
||||
expectedSchemaVersion: 3,
|
||||
// We still remain on version 3 though.
|
||||
expectedVersion: 3,
|
||||
},
|
||||
{
|
||||
name: "success of migration 4",
|
||||
migrations: append(migrations, MigrationConfig{
|
||||
Name: "4",
|
||||
Version: 4,
|
||||
SchemaVersion: 3,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("4")
|
||||
|
||||
return nil
|
||||
},
|
||||
}),
|
||||
expectedSuccess: true,
|
||||
expectedMigrationLog: []string{"1", "2", "3", "4"},
|
||||
expectedSchemaVersion: 3,
|
||||
expectedVersion: 4,
|
||||
},
|
||||
}
|
||||
|
||||
ctxb := context.Background()
|
||||
for _, test := range tests {
|
||||
// checkSchemaVersion checks the database schema version against
|
||||
// the expected version.
|
||||
getSchemaVersion := func(t *testing.T,
|
||||
driver database.Driver, dbName string) int {
|
||||
|
||||
sqlMigrate, err := migrate.NewWithInstance(
|
||||
"migrations", nil, dbName, driver,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
version, _, err := sqlMigrate.Version()
|
||||
if err != migrate.ErrNilVersion {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return int(version)
|
||||
}
|
||||
|
||||
t.Run("SQLite "+test.name, func(t *testing.T) {
|
||||
customMigrationLog = nil
|
||||
|
||||
// First instantiate the database and run the migrations
|
||||
// including the custom migrations.
|
||||
t.Logf("Creating new SQLite DB for testing migrations")
|
||||
|
||||
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
||||
var (
|
||||
db *SqliteStore
|
||||
err error
|
||||
)
|
||||
|
||||
// Run the migration 3 times to test that the migrations
|
||||
// are idempotent.
|
||||
for i := 0; i < 3; i++ {
|
||||
db, err = NewSqliteStore(&SqliteConfig{
|
||||
SkipMigrations: false,
|
||||
}, dbFileName, test.migrations)
|
||||
if db != nil {
|
||||
dbToCleanup := db.DB
|
||||
t.Cleanup(func() {
|
||||
require.NoError(
|
||||
t, dbToCleanup.Close(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
if test.expectedSuccess {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
// Also repoen the DB without migrations
|
||||
// so we can read versions.
|
||||
db, err = NewSqliteStore(&SqliteConfig{
|
||||
SkipMigrations: true,
|
||||
}, dbFileName, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t,
|
||||
test.expectedMigrationLog,
|
||||
customMigrationLog,
|
||||
)
|
||||
|
||||
// Create the migration executor to be able to
|
||||
// query the current schema version.
|
||||
driver, err := sqlite_migrate.WithInstance(
|
||||
db.DB, &sqlite_migrate.Config{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(
|
||||
t, test.expectedSchemaVersion,
|
||||
getSchemaVersion(t, driver, ""),
|
||||
)
|
||||
|
||||
// Check the migraton version in the database.
|
||||
version, err := db.GetDatabaseVersion(ctxb)
|
||||
if test.expectedSchemaVersion != 0 {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Equal(t, sql.ErrNoRows, err)
|
||||
}
|
||||
|
||||
require.Equal(
|
||||
t, test.expectedVersion, int(version),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Postgres "+test.name, func(t *testing.T) {
|
||||
customMigrationLog = nil
|
||||
|
||||
// First create a temporary Postgres database to run
|
||||
// the migrations on.
|
||||
fixture := NewTestPgFixture(
|
||||
t, DefaultPostgresFixtureLifetime,
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
fixture.TearDown(t)
|
||||
})
|
||||
|
||||
dbName := randomDBName(t)
|
||||
|
||||
// Next instantiate the database and run the migrations
|
||||
// including the custom migrations.
|
||||
t.Logf("Creating new Postgres DB '%s' for testing "+
|
||||
"migrations", dbName)
|
||||
|
||||
_, err := fixture.db.ExecContext(
|
||||
context.Background(), "CREATE DATABASE "+dbName,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := fixture.GetConfig(dbName)
|
||||
var db *PostgresStore
|
||||
|
||||
// Run the migration 3 times to test that the migrations
|
||||
// are idempotent.
|
||||
for i := 0; i < 3; i++ {
|
||||
cfg.SkipMigrations = false
|
||||
db, err = NewPostgresStore(cfg, test.migrations)
|
||||
|
||||
if test.expectedSuccess {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
// Also repoen the DB without migrations
|
||||
// so we can read versions.
|
||||
cfg.SkipMigrations = true
|
||||
db, err = NewPostgresStore(cfg, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t,
|
||||
test.expectedMigrationLog,
|
||||
customMigrationLog,
|
||||
)
|
||||
|
||||
// Create the migration executor to be able to
|
||||
// query the current version.
|
||||
driver, err := pgx_migrate.WithInstance(
|
||||
db.DB, &pgx_migrate.Config{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(
|
||||
t, test.expectedSchemaVersion,
|
||||
getSchemaVersion(t, driver, ""),
|
||||
)
|
||||
|
||||
// Check the migraton version in the database.
|
||||
version, err := db.GetDatabaseVersion(ctxb)
|
||||
if test.expectedSchemaVersion != 0 {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Equal(t, sql.ErrNoRows, err)
|
||||
}
|
||||
|
||||
require.Equal(
|
||||
t, test.expectedVersion, int(version),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -32,6 +33,9 @@ var (
|
||||
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
|
||||
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
|
||||
}
|
||||
|
||||
// Make sure PostgresStore implements the MigrationExecutor interface.
|
||||
_ MigrationExecutor = (*PostgresStore)(nil)
|
||||
)
|
||||
|
||||
// replacePasswordInDSN takes a DSN string and returns it with the password
|
||||
@@ -85,43 +89,62 @@ type PostgresStore struct {
|
||||
|
||||
// NewPostgresStore creates a new store that is backed by a Postgres database
|
||||
// backend.
|
||||
func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
|
||||
func NewPostgresStore(cfg *PostgresConfig, migrations []MigrationConfig) (
|
||||
*PostgresStore, error) {
|
||||
|
||||
sanitizedDSN, err := replacePasswordInDSN(cfg.Dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("Using SQL database '%s'", sanitizedDSN)
|
||||
|
||||
rawDB, err := sql.Open("pgx", cfg.Dsn)
|
||||
db, err := sql.Open("pgx", cfg.Dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the migration tracker table exists before running migrations.
|
||||
// This table tracks migration progress and ensures compatibility with
|
||||
// SQLC query generation. If the table is already created by an SQLC
|
||||
// migration, this operation becomes a no-op.
|
||||
migrationTrackerSQL := `
|
||||
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||
version INTEGER UNIQUE NOT NULL,
|
||||
migration_time TIMESTAMP NOT NULL
|
||||
);`
|
||||
|
||||
_, err = db.Exec(migrationTrackerSQL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating migration tracker: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
maxConns := defaultMaxConns
|
||||
if cfg.MaxConnections > 0 {
|
||||
maxConns = cfg.MaxConnections
|
||||
}
|
||||
|
||||
rawDB.SetMaxOpenConns(maxConns)
|
||||
rawDB.SetMaxIdleConns(maxConns)
|
||||
rawDB.SetConnMaxLifetime(connIdleLifetime)
|
||||
db.SetMaxOpenConns(maxConns)
|
||||
db.SetMaxIdleConns(maxConns)
|
||||
db.SetConnMaxLifetime(connIdleLifetime)
|
||||
|
||||
queries := sqlc.New(rawDB)
|
||||
queries := sqlc.New(db)
|
||||
|
||||
s := &PostgresStore{
|
||||
cfg: cfg,
|
||||
BaseDB: &BaseDB{
|
||||
DB: rawDB,
|
||||
DB: db,
|
||||
Queries: queries,
|
||||
},
|
||||
}
|
||||
|
||||
// Execute migrations unless configured to skip them.
|
||||
if !cfg.SkipMigrations {
|
||||
err := s.ExecuteMigrations(TargetLatest)
|
||||
err := ApplyMigrations(
|
||||
context.Background(), s.BaseDB, s, migrations,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error executing migrations: %w",
|
||||
err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -148,7 +148,7 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore {
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := fixture.GetConfig(dbName)
|
||||
store, err := NewPostgresStore(cfg)
|
||||
store, err := NewPostgresStore(cfg, GetMigrations())
|
||||
require.NoError(t, err)
|
||||
|
||||
return store
|
||||
@@ -172,7 +172,7 @@ func NewTestPostgresDBWithVersion(t *testing.T, fixture *TestPgFixture,
|
||||
|
||||
storeCfg := fixture.GetConfig(dbName)
|
||||
storeCfg.SkipMigrations = true
|
||||
store, err := NewPostgresStore(storeCfg)
|
||||
store, err := NewPostgresStore(storeCfg, GetMigrations())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.ExecuteMigrations(TargetVersion(version))
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -34,6 +35,9 @@ var (
|
||||
sqliteSchemaReplacements = map[string]string{
|
||||
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
|
||||
}
|
||||
|
||||
// Make sure SqliteStore implements the MigrationExecutor interface.
|
||||
_ MigrationExecutor = (*SqliteStore)(nil)
|
||||
)
|
||||
|
||||
// SqliteStore is a database store implementation that uses a sqlite backend.
|
||||
@@ -45,7 +49,9 @@ type SqliteStore struct {
|
||||
|
||||
// NewSqliteStore attempts to open a new sqlite database based on the passed
|
||||
// config.
|
||||
func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
||||
func NewSqliteStore(cfg *SqliteConfig, dbPath string,
|
||||
migrations []MigrationConfig) (*SqliteStore, error) {
|
||||
|
||||
// The set of pragma options are accepted using query options. For now
|
||||
// we only want to ensure that foreign key constraints are properly
|
||||
// enforced.
|
||||
@@ -102,6 +108,23 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the migration tracker table before starting migrations to
|
||||
// ensure it can be used to track migration progress. Note that a
|
||||
// corresponding SQLC migration also creates this table, making this
|
||||
// operation a no-op in that context. Its purpose is to ensure
|
||||
// compatibility with SQLC query generation.
|
||||
migrationTrackerSQL := `
|
||||
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||
version INTEGER UNIQUE NOT NULL,
|
||||
migration_time TIMESTAMP NOT NULL
|
||||
);`
|
||||
|
||||
_, err = db.Exec(migrationTrackerSQL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating migration tracker: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(defaultMaxConns)
|
||||
db.SetMaxIdleConns(defaultMaxConns)
|
||||
db.SetConnMaxLifetime(connIdleLifetime)
|
||||
@@ -117,10 +140,11 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
||||
|
||||
// Execute migrations unless configured to skip them.
|
||||
if !cfg.SkipMigrations {
|
||||
if err := s.ExecuteMigrations(TargetLatest); err != nil {
|
||||
return nil, fmt.Errorf("error executing migrations: "+
|
||||
"%w", err)
|
||||
|
||||
err := ApplyMigrations(
|
||||
context.Background(), s.BaseDB, s, migrations,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,7 +181,7 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore {
|
||||
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
||||
sqlDB, err := NewSqliteStore(&SqliteConfig{
|
||||
SkipMigrations: false,
|
||||
}, dbFileName)
|
||||
}, dbFileName, GetMigrations())
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
@@ -180,7 +204,7 @@ func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore {
|
||||
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
||||
sqlDB, err := NewSqliteStore(&SqliteConfig{
|
||||
SkipMigrations: true,
|
||||
}, dbFileName)
|
||||
}, dbFileName, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sqlDB.ExecuteMigrations(TargetVersion(version))
|
||||
|
Reference in New Issue
Block a user