mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-30 02:21:08 +02:00
sqldb: add the sqldb package
This commit provides the scaffolding for using the new sql stores. The new interfaces, structs and methods are in sync with other projects like Taproot Assets. - Transactional Queries: the sqldb package defines the interfaces required to execute transactional queries to our storage interface. - Migration Files Embedded: the migration files are embedded into the binary. - Database Migrations: I kept the use of 'golang-migrate' to ensure our codebase remains in sync with the other projects, but can be changed. - Build Flags for Conditional DB Target: flexibility to specify our database target at compile-time based on the build flags in the same way we do with our kv stores. - Update modules: ran `go mod tidy`.
This commit is contained in:
258
sqldb/interfaces.go
Normal file
258
sqldb/interfaces.go
Normal file
@ -0,0 +1,258 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
prand "math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultStoreTimeout is the default timeout used for any interaction
|
||||
// with the storage/database.
|
||||
DefaultStoreTimeout = time.Second * 10
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultNumTxRetries is the default number of times we'll retry a
|
||||
// transaction if it fails with an error that permits transaction
|
||||
// repetition.
|
||||
DefaultNumTxRetries = 10
|
||||
|
||||
// DefaultRetryDelay is the default delay between retries. This will be
|
||||
// used to generate a random delay between 0 and this value.
|
||||
DefaultRetryDelay = time.Millisecond * 50
|
||||
)
|
||||
|
||||
// TxOptions represents a set of options one can use to control what type of
|
||||
// database transaction is created. Transaction can wither be read or write.
|
||||
type TxOptions interface {
|
||||
// ReadOnly returns true if the transaction should be read only.
|
||||
ReadOnly() bool
|
||||
}
|
||||
|
||||
// BatchedTx is a generic interface that represents the ability to execute
|
||||
// several operations to a given storage interface in a single atomic
|
||||
// transaction. Typically, Q here will be some subset of the main sqlc.Querier
|
||||
// interface allowing it to only depend on the routines it needs to implement
|
||||
// any additional business logic.
|
||||
type BatchedTx[Q any] interface {
|
||||
// ExecTx will execute the passed txBody, operating upon generic
|
||||
// parameter Q (usually a storage interface) in a single transaction.
|
||||
//
|
||||
// The set of TxOptions are passed in order to allow the caller to
|
||||
// specify if a transaction should be read-only and optionally what
|
||||
// type of concurrency control should be used.
|
||||
ExecTx(ctx context.Context, txOptions TxOptions,
|
||||
txBody func(Q) error) error
|
||||
}
|
||||
|
||||
// Tx represents a database transaction that can be committed or rolled back.
|
||||
type Tx interface {
|
||||
// Commit commits the database transaction, an error should be returned
|
||||
// if the commit isn't possible.
|
||||
Commit() error
|
||||
|
||||
// Rollback rolls back an incomplete database transaction.
|
||||
// Transactions that were able to be committed can still call this as a
|
||||
// noop.
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// QueryCreator is a generic function that's used to create a Querier, which is
|
||||
// a type of interface that implements storage related methods from a database
|
||||
// transaction. This will be used to instantiate an object callers can use to
|
||||
// apply multiple modifications to an object interface in a single atomic
|
||||
// transaction.
|
||||
type QueryCreator[Q any] func(*sql.Tx) Q
|
||||
|
||||
// BatchedQuerier is a generic interface that allows callers to create a new
|
||||
// database transaction based on an abstract type that implements the TxOptions
|
||||
// interface.
|
||||
type BatchedQuerier interface {
|
||||
// Querier is the underlying query source, this is in place so we can
|
||||
// pass a BatchedQuerier implementation directly into objects that
|
||||
// create a batched version of the normal methods they need.
|
||||
sqlc.Querier
|
||||
|
||||
// BeginTx creates a new database transaction given the set of
|
||||
// transaction options.
|
||||
BeginTx(ctx context.Context, options TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// txExecutorOptions is a struct that holds the options for the transaction
|
||||
// executor. This can be used to do things like retry a transaction due to an
|
||||
// error a certain amount of times.
|
||||
type txExecutorOptions struct {
|
||||
numRetries int
|
||||
retryDelay time.Duration
|
||||
}
|
||||
|
||||
// defaultTxExecutorOptions returns the default options for the transaction
|
||||
// executor.
|
||||
func defaultTxExecutorOptions() *txExecutorOptions {
|
||||
return &txExecutorOptions{
|
||||
numRetries: DefaultNumTxRetries,
|
||||
retryDelay: DefaultRetryDelay,
|
||||
}
|
||||
}
|
||||
|
||||
// randRetryDelay returns a random retry delay between 0 and the configured max
|
||||
// delay.
|
||||
func (t *txExecutorOptions) randRetryDelay() time.Duration {
|
||||
return time.Duration(prand.Int63n(int64(t.retryDelay))) //nolint:gosec
|
||||
}
|
||||
|
||||
// TxExecutorOption is a functional option that allows us to pass in optional
|
||||
// argument when creating the executor.
|
||||
type TxExecutorOption func(*txExecutorOptions)
|
||||
|
||||
// WithTxRetries is a functional option that allows us to specify the number of
|
||||
// times a transaction should be retried if it fails with a repeatable error.
|
||||
func WithTxRetries(numRetries int) TxExecutorOption {
|
||||
return func(o *txExecutorOptions) {
|
||||
o.numRetries = numRetries
|
||||
}
|
||||
}
|
||||
|
||||
// WithTxRetryDelay is a functional option that allows us to specify the delay
|
||||
// to wait before a transaction is retried.
|
||||
func WithTxRetryDelay(delay time.Duration) TxExecutorOption {
|
||||
return func(o *txExecutorOptions) {
|
||||
o.retryDelay = delay
|
||||
}
|
||||
}
|
||||
|
||||
// TransactionExecutor is a generic struct that abstracts away from the type of
|
||||
// query a type needs to run under a database transaction, and also the set of
|
||||
// options for that transaction. The QueryCreator is used to create a query
|
||||
// given a database transaction created by the BatchedQuerier.
|
||||
type TransactionExecutor[Query any] struct {
|
||||
BatchedQuerier
|
||||
|
||||
createQuery QueryCreator[Query]
|
||||
|
||||
opts *txExecutorOptions
|
||||
}
|
||||
|
||||
// NewTransactionExecutor creates a new instance of a TransactionExecutor given
|
||||
// a Querier query object and a concrete type for the type of transactions the
|
||||
// Querier understands.
|
||||
func NewTransactionExecutor[Querier any](db BatchedQuerier,
|
||||
createQuery QueryCreator[Querier],
|
||||
opts ...TxExecutorOption) *TransactionExecutor[Querier] {
|
||||
|
||||
txOpts := defaultTxExecutorOptions()
|
||||
for _, optFunc := range opts {
|
||||
optFunc(txOpts)
|
||||
}
|
||||
|
||||
return &TransactionExecutor[Querier]{
|
||||
BatchedQuerier: db,
|
||||
createQuery: createQuery,
|
||||
opts: txOpts,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecTx is a wrapper for txBody to abstract the creation and commit of a db
|
||||
// transaction. The db transaction is embedded in a `*Queries` that txBody
|
||||
// needs to use when executing each one of the queries that need to be applied
|
||||
// atomically. This can be used by other storage interfaces to parameterize the
|
||||
// type of query and options run, in order to have access to batched operations
|
||||
// related to a storage object.
|
||||
func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context,
|
||||
txOptions TxOptions, txBody func(Q) error) error {
|
||||
|
||||
waitBeforeRetry := func(attemptNumber int) {
|
||||
retryDelay := t.opts.randRetryDelay()
|
||||
|
||||
log.Tracef("Retrying transaction due to tx serialization "+
|
||||
"error, attempt_number=%v, delay=%v", attemptNumber,
|
||||
retryDelay)
|
||||
|
||||
// Before we try again, we'll wait with a random backoff based
|
||||
// on the retry delay.
|
||||
time.Sleep(retryDelay)
|
||||
}
|
||||
|
||||
for i := 0; i < t.opts.numRetries; i++ {
|
||||
// Create the db transaction.
|
||||
tx, err := t.BatchedQuerier.BeginTx(ctx, txOptions)
|
||||
if err != nil {
|
||||
dbErr := MapSQLError(err)
|
||||
if IsSerializationError(dbErr) {
|
||||
// Nothing to roll back here, since we didn't
|
||||
// even get a transaction yet.
|
||||
waitBeforeRetry(i)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return dbErr
|
||||
}
|
||||
|
||||
// Rollback is safe to call even if the tx is already closed,
|
||||
// so if the tx commits successfully, this is a no-op.
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
if err := txBody(t.createQuery(tx)); err != nil {
|
||||
dbErr := MapSQLError(err)
|
||||
if IsSerializationError(dbErr) {
|
||||
// Roll back the transaction, then pop back up
|
||||
// to try once again.
|
||||
_ = tx.Rollback()
|
||||
|
||||
waitBeforeRetry(i)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return dbErr
|
||||
}
|
||||
|
||||
// Commit transaction.
|
||||
if err = tx.Commit(); err != nil {
|
||||
dbErr := MapSQLError(err)
|
||||
if IsSerializationError(dbErr) {
|
||||
// Roll back the transaction, then pop back up
|
||||
// to try once again.
|
||||
_ = tx.Rollback()
|
||||
|
||||
waitBeforeRetry(i)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return dbErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we get to this point, then we weren't able to successfully commit
|
||||
// a tx given the max number of retries.
|
||||
return ErrRetriesExceeded
|
||||
}
|
||||
|
||||
// BaseDB is the base database struct that each implementation can embed to
|
||||
// gain some common functionality.
|
||||
type BaseDB struct {
|
||||
*sql.DB
|
||||
|
||||
*sqlc.Queries
|
||||
}
|
||||
|
||||
// BeginTx wraps the normal sql specific BeginTx method with the TxOptions
|
||||
// interface. This interface is then mapped to the concrete sql tx options
|
||||
// struct.
|
||||
func (s *BaseDB) BeginTx(ctx context.Context, opts TxOptions) (*sql.Tx, error) {
|
||||
sqlOptions := sql.TxOptions{
|
||||
ReadOnly: opts.ReadOnly(),
|
||||
}
|
||||
|
||||
return s.DB.BeginTx(ctx, &sqlOptions)
|
||||
}
|
26
sqldb/log.go
Normal file
26
sqldb/log.go
Normal file
@ -0,0 +1,26 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
)
|
||||
|
||||
// Subsystem defines the logging code for this subsystem.
|
||||
const Subsystem = "SQLD"
|
||||
|
||||
// log is a logger that is initialized with no output filters. This
|
||||
// means the package will not perform any logging by default until the caller
|
||||
// requests it.
|
||||
var log = btclog.Disabled
|
||||
|
||||
// DisableLog disables all library log output. Logging output is disabled
|
||||
// by default until UseLogger is called.
|
||||
func DisableLog() {
|
||||
UseLogger(btclog.Disabled)
|
||||
}
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
// This should be used in preference to SetLogWriter if the caller is also
|
||||
// using btclog.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
148
sqldb/migrations.go
Normal file
148
sqldb/migrations.go
Normal file
@ -0,0 +1,148 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
||||
)
|
||||
|
||||
// applyMigrations executes all database migration files found in the given file
|
||||
// system under the given path, using the passed database driver and database
|
||||
// name.
|
||||
func applyMigrations(fs fs.FS, driver database.Driver, path,
|
||||
dbName string) error {
|
||||
|
||||
// With the migrate instance open, we'll create a new migration source
|
||||
// using the embedded file system stored in sqlSchemas. The library
|
||||
// we're using can't handle a raw file system interface, so we wrap it
|
||||
// in this intermediate layer.
|
||||
migrateFileServer, err := httpfs.New(http.FS(fs), path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, we'll run the migration with our driver above based on the
|
||||
// open DB, and also the migration source stored in the file system
|
||||
// above.
|
||||
sqlMigrate, err := migrate.NewWithInstance(
|
||||
"migrations", migrateFileServer, dbName, driver,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = sqlMigrate.Up()
|
||||
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// replacerFS is an implementation of a fs.FS virtual file system that wraps an
|
||||
// existing file system but does a search-and-replace operation on each file
|
||||
// when it is opened.
|
||||
type replacerFS struct {
|
||||
parentFS fs.FS
|
||||
replaces map[string]string
|
||||
}
|
||||
|
||||
// A compile-time assertion to make sure replacerFS implements the fs.FS
|
||||
// interface.
|
||||
var _ fs.FS = (*replacerFS)(nil)
|
||||
|
||||
// newReplacerFS creates a new replacer file system, wrapping the given parent
|
||||
// virtual file system. Each file within the file system is undergoing a
|
||||
// search-and-replace operation when it is opened, using the given map where the
|
||||
// key denotes the search term and the value the term to replace each occurrence
|
||||
// with.
|
||||
func newReplacerFS(parent fs.FS, replaces map[string]string) *replacerFS {
|
||||
return &replacerFS{
|
||||
parentFS: parent,
|
||||
replaces: replaces,
|
||||
}
|
||||
}
|
||||
|
||||
// Open opens a file in the virtual file system.
|
||||
//
|
||||
// NOTE: This is part of the fs.FS interface.
|
||||
func (t *replacerFS) Open(name string) (fs.File, error) {
|
||||
f, err := t.parentFS.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
return f, err
|
||||
}
|
||||
|
||||
return newReplacerFile(f, t.replaces)
|
||||
}
|
||||
|
||||
type replacerFile struct {
|
||||
parentFile fs.File
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
// A compile-time assertion to make sure replacerFile implements the fs.File
|
||||
// interface.
|
||||
var _ fs.File = (*replacerFile)(nil)
|
||||
|
||||
func newReplacerFile(parent fs.File, replaces map[string]string) (*replacerFile,
|
||||
error) {
|
||||
|
||||
content, err := io.ReadAll(parent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentStr := string(content)
|
||||
for from, to := range replaces {
|
||||
contentStr = strings.ReplaceAll(contentStr, from, to)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err = buf.WriteString(contentStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &replacerFile{
|
||||
parentFile: parent,
|
||||
buf: buf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stat returns statistics/info about the file.
|
||||
//
|
||||
// NOTE: This is part of the fs.File interface.
|
||||
func (t *replacerFile) Stat() (fs.FileInfo, error) {
|
||||
return t.parentFile.Stat()
|
||||
}
|
||||
|
||||
// Read reads as many bytes as possible from the file into the given slice.
|
||||
//
|
||||
// NOTE: This is part of the fs.File interface.
|
||||
func (t *replacerFile) Read(bytes []byte) (int, error) {
|
||||
return t.buf.Read(bytes)
|
||||
}
|
||||
|
||||
// Close closes the underlying file.
|
||||
//
|
||||
// NOTE: This is part of the fs.File interface.
|
||||
func (t *replacerFile) Close() error {
|
||||
// We already fully read and then closed the file when creating this
|
||||
// instance, so there's nothing to do for us here.
|
||||
return nil
|
||||
}
|
141
sqldb/postgres.go
Normal file
141
sqldb/postgres.go
Normal file
@ -0,0 +1,141 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file" // Read migrations from files. // nolint:lll
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
dsnTemplate = "postgres://%v:%v@%v:%d/%v?sslmode=%v"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultPostgresFixtureLifetime is the default maximum time a Postgres
|
||||
// test fixture is being kept alive. After that time the docker
|
||||
// container will be terminated forcefully, even if the tests aren't
|
||||
// fully executed yet. So this time needs to be chosen correctly to be
|
||||
// longer than the longest expected individual test run time.
|
||||
DefaultPostgresFixtureLifetime = 10 * time.Minute
|
||||
)
|
||||
|
||||
// PostgresConfig holds the postgres database configuration.
|
||||
//
|
||||
//nolint:lll
|
||||
type PostgresConfig struct {
|
||||
SkipMigrations bool `long:"skipmigrations" description:"Skip applying migrations on startup."`
|
||||
Host string `long:"host" description:"Database server hostname."`
|
||||
Port int `long:"port" description:"Database server port."`
|
||||
User string `long:"user" description:"Database user."`
|
||||
Password string `long:"password" description:"Database user's password."`
|
||||
DBName string `long:"dbname" description:"Database name to use."`
|
||||
MaxOpenConnections int `long:"maxconnections" description:"Max open connections to keep alive to the database server."`
|
||||
RequireSSL bool `long:"requiressl" description:"Whether to require using SSL (mode: require) when connecting to the server."`
|
||||
}
|
||||
|
||||
// DSN returns the dns to connect to the database.
|
||||
func (s *PostgresConfig) DSN(hidePassword bool) string {
|
||||
var sslMode = "disable"
|
||||
if s.RequireSSL {
|
||||
sslMode = "require"
|
||||
}
|
||||
|
||||
password := s.Password
|
||||
if hidePassword {
|
||||
// Placeholder used for logging the DSN safely.
|
||||
password = "****"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(dsnTemplate, s.User, password, s.Host, s.Port,
|
||||
s.DBName, sslMode)
|
||||
}
|
||||
|
||||
// PostgresStore is a database store implementation that uses a Postgres
|
||||
// backend.
|
||||
type PostgresStore struct {
|
||||
cfg *PostgresConfig
|
||||
|
||||
*BaseDB
|
||||
}
|
||||
|
||||
// NewPostgresStore creates a new store that is backed by a Postgres database
|
||||
// backend.
|
||||
func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
|
||||
log.Infof("Using SQL database '%s'", cfg.DSN(true))
|
||||
|
||||
rawDB, err := sql.Open("pgx", cfg.DSN(false))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxConns := defaultMaxConns
|
||||
if cfg.MaxOpenConnections > 0 {
|
||||
maxConns = cfg.MaxOpenConnections
|
||||
}
|
||||
|
||||
rawDB.SetMaxOpenConns(maxConns)
|
||||
rawDB.SetMaxIdleConns(maxConns)
|
||||
rawDB.SetConnMaxLifetime(connIdleLifetime)
|
||||
|
||||
if !cfg.SkipMigrations {
|
||||
// Now that the database is open, populate the database with
|
||||
// our set of schemas based on our embedded in-memory file
|
||||
// system.
|
||||
//
|
||||
// First, we'll need to open up a new migration instance for
|
||||
// our current target database: sqlite.
|
||||
driver, err := postgres_migrate.WithInstance(
|
||||
rawDB, &postgres_migrate.Config{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
postgresFS := newReplacerFS(sqlSchemas, map[string]string{
|
||||
"BLOB": "BYTEA",
|
||||
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
|
||||
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
|
||||
})
|
||||
|
||||
err = applyMigrations(
|
||||
postgresFS, driver, "sqlc/migrations", cfg.DBName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
queries := sqlc.New(rawDB)
|
||||
|
||||
return &PostgresStore{
|
||||
cfg: cfg,
|
||||
BaseDB: &BaseDB{
|
||||
DB: rawDB,
|
||||
Queries: queries,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewTestPostgresDB is a helper function that creates a Postgres database for
|
||||
// testing.
|
||||
func NewTestPostgresDB(t *testing.T) *PostgresStore {
|
||||
t.Helper()
|
||||
|
||||
t.Logf("Creating new Postgres DB for testing")
|
||||
|
||||
sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime)
|
||||
store, err := NewPostgresStore(sqlFixture.GetConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
sqlFixture.TearDown(t)
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
140
sqldb/postgres_fixture.go
Normal file
140
sqldb/postgres_fixture.go
Normal file
@ -0,0 +1,140 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq" // Import the postgres driver.
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testPgUser = "test"
|
||||
testPgPass = "test"
|
||||
testPgDBName = "test"
|
||||
PostgresTag = "11"
|
||||
)
|
||||
|
||||
// TestPgFixture is a test fixture that starts a Postgres 11 instance in a
|
||||
// docker container.
|
||||
type TestPgFixture struct {
|
||||
db *sql.DB
|
||||
pool *dockertest.Pool
|
||||
resource *dockertest.Resource
|
||||
host string
|
||||
port int
|
||||
}
|
||||
|
||||
// NewTestPgFixture constructs a new TestPgFixture starting up a docker
|
||||
// container running Postgres 11. The started container will expire in after
|
||||
// the passed duration.
|
||||
func NewTestPgFixture(t *testing.T, expiry time.Duration) *TestPgFixture {
|
||||
// Use a sensible default on Windows (tcp/http) and linux/osx (socket)
|
||||
// by specifying an empty endpoint.
|
||||
pool, err := dockertest.NewPool("")
|
||||
require.NoError(t, err, "Could not connect to docker")
|
||||
|
||||
// Pulls an image, creates a container based on it and runs it.
|
||||
resource, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "postgres",
|
||||
Tag: PostgresTag,
|
||||
Env: []string{
|
||||
fmt.Sprintf("POSTGRES_USER=%v", testPgUser),
|
||||
fmt.Sprintf("POSTGRES_PASSWORD=%v", testPgPass),
|
||||
fmt.Sprintf("POSTGRES_DB=%v", testPgDBName),
|
||||
"listen_addresses='*'",
|
||||
},
|
||||
Cmd: []string{
|
||||
"postgres",
|
||||
"-c", "log_statement=all",
|
||||
"-c", "log_destination=stderr",
|
||||
},
|
||||
}, func(config *docker.HostConfig) {
|
||||
// Set AutoRemove to true so that stopped container goes away
|
||||
// by itself.
|
||||
config.AutoRemove = true
|
||||
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
|
||||
})
|
||||
require.NoError(t, err, "Could not start resource")
|
||||
|
||||
hostAndPort := resource.GetHostPort("5432/tcp")
|
||||
parts := strings.Split(hostAndPort, ":")
|
||||
host := parts[0]
|
||||
port, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
require.NoError(t, err)
|
||||
|
||||
fixture := &TestPgFixture{
|
||||
host: host,
|
||||
port: int(port),
|
||||
}
|
||||
databaseURL := fixture.GetDSN()
|
||||
log.Infof("Connecting to Postgres fixture: %v\n", databaseURL)
|
||||
|
||||
// Tell docker to hard kill the container in "expiry" seconds.
|
||||
require.NoError(t, resource.Expire(uint(expiry.Seconds())))
|
||||
|
||||
// Exponential backoff-retry, because the application in the container
|
||||
// might not be ready to accept connections yet.
|
||||
pool.MaxWait = 120 * time.Second
|
||||
|
||||
var testDB *sql.DB
|
||||
err = pool.Retry(func() error {
|
||||
testDB, err = sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return testDB.Ping()
|
||||
})
|
||||
require.NoError(t, err, "Could not connect to docker")
|
||||
|
||||
// Now fill in the rest of the fixture.
|
||||
fixture.db = testDB
|
||||
fixture.pool = pool
|
||||
fixture.resource = resource
|
||||
|
||||
return fixture
|
||||
}
|
||||
|
||||
// GetDSN returns the DSN (Data Source Name) for the started Postgres node.
|
||||
func (f *TestPgFixture) GetDSN() string {
|
||||
return f.GetConfig().DSN(false)
|
||||
}
|
||||
|
||||
// GetConfig returns the full config of the Postgres node.
|
||||
func (f *TestPgFixture) GetConfig() *PostgresConfig {
|
||||
return &PostgresConfig{
|
||||
Host: f.host,
|
||||
Port: f.port,
|
||||
User: testPgUser,
|
||||
Password: testPgPass,
|
||||
DBName: testPgDBName,
|
||||
RequireSSL: false,
|
||||
}
|
||||
}
|
||||
|
||||
// TearDown stops the underlying docker container.
|
||||
func (f *TestPgFixture) TearDown(t *testing.T) {
|
||||
err := f.pool.Purge(f.resource)
|
||||
require.NoError(t, err, "Could not purge resource")
|
||||
}
|
||||
|
||||
// ClearDB clears the database.
|
||||
func (f *TestPgFixture) ClearDB(t *testing.T) {
|
||||
dbConn, err := sql.Open("postgres", f.GetDSN())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = dbConn.ExecContext(
|
||||
context.Background(),
|
||||
`DROP SCHEMA IF EXISTS public CASCADE;
|
||||
CREATE SCHEMA public;`,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
13
sqldb/postgres_test.go
Normal file
13
sqldb/postgres_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
//go:build test_db_postgres
|
||||
// +build test_db_postgres
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// NewTestDB is a helper function that creates a Postgres database for testing.
|
||||
func NewTestDB(t *testing.T) *PostgresStore {
|
||||
return NewTestPostgresDB(t)
|
||||
}
|
8
sqldb/schemas.go
Normal file
8
sqldb/schemas.go
Normal file
@ -0,0 +1,8 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed sqlc/migrations/*.up.sql
|
||||
var sqlSchemas embed.FS
|
113
sqldb/sqlerrors.go
Normal file
113
sqldb/sqlerrors.go
Normal file
@ -0,0 +1,113 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgerrcode"
|
||||
"modernc.org/sqlite"
|
||||
sqlite3 "modernc.org/sqlite/lib"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRetriesExceeded is returned when a transaction is retried more
|
||||
// than the max allowed valued without a success.
|
||||
ErrRetriesExceeded = errors.New("db tx retries exceeded")
|
||||
)
|
||||
|
||||
// MapSQLError attempts to interpret a given error as a database agnostic SQL
|
||||
// error.
|
||||
func MapSQLError(err error) error {
|
||||
// Attempt to interpret the error as a sqlite error.
|
||||
var sqliteErr *sqlite.Error
|
||||
if errors.As(err, &sqliteErr) {
|
||||
return parseSqliteError(sqliteErr)
|
||||
}
|
||||
|
||||
// Attempt to interpret the error as a postgres error.
|
||||
var pqErr *pgconn.PgError
|
||||
if errors.As(err, &pqErr) {
|
||||
return parsePostgresError(pqErr)
|
||||
}
|
||||
|
||||
// Return original error if it could not be classified as a database
|
||||
// specific error.
|
||||
return err
|
||||
}
|
||||
|
||||
// parsePostgresError attempts to parse a sqlite error as a database agnostic
|
||||
// SQL error.
|
||||
func parseSqliteError(sqliteErr *sqlite.Error) error {
|
||||
switch sqliteErr.Code() {
|
||||
// Handle unique constraint violation error.
|
||||
case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
|
||||
return &ErrSQLUniqueConstraintViolation{
|
||||
DBError: sqliteErr,
|
||||
}
|
||||
|
||||
// Database is currently busy, so we'll need to try again.
|
||||
case sqlite3.SQLITE_BUSY:
|
||||
return &ErrSerializationError{
|
||||
DBError: sqliteErr,
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown sqlite error: %w", sqliteErr)
|
||||
}
|
||||
}
|
||||
|
||||
// parsePostgresError attempts to parse a postgres error as a database agnostic
|
||||
// SQL error.
|
||||
func parsePostgresError(pqErr *pgconn.PgError) error {
|
||||
switch pqErr.Code {
|
||||
// Handle unique constraint violation error.
|
||||
case pgerrcode.UniqueViolation:
|
||||
return &ErrSQLUniqueConstraintViolation{
|
||||
DBError: pqErr,
|
||||
}
|
||||
|
||||
// Unable to serialize the transaction, so we'll need to try again.
|
||||
case pgerrcode.SerializationFailure:
|
||||
return &ErrSerializationError{
|
||||
DBError: pqErr,
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown postgres error: %w", pqErr)
|
||||
}
|
||||
}
|
||||
|
||||
// ErrSQLUniqueConstraintViolation is an error type which represents a database
|
||||
// agnostic SQL unique constraint violation.
|
||||
type ErrSQLUniqueConstraintViolation struct {
|
||||
DBError error
|
||||
}
|
||||
|
||||
func (e ErrSQLUniqueConstraintViolation) Error() string {
|
||||
return fmt.Sprintf("sql unique constraint violation: %v", e.DBError)
|
||||
}
|
||||
|
||||
// ErrSerializationError is an error type which represents a database agnostic
|
||||
// error that a transaction couldn't be serialized with other concurrent db
|
||||
// transactions.
|
||||
type ErrSerializationError struct {
|
||||
DBError error
|
||||
}
|
||||
|
||||
// Unwrap returns the wrapped error.
|
||||
func (e ErrSerializationError) Unwrap() error {
|
||||
return e.DBError
|
||||
}
|
||||
|
||||
// Error returns the error message.
|
||||
func (e ErrSerializationError) Error() string {
|
||||
return e.DBError.Error()
|
||||
}
|
||||
|
||||
// IsSerializationError returns true if the given error is a serialization
|
||||
// error.
|
||||
func IsSerializationError(err error) bool {
|
||||
var serializationError *ErrSerializationError
|
||||
return errors.As(err, &serializationError)
|
||||
}
|
176
sqldb/sqlite.go
Normal file
176
sqldb/sqlite.go
Normal file
@ -0,0 +1,176 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite"
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
"github.com/stretchr/testify/require"
|
||||
_ "modernc.org/sqlite" // Register relevant drivers.
|
||||
)
|
||||
|
||||
const (
|
||||
// sqliteOptionPrefix is the string prefix sqlite uses to set various
|
||||
// options. This is used in the following format:
|
||||
// * sqliteOptionPrefix || option_name = option_value.
|
||||
sqliteOptionPrefix = "_pragma"
|
||||
|
||||
// sqliteTxLockImmediate is a dsn option used to ensure that write
|
||||
// transactions are started immediately.
|
||||
sqliteTxLockImmediate = "_txlock=immediate"
|
||||
|
||||
// defaultMaxConns is the number of permitted active and idle
|
||||
// connections. We want to limit this so it isn't unlimited. We use the
|
||||
// same value for the number of idle connections as, this can speed up
|
||||
// queries given a new connection doesn't need to be established each
|
||||
// time.
|
||||
defaultMaxConns = 25
|
||||
|
||||
// connIdleLifetime is the amount of time a connection can be idle.
|
||||
connIdleLifetime = 5 * time.Minute
|
||||
)
|
||||
|
||||
// SqliteConfig holds all the config arguments needed to interact with our
|
||||
// sqlite DB.
|
||||
//
|
||||
//nolint:lll
|
||||
type SqliteConfig struct {
|
||||
// SkipMigrations if true, then all the tables will be created on start
|
||||
// up if they don't already exist.
|
||||
SkipMigrations bool `long:"skipmigrations" description:"Skip applying migrations on startup."`
|
||||
|
||||
// DatabaseFileName is the full file path where the database file can be
|
||||
// found.
|
||||
DatabaseFileName string `long:"dbfile" description:"The full path to the database."`
|
||||
}
|
||||
|
||||
// SqliteStore is a database store implementation that uses a sqlite backend.
|
||||
type SqliteStore struct {
|
||||
cfg *SqliteConfig
|
||||
|
||||
*BaseDB
|
||||
}
|
||||
|
||||
// NewSqliteStore attempts to open a new sqlite database based on the passed
|
||||
// config.
|
||||
func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) {
|
||||
// The set of pragma options are accepted using query options. For now
|
||||
// we only want to ensure that foreign key constraints are properly
|
||||
// enforced.
|
||||
pragmaOptions := []struct {
|
||||
name string
|
||||
value string
|
||||
}{
|
||||
{
|
||||
name: "foreign_keys",
|
||||
value: "on",
|
||||
},
|
||||
{
|
||||
name: "journal_mode",
|
||||
value: "WAL",
|
||||
},
|
||||
{
|
||||
name: "busy_timeout",
|
||||
value: "5000",
|
||||
},
|
||||
{
|
||||
// With the WAL mode, this ensures that we also do an
|
||||
// extra WAL sync after each transaction. The normal
|
||||
// sync mode skips this and gives better performance,
|
||||
// but risks durability.
|
||||
name: "synchronous",
|
||||
value: "full",
|
||||
},
|
||||
{
|
||||
// This is used to ensure proper durability for users
|
||||
// running on Mac OS. It uses the correct fsync system
|
||||
// call to ensure items are fully flushed to disk.
|
||||
name: "fullfsync",
|
||||
value: "true",
|
||||
},
|
||||
}
|
||||
sqliteOptions := make(url.Values)
|
||||
for _, option := range pragmaOptions {
|
||||
sqliteOptions.Add(
|
||||
sqliteOptionPrefix,
|
||||
fmt.Sprintf("%v=%v", option.name, option.value),
|
||||
)
|
||||
}
|
||||
|
||||
// Construct the DSN which is just the database file name, appended
|
||||
// with the series of pragma options as a query URL string. For more
|
||||
// details on the formatting here, see the modernc.org/sqlite docs:
|
||||
// https://pkg.go.dev/modernc.org/sqlite#Driver.Open.
|
||||
dsn := fmt.Sprintf(
|
||||
"%v?%v&%v", cfg.DatabaseFileName, sqliteOptions.Encode(),
|
||||
sqliteTxLockImmediate,
|
||||
)
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(defaultMaxConns)
|
||||
db.SetMaxIdleConns(defaultMaxConns)
|
||||
db.SetConnMaxLifetime(connIdleLifetime)
|
||||
|
||||
if !cfg.SkipMigrations {
|
||||
// Now that the database is open, populate the database with
|
||||
// our set of schemas based on our embedded in-memory file
|
||||
// system.
|
||||
//
|
||||
// First, we'll need to open up a new migration instance for
|
||||
// our current target database: sqlite.
|
||||
driver, err := sqlite_migrate.WithInstance(
|
||||
db, &sqlite_migrate.Config{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = applyMigrations(
|
||||
sqlSchemas, driver, "sqlc/migrations", "sqlc",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
queries := sqlc.New(db)
|
||||
|
||||
return &SqliteStore{
|
||||
cfg: cfg,
|
||||
BaseDB: &BaseDB{
|
||||
DB: db,
|
||||
Queries: queries,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewTestSqliteDB is a helper function that creates an SQLite database for
|
||||
// testing.
|
||||
func NewTestSqliteDB(t *testing.T) *SqliteStore {
|
||||
t.Helper()
|
||||
|
||||
t.Logf("Creating new SQLite DB for testing")
|
||||
|
||||
// TODO(roasbeef): if we pass :memory: for the file name, then we get
|
||||
// an in mem version to speed up tests
|
||||
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
||||
sqlDB, err := NewSqliteStore(&SqliteConfig{
|
||||
DatabaseFileName: dbFileName,
|
||||
SkipMigrations: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, sqlDB.DB.Close())
|
||||
})
|
||||
|
||||
return sqlDB
|
||||
}
|
13
sqldb/sqlite_test.go
Normal file
13
sqldb/sqlite_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
//go:build !test_db_postgres
|
||||
// +build !test_db_postgres
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// NewTestDB is a helper function that creates an SQLite database for testing.
|
||||
func NewTestDB(t *testing.T) *SqliteStore {
|
||||
return NewTestSqliteDB(t)
|
||||
}
|
Reference in New Issue
Block a user