kvdb: add postgres

This commit is contained in:
Joost Jager
2021-07-13 12:07:48 +02:00
parent 9264185f5b
commit 3eb80cac97
19 changed files with 2055 additions and 10 deletions

9
kvdb/postgres/config.go Normal file
View File

@@ -0,0 +1,9 @@
package postgres
import "time"
// Config holds postgres configuration data.
type Config struct {
Dsn string `long:"dsn" description:"Database connection string."`
Timeout time.Duration `long:"timeout" description:"Database connection timeout. Set to zero to disable."`
}

241
kvdb/postgres/db.go Normal file
View File

@@ -0,0 +1,241 @@
// +build kvdb_postgres
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/btcsuite/btcwallet/walletdb"
_ "github.com/jackc/pgx/v4/stdlib"
)
const (
// kvTableName is the name of the table that will contain all the kv
// pairs.
kvTableName = "kv"
)
// KV stores a key/value pair.
type KV struct {
key string
val string
}
// db holds a reference to the postgres connection connection.
type db struct {
// cfg is the postgres connection config.
cfg *Config
// prefix is the table name prefix that is used to simulate namespaces.
// We don't use schemas because at least sqlite does not support that.
prefix string
// ctx is the overall context for the database driver.
//
// TODO: This is an anti-pattern that is in place until the kvdb
// interface supports a context.
ctx context.Context
// db is the underlying database connection instance.
db *sql.DB
// lock is the global write lock that ensures single writer.
lock sync.RWMutex
// table is the name of the table that contains the data for all
// top-level buckets that have keys that cannot be mapped to a distinct
// sql table.
table string
}
// Enforce db implements the walletdb.DB interface.
var _ walletdb.DB = (*db)(nil)
// newPostgresBackend returns a db object initialized with the passed backend
// config. If postgres connection cannot be estabished, then returns error.
func newPostgresBackend(ctx context.Context, config *Config, prefix string) (
*db, error) {
if prefix == "" {
return nil, errors.New("empty postgres prefix")
}
dbConn, err := sql.Open("pgx", config.Dsn)
if err != nil {
return nil, err
}
// Compose system table names.
table := fmt.Sprintf(
"%s_%s", prefix, kvTableName,
)
// Execute the create statements to set up a kv table in postgres. Every
// row points to the bucket that it is one via its parent_id field. A
// NULL parent_id means that the key belongs to the upper-most bucket in
// this table. A constraint on parent_id is enforcing referential
// integrity.
//
// Furthermore there is a <table>_p index on parent_id that is required
// for the foreign key constraint.
//
// Finally there are unique indices on (parent_id, key) to prevent the
// same key being present in a bucket more than once (<table>_up and
// <table>_unp). In postgres, a single index wouldn't enforce the unique
// constraint on rows with a NULL parent_id. Therefore two indices are
// defined.
_, err = dbConn.ExecContext(ctx, `
CREATE SCHEMA IF NOT EXISTS public;
CREATE TABLE IF NOT EXISTS public.`+table+`
(
key bytea NOT NULL,
value bytea,
parent_id bigint,
id bigserial PRIMARY KEY,
sequence bigint,
CONSTRAINT `+table+`_parent FOREIGN KEY (parent_id)
REFERENCES public.`+table+` (id)
ON UPDATE NO ACTION
ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS `+table+`_p
ON public.`+table+` (parent_id);
CREATE UNIQUE INDEX IF NOT EXISTS `+table+`_up
ON public.`+table+`
(parent_id, key) WHERE parent_id IS NOT NULL;
CREATE UNIQUE INDEX IF NOT EXISTS `+table+`_unp
ON public.`+table+` (key) WHERE parent_id IS NULL;
`)
if err != nil {
_ = dbConn.Close()
return nil, err
}
backend := &db{
cfg: config,
prefix: prefix,
ctx: ctx,
db: dbConn,
table: table,
}
return backend, nil
}
// getTimeoutCtx gets a timeout context for database requests.
func (db *db) getTimeoutCtx() (context.Context, func()) {
if db.cfg.Timeout == time.Duration(0) {
return db.ctx, func() {}
}
return context.WithTimeout(db.ctx, db.cfg.Timeout)
}
// getPrefixedTableName returns a table name for this prefix (namespace).
func (db *db) getPrefixedTableName(table string) string {
return fmt.Sprintf("%s_%s", db.prefix, table)
}
// catchPanic executes the specified function. If a panic occurs, it is returned
// as an error value.
func catchPanic(f func() error) (err error) {
defer func() {
if r := recover(); r != nil {
err = r.(error)
log.Criticalf("Caught unhandled error: %v", err)
}
}()
err = f()
return
}
// View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any
// occur). The passed reset function is called before the start of the
// transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error {
return db.executeTransaction(
func(tx walletdb.ReadWriteTx) error {
return f(tx.(walletdb.ReadTx))
},
reset, true,
)
}
// Update opens a database read/write transaction and executes the function f
// with the transaction passed as a parameter. After f exits, if f did not
// error, the transaction is committed. Otherwise, if f did error, the
// transaction is rolled back. If the rollback fails, the original error
// returned by f is still returned. If the commit fails, the commit error is
// returned. As callers may expect retries of the f closure, the reset function
// will be called before each retry respectively.
func (db *db) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) (err error) {
return db.executeTransaction(f, reset, false)
}
// executeTransaction creates a new read-only or read-write transaction and
// executes the given function within it.
func (db *db) executeTransaction(f func(tx walletdb.ReadWriteTx) error,
reset func(), readOnly bool) error {
reset()
tx, err := newReadWriteTx(db, readOnly)
if err != nil {
return err
}
err = catchPanic(func() error { return f(tx) })
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
log.Errorf("Error rolling back tx: %v", rollbackErr)
}
return err
}
return tx.Commit()
}
// PrintStats returns all collected stats pretty printed into a string.
func (db *db) PrintStats() string {
return "stats not supported by Postgres driver"
}
// BeginReadWriteTx opens a database read+write transaction.
func (db *db) BeginReadWriteTx() (walletdb.ReadWriteTx, error) {
return newReadWriteTx(db, false)
}
// BeginReadTx opens a database read transaction.
func (db *db) BeginReadTx() (walletdb.ReadTx, error) {
return newReadWriteTx(db, true)
}
// Copy writes a copy of the database to the provided writer. This call will
// start a read-only transaction to perform all operations.
// This function is part of the walletdb.Db interface implementation.
func (db *db) Copy(w io.Writer) error {
return errors.New("not implemented")
}
// Close cleanly shuts down the database and syncs all data.
// This function is part of the walletdb.Db interface implementation.
func (db *db) Close() error {
return db.db.Close()
}

56
kvdb/postgres/db_test.go Normal file
View File

@@ -0,0 +1,56 @@
// +build kvdb_postgres
package postgres
import (
"testing"
"time"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/btcsuite/btcwallet/walletdb/walletdbtest"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
)
// TestInterface performs all interfaces tests for this database driver.
func TestInterface(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()
// dbType is the database type name for this driver.
const dbType = "postgres"
ctx := context.Background()
cfg := &Config{
Dsn: testDsn,
}
walletdbtest.TestInterface(t, dbType, ctx, cfg, prefix)
}
// TestPanic tests recovery from panic conditions.
func TestPanic(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()
d := f.NewBackend()
err := d.(*db).Update(func(tx walletdb.ReadWriteTx) error {
bucket, err := tx.CreateTopLevelBucket([]byte("test"))
require.NoError(t, err)
// Stop database server.
f.Cleanup()
// Keep trying to get data until Get panics because the
// connection is lost.
for i := 0; i < 50; i++ {
bucket.Get([]byte("key"))
time.Sleep(100 * time.Millisecond)
}
return nil
}, func() {})
require.Contains(t, err.Error(), "terminating connection")
}

86
kvdb/postgres/driver.go Normal file
View File

@@ -0,0 +1,86 @@
// +build kvdb_postgres
package postgres
import (
"context"
"fmt"
"github.com/btcsuite/btcwallet/walletdb"
)
const (
dbType = "postgres"
)
// parseArgs parses the arguments from the walletdb Open/Create methods.
func parseArgs(funcName string, args ...interface{}) (context.Context,
*Config, string, error) {
if len(args) != 3 {
return nil, nil, "", fmt.Errorf("invalid number of arguments "+
"to %s.%s -- expected: context.Context, "+
"postgres.Config, string", dbType, funcName,
)
}
ctx, ok := args[0].(context.Context)
if !ok {
return nil, nil, "", fmt.Errorf("argument 0 to %s.%s is "+
"invalid -- expected: context.Context",
dbType, funcName,
)
}
config, ok := args[1].(*Config)
if !ok {
return nil, nil, "", fmt.Errorf("argument 1 to %s.%s is "+
"invalid -- expected: postgres.Config",
dbType, funcName,
)
}
prefix, ok := args[2].(string)
if !ok {
return nil, nil, "", fmt.Errorf("argument 2 to %s.%s is "+
"invalid -- expected string", dbType,
funcName)
}
return ctx, config, prefix, nil
}
// createDBDriver is the callback provided during driver registration that
// creates, initializes, and opens a database for use.
func createDBDriver(args ...interface{}) (walletdb.DB, error) {
ctx, config, prefix, err := parseArgs("Create", args...)
if err != nil {
return nil, err
}
return newPostgresBackend(ctx, config, prefix)
}
// openDBDriver is the callback provided during driver registration that opens
// an existing database for use.
func openDBDriver(args ...interface{}) (walletdb.DB, error) {
ctx, config, prefix, err := parseArgs("Open", args...)
if err != nil {
return nil, err
}
return newPostgresBackend(ctx, config, prefix)
}
func init() {
// Register the driver.
driver := walletdb.Driver{
DbType: dbType,
Create: createDBDriver,
Open: openDBDriver,
}
if err := walletdb.RegisterDriver(driver); err != nil {
panic(fmt.Sprintf("Failed to regiser database driver '%s': %v",
dbType, err))
}
}

131
kvdb/postgres/fixture.go Normal file
View File

@@ -0,0 +1,131 @@
// +build kvdb_postgres
package postgres
import (
"context"
"database/sql"
"testing"
"github.com/btcsuite/btcwallet/walletdb"
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
"github.com/stretchr/testify/require"
)
const (
testDsn = "postgres://postgres:postgres@localhost:9876/postgres?sslmode=disable"
prefix = "test"
)
func clearTestDb(t *testing.T) {
dbConn, err := sql.Open("pgx", testDsn)
require.NoError(t, err)
_, err = dbConn.ExecContext(context.Background(), "DROP SCHEMA IF EXISTS public CASCADE;")
require.NoError(t, err)
}
func openTestDb(t *testing.T) *db {
clearTestDb(t)
db, err := newPostgresBackend(
context.Background(),
&Config{
Dsn: testDsn,
},
prefix,
)
require.NoError(t, err)
return db
}
type fixture struct {
t *testing.T
tempDir string
postgres *embeddedpostgres.EmbeddedPostgres
}
func NewFixture(t *testing.T) *fixture {
postgres := embeddedpostgres.NewDatabase(
embeddedpostgres.DefaultConfig().
Port(9876))
err := postgres.Start()
require.NoError(t, err)
return &fixture{
t: t,
postgres: postgres,
}
}
func (b *fixture) Cleanup() {
b.postgres.Stop()
}
func (b *fixture) NewBackend() walletdb.DB {
clearTestDb(b.t)
db := openTestDb(b.t)
return db
}
func (b *fixture) Dump() map[string]interface{} {
dbConn, err := sql.Open("pgx", testDsn)
require.NoError(b.t, err)
rows, err := dbConn.Query(
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'",
)
require.NoError(b.t, err)
var tables []string
for rows.Next() {
var table string
err := rows.Scan(&table)
require.NoError(b.t, err)
tables = append(tables, table)
}
result := make(map[string]interface{})
for _, table := range tables {
rows, err := dbConn.Query("SELECT * FROM " + table)
require.NoError(b.t, err)
cols, err := rows.Columns()
require.NoError(b.t, err)
colCount := len(cols)
var tableRows []map[string]interface{}
for rows.Next() {
values := make([]interface{}, colCount)
valuePtrs := make([]interface{}, colCount)
for i := range values {
valuePtrs[i] = &values[i]
}
err := rows.Scan(valuePtrs...)
require.NoError(b.t, err)
tableData := make(map[string]interface{})
for i, v := range values {
// Cast byte slices to string to keep the
// expected database contents in test code more
// readable.
if ar, ok := v.([]uint8); ok {
v = string(ar)
}
tableData[cols[i]] = v
}
tableRows = append(tableRows, tableData)
}
result[table] = tableRows
}
return result
}

12
kvdb/postgres/log.go Normal file
View File

@@ -0,0 +1,12 @@
package postgres
import "github.com/btcsuite/btclog"
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}

View File

@@ -0,0 +1,412 @@
// +build kvdb_postgres
package postgres
import (
"database/sql"
"errors"
"fmt"
"github.com/btcsuite/btcwallet/walletdb"
)
// readWriteBucket stores the bucket id and the buckets transaction.
type readWriteBucket struct {
// id is used to identify the bucket. If id is null, it refers to the
// root bucket.
id *int64
// tx holds the parent transaction.
tx *readWriteTx
table string
}
// newReadWriteBucket creates a new rw bucket with the passed transaction
// and bucket id.
func newReadWriteBucket(tx *readWriteTx, id *int64) *readWriteBucket {
return &readWriteBucket{
id: id,
tx: tx,
table: tx.db.table,
}
}
// NestedReadBucket retrieves a nested read bucket with the given key.
// Returns nil if the bucket does not exist.
func (b *readWriteBucket) NestedReadBucket(key []byte) walletdb.ReadBucket {
return b.NestedReadWriteBucket(key)
}
func parentSelector(id *int64) string {
if id == nil {
return "parent_id IS NULL"
}
return fmt.Sprintf("parent_id=%v", *id)
}
// ForEach invokes the passed function with every key/value pair in
// the bucket. This includes nested buckets, in which case the value
// is nil, but it does not include the key/value pairs within those
// nested buckets.
func (b *readWriteBucket) ForEach(cb func(k, v []byte) error) error {
cursor := b.ReadWriteCursor()
k, v := cursor.First()
for k != nil {
err := cb(k, v)
if err != nil {
return err
}
k, v = cursor.Next()
}
return nil
}
// Get returns the value for the given key. Returns nil if the key does
// not exist in this bucket.
func (b *readWriteBucket) Get(key []byte) []byte {
// Return nil if the key is empty.
if len(key) == 0 {
return nil
}
var value *[]byte
err := b.tx.QueryRow(
"SELECT value FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1", key,
).Scan(&value)
switch {
case err == sql.ErrNoRows:
return nil
case err != nil:
panic(err)
}
return *value
}
// ReadCursor returns a new read-only cursor for this bucket.
func (b *readWriteBucket) ReadCursor() walletdb.ReadCursor {
return newReadWriteCursor(b)
}
// NestedReadWriteBucket retrieves a nested bucket with the given key.
// Returns nil if the bucket does not exist.
func (b *readWriteBucket) NestedReadWriteBucket(
key []byte) walletdb.ReadWriteBucket {
if len(key) == 0 {
return nil
}
var id int64
err := b.tx.QueryRow(
"SELECT id FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1 AND value IS NULL", key,
).Scan(&id)
switch {
case err == sql.ErrNoRows:
return nil
case err != nil:
panic(err)
}
return newReadWriteBucket(b.tx, &id)
}
// CreateBucket creates and returns a new nested bucket with the given key.
// Returns ErrBucketExists if the bucket already exists, ErrBucketNameRequired
// if the key is empty, or ErrIncompatibleValue if the key value is otherwise
// invalid for the particular database implementation. Other errors are
// possible depending on the implementation.
func (b *readWriteBucket) CreateBucket(key []byte) (
walletdb.ReadWriteBucket, error) {
if len(key) == 0 {
return nil, walletdb.ErrBucketNameRequired
}
// Check to see if the bucket already exists.
var (
value *[]byte
id int64
)
err := b.tx.QueryRow(
"SELECT id,value FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1", key,
).Scan(&id, &value)
switch {
case err == sql.ErrNoRows:
case err == nil && value == nil:
return nil, walletdb.ErrBucketExists
case err == nil && value != nil:
return nil, walletdb.ErrIncompatibleValue
case err != nil:
return nil, err
}
// Bucket does not yet exist, so create it. Postgres will generate a
// bucket id for the new bucket.
err = b.tx.QueryRow(
"INSERT INTO "+b.table+" (parent_id, key) "+
"VALUES($1, $2) RETURNING id", b.id, key,
).Scan(&id)
if err != nil {
return nil, err
}
return newReadWriteBucket(b.tx, &id), nil
}
// CreateBucketIfNotExists creates and returns a new nested bucket with
// the given key if it does not already exist. Returns
// ErrBucketNameRequired if the key is empty or ErrIncompatibleValue
// if the key value is otherwise invalid for the particular database
// backend. Other errors are possible depending on the implementation.
func (b *readWriteBucket) CreateBucketIfNotExists(key []byte) (
walletdb.ReadWriteBucket, error) {
if len(key) == 0 {
return nil, walletdb.ErrBucketNameRequired
}
// Check to see if the bucket already exists.
var (
value *[]byte
id int64
)
err := b.tx.QueryRow(
"SELECT id,value FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1", key,
).Scan(&id, &value)
switch {
// Bucket does not yet exist, so create it now. Postgres will generate a
// bucket id for the new bucket.
case err == sql.ErrNoRows:
err = b.tx.QueryRow(
"INSERT INTO "+b.table+" (parent_id, key) "+
"VALUES($1, $2) RETURNING id", b.id, key).
Scan(&id)
if err != nil {
return nil, err
}
case err == nil && value != nil:
return nil, walletdb.ErrIncompatibleValue
case err != nil:
return nil, err
}
return newReadWriteBucket(b.tx, &id), nil
}
// DeleteNestedBucket deletes the nested bucket and its sub-buckets
// pointed to by the passed key. All values in the bucket and sub-buckets
// will be deleted as well.
func (b *readWriteBucket) DeleteNestedBucket(key []byte) error {
if len(key) == 0 {
return walletdb.ErrIncompatibleValue
}
result, err := b.tx.Exec(
"DELETE FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1 AND value IS NULL",
key,
)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return walletdb.ErrBucketNotFound
}
return nil
}
// Put updates the value for the passed key.
// Returns ErrKeyRequired if te passed key is empty.
func (b *readWriteBucket) Put(key, value []byte) error {
if len(key) == 0 {
return walletdb.ErrKeyRequired
}
// Prevent NULL being written for an empty value slice.
if value == nil {
value = []byte{}
}
var (
result sql.Result
err error
)
// We are putting a value in a bucket in this table. Try to insert the
// key first. If the key already exists (ON CONFLICT), update the key.
// Do not update a NULL value, because this indicates that the key
// contains a sub-bucket. This case will be caught via RowsAffected
// below.
if b.id == nil {
// ON CONFLICT requires the WHERE parent_id IS NULL hint to let
// Postgres find the NULL-parent_id unique index (<table>_unp).
result, err = b.tx.Exec(
"INSERT INTO "+b.table+" (key, value) VALUES($1, $2) "+
"ON CONFLICT (key) WHERE parent_id IS NULL "+
"DO UPDATE SET value=$2 "+
"WHERE "+b.table+".value IS NOT NULL",
key, value,
)
} else {
// ON CONFLICT requires the WHERE parent_id NOT IS NULL hint to
// let Postgres find the non-NULL-parent_id unique index
// (<table>_up).
result, err = b.tx.Exec(
"INSERT INTO "+b.table+" (key, value, parent_id) "+
"VALUES($1, $2, $3) "+
"ON CONFLICT (key, parent_id) "+
"WHERE parent_id IS NOT NULL "+
"DO UPDATE SET value=$2 "+
"WHERE "+b.table+".value IS NOT NULL",
key, value, b.id,
)
}
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return walletdb.ErrIncompatibleValue
}
return nil
}
// Delete deletes the key/value pointed to by the passed key.
// Returns ErrKeyRequired if the passed key is empty.
func (b *readWriteBucket) Delete(key []byte) error {
if key == nil {
return nil
}
if len(key) == 0 {
return walletdb.ErrKeyRequired
}
// Check to see if a bucket with this key exists.
var dummy int
err := b.tx.QueryRow(
"SELECT 1 FROM "+b.table+" WHERE "+parentSelector(b.id)+
" AND key=$1 AND value IS NULL", key,
).Scan(&dummy)
switch {
// No bucket exists, proceed to deletion of the key.
case err == sql.ErrNoRows:
case err != nil:
return err
// Bucket exists.
default:
return walletdb.ErrIncompatibleValue
}
_, err = b.tx.Exec(
"DELETE FROM "+b.table+" WHERE key=$1 AND "+
parentSelector(b.id)+" AND value IS NOT NULL",
key,
)
if err != nil {
return err
}
return nil
}
// ReadWriteCursor returns a new read-write cursor for this bucket.
func (b *readWriteBucket) ReadWriteCursor() walletdb.ReadWriteCursor {
return newReadWriteCursor(b)
}
// Tx returns the buckets transaction.
func (b *readWriteBucket) Tx() walletdb.ReadWriteTx {
return b.tx
}
// NextSequence returns an autoincrementing sequence number for this bucket.
// Note that this is not a thread safe function and as such it must not be used
// for synchronization.
func (b *readWriteBucket) NextSequence() (uint64, error) {
seq := b.Sequence() + 1
return seq, b.SetSequence(seq)
}
// SetSequence updates the sequence number for the bucket.
func (b *readWriteBucket) SetSequence(v uint64) error {
if b.id == nil {
panic("sequence not supported on top level bucket")
}
result, err := b.tx.Exec(
"UPDATE "+b.table+" SET sequence=$2 WHERE id=$1",
b.id, int64(v),
)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.New("cannot set sequence")
}
return nil
}
// Sequence returns the current sequence number for this bucket without
// incrementing it.
func (b *readWriteBucket) Sequence() uint64 {
if b.id == nil {
panic("sequence not supported on top level bucket")
}
var seq int64
err := b.tx.QueryRow(
"SELECT sequence FROM "+b.table+" WHERE id=$1 "+
"AND sequence IS NOT NULL",
b.id,
).Scan(&seq)
switch {
case err == sql.ErrNoRows:
return 0
case err != nil:
panic(err)
}
return uint64(seq)
}

View File

@@ -0,0 +1,220 @@
// +build kvdb_postgres
package postgres
import (
"database/sql"
"github.com/btcsuite/btcwallet/walletdb"
)
// readWriteCursor holds a reference to the cursors bucket, the value
// prefix and the current key used while iterating.
type readWriteCursor struct {
bucket *readWriteBucket
// currKey holds the current key of the cursor.
currKey []byte
}
func newReadWriteCursor(b *readWriteBucket) *readWriteCursor {
return &readWriteCursor{
bucket: b,
}
}
// First positions the cursor at the first key/value pair and returns
// the pair.
func (c *readWriteCursor) First() ([]byte, []byte) {
var (
key []byte
value []byte
)
err := c.bucket.tx.QueryRow(
"SELECT key, value FROM "+c.bucket.table+" WHERE "+
parentSelector(c.bucket.id)+
" ORDER BY key LIMIT 1",
).Scan(&key, &value)
switch {
case err == sql.ErrNoRows:
return nil, nil
case err != nil:
panic(err)
}
// Copy current key to prevent modification by the caller.
c.currKey = make([]byte, len(key))
copy(c.currKey, key)
return key, value
}
// Last positions the cursor at the last key/value pair and returns the
// pair.
func (c *readWriteCursor) Last() ([]byte, []byte) {
var (
key []byte
value []byte
)
err := c.bucket.tx.QueryRow(
"SELECT key, value FROM "+c.bucket.table+" WHERE "+
parentSelector(c.bucket.id)+
" ORDER BY key DESC LIMIT 1",
).Scan(&key, &value)
switch {
case err == sql.ErrNoRows:
return nil, nil
case err != nil:
panic(err)
}
// Copy current key to prevent modification by the caller.
c.currKey = make([]byte, len(key))
copy(c.currKey, key)
return key, value
}
// Next moves the cursor one key/value pair forward and returns the new
// pair.
func (c *readWriteCursor) Next() ([]byte, []byte) {
var (
key []byte
value []byte
)
err := c.bucket.tx.QueryRow(
"SELECT key, value FROM "+c.bucket.table+" WHERE "+
parentSelector(c.bucket.id)+
" AND key>$1 ORDER BY key LIMIT 1",
c.currKey,
).Scan(&key, &value)
switch {
case err == sql.ErrNoRows:
return nil, nil
case err != nil:
panic(err)
}
// Copy current key to prevent modification by the caller.
c.currKey = make([]byte, len(key))
copy(c.currKey, key)
return key, value
}
// Prev moves the cursor one key/value pair backward and returns the new
// pair.
func (c *readWriteCursor) Prev() ([]byte, []byte) {
var (
key []byte
value []byte
)
err := c.bucket.tx.QueryRow(
"SELECT key, value FROM "+c.bucket.table+" WHERE "+
parentSelector(c.bucket.id)+
" AND key<$1 ORDER BY key DESC LIMIT 1",
c.currKey,
).Scan(&key, &value)
switch {
case err == sql.ErrNoRows:
return nil, nil
case err != nil:
panic(err)
}
// Copy current key to prevent modification by the caller.
c.currKey = make([]byte, len(key))
copy(c.currKey, key)
return key, value
}
// Seek positions the cursor at the passed seek key. If the key does
// not exist, the cursor is moved to the next key after seek. Returns
// the new pair.
func (c *readWriteCursor) Seek(seek []byte) ([]byte, []byte) {
// Convert nil to empty slice, otherwise sql mapping won't be correct
// and no keys are found.
if seek == nil {
seek = []byte{}
}
var (
key []byte
value []byte
)
err := c.bucket.tx.QueryRow(
"SELECT key, value FROM "+c.bucket.table+" WHERE "+
parentSelector(c.bucket.id)+
" AND key>=$1 ORDER BY key LIMIT 1",
seek,
).Scan(&key, &value)
switch {
case err == sql.ErrNoRows:
return nil, nil
case err != nil:
panic(err)
}
// Copy current key to prevent modification by the caller.
c.currKey = make([]byte, len(key))
copy(c.currKey, key)
return key, value
}
// Delete removes the current key/value pair the cursor is at without
// invalidating the cursor. Returns ErrIncompatibleValue if attempted
// when the cursor points to a nested bucket.
func (c *readWriteCursor) Delete() error {
// Get first record at or after cursor.
var key []byte
err := c.bucket.tx.QueryRow(
"SELECT key FROM "+c.bucket.table+" WHERE "+
parentSelector(c.bucket.id)+
" AND key>=$1 ORDER BY key LIMIT 1",
c.currKey,
).Scan(&key)
switch {
case err == sql.ErrNoRows:
return nil
case err != nil:
panic(err)
}
// Delete record.
result, err := c.bucket.tx.Exec(
"DELETE FROM "+c.bucket.table+" WHERE "+
parentSelector(c.bucket.id)+
" AND key=$1 AND value IS NOT NULL",
key,
)
if err != nil {
panic(err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
// The key exists but nothing has been deleted. This means that the key
// must have been a bucket key.
if rows != 1 {
return walletdb.ErrIncompatibleValue
}
return err
}

View File

@@ -0,0 +1,185 @@
// +build kvdb_postgres
package postgres
import (
"database/sql"
"sync"
"github.com/btcsuite/btcwallet/walletdb"
)
// readWriteTx holds a reference to an open postgres transaction.
type readWriteTx struct {
db *db
tx *sql.Tx
// onCommit gets called upon commit.
onCommit func()
// active is true if the transaction hasn't been committed yet.
active bool
// locker is a pointer to the global db lock.
locker sync.Locker
}
// newReadWriteTx creates an rw transaction using a connection from the
// specified pool.
func newReadWriteTx(db *db, readOnly bool) (*readWriteTx, error) {
// Obtain the global lock instance. An alternative here is to obtain a
// database lock from Postgres. Unfortunately there is no database-level
// lock in Postgres, meaning that each table would need to be locked
// individually. Perhaps an advisory lock could perform this function
// too.
var locker sync.Locker = &db.lock
if readOnly {
locker = db.lock.RLocker()
}
locker.Lock()
tx, err := db.db.Begin()
if err != nil {
locker.Unlock()
return nil, err
}
return &readWriteTx{
db: db,
tx: tx,
active: true,
locker: locker,
}, nil
}
// ReadBucket opens the root bucket for read only access. If the bucket
// described by the key does not exist, nil is returned.
func (tx *readWriteTx) ReadBucket(key []byte) walletdb.ReadBucket {
return tx.ReadWriteBucket(key)
}
// ForEachBucket iterates through all top level buckets.
func (tx *readWriteTx) ForEachBucket(fn func(key []byte) error) error {
// Fetch binary top level buckets.
bucket := newReadWriteBucket(tx, nil)
err := bucket.ForEach(func(k, _ []byte) error {
return fn(k)
})
return err
}
// Rollback closes the transaction, discarding changes (if any) if the
// database was modified by a write transaction.
func (tx *readWriteTx) Rollback() error {
// If the transaction has been closed roolback will fail.
if !tx.active {
return walletdb.ErrTxClosed
}
err := tx.tx.Rollback()
// Unlock the transaction regardless of the error result.
tx.active = false
tx.locker.Unlock()
return err
}
// ReadWriteBucket opens the root bucket for read/write access. If the
// bucket described by the key does not exist, nil is returned.
func (tx *readWriteTx) ReadWriteBucket(key []byte) walletdb.ReadWriteBucket {
if len(key) == 0 {
return nil
}
bucket := newReadWriteBucket(tx, nil)
return bucket.NestedReadWriteBucket(key)
}
// CreateTopLevelBucket creates the top level bucket for a key if it
// does not exist. The newly-created bucket it returned.
func (tx *readWriteTx) CreateTopLevelBucket(key []byte) (walletdb.ReadWriteBucket, error) {
if len(key) == 0 {
return nil, walletdb.ErrBucketNameRequired
}
bucket := newReadWriteBucket(tx, nil)
return bucket.CreateBucketIfNotExists(key)
}
// DeleteTopLevelBucket deletes the top level bucket for a key. This
// errors if the bucket can not be found or the key keys a single value
// instead of a bucket.
func (tx *readWriteTx) DeleteTopLevelBucket(key []byte) error {
// Execute a cascading delete on the key.
result, err := tx.Exec(
"DELETE FROM "+tx.db.table+" WHERE key=$1 "+
"AND parent_id IS NULL",
key,
)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return walletdb.ErrBucketNotFound
}
return nil
}
// Commit commits the transaction if not already committed.
func (tx *readWriteTx) Commit() error {
// Commit will fail if the transaction is already committed.
if !tx.active {
return walletdb.ErrTxClosed
}
// Try committing the transaction.
err := tx.tx.Commit()
if err == nil && tx.onCommit != nil {
tx.onCommit()
}
// Unlock the transaction regardless of the error result.
tx.active = false
tx.locker.Unlock()
return err
}
// OnCommit sets the commit callback (overriding if already set).
func (tx *readWriteTx) OnCommit(cb func()) {
tx.onCommit = cb
}
// QueryRow executes a QueryRow call with a timeout context.
func (tx *readWriteTx) QueryRow(query string, args ...interface{}) *sql.Row {
ctx, cancel := tx.db.getTimeoutCtx()
defer cancel()
return tx.tx.QueryRowContext(ctx, query, args...)
}
// Query executes a Query call with a timeout context.
func (tx *readWriteTx) Query(query string, args ...interface{}) (*sql.Rows,
error) {
ctx, cancel := tx.db.getTimeoutCtx()
defer cancel()
return tx.tx.QueryContext(ctx, query, args...)
}
// Exec executes a Exec call with a timeout context.
func (tx *readWriteTx) Exec(query string, args ...interface{}) (sql.Result,
error) {
ctx, cancel := tx.db.getTimeoutCtx()
defer cancel()
return tx.tx.ExecContext(ctx, query, args...)
}