diff --git a/sqldb/postgres_fixture.go b/sqldb/postgres_fixture.go index a0f0893ce..e801d943f 100644 --- a/sqldb/postgres_fixture.go +++ b/sqldb/postgres_fixture.go @@ -124,28 +124,28 @@ func (f *TestPgFixture) TearDown(t *testing.T) { require.NoError(t, err, "Could not purge resource") } +// randomDBName generates a random database name. +func randomDBName(t *testing.T) string { + randBytes := make([]byte, 8) + _, err := rand.Read(randBytes) + require.NoError(t, err) + + return "test_" + hex.EncodeToString(randBytes) +} + // NewTestPostgresDB is a helper function that creates a Postgres database for // testing using the given fixture. func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore { t.Helper() - // Create random database name. - randBytes := make([]byte, 8) - _, err := rand.Read(randBytes) - if err != nil { - t.Fatal(err) - } - - dbName := "test_" + hex.EncodeToString(randBytes) + dbName := randomDBName(t) t.Logf("Creating new Postgres DB '%s' for testing", dbName) - _, err = fixture.db.ExecContext( + _, err := fixture.db.ExecContext( context.Background(), "CREATE DATABASE "+dbName, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) cfg := fixture.GetConfig(dbName) store, err := NewPostgresStore(cfg) @@ -153,3 +153,30 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore { return store } + +// NewTestPostgresDBWithVersion is a helper function that creates a Postgres +// database for testing and migrates it to the given version. +func NewTestPostgresDBWithVersion(t *testing.T, fixture *TestPgFixture, + version uint) *PostgresStore { + + t.Helper() + + t.Logf("Creating new Postgres DB for testing, migrating to version %d", + version) + + dbName := randomDBName(t) + _, err := fixture.db.ExecContext( + context.Background(), "CREATE DATABASE "+dbName, + ) + require.NoError(t, err) + + storeCfg := fixture.GetConfig(dbName) + storeCfg.SkipMigrations = true + store, err := NewPostgresStore(storeCfg) + require.NoError(t, err) + + err = store.ExecuteMigrations(TargetVersion(version)) + require.NoError(t, err) + + return store +} diff --git a/sqldb/postgres_test.go b/sqldb/postgres_test.go index c31daf04b..22a29f885 100644 --- a/sqldb/postgres_test.go +++ b/sqldb/postgres_test.go @@ -9,5 +9,21 @@ import ( // NewTestDB is a helper function that creates a Postgres database for testing. func NewTestDB(t *testing.T) *PostgresStore { - return NewTestPostgresDB(t) + pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + return NewTestPostgresDB(t, pgFixture) +} + +// NewTestDBWithVersion is a helper function that creates a Postgres database +// for testing and migrates it to the given version. +func NewTestDBWithVersion(t *testing.T, version uint) *PostgresStore { + pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + return NewTestPostgresDBWithVersion(t, pgFixture, version) } diff --git a/sqldb/sqlite.go b/sqldb/sqlite.go index 9058a8765..99e55d6ea 100644 --- a/sqldb/sqlite.go +++ b/sqldb/sqlite.go @@ -166,3 +166,29 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore { return sqlDB } + +// NewTestSqliteDBWithVersion is a helper function that creates an SQLite +// database for testing and migrates it to the given version. +func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore { + t.Helper() + + t.Logf("Creating new SQLite DB for testing, migrating to version %d", + version) + + // TODO(roasbeef): if we pass :memory: for the file name, then we get + // an in mem version to speed up tests + dbFileName := filepath.Join(t.TempDir(), "tmp.db") + sqlDB, err := NewSqliteStore(&SqliteConfig{ + SkipMigrations: true, + }, dbFileName) + require.NoError(t, err) + + err = sqlDB.ExecuteMigrations(TargetVersion(version)) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, sqlDB.DB.Close()) + }) + + return sqlDB +} diff --git a/sqldb/sqlite_test.go b/sqldb/sqlite_test.go index 58c94b351..9dfb875ea 100644 --- a/sqldb/sqlite_test.go +++ b/sqldb/sqlite_test.go @@ -11,3 +11,9 @@ import ( func NewTestDB(t *testing.T) *SqliteStore { return NewTestSqliteDB(t) } + +// NewTestDBWithVersion is a helper function that creates an SQLite database +// for testing and migrates it to the given version. +func NewTestDBWithVersion(t *testing.T, version uint) *SqliteStore { + return NewTestSqliteDBWithVersion(t, version) +}