diff --git a/kvdb/postgres/db_test.go b/kvdb/postgres/db_test.go index faf853dd2..dab6bf447 100644 --- a/kvdb/postgres/db_test.go +++ b/kvdb/postgres/db_test.go @@ -14,15 +14,19 @@ import ( // TestInterface performs all interfaces tests for this database driver. func TestInterface(t *testing.T) { - f := NewFixture(t) - defer f.Cleanup() + stop, err := StartEmbeddedPostgres() + require.NoError(t, err) + defer stop() + + f, err := NewFixture("") + require.NoError(t, err) // dbType is the database type name for this driver. const dbType = "postgres" ctx := context.Background() cfg := &Config{ - Dsn: testDsn, + Dsn: f.Dsn, } walletdbtest.TestInterface(t, dbType, ctx, cfg, prefix) @@ -30,17 +34,19 @@ func TestInterface(t *testing.T) { // TestPanic tests recovery from panic conditions. func TestPanic(t *testing.T) { - f := NewFixture(t) - defer f.Cleanup() + stop, err := StartEmbeddedPostgres() + require.NoError(t, err) + defer stop() - d := f.NewBackend() + f, err := NewFixture("") + require.NoError(t, err) - err := d.(*db).Update(func(tx walletdb.ReadWriteTx) error { + err = f.Db.(*db).Update(func(tx walletdb.ReadWriteTx) error { bucket, err := tx.CreateTopLevelBucket([]byte("test")) require.NoError(t, err) // Stop database server. - f.Cleanup() + stop() // Keep trying to get data until Get panics because the // connection is lost. diff --git a/kvdb/postgres/fixture.go b/kvdb/postgres/fixture.go index 8b3ecae73..929563220 100644 --- a/kvdb/postgres/fixture.go +++ b/kvdb/postgres/fixture.go @@ -4,87 +4,118 @@ package postgres import ( "context" + "crypto/rand" "database/sql" - "testing" + "encoding/hex" + "fmt" + "strings" "github.com/btcsuite/btcwallet/walletdb" embeddedpostgres "github.com/fergusstrange/embedded-postgres" - "github.com/stretchr/testify/require" ) const ( - testDsn = "postgres://postgres:postgres@localhost:9876/postgres?sslmode=disable" - prefix = "test" + testDsnTemplate = "postgres://postgres:postgres@localhost:9876/%v?sslmode=disable" + prefix = "test" ) -func clearTestDb(t *testing.T) { - dbConn, err := sql.Open("pgx", testDsn) - require.NoError(t, err) - - _, err = dbConn.ExecContext(context.Background(), "DROP SCHEMA IF EXISTS public CASCADE;") - require.NoError(t, err) +func getTestDsn(dbName string) string { + return fmt.Sprintf(testDsnTemplate, dbName) } -func openTestDb(t *testing.T) *db { - clearTestDb(t) +var testPostgres *embeddedpostgres.EmbeddedPostgres - db, err := newPostgresBackend( - context.Background(), - &Config{ - Dsn: testDsn, - }, - prefix, - ) - require.NoError(t, err) - - return db -} - -type fixture struct { - t *testing.T - tempDir string - postgres *embeddedpostgres.EmbeddedPostgres -} - -func NewFixture(t *testing.T) *fixture { +// StartEmbeddedPostgres starts an embedded postgres instance. This only needs +// to be done once, because NewFixture will create random new databases on every +// call. It returns a stop closure that stops the database if called. +func StartEmbeddedPostgres() (func() error, error) { postgres := embeddedpostgres.NewDatabase( embeddedpostgres.DefaultConfig(). Port(9876)) err := postgres.Start() - require.NoError(t, err) + if err != nil { + return nil, err + } + + testPostgres = postgres + + return testPostgres.Stop, nil +} + +// NewFixture returns a new postgres test database. The database name is +// randomly generated. +func NewFixture(dbName string) (*fixture, error) { + if dbName == "" { + // Create random database name. + randBytes := make([]byte, 8) + _, err := rand.Read(randBytes) + if err != nil { + return nil, err + } + + dbName = "test_" + hex.EncodeToString(randBytes) + } + + // Create database if it doesn't exist yet. + dbConn, err := sql.Open("pgx", getTestDsn("postgres")) + if err != nil { + return nil, err + } + defer dbConn.Close() + + _, err = dbConn.ExecContext( + context.Background(), "CREATE DATABASE "+dbName, + ) + if err != nil && !strings.Contains(err.Error(), "already exists") { + return nil, err + } + + // Open database + dsn := getTestDsn(dbName) + db, err := newPostgresBackend( + context.Background(), + &Config{ + Dsn: dsn, + }, + prefix, + ) + if err != nil { + return nil, err + } return &fixture{ - t: t, - postgres: postgres, + Dsn: dsn, + Db: db, + }, nil +} + +type fixture struct { + Dsn string + Db walletdb.DB +} + +// Dump returns the raw contents of the database. +func (b *fixture) Dump() (map[string]interface{}, error) { + dbConn, err := sql.Open("pgx", b.Dsn) + if err != nil { + return nil, err } -} - -func (b *fixture) Cleanup() { - b.postgres.Stop() -} - -func (b *fixture) NewBackend() walletdb.DB { - clearTestDb(b.t) - db := openTestDb(b.t) - - return db -} - -func (b *fixture) Dump() map[string]interface{} { - dbConn, err := sql.Open("pgx", testDsn) - require.NoError(b.t, err) rows, err := dbConn.Query( "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'", ) - require.NoError(b.t, err) + if err != nil { + return nil, err + } var tables []string for rows.Next() { var table string err := rows.Scan(&table) - require.NoError(b.t, err) + if err != nil { + return nil, err + } tables = append(tables, table) } @@ -93,10 +124,14 @@ func (b *fixture) Dump() map[string]interface{} { for _, table := range tables { rows, err := dbConn.Query("SELECT * FROM " + table) - require.NoError(b.t, err) + if err != nil { + return nil, err + } cols, err := rows.Columns() - require.NoError(b.t, err) + if err != nil { + return nil, err + } colCount := len(cols) var tableRows []map[string]interface{} @@ -108,7 +143,9 @@ func (b *fixture) Dump() map[string]interface{} { } err := rows.Scan(valuePtrs...) - require.NoError(b.t, err) + if err != nil { + return nil, err + } tableData := make(map[string]interface{}) for i, v := range values { @@ -127,5 +164,5 @@ func (b *fixture) Dump() map[string]interface{} { result[table] = tableRows } - return result + return result, nil } diff --git a/kvdb/postgres_test.go b/kvdb/postgres_test.go index a880ceea7..445f71b73 100644 --- a/kvdb/postgres_test.go +++ b/kvdb/postgres_test.go @@ -13,8 +13,9 @@ import ( type m = map[string]interface{} func TestPostgres(t *testing.T) { - f := postgres.NewFixture(t) - defer f.Cleanup() + stop, err := postgres.StartEmbeddedPostgres() + require.NoError(t, err) + defer stop() tests := []struct { name string @@ -173,10 +174,14 @@ func TestPostgres(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - test.test(t, f.NewBackend()) + f, err := postgres.NewFixture("") + require.NoError(t, err) + + test.test(t, f.Db) if test.expectedDb != nil { - dump := f.Dump() + dump, err := f.Dump() + require.NoError(t, err) require.Equal(t, test.expectedDb, dump) } })