mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-05 17:05:50 +02:00
sqldb: add support for custom in-code migrations
This commit introduces support for custom, in-code migrations, allowing a specific Go function to be executed at a designated database version during sqlc migrations. If the current database version surpasses the specified version, the migration will be skipped.
This commit is contained in:
@@ -452,7 +452,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
|||||||
var nativeSQLStore *sqldb.BaseDB
|
var nativeSQLStore *sqldb.BaseDB
|
||||||
if db.UseNativeSQL {
|
if db.UseNativeSQL {
|
||||||
nativePostgresStore, err := sqldb.NewPostgresStore(
|
nativePostgresStore, err := sqldb.NewPostgresStore(
|
||||||
db.Postgres,
|
db.Postgres, sqldb.GetMigrations(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error opening "+
|
return nil, fmt.Errorf("error opening "+
|
||||||
@@ -576,6 +576,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
|||||||
nativeSQLiteStore, err := sqldb.NewSqliteStore(
|
nativeSQLiteStore, err := sqldb.NewSqliteStore(
|
||||||
db.Sqlite,
|
db.Sqlite,
|
||||||
path.Join(chanDBPath, SqliteNativeDBName),
|
path.Join(chanDBPath, SqliteNativeDBName),
|
||||||
|
sqldb.GetMigrations(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error opening "+
|
return nil, fmt.Errorf("error opening "+
|
||||||
|
@@ -2,22 +2,104 @@ package sqldb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btclog/v2"
|
"github.com/btcsuite/btclog/v2"
|
||||||
"github.com/golang-migrate/migrate/v4"
|
"github.com/golang-migrate/migrate/v4"
|
||||||
"github.com/golang-migrate/migrate/v4/database"
|
"github.com/golang-migrate/migrate/v4/database"
|
||||||
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// migrationConfig defines a list of migrations to be applied to the
|
||||||
|
// database. Each migration is assigned a version number, determining
|
||||||
|
// its execution order.
|
||||||
|
// The schema version, tracked by golang-migrate, ensures migrations are
|
||||||
|
// applied to the correct schema. For migrations involving only schema
|
||||||
|
// changes, the migration function can be left nil. For custom
|
||||||
|
// migrations an implemented migration function is required.
|
||||||
|
//
|
||||||
|
// NOTE: The migration function may have runtime dependencies, which
|
||||||
|
// must be injected during runtime.
|
||||||
|
migrationConfig = []MigrationConfig{
|
||||||
|
{
|
||||||
|
Name: "000001_invoices",
|
||||||
|
Version: 1,
|
||||||
|
SchemaVersion: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "000002_amp_invoices",
|
||||||
|
Version: 2,
|
||||||
|
SchemaVersion: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "000003_invoice_events",
|
||||||
|
Version: 3,
|
||||||
|
SchemaVersion: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "000004_invoice_expiry_fix",
|
||||||
|
Version: 4,
|
||||||
|
SchemaVersion: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "000005_migration_tracker",
|
||||||
|
Version: 5,
|
||||||
|
SchemaVersion: 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// MigrationConfig is a configuration struct that describes SQL migrations. Each
|
||||||
|
// migration is associated with a specific schema version and a global database
|
||||||
|
// version. Migrations are applied in the order of their global database
|
||||||
|
// version. If a migration includes a non-nil MigrationFn, it is executed after
|
||||||
|
// the SQL schema has been migrated to the corresponding schema version.
|
||||||
|
type MigrationConfig struct {
|
||||||
|
// Name is the name of the migration.
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Version represents the "global" database version for this migration.
|
||||||
|
// Unlike the schema version tracked by golang-migrate, it encompasses
|
||||||
|
// all migrations, including those managed by golang-migrate as well
|
||||||
|
// as custom in-code migrations.
|
||||||
|
Version int
|
||||||
|
|
||||||
|
// SchemaVersion represents the schema version tracked by golang-migrate
|
||||||
|
// at which the migration is applied.
|
||||||
|
SchemaVersion int
|
||||||
|
|
||||||
|
// MigrationFn is the function executed for custom migrations at the
|
||||||
|
// specified version. It is used to handle migrations that cannot be
|
||||||
|
// performed through SQL alone. If set to nil, no custom migration is
|
||||||
|
// applied.
|
||||||
|
MigrationFn func(tx *sqlc.Queries) error
|
||||||
|
}
|
||||||
|
|
||||||
// MigrationTarget is a functional option that can be passed to applyMigrations
|
// MigrationTarget is a functional option that can be passed to applyMigrations
|
||||||
// to specify a target version to migrate to.
|
// to specify a target version to migrate to.
|
||||||
type MigrationTarget func(mig *migrate.Migrate) error
|
type MigrationTarget func(mig *migrate.Migrate) error
|
||||||
|
|
||||||
|
// MigrationExecutor is an interface that abstracts the migration functionality.
|
||||||
|
type MigrationExecutor interface {
|
||||||
|
// ExecuteMigrations runs database migrations up to the specified target
|
||||||
|
// version or all migrations if no target is specified. A migration may
|
||||||
|
// include a schema change, a custom migration function, or both.
|
||||||
|
// Developers must ensure that migrations are defined in the correct
|
||||||
|
// order. Migration details are stored in the global variable
|
||||||
|
// migrationConfig.
|
||||||
|
ExecuteMigrations(target MigrationTarget) error
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// TargetLatest is a MigrationTarget that migrates to the latest
|
// TargetLatest is a MigrationTarget that migrates to the latest
|
||||||
// version available.
|
// version available.
|
||||||
@@ -34,6 +116,14 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetMigrations returns a copy of the migration configuration.
|
||||||
|
func GetMigrations() []MigrationConfig {
|
||||||
|
migrations := make([]MigrationConfig, len(migrationConfig))
|
||||||
|
copy(migrations, migrationConfig)
|
||||||
|
|
||||||
|
return migrations
|
||||||
|
}
|
||||||
|
|
||||||
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
|
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
|
||||||
// used to log migrations.
|
// used to log migrations.
|
||||||
type migrationLogger struct {
|
type migrationLogger struct {
|
||||||
@@ -216,3 +306,117 @@ func (t *replacerFile) Close() error {
|
|||||||
// instance, so there's nothing to do for us here.
|
// instance, so there's nothing to do for us here.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MigrationTxOptions is the implementation of the TxOptions interface for
|
||||||
|
// migration transactions.
|
||||||
|
type MigrationTxOptions struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadOnly returns false to indicate that migration transactions are not read
|
||||||
|
// only.
|
||||||
|
func (m *MigrationTxOptions) ReadOnly() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyMigrations applies the provided migrations to the database in sequence.
|
||||||
|
// It ensures migrations are executed in the correct order, applying both custom
|
||||||
|
// migration functions and SQL migrations as needed.
|
||||||
|
func ApplyMigrations(ctx context.Context, db *BaseDB,
|
||||||
|
migrator MigrationExecutor, migrations []MigrationConfig) error {
|
||||||
|
|
||||||
|
// Ensure that the migrations are sorted by version.
|
||||||
|
for i := 0; i < len(migrations); i++ {
|
||||||
|
if migrations[i].Version != i+1 {
|
||||||
|
return fmt.Errorf("migration version %d is out of "+
|
||||||
|
"order. Expected %d", migrations[i].Version,
|
||||||
|
i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Construct a transaction executor to apply custom migrations.
|
||||||
|
executor := NewTransactionExecutor(db, func(tx *sql.Tx) *sqlc.Queries {
|
||||||
|
return db.WithTx(tx)
|
||||||
|
})
|
||||||
|
|
||||||
|
currentVersion := 0
|
||||||
|
version, err := db.GetDatabaseVersion(ctx)
|
||||||
|
if !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting current database "+
|
||||||
|
"version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentVersion = int(version)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, migration := range migrations {
|
||||||
|
if migration.Version <= currentVersion {
|
||||||
|
log.Infof("Skipping migration '%s' (version %d) as it "+
|
||||||
|
"has already been applied", migration.Name,
|
||||||
|
migration.Version)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Migrating SQL schema to version %d",
|
||||||
|
migration.SchemaVersion)
|
||||||
|
|
||||||
|
// Execute SQL schema migrations up to the target version.
|
||||||
|
err = migrator.ExecuteMigrations(
|
||||||
|
TargetVersion(uint(migration.SchemaVersion)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error executing schema migrations "+
|
||||||
|
"to target version %d: %w",
|
||||||
|
migration.SchemaVersion, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var opts MigrationTxOptions
|
||||||
|
|
||||||
|
// Run the custom migration as a transaction to ensure
|
||||||
|
// atomicity. If successful, mark the migration as complete in
|
||||||
|
// the migration tracker table.
|
||||||
|
err = executor.ExecTx(ctx, &opts, func(tx *sqlc.Queries) error {
|
||||||
|
// Apply the migration function if one is provided.
|
||||||
|
if migration.MigrationFn != nil {
|
||||||
|
log.Infof("Applying custom migration '%v' "+
|
||||||
|
"(version %d) to schema version %d",
|
||||||
|
migration.Name, migration.Version,
|
||||||
|
migration.SchemaVersion)
|
||||||
|
|
||||||
|
err = migration.MigrationFn(tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error applying "+
|
||||||
|
"migration '%v' (version %d) "+
|
||||||
|
"to schema version %d: %w",
|
||||||
|
migration.Name,
|
||||||
|
migration.Version,
|
||||||
|
migration.SchemaVersion, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Migration '%v' (version %d) "+
|
||||||
|
"applied ", migration.Name,
|
||||||
|
migration.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark the migration as complete by adding the version
|
||||||
|
// to the migration tracker table along with the current
|
||||||
|
// timestamp.
|
||||||
|
err = tx.SetMigration(ctx, sqlc.SetMigrationParams{
|
||||||
|
Version: int32(migration.Version),
|
||||||
|
MigrationTime: time.Now(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error setting migration "+
|
||||||
|
"version %d: %w", migration.Version,
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, func() {})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@@ -2,8 +2,15 @@ package sqldb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang-migrate/migrate/v4"
|
||||||
|
"github.com/golang-migrate/migrate/v4/database"
|
||||||
|
pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
||||||
|
sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite"
|
||||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -152,3 +159,289 @@ func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, expected, invoices)
|
require.Equal(t, expected, invoices)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestCustomMigration tests that a custom in-code migrations are correctly
|
||||||
|
// executed during the migration process.
|
||||||
|
func TestCustomMigration(t *testing.T) {
|
||||||
|
var customMigrationLog []string
|
||||||
|
|
||||||
|
logMigration := func(name string) {
|
||||||
|
customMigrationLog = append(customMigrationLog, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some migrations to use for both the failure and success tests. Note
|
||||||
|
// that the migrations are not in order to test that they are executed
|
||||||
|
// in the correct order.
|
||||||
|
migrations := []MigrationConfig{
|
||||||
|
{
|
||||||
|
Name: "1",
|
||||||
|
Version: 1,
|
||||||
|
SchemaVersion: 1,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("1")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "2",
|
||||||
|
Version: 2,
|
||||||
|
SchemaVersion: 1,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("2")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "3",
|
||||||
|
Version: 3,
|
||||||
|
SchemaVersion: 2,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("3")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
migrations []MigrationConfig
|
||||||
|
expectedSuccess bool
|
||||||
|
expectedMigrationLog []string
|
||||||
|
expectedSchemaVersion int
|
||||||
|
expectedVersion int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
migrations: migrations,
|
||||||
|
expectedSuccess: true,
|
||||||
|
expectedMigrationLog: []string{"1", "2", "3"},
|
||||||
|
expectedSchemaVersion: 2,
|
||||||
|
expectedVersion: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unordered migrations",
|
||||||
|
migrations: append([]MigrationConfig{
|
||||||
|
{
|
||||||
|
Name: "4",
|
||||||
|
Version: 4,
|
||||||
|
SchemaVersion: 3,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("4")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, migrations...),
|
||||||
|
expectedSuccess: false,
|
||||||
|
expectedMigrationLog: nil,
|
||||||
|
expectedSchemaVersion: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failure of migration 4",
|
||||||
|
migrations: append(migrations, MigrationConfig{
|
||||||
|
Name: "4",
|
||||||
|
Version: 4,
|
||||||
|
SchemaVersion: 3,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
return fmt.Errorf("migration 4 failed")
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
expectedSuccess: false,
|
||||||
|
expectedMigrationLog: []string{"1", "2", "3"},
|
||||||
|
// Since schema migration is a separate step we expect
|
||||||
|
// that migrating up to 3 succeeded.
|
||||||
|
expectedSchemaVersion: 3,
|
||||||
|
// We still remain on version 3 though.
|
||||||
|
expectedVersion: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success of migration 4",
|
||||||
|
migrations: append(migrations, MigrationConfig{
|
||||||
|
Name: "4",
|
||||||
|
Version: 4,
|
||||||
|
SchemaVersion: 3,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("4")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
expectedSuccess: true,
|
||||||
|
expectedMigrationLog: []string{"1", "2", "3", "4"},
|
||||||
|
expectedSchemaVersion: 3,
|
||||||
|
expectedVersion: 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxb := context.Background()
|
||||||
|
for _, test := range tests {
|
||||||
|
// checkSchemaVersion checks the database schema version against
|
||||||
|
// the expected version.
|
||||||
|
getSchemaVersion := func(t *testing.T,
|
||||||
|
driver database.Driver, dbName string) int {
|
||||||
|
|
||||||
|
sqlMigrate, err := migrate.NewWithInstance(
|
||||||
|
"migrations", nil, dbName, driver,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
version, _, err := sqlMigrate.Version()
|
||||||
|
if err != migrate.ErrNilVersion {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(version)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("SQLite "+test.name, func(t *testing.T) {
|
||||||
|
customMigrationLog = nil
|
||||||
|
|
||||||
|
// First instantiate the database and run the migrations
|
||||||
|
// including the custom migrations.
|
||||||
|
t.Logf("Creating new SQLite DB for testing migrations")
|
||||||
|
|
||||||
|
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
||||||
|
var (
|
||||||
|
db *SqliteStore
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
// Run the migration 3 times to test that the migrations
|
||||||
|
// are idempotent.
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
db, err = NewSqliteStore(&SqliteConfig{
|
||||||
|
SkipMigrations: false,
|
||||||
|
}, dbFileName, test.migrations)
|
||||||
|
if db != nil {
|
||||||
|
dbToCleanup := db.DB
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(
|
||||||
|
t, dbToCleanup.Close(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.expectedSuccess {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Also repoen the DB without migrations
|
||||||
|
// so we can read versions.
|
||||||
|
db, err = NewSqliteStore(&SqliteConfig{
|
||||||
|
SkipMigrations: true,
|
||||||
|
}, dbFileName, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t,
|
||||||
|
test.expectedMigrationLog,
|
||||||
|
customMigrationLog,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create the migration executor to be able to
|
||||||
|
// query the current schema version.
|
||||||
|
driver, err := sqlite_migrate.WithInstance(
|
||||||
|
db.DB, &sqlite_migrate.Config{},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, test.expectedSchemaVersion,
|
||||||
|
getSchemaVersion(t, driver, ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check the migraton version in the database.
|
||||||
|
version, err := db.GetDatabaseVersion(ctxb)
|
||||||
|
if test.expectedSchemaVersion != 0 {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Equal(t, sql.ErrNoRows, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, test.expectedVersion, int(version),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Postgres "+test.name, func(t *testing.T) {
|
||||||
|
customMigrationLog = nil
|
||||||
|
|
||||||
|
// First create a temporary Postgres database to run
|
||||||
|
// the migrations on.
|
||||||
|
fixture := NewTestPgFixture(
|
||||||
|
t, DefaultPostgresFixtureLifetime,
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
fixture.TearDown(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbName := randomDBName(t)
|
||||||
|
|
||||||
|
// Next instantiate the database and run the migrations
|
||||||
|
// including the custom migrations.
|
||||||
|
t.Logf("Creating new Postgres DB '%s' for testing "+
|
||||||
|
"migrations", dbName)
|
||||||
|
|
||||||
|
_, err := fixture.db.ExecContext(
|
||||||
|
context.Background(), "CREATE DATABASE "+dbName,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := fixture.GetConfig(dbName)
|
||||||
|
var db *PostgresStore
|
||||||
|
|
||||||
|
// Run the migration 3 times to test that the migrations
|
||||||
|
// are idempotent.
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
cfg.SkipMigrations = false
|
||||||
|
db, err = NewPostgresStore(cfg, test.migrations)
|
||||||
|
|
||||||
|
if test.expectedSuccess {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Also repoen the DB without migrations
|
||||||
|
// so we can read versions.
|
||||||
|
cfg.SkipMigrations = true
|
||||||
|
db, err = NewPostgresStore(cfg, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t,
|
||||||
|
test.expectedMigrationLog,
|
||||||
|
customMigrationLog,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create the migration executor to be able to
|
||||||
|
// query the current version.
|
||||||
|
driver, err := pgx_migrate.WithInstance(
|
||||||
|
db.DB, &pgx_migrate.Config{},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, test.expectedSchemaVersion,
|
||||||
|
getSchemaVersion(t, driver, ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check the migraton version in the database.
|
||||||
|
version, err := db.GetDatabaseVersion(ctxb)
|
||||||
|
if test.expectedSchemaVersion != 0 {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Equal(t, sql.ErrNoRows, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, test.expectedVersion, int(version),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package sqldb
|
package sqldb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -32,6 +33,9 @@ var (
|
|||||||
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
|
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
|
||||||
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
|
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make sure PostgresStore implements the MigrationExecutor interface.
|
||||||
|
_ MigrationExecutor = (*PostgresStore)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// replacePasswordInDSN takes a DSN string and returns it with the password
|
// replacePasswordInDSN takes a DSN string and returns it with the password
|
||||||
@@ -85,43 +89,62 @@ type PostgresStore struct {
|
|||||||
|
|
||||||
// NewPostgresStore creates a new store that is backed by a Postgres database
|
// NewPostgresStore creates a new store that is backed by a Postgres database
|
||||||
// backend.
|
// backend.
|
||||||
func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
|
func NewPostgresStore(cfg *PostgresConfig, migrations []MigrationConfig) (
|
||||||
|
*PostgresStore, error) {
|
||||||
|
|
||||||
sanitizedDSN, err := replacePasswordInDSN(cfg.Dsn)
|
sanitizedDSN, err := replacePasswordInDSN(cfg.Dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
log.Infof("Using SQL database '%s'", sanitizedDSN)
|
log.Infof("Using SQL database '%s'", sanitizedDSN)
|
||||||
|
|
||||||
rawDB, err := sql.Open("pgx", cfg.Dsn)
|
db, err := sql.Open("pgx", cfg.Dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure the migration tracker table exists before running migrations.
|
||||||
|
// This table tracks migration progress and ensures compatibility with
|
||||||
|
// SQLC query generation. If the table is already created by an SQLC
|
||||||
|
// migration, this operation becomes a no-op.
|
||||||
|
migrationTrackerSQL := `
|
||||||
|
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||||
|
version INTEGER UNIQUE NOT NULL,
|
||||||
|
migration_time TIMESTAMP NOT NULL
|
||||||
|
);`
|
||||||
|
|
||||||
|
_, err = db.Exec(migrationTrackerSQL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating migration tracker: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
maxConns := defaultMaxConns
|
maxConns := defaultMaxConns
|
||||||
if cfg.MaxConnections > 0 {
|
if cfg.MaxConnections > 0 {
|
||||||
maxConns = cfg.MaxConnections
|
maxConns = cfg.MaxConnections
|
||||||
}
|
}
|
||||||
|
|
||||||
rawDB.SetMaxOpenConns(maxConns)
|
db.SetMaxOpenConns(maxConns)
|
||||||
rawDB.SetMaxIdleConns(maxConns)
|
db.SetMaxIdleConns(maxConns)
|
||||||
rawDB.SetConnMaxLifetime(connIdleLifetime)
|
db.SetConnMaxLifetime(connIdleLifetime)
|
||||||
|
|
||||||
queries := sqlc.New(rawDB)
|
queries := sqlc.New(db)
|
||||||
|
|
||||||
s := &PostgresStore{
|
s := &PostgresStore{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
BaseDB: &BaseDB{
|
BaseDB: &BaseDB{
|
||||||
DB: rawDB,
|
DB: db,
|
||||||
Queries: queries,
|
Queries: queries,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute migrations unless configured to skip them.
|
// Execute migrations unless configured to skip them.
|
||||||
if !cfg.SkipMigrations {
|
if !cfg.SkipMigrations {
|
||||||
err := s.ExecuteMigrations(TargetLatest)
|
err := ApplyMigrations(
|
||||||
|
context.Background(), s.BaseDB, s, migrations,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error executing migrations: %w",
|
return nil, err
|
||||||
err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -148,7 +148,7 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cfg := fixture.GetConfig(dbName)
|
cfg := fixture.GetConfig(dbName)
|
||||||
store, err := NewPostgresStore(cfg)
|
store, err := NewPostgresStore(cfg, GetMigrations())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return store
|
return store
|
||||||
@@ -172,7 +172,7 @@ func NewTestPostgresDBWithVersion(t *testing.T, fixture *TestPgFixture,
|
|||||||
|
|
||||||
storeCfg := fixture.GetConfig(dbName)
|
storeCfg := fixture.GetConfig(dbName)
|
||||||
storeCfg.SkipMigrations = true
|
storeCfg.SkipMigrations = true
|
||||||
store, err := NewPostgresStore(storeCfg)
|
store, err := NewPostgresStore(storeCfg, GetMigrations())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = store.ExecuteMigrations(TargetVersion(version))
|
err = store.ExecuteMigrations(TargetVersion(version))
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
package sqldb
|
package sqldb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -34,6 +35,9 @@ var (
|
|||||||
sqliteSchemaReplacements = map[string]string{
|
sqliteSchemaReplacements = map[string]string{
|
||||||
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
|
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make sure SqliteStore implements the MigrationExecutor interface.
|
||||||
|
_ MigrationExecutor = (*SqliteStore)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// SqliteStore is a database store implementation that uses a sqlite backend.
|
// SqliteStore is a database store implementation that uses a sqlite backend.
|
||||||
@@ -45,7 +49,9 @@ type SqliteStore struct {
|
|||||||
|
|
||||||
// NewSqliteStore attempts to open a new sqlite database based on the passed
|
// NewSqliteStore attempts to open a new sqlite database based on the passed
|
||||||
// config.
|
// config.
|
||||||
func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
func NewSqliteStore(cfg *SqliteConfig, dbPath string,
|
||||||
|
migrations []MigrationConfig) (*SqliteStore, error) {
|
||||||
|
|
||||||
// The set of pragma options are accepted using query options. For now
|
// The set of pragma options are accepted using query options. For now
|
||||||
// we only want to ensure that foreign key constraints are properly
|
// we only want to ensure that foreign key constraints are properly
|
||||||
// enforced.
|
// enforced.
|
||||||
@@ -102,6 +108,23 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create the migration tracker table before starting migrations to
|
||||||
|
// ensure it can be used to track migration progress. Note that a
|
||||||
|
// corresponding SQLC migration also creates this table, making this
|
||||||
|
// operation a no-op in that context. Its purpose is to ensure
|
||||||
|
// compatibility with SQLC query generation.
|
||||||
|
migrationTrackerSQL := `
|
||||||
|
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||||
|
version INTEGER UNIQUE NOT NULL,
|
||||||
|
migration_time TIMESTAMP NOT NULL
|
||||||
|
);`
|
||||||
|
|
||||||
|
_, err = db.Exec(migrationTrackerSQL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating migration tracker: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
db.SetMaxOpenConns(defaultMaxConns)
|
db.SetMaxOpenConns(defaultMaxConns)
|
||||||
db.SetMaxIdleConns(defaultMaxConns)
|
db.SetMaxIdleConns(defaultMaxConns)
|
||||||
db.SetConnMaxLifetime(connIdleLifetime)
|
db.SetConnMaxLifetime(connIdleLifetime)
|
||||||
@@ -117,10 +140,11 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
|||||||
|
|
||||||
// Execute migrations unless configured to skip them.
|
// Execute migrations unless configured to skip them.
|
||||||
if !cfg.SkipMigrations {
|
if !cfg.SkipMigrations {
|
||||||
if err := s.ExecuteMigrations(TargetLatest); err != nil {
|
err := ApplyMigrations(
|
||||||
return nil, fmt.Errorf("error executing migrations: "+
|
context.Background(), s.BaseDB, s, migrations,
|
||||||
"%w", err)
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +181,7 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore {
|
|||||||
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
||||||
sqlDB, err := NewSqliteStore(&SqliteConfig{
|
sqlDB, err := NewSqliteStore(&SqliteConfig{
|
||||||
SkipMigrations: false,
|
SkipMigrations: false,
|
||||||
}, dbFileName)
|
}, dbFileName, GetMigrations())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -180,7 +204,7 @@ func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore {
|
|||||||
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
||||||
sqlDB, err := NewSqliteStore(&SqliteConfig{
|
sqlDB, err := NewSqliteStore(&SqliteConfig{
|
||||||
SkipMigrations: true,
|
SkipMigrations: true,
|
||||||
}, dbFileName)
|
}, dbFileName, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = sqlDB.ExecuteMigrations(TargetVersion(version))
|
err = sqlDB.ExecuteMigrations(TargetVersion(version))
|
||||||
|
Reference in New Issue
Block a user