mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-09 11:20:19 +02:00
kvdb/postgres/test: single instance embedded postgres database
Prepare for parallel tests that use a postgres backend. We don't want a high number of embedded postgres instances running simultaneously.
This commit is contained in:
parent
bd291286f7
commit
7c048efa21
@ -14,15 +14,19 @@ import (
|
|||||||
|
|
||||||
// TestInterface performs all interfaces tests for this database driver.
|
// TestInterface performs all interfaces tests for this database driver.
|
||||||
func TestInterface(t *testing.T) {
|
func TestInterface(t *testing.T) {
|
||||||
f := NewFixture(t)
|
stop, err := StartEmbeddedPostgres()
|
||||||
defer f.Cleanup()
|
require.NoError(t, err)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
f, err := NewFixture("")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// dbType is the database type name for this driver.
|
// dbType is the database type name for this driver.
|
||||||
const dbType = "postgres"
|
const dbType = "postgres"
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
Dsn: testDsn,
|
Dsn: f.Dsn,
|
||||||
}
|
}
|
||||||
|
|
||||||
walletdbtest.TestInterface(t, dbType, ctx, cfg, prefix)
|
walletdbtest.TestInterface(t, dbType, ctx, cfg, prefix)
|
||||||
@ -30,17 +34,19 @@ func TestInterface(t *testing.T) {
|
|||||||
|
|
||||||
// TestPanic tests recovery from panic conditions.
|
// TestPanic tests recovery from panic conditions.
|
||||||
func TestPanic(t *testing.T) {
|
func TestPanic(t *testing.T) {
|
||||||
f := NewFixture(t)
|
stop, err := StartEmbeddedPostgres()
|
||||||
defer f.Cleanup()
|
require.NoError(t, err)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
d := f.NewBackend()
|
f, err := NewFixture("")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
err := d.(*db).Update(func(tx walletdb.ReadWriteTx) error {
|
err = f.Db.(*db).Update(func(tx walletdb.ReadWriteTx) error {
|
||||||
bucket, err := tx.CreateTopLevelBucket([]byte("test"))
|
bucket, err := tx.CreateTopLevelBucket([]byte("test"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Stop database server.
|
// Stop database server.
|
||||||
f.Cleanup()
|
stop()
|
||||||
|
|
||||||
// Keep trying to get data until Get panics because the
|
// Keep trying to get data until Get panics because the
|
||||||
// connection is lost.
|
// connection is lost.
|
||||||
|
@ -4,87 +4,118 @@ package postgres
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"testing"
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/btcsuite/btcwallet/walletdb"
|
"github.com/btcsuite/btcwallet/walletdb"
|
||||||
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
|
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testDsn = "postgres://postgres:postgres@localhost:9876/postgres?sslmode=disable"
|
testDsnTemplate = "postgres://postgres:postgres@localhost:9876/%v?sslmode=disable"
|
||||||
prefix = "test"
|
prefix = "test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func clearTestDb(t *testing.T) {
|
func getTestDsn(dbName string) string {
|
||||||
dbConn, err := sql.Open("pgx", testDsn)
|
return fmt.Sprintf(testDsnTemplate, dbName)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = dbConn.ExecContext(context.Background(), "DROP SCHEMA IF EXISTS public CASCADE;")
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func openTestDb(t *testing.T) *db {
|
var testPostgres *embeddedpostgres.EmbeddedPostgres
|
||||||
clearTestDb(t)
|
|
||||||
|
|
||||||
db, err := newPostgresBackend(
|
// StartEmbeddedPostgres starts an embedded postgres instance. This only needs
|
||||||
context.Background(),
|
// to be done once, because NewFixture will create random new databases on every
|
||||||
&Config{
|
// call. It returns a stop closure that stops the database if called.
|
||||||
Dsn: testDsn,
|
func StartEmbeddedPostgres() (func() error, error) {
|
||||||
},
|
|
||||||
prefix,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
type fixture struct {
|
|
||||||
t *testing.T
|
|
||||||
tempDir string
|
|
||||||
postgres *embeddedpostgres.EmbeddedPostgres
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewFixture(t *testing.T) *fixture {
|
|
||||||
postgres := embeddedpostgres.NewDatabase(
|
postgres := embeddedpostgres.NewDatabase(
|
||||||
embeddedpostgres.DefaultConfig().
|
embeddedpostgres.DefaultConfig().
|
||||||
Port(9876))
|
Port(9876))
|
||||||
|
|
||||||
err := postgres.Start()
|
err := postgres.Start()
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
testPostgres = postgres
|
||||||
|
|
||||||
|
return testPostgres.Stop, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFixture returns a new postgres test database. The database name is
|
||||||
|
// randomly generated.
|
||||||
|
func NewFixture(dbName string) (*fixture, error) {
|
||||||
|
if dbName == "" {
|
||||||
|
// Create random database name.
|
||||||
|
randBytes := make([]byte, 8)
|
||||||
|
_, err := rand.Read(randBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dbName = "test_" + hex.EncodeToString(randBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create database if it doesn't exist yet.
|
||||||
|
dbConn, err := sql.Open("pgx", getTestDsn("postgres"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer dbConn.Close()
|
||||||
|
|
||||||
|
_, err = dbConn.ExecContext(
|
||||||
|
context.Background(), "CREATE DATABASE "+dbName,
|
||||||
|
)
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "already exists") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open database
|
||||||
|
dsn := getTestDsn(dbName)
|
||||||
|
db, err := newPostgresBackend(
|
||||||
|
context.Background(),
|
||||||
|
&Config{
|
||||||
|
Dsn: dsn,
|
||||||
|
},
|
||||||
|
prefix,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &fixture{
|
return &fixture{
|
||||||
t: t,
|
Dsn: dsn,
|
||||||
postgres: postgres,
|
Db: db,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fixture struct {
|
||||||
|
Dsn string
|
||||||
|
Db walletdb.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dump returns the raw contents of the database.
|
||||||
|
func (b *fixture) Dump() (map[string]interface{}, error) {
|
||||||
|
dbConn, err := sql.Open("pgx", b.Dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (b *fixture) Cleanup() {
|
|
||||||
b.postgres.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *fixture) NewBackend() walletdb.DB {
|
|
||||||
clearTestDb(b.t)
|
|
||||||
db := openTestDb(b.t)
|
|
||||||
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *fixture) Dump() map[string]interface{} {
|
|
||||||
dbConn, err := sql.Open("pgx", testDsn)
|
|
||||||
require.NoError(b.t, err)
|
|
||||||
|
|
||||||
rows, err := dbConn.Query(
|
rows, err := dbConn.Query(
|
||||||
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'",
|
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'",
|
||||||
)
|
)
|
||||||
require.NoError(b.t, err)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
var tables []string
|
var tables []string
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var table string
|
var table string
|
||||||
err := rows.Scan(&table)
|
err := rows.Scan(&table)
|
||||||
require.NoError(b.t, err)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
tables = append(tables, table)
|
tables = append(tables, table)
|
||||||
}
|
}
|
||||||
@ -93,10 +124,14 @@ func (b *fixture) Dump() map[string]interface{} {
|
|||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
rows, err := dbConn.Query("SELECT * FROM " + table)
|
rows, err := dbConn.Query("SELECT * FROM " + table)
|
||||||
require.NoError(b.t, err)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
cols, err := rows.Columns()
|
cols, err := rows.Columns()
|
||||||
require.NoError(b.t, err)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
colCount := len(cols)
|
colCount := len(cols)
|
||||||
|
|
||||||
var tableRows []map[string]interface{}
|
var tableRows []map[string]interface{}
|
||||||
@ -108,7 +143,9 @@ func (b *fixture) Dump() map[string]interface{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err := rows.Scan(valuePtrs...)
|
err := rows.Scan(valuePtrs...)
|
||||||
require.NoError(b.t, err)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
tableData := make(map[string]interface{})
|
tableData := make(map[string]interface{})
|
||||||
for i, v := range values {
|
for i, v := range values {
|
||||||
@ -127,5 +164,5 @@ func (b *fixture) Dump() map[string]interface{} {
|
|||||||
result[table] = tableRows
|
result[table] = tableRows
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result, nil
|
||||||
}
|
}
|
||||||
|
@ -13,8 +13,9 @@ import (
|
|||||||
type m = map[string]interface{}
|
type m = map[string]interface{}
|
||||||
|
|
||||||
func TestPostgres(t *testing.T) {
|
func TestPostgres(t *testing.T) {
|
||||||
f := postgres.NewFixture(t)
|
stop, err := postgres.StartEmbeddedPostgres()
|
||||||
defer f.Cleanup()
|
require.NoError(t, err)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -173,10 +174,14 @@ func TestPostgres(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
test.test(t, f.NewBackend())
|
f, err := postgres.NewFixture("")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
test.test(t, f.Db)
|
||||||
|
|
||||||
if test.expectedDb != nil {
|
if test.expectedDb != nil {
|
||||||
dump := f.Dump()
|
dump, err := f.Dump()
|
||||||
|
require.NoError(t, err)
|
||||||
require.Equal(t, test.expectedDb, dump)
|
require.Equal(t, test.expectedDb, dump)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user