sqldb: extract migration into method

Based on: https://github.com/lightninglabs/taproot-assets/pull/707
This commit is contained in:
Andras Banki-Horvath
2024-06-25 15:18:28 +02:00
parent 323af946e0
commit 5292c76e10
3 changed files with 162 additions and 72 deletions

View File

@@ -8,16 +8,71 @@ import (
"net/http"
"strings"
"github.com/btcsuite/btclog"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/source/httpfs"
)
// MigrationTarget is a functional option that can be passed to applyMigrations
// to specify a target version to migrate to.
type MigrationTarget func(mig *migrate.Migrate) error
var (
// TargetLatest is a MigrationTarget that migrates to the latest
// version available.
TargetLatest = func(mig *migrate.Migrate) error {
return mig.Up()
}
// TargetVersion is a MigrationTarget that migrates to the given
// version.
TargetVersion = func(version uint) MigrationTarget {
return func(mig *migrate.Migrate) error {
return mig.Migrate(version)
}
}
)
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
// used to log migrations.
type migrationLogger struct {
log btclog.Logger
}
// Printf is like fmt.Printf. We map this to the target logger based on the
// current log level.
func (m *migrationLogger) Printf(format string, v ...interface{}) {
// Trim trailing newlines from the format.
format = strings.TrimRight(format, "\n")
switch m.log.Level() {
case btclog.LevelTrace:
m.log.Tracef(format, v...)
case btclog.LevelDebug:
m.log.Debugf(format, v...)
case btclog.LevelInfo:
m.log.Infof(format, v...)
case btclog.LevelWarn:
m.log.Warnf(format, v...)
case btclog.LevelError:
m.log.Errorf(format, v...)
case btclog.LevelCritical:
m.log.Criticalf(format, v...)
case btclog.LevelOff:
}
}
// Verbose should return true when verbose logging output is wanted
func (m *migrationLogger) Verbose() bool {
return m.log.Level() <= btclog.LevelDebug
}
// applyMigrations executes all database migration files found in the given file
// system under the given path, using the passed database driver and database
// name.
func applyMigrations(fs fs.FS, driver database.Driver, path,
dbName string) error {
dbName string, targetVersion MigrationTarget) error {
// With the migrate instance open, we'll create a new migration source
// using the embedded file system stored in sqlSchemas. The library
@@ -37,7 +92,22 @@ func applyMigrations(fs fs.FS, driver database.Driver, path,
if err != nil {
return err
}
err = sqlMigrate.Up()
migrationVersion, _, err := sqlMigrate.Version()
if err != nil && !errors.Is(err, migrate.ErrNilVersion) {
log.Errorf("Unable to determine current migration version: %v",
err)
return err
}
log.Infof("Applying migrations from version=%v", migrationVersion)
// Apply our local logger to the migration instance.
sqlMigrate.Log = &migrationLogger{log}
// Execute the migration based on the target given.
err = targetVersion(sqlMigrate)
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return err
}