mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-05 17:05:50 +02:00
kvdb: move channeldb/kvdb to top level
This commit is contained in:
92
kvdb/etcd/bucket.go
Normal file
92
kvdb/etcd/bucket.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
)
|
||||
|
||||
const (
|
||||
bucketIDLength = 32
|
||||
)
|
||||
|
||||
var (
|
||||
valuePostfix = []byte{0x00}
|
||||
bucketPostfix = []byte{0xFF}
|
||||
sequencePrefix = []byte("$seq$")
|
||||
)
|
||||
|
||||
// makeBucketID returns a deterministic key for the passed byte slice.
|
||||
// Currently it returns the sha256 hash of the slice.
|
||||
func makeBucketID(key []byte) [bucketIDLength]byte {
|
||||
return sha256.Sum256(key)
|
||||
}
|
||||
|
||||
// isValidBucketID checks if the passed slice is the required length to be a
|
||||
// valid bucket id.
|
||||
func isValidBucketID(s []byte) bool {
|
||||
return len(s) == bucketIDLength
|
||||
}
|
||||
|
||||
// makeKey concatenates parent, key and postfix into one byte slice.
|
||||
// The postfix indicates the use of this key (whether bucket or value), while
|
||||
// parent refers to the parent bucket.
|
||||
func makeKey(parent, key, postfix []byte) []byte {
|
||||
keyBuf := make([]byte, len(parent)+len(key)+len(postfix))
|
||||
copy(keyBuf, parent)
|
||||
copy(keyBuf[len(parent):], key)
|
||||
copy(keyBuf[len(parent)+len(key):], postfix)
|
||||
|
||||
return keyBuf
|
||||
}
|
||||
|
||||
// makeBucketKey returns a bucket key from the passed parent bucket id and
|
||||
// the key.
|
||||
func makeBucketKey(parent []byte, key []byte) []byte {
|
||||
return makeKey(parent, key, bucketPostfix)
|
||||
}
|
||||
|
||||
// makeValueKey returns a value key from the passed parent bucket id and
|
||||
// the key.
|
||||
func makeValueKey(parent []byte, key []byte) []byte {
|
||||
return makeKey(parent, key, valuePostfix)
|
||||
}
|
||||
|
||||
// makeSequenceKey returns a sequence key of the passed parent bucket id.
|
||||
func makeSequenceKey(parent []byte) []byte {
|
||||
keyBuf := make([]byte, len(sequencePrefix)+len(parent))
|
||||
copy(keyBuf, sequencePrefix)
|
||||
copy(keyBuf[len(sequencePrefix):], parent)
|
||||
return keyBuf
|
||||
}
|
||||
|
||||
// isBucketKey returns true if the passed key is a bucket key, meaning it
|
||||
// keys a bucket name.
|
||||
func isBucketKey(key string) bool {
|
||||
if len(key) < bucketIDLength+1 {
|
||||
return false
|
||||
}
|
||||
|
||||
return key[len(key)-1] == bucketPostfix[0]
|
||||
}
|
||||
|
||||
// getKey chops out the key from the raw key (by removing the bucket id
|
||||
// prefixing the key and the postfix indicating whether it is a bucket or
|
||||
// a value key)
|
||||
func getKey(rawKey string) []byte {
|
||||
return []byte(rawKey[bucketIDLength : len(rawKey)-1])
|
||||
}
|
||||
|
||||
// getKeyVal chops out the key from the raw key (by removing the bucket id
|
||||
// prefixing the key and the postfix indicating whether it is a bucket or
|
||||
// a value key) and also returns the appropriate value for the key, which is
|
||||
// nil in case of buckets (or the set value otherwise).
|
||||
func getKeyVal(kv *KV) ([]byte, []byte) {
|
||||
var val []byte
|
||||
|
||||
if !isBucketKey(kv.key) {
|
||||
val = []byte(kv.val)
|
||||
}
|
||||
|
||||
return getKey(kv.key), val
|
||||
}
|
42
kvdb/etcd/bucket_test.go
Normal file
42
kvdb/etcd/bucket_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
// bkey is a helper functon used in tests to create a bucket key from passed
|
||||
// bucket list.
|
||||
func bkey(buckets ...string) string {
|
||||
var bucketKey []byte
|
||||
|
||||
rootID := makeBucketID([]byte(etcdDefaultRootBucketId))
|
||||
parent := rootID[:]
|
||||
|
||||
for _, bucketName := range buckets {
|
||||
bucketKey = makeBucketKey(parent, []byte(bucketName))
|
||||
id := makeBucketID(bucketKey)
|
||||
parent = id[:]
|
||||
}
|
||||
|
||||
return string(bucketKey)
|
||||
}
|
||||
|
||||
// bval is a helper function used in tests to create a bucket value (the value
|
||||
// for a bucket key) from the passed bucket list.
|
||||
func bval(buckets ...string) string {
|
||||
id := makeBucketID([]byte(bkey(buckets...)))
|
||||
return string(id[:])
|
||||
}
|
||||
|
||||
// vkey is a helper function used in tests to create a value key from the
|
||||
// passed key and bucket list.
|
||||
func vkey(key string, buckets ...string) string {
|
||||
rootID := makeBucketID([]byte(etcdDefaultRootBucketId))
|
||||
bucket := rootID[:]
|
||||
|
||||
for _, bucketName := range buckets {
|
||||
bucketKey := makeBucketKey(bucket, []byte(bucketName))
|
||||
id := makeBucketID(bucketKey)
|
||||
bucket = id[:]
|
||||
}
|
||||
|
||||
return string(makeValueKey(bucket, []byte(key)))
|
||||
}
|
150
kvdb/etcd/commit_queue.go
Normal file
150
kvdb/etcd/commit_queue.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// commitQueueSize is the maximum number of commits we let to queue up. All
|
||||
// remaining commits will block on commitQueue.Add().
|
||||
const commitQueueSize = 100
|
||||
|
||||
// commitQueue is a simple execution queue to manage conflicts for transactions
|
||||
// and thereby reduce the number of times conflicting transactions need to be
|
||||
// retried. When a new transaction is added to the queue, we first upgrade the
|
||||
// read/write counts in the queue's own accounting to decide whether the new
|
||||
// transaction has any conflicting dependencies. If the transaction does not
|
||||
// conflict with any other, then it is comitted immediately, otherwise it'll be
|
||||
// queued up for later exection.
|
||||
// The algorithm is described in: http://www.cs.umd.edu/~abadi/papers/vll-vldb13.pdf
|
||||
type commitQueue struct {
|
||||
ctx context.Context
|
||||
mx sync.Mutex
|
||||
readerMap map[string]int
|
||||
writerMap map[string]int
|
||||
|
||||
commitMutex sync.RWMutex
|
||||
queue chan (func())
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewCommitQueue creates a new commit queue, with the passed abort context.
|
||||
func NewCommitQueue(ctx context.Context) *commitQueue {
|
||||
q := &commitQueue{
|
||||
ctx: ctx,
|
||||
readerMap: make(map[string]int),
|
||||
writerMap: make(map[string]int),
|
||||
queue: make(chan func(), commitQueueSize),
|
||||
}
|
||||
|
||||
// Start the queue consumer loop.
|
||||
q.wg.Add(1)
|
||||
go q.mainLoop()
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// Wait waits for the queue to stop (after the queue context has been canceled).
|
||||
func (c *commitQueue) Wait() {
|
||||
c.wg.Wait()
|
||||
}
|
||||
|
||||
// Add increases lock counts and queues up tx commit closure for execution.
|
||||
// Transactions that don't have any conflicts are executed immediately by
|
||||
// "downgrading" the count mutex to allow concurrency.
|
||||
func (c *commitQueue) Add(commitLoop func(), rset readSet, wset writeSet) {
|
||||
c.mx.Lock()
|
||||
blocked := false
|
||||
|
||||
// Mark as blocked if there's any writer changing any of the keys in
|
||||
// the read set. Do not increment the reader counts yet as we'll need to
|
||||
// use the original reader counts when scanning through the write set.
|
||||
for key := range rset {
|
||||
if c.writerMap[key] > 0 {
|
||||
blocked = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Mark as blocked if there's any writer or reader for any of the keys
|
||||
// in the write set.
|
||||
for key := range wset {
|
||||
blocked = blocked || c.readerMap[key] > 0 || c.writerMap[key] > 0
|
||||
|
||||
// Increment the writer count.
|
||||
c.writerMap[key] += 1
|
||||
}
|
||||
|
||||
// Finally we can increment the reader counts for keys in the read set.
|
||||
for key := range rset {
|
||||
c.readerMap[key] += 1
|
||||
}
|
||||
|
||||
if blocked {
|
||||
// Add the transaction to the queue if conflicts with an already
|
||||
// queued one.
|
||||
c.mx.Unlock()
|
||||
|
||||
select {
|
||||
case c.queue <- commitLoop:
|
||||
case <-c.ctx.Done():
|
||||
}
|
||||
} else {
|
||||
// To make sure we don't add a new tx to the queue that depends
|
||||
// on this "unblocked" tx, grab the commitMutex before lifting
|
||||
// the mutex guarding the lock maps.
|
||||
c.commitMutex.RLock()
|
||||
c.mx.Unlock()
|
||||
|
||||
// At this point we're safe to execute the "unblocked" tx, as
|
||||
// we cannot execute blocked tx that may have been read from the
|
||||
// queue until the commitMutex is held.
|
||||
commitLoop()
|
||||
|
||||
c.commitMutex.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Done decreases lock counts of the keys in the read/write sets.
|
||||
func (c *commitQueue) Done(rset readSet, wset writeSet) {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
for key := range rset {
|
||||
c.readerMap[key] -= 1
|
||||
if c.readerMap[key] == 0 {
|
||||
delete(c.readerMap, key)
|
||||
}
|
||||
}
|
||||
|
||||
for key := range wset {
|
||||
c.writerMap[key] -= 1
|
||||
if c.writerMap[key] == 0 {
|
||||
delete(c.writerMap, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mainLoop executes queued transaction commits for transactions that have
|
||||
// dependencies. The queue ensures that the top element doesn't conflict with
|
||||
// any other transactions and therefore can be executed freely.
|
||||
func (c *commitQueue) mainLoop() {
|
||||
defer c.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case top := <-c.queue:
|
||||
// Execute the next blocked transaction. As it is
|
||||
// the top element in the queue it means that it doesn't
|
||||
// depend on any other transactions anymore.
|
||||
c.commitMutex.Lock()
|
||||
top()
|
||||
c.commitMutex.Unlock()
|
||||
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
115
kvdb/etcd/commit_queue_test.go
Normal file
115
kvdb/etcd/commit_queue_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCommitQueue tests that non-conflicting transactions commit concurrently,
|
||||
// while conflicting transactions are queued up.
|
||||
func TestCommitQueue(t *testing.T) {
|
||||
// The duration of each commit.
|
||||
const commitDuration = time.Millisecond * 500
|
||||
const numCommits = 4
|
||||
|
||||
var wg sync.WaitGroup
|
||||
commits := make([]string, numCommits)
|
||||
idx := int32(-1)
|
||||
|
||||
commit := func(tag string, sleep bool) func() {
|
||||
return func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Update our log of commit order. Avoid blocking
|
||||
// by preallocating the commit log and increasing
|
||||
// the log index atomically.
|
||||
i := atomic.AddInt32(&idx, 1)
|
||||
commits[i] = tag
|
||||
|
||||
if sleep {
|
||||
time.Sleep(commitDuration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a read set from the passed keys.
|
||||
makeReadSet := func(keys []string) readSet {
|
||||
rs := make(map[string]stmGet)
|
||||
|
||||
for _, key := range keys {
|
||||
rs[key] = stmGet{}
|
||||
}
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
// Helper function to create a write set from the passed keys.
|
||||
makeWriteSet := func(keys []string) writeSet {
|
||||
ws := make(map[string]stmPut)
|
||||
|
||||
for _, key := range keys {
|
||||
ws[key] = stmPut{}
|
||||
}
|
||||
|
||||
return ws
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
q := NewCommitQueue(ctx)
|
||||
defer q.Wait()
|
||||
defer cancel()
|
||||
|
||||
wg.Add(numCommits)
|
||||
t1 := time.Now()
|
||||
|
||||
// Tx1: reads: key1, key2, writes: key3, conflict: none
|
||||
q.Add(
|
||||
commit("free", true),
|
||||
makeReadSet([]string{"key1", "key2"}),
|
||||
makeWriteSet([]string{"key3"}),
|
||||
)
|
||||
// Tx2: reads: key1, key2, writes: key3, conflict: Tx1
|
||||
q.Add(
|
||||
commit("blocked1", false),
|
||||
makeReadSet([]string{"key1", "key2"}),
|
||||
makeWriteSet([]string{"key3"}),
|
||||
)
|
||||
// Tx3: reads: key1, writes: key4, conflict: none
|
||||
q.Add(
|
||||
commit("free", true),
|
||||
makeReadSet([]string{"key1", "key2"}),
|
||||
makeWriteSet([]string{"key4"}),
|
||||
)
|
||||
// Tx4: reads: key2, writes: key4 conflict: Tx3
|
||||
q.Add(
|
||||
commit("blocked2", false),
|
||||
makeReadSet([]string{"key2"}),
|
||||
makeWriteSet([]string{"key4"}),
|
||||
)
|
||||
|
||||
// Wait for all commits.
|
||||
wg.Wait()
|
||||
t2 := time.Now()
|
||||
|
||||
// Expected total execution time: delta.
|
||||
// 2 * commitDuration <= delta < 3 * commitDuration
|
||||
delta := t2.Sub(t1)
|
||||
require.LessOrEqual(t, int64(commitDuration*2), int64(delta))
|
||||
require.Greater(t, int64(commitDuration*3), int64(delta))
|
||||
|
||||
// Expect that the non-conflicting "free" transactions are executed
|
||||
// before the blocking ones, and the blocking ones are executed in
|
||||
// the order of addition.
|
||||
require.Equal(t,
|
||||
[]string{"free", "free", "blocked1", "blocked2"},
|
||||
commits,
|
||||
)
|
||||
}
|
28
kvdb/etcd/config.go
Normal file
28
kvdb/etcd/config.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package etcd
|
||||
|
||||
// Config holds etcd configuration alongside with configuration related to our higher level interface.
|
||||
type Config struct {
|
||||
Embedded bool `long:"embedded" description:"Use embedded etcd instance instead of the external one. Note: use for testing only."`
|
||||
|
||||
EmbeddedClientPort uint16 `long:"embedded_client_port" description:"Client port to use for the embedded instance. Note: use for testing only."`
|
||||
|
||||
EmbeddedPeerPort uint16 `long:"embedded_peer_port" description:"Peer port to use for the embedded instance. Note: use for testing only."`
|
||||
|
||||
Host string `long:"host" description:"Etcd database host."`
|
||||
|
||||
User string `long:"user" description:"Etcd database user."`
|
||||
|
||||
Pass string `long:"pass" description:"Password for the database user."`
|
||||
|
||||
Namespace string `long:"namespace" description:"The etcd namespace to use."`
|
||||
|
||||
DisableTLS bool `long:"disabletls" description:"Disable TLS for etcd connection. Caution: use for development only."`
|
||||
|
||||
CertFile string `long:"cert_file" description:"Path to the TLS certificate for etcd RPC."`
|
||||
|
||||
KeyFile string `long:"key_file" description:"Path to the TLS private key for etcd RPC."`
|
||||
|
||||
InsecureSkipVerify bool `long:"insecure_skip_verify" description:"Whether we intend to skip TLS verification"`
|
||||
|
||||
CollectStats bool `long:"collect_stats" description:"Whether to collect etcd commit stats."`
|
||||
}
|
277
kvdb/etcd/db.go
Normal file
277
kvdb/etcd/db.go
Normal file
@@ -0,0 +1,277 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
"go.etcd.io/etcd/clientv3/namespace"
|
||||
"go.etcd.io/etcd/pkg/transport"
|
||||
)
|
||||
|
||||
const (
|
||||
// etcdConnectionTimeout is the timeout until successful connection to
|
||||
// the etcd instance.
|
||||
etcdConnectionTimeout = 10 * time.Second
|
||||
|
||||
// etcdLongTimeout is a timeout for longer taking etcd operatons.
|
||||
etcdLongTimeout = 30 * time.Second
|
||||
|
||||
// etcdDefaultRootBucketId is used as the root bucket key. Note that
|
||||
// the actual key is not visible, since all bucket keys are hashed.
|
||||
etcdDefaultRootBucketId = "@"
|
||||
)
|
||||
|
||||
// callerStats holds commit stats for a specific caller. Currently it only
|
||||
// holds the max stat, meaning that for a particular caller the largest
|
||||
// commit set is recorded.
|
||||
type callerStats struct {
|
||||
count int
|
||||
commitStats CommitStats
|
||||
}
|
||||
|
||||
func (s callerStats) String() string {
|
||||
return fmt.Sprintf("count: %d, retries: %d, rset: %d, wset: %d",
|
||||
s.count, s.commitStats.Retries, s.commitStats.Rset,
|
||||
s.commitStats.Wset)
|
||||
}
|
||||
|
||||
// commitStatsCollector collects commit stats for commits succeeding
|
||||
// and also for commits failing.
|
||||
type commitStatsCollector struct {
|
||||
sync.RWMutex
|
||||
succ map[string]*callerStats
|
||||
fail map[string]*callerStats
|
||||
}
|
||||
|
||||
// newCommitStatsColletor creates a new commitStatsCollector instance.
|
||||
func newCommitStatsColletor() *commitStatsCollector {
|
||||
return &commitStatsCollector{
|
||||
succ: make(map[string]*callerStats),
|
||||
fail: make(map[string]*callerStats),
|
||||
}
|
||||
}
|
||||
|
||||
// PrintStats returns collected stats pretty printed into a string.
|
||||
func (c *commitStatsCollector) PrintStats() string {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
s := "\nFailure:\n"
|
||||
for k, v := range c.fail {
|
||||
s += fmt.Sprintf("%s\t%s\n", k, v)
|
||||
}
|
||||
|
||||
s += "\nSuccess:\n"
|
||||
for k, v := range c.succ {
|
||||
s += fmt.Sprintf("%s\t%s\n", k, v)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// updateStatsMap updatess commit stats map for a caller.
|
||||
func updateStatMap(
|
||||
caller string, stats CommitStats, m map[string]*callerStats) {
|
||||
|
||||
if _, ok := m[caller]; !ok {
|
||||
m[caller] = &callerStats{}
|
||||
}
|
||||
|
||||
curr := m[caller]
|
||||
curr.count++
|
||||
|
||||
// Update only if the total commit set is greater or equal.
|
||||
currTotal := curr.commitStats.Rset + curr.commitStats.Wset
|
||||
if currTotal <= (stats.Rset + stats.Wset) {
|
||||
curr.commitStats = stats
|
||||
}
|
||||
}
|
||||
|
||||
// callback is an STM commit stats callback passed which can be passed
|
||||
// using a WithCommitStatsCallback to the STM upon construction.
|
||||
func (c *commitStatsCollector) callback(succ bool, stats CommitStats) {
|
||||
caller := "unknown"
|
||||
|
||||
// Get the caller. As this callback is called from
|
||||
// the backend interface that means we need to ascend
|
||||
// 4 frames in the callstack.
|
||||
_, file, no, ok := runtime.Caller(4)
|
||||
if ok {
|
||||
caller = fmt.Sprintf("%s#%d", file, no)
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if succ {
|
||||
updateStatMap(caller, stats, c.succ)
|
||||
} else {
|
||||
updateStatMap(caller, stats, c.fail)
|
||||
}
|
||||
}
|
||||
|
||||
// db holds a reference to the etcd client connection.
|
||||
type db struct {
|
||||
cfg Config
|
||||
ctx context.Context
|
||||
cli *clientv3.Client
|
||||
commitStatsCollector *commitStatsCollector
|
||||
txQueue *commitQueue
|
||||
}
|
||||
|
||||
// Enforce db implements the walletdb.DB interface.
|
||||
var _ walletdb.DB = (*db)(nil)
|
||||
|
||||
// newEtcdBackend returns a db object initialized with the passed backend
|
||||
// config. If etcd connection cannot be estabished, then returns error.
|
||||
func newEtcdBackend(ctx context.Context, cfg Config) (*db, error) {
|
||||
clientCfg := clientv3.Config{
|
||||
Context: ctx,
|
||||
Endpoints: []string{cfg.Host},
|
||||
DialTimeout: etcdConnectionTimeout,
|
||||
Username: cfg.User,
|
||||
Password: cfg.Pass,
|
||||
MaxCallSendMsgSize: 16384*1024 - 1,
|
||||
}
|
||||
|
||||
if !cfg.DisableTLS {
|
||||
tlsInfo := transport.TLSInfo{
|
||||
CertFile: cfg.CertFile,
|
||||
KeyFile: cfg.KeyFile,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
tlsConfig, err := tlsInfo.ClientConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientCfg.TLS = tlsConfig
|
||||
}
|
||||
|
||||
cli, err := clientv3.New(clientCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply the namespace.
|
||||
cli.KV = namespace.NewKV(cli.KV, cfg.Namespace)
|
||||
cli.Watcher = namespace.NewWatcher(cli.Watcher, cfg.Namespace)
|
||||
cli.Lease = namespace.NewLease(cli.Lease, cfg.Namespace)
|
||||
|
||||
backend := &db{
|
||||
cfg: cfg,
|
||||
ctx: ctx,
|
||||
cli: cli,
|
||||
txQueue: NewCommitQueue(ctx),
|
||||
}
|
||||
|
||||
if cfg.CollectStats {
|
||||
backend.commitStatsCollector = newCommitStatsColletor()
|
||||
}
|
||||
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// getSTMOptions creats all STM options based on the backend config.
|
||||
func (db *db) getSTMOptions() []STMOptionFunc {
|
||||
opts := []STMOptionFunc{
|
||||
WithAbortContext(db.ctx),
|
||||
}
|
||||
|
||||
if db.cfg.CollectStats {
|
||||
opts = append(opts,
|
||||
WithCommitStatsCallback(db.commitStatsCollector.callback),
|
||||
)
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
// View opens a database read transaction and executes the function f with the
|
||||
// transaction passed as a parameter. After f exits, the transaction is rolled
|
||||
// back. If f errors, its error is returned, not a rollback error (if any
|
||||
// occur). The passed reset function is called before the start of the
|
||||
// transaction and can be used to reset intermediate state. As callers may
|
||||
// expect retries of the f closure (depending on the database backend used), the
|
||||
// reset function will be called before each retry respectively.
|
||||
func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error {
|
||||
apply := func(stm STM) error {
|
||||
reset()
|
||||
return f(newReadWriteTx(stm, etcdDefaultRootBucketId))
|
||||
}
|
||||
|
||||
return RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...)
|
||||
}
|
||||
|
||||
// Update opens a database read/write transaction and executes the function f
|
||||
// with the transaction passed as a parameter. After f exits, if f did not
|
||||
// error, the transaction is committed. Otherwise, if f did error, the
|
||||
// transaction is rolled back. If the rollback fails, the original error
|
||||
// returned by f is still returned. If the commit fails, the commit error is
|
||||
// returned. As callers may expect retries of the f closure, the reset function
|
||||
// will be called before each retry respectively.
|
||||
func (db *db) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error {
|
||||
apply := func(stm STM) error {
|
||||
reset()
|
||||
return f(newReadWriteTx(stm, etcdDefaultRootBucketId))
|
||||
}
|
||||
|
||||
return RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...)
|
||||
}
|
||||
|
||||
// PrintStats returns all collected stats pretty printed into a string.
|
||||
func (db *db) PrintStats() string {
|
||||
if db.commitStatsCollector != nil {
|
||||
return db.commitStatsCollector.PrintStats()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// BeginReadWriteTx opens a database read+write transaction.
|
||||
func (db *db) BeginReadWriteTx() (walletdb.ReadWriteTx, error) {
|
||||
return newReadWriteTx(
|
||||
NewSTM(db.cli, db.txQueue, db.getSTMOptions()...),
|
||||
etcdDefaultRootBucketId,
|
||||
), nil
|
||||
}
|
||||
|
||||
// BeginReadTx opens a database read transaction.
|
||||
func (db *db) BeginReadTx() (walletdb.ReadTx, error) {
|
||||
return newReadWriteTx(
|
||||
NewSTM(db.cli, db.txQueue, db.getSTMOptions()...),
|
||||
etcdDefaultRootBucketId,
|
||||
), nil
|
||||
}
|
||||
|
||||
// Copy writes a copy of the database to the provided writer. This call will
|
||||
// start a read-only transaction to perform all operations.
|
||||
// This function is part of the walletdb.Db interface implementation.
|
||||
func (db *db) Copy(w io.Writer) error {
|
||||
ctx, cancel := context.WithTimeout(db.ctx, etcdLongTimeout)
|
||||
defer cancel()
|
||||
|
||||
readCloser, err := db.cli.Snapshot(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(w, readCloser)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Close cleanly shuts down the database and syncs all data.
|
||||
// This function is part of the walletdb.Db interface implementation.
|
||||
func (db *db) Close() error {
|
||||
return db.cli.Close()
|
||||
}
|
74
kvdb/etcd/db_test.go
Normal file
74
kvdb/etcd/db_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCopy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
// "apple"
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
require.NoError(t, apple.Put([]byte("key"), []byte("val")))
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
// Expect non-zero copy.
|
||||
var buf bytes.Buffer
|
||||
|
||||
require.NoError(t, db.Copy(&buf))
|
||||
require.Greater(t, buf.Len(), 0)
|
||||
require.Nil(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
vkey("key", "apple"): "val",
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
||||
|
||||
func TestAbortContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
config := f.BackendConfig()
|
||||
|
||||
// Pass abort context and abort right away.
|
||||
db, err := newEtcdBackend(ctx, config)
|
||||
require.NoError(t, err)
|
||||
cancel()
|
||||
|
||||
// Expect that the update will fail.
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
_, err := tx.CreateTopLevelBucket([]byte("bucket"))
|
||||
require.Error(t, err, "context canceled")
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.Error(t, err, "context canceled")
|
||||
|
||||
// No changes in the DB.
|
||||
require.Equal(t, map[string]string{}, f.Dump())
|
||||
}
|
79
kvdb/etcd/driver.go
Normal file
79
kvdb/etcd/driver.go
Normal file
@@ -0,0 +1,79 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
)
|
||||
|
||||
const (
|
||||
dbType = "etcd"
|
||||
)
|
||||
|
||||
// parseArgs parses the arguments from the walletdb Open/Create methods.
|
||||
func parseArgs(funcName string, args ...interface{}) (context.Context,
|
||||
*Config, error) {
|
||||
|
||||
if len(args) != 2 {
|
||||
return nil, nil, fmt.Errorf("invalid number of arguments to "+
|
||||
"%s.%s -- expected: context.Context, etcd.Config",
|
||||
dbType, funcName,
|
||||
)
|
||||
}
|
||||
|
||||
ctx, ok := args[0].(context.Context)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("argument 0 to %s.%s is invalid "+
|
||||
"-- expected: context.Context",
|
||||
dbType, funcName,
|
||||
)
|
||||
}
|
||||
|
||||
config, ok := args[1].(*Config)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("argument 1 to %s.%s is invalid -- "+
|
||||
"expected: etcd.Config",
|
||||
dbType, funcName,
|
||||
)
|
||||
}
|
||||
|
||||
return ctx, config, nil
|
||||
}
|
||||
|
||||
// createDBDriver is the callback provided during driver registration that
|
||||
// creates, initializes, and opens a database for use.
|
||||
func createDBDriver(args ...interface{}) (walletdb.DB, error) {
|
||||
ctx, config, err := parseArgs("Create", args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newEtcdBackend(ctx, *config)
|
||||
}
|
||||
|
||||
// openDBDriver is the callback provided during driver registration that opens
|
||||
// an existing database for use.
|
||||
func openDBDriver(args ...interface{}) (walletdb.DB, error) {
|
||||
ctx, config, err := parseArgs("Open", args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newEtcdBackend(ctx, *config)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register the driver.
|
||||
driver := walletdb.Driver{
|
||||
DbType: dbType,
|
||||
Create: createDBDriver,
|
||||
Open: openDBDriver,
|
||||
}
|
||||
if err := walletdb.RegisterDriver(driver); err != nil {
|
||||
panic(fmt.Sprintf("Failed to regiser database driver '%s': %v",
|
||||
dbType, err))
|
||||
}
|
||||
}
|
30
kvdb/etcd/driver_test.go
Normal file
30
kvdb/etcd/driver_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenCreateFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := walletdb.Open(dbType)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, db)
|
||||
|
||||
db, err = walletdb.Open(dbType, "wrong")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, db)
|
||||
|
||||
db, err = walletdb.Create(dbType)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, db)
|
||||
|
||||
db, err = walletdb.Create(dbType, "wrong")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, db)
|
||||
}
|
109
kvdb/etcd/embed.go
Normal file
109
kvdb/etcd/embed.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.etcd.io/etcd/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
// readyTimeout is the time until the embedded etcd instance should start.
|
||||
readyTimeout = 10 * time.Second
|
||||
|
||||
// defaultEtcdPort is the start of the range for listening ports of
|
||||
// embedded etcd servers. Ports are monotonically increasing starting
|
||||
// from this number and are determined by the results of getFreePort().
|
||||
defaultEtcdPort = 2379
|
||||
|
||||
// defaultNamespace is the namespace we'll use in our embedded etcd
|
||||
// instance. Since it is only used for testing, we'll use the namespace
|
||||
// name "test/" for this. Note that the namespace can be any string,
|
||||
// the trailing / is not required.
|
||||
defaultNamespace = "test/"
|
||||
)
|
||||
|
||||
var (
|
||||
// lastPort is the last port determined to be free for use by a new
|
||||
// embedded etcd server. It should be used atomically.
|
||||
lastPort uint32 = defaultEtcdPort
|
||||
)
|
||||
|
||||
// getFreePort returns the first port that is available for listening by a new
|
||||
// embedded etcd server. It panics if no port is found and the maximum available
|
||||
// TCP port is reached.
|
||||
func getFreePort() int {
|
||||
port := atomic.AddUint32(&lastPort, 1)
|
||||
for port < 65535 {
|
||||
// If there are no errors while attempting to listen on this
|
||||
// port, close the socket and return it as available.
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
l, err := net.Listen("tcp4", addr)
|
||||
if err == nil {
|
||||
err := l.Close()
|
||||
if err == nil {
|
||||
return int(port)
|
||||
}
|
||||
}
|
||||
port = atomic.AddUint32(&lastPort, 1)
|
||||
}
|
||||
|
||||
// No ports available? Must be a mistake.
|
||||
panic("no ports available for listening")
|
||||
}
|
||||
|
||||
// NewEmbeddedEtcdInstance creates an embedded etcd instance for testing,
|
||||
// listening on random open ports. Returns the backend config and a cleanup
|
||||
// func that will stop the etcd instance.
|
||||
func NewEmbeddedEtcdInstance(path string, clientPort, peerPort uint16) (
|
||||
*Config, func(), error) {
|
||||
|
||||
cfg := embed.NewConfig()
|
||||
cfg.Dir = path
|
||||
|
||||
// To ensure that we can submit large transactions.
|
||||
cfg.MaxTxnOps = 8192
|
||||
cfg.MaxRequestBytes = 16384 * 1024
|
||||
|
||||
// Listen on random free ports if no ports were specified.
|
||||
if clientPort == 0 {
|
||||
clientPort = uint16(getFreePort())
|
||||
}
|
||||
|
||||
if peerPort == 0 {
|
||||
peerPort = uint16(getFreePort())
|
||||
}
|
||||
|
||||
clientURL := fmt.Sprintf("127.0.0.1:%d", clientPort)
|
||||
peerURL := fmt.Sprintf("127.0.0.1:%d", peerPort)
|
||||
cfg.LCUrls = []url.URL{{Host: clientURL}}
|
||||
cfg.LPUrls = []url.URL{{Host: peerURL}}
|
||||
|
||||
etcd, err := embed.StartEtcd(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-etcd.Server.ReadyNotify():
|
||||
case <-time.After(readyTimeout):
|
||||
etcd.Close()
|
||||
return nil, nil,
|
||||
fmt.Errorf("etcd failed to start after: %v", readyTimeout)
|
||||
}
|
||||
|
||||
connConfig := &Config{
|
||||
Host: "http://" + clientURL,
|
||||
InsecureSkipVerify: true,
|
||||
Namespace: defaultNamespace,
|
||||
}
|
||||
|
||||
return connConfig, func() {
|
||||
etcd.Close()
|
||||
}, nil
|
||||
}
|
135
kvdb/etcd/fixture_test.go
Normal file
135
kvdb/etcd/fixture_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
"go.etcd.io/etcd/clientv3/namespace"
|
||||
)
|
||||
|
||||
const (
|
||||
// testEtcdTimeout is used for all RPC calls initiated by the test fixture.
|
||||
testEtcdTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// EtcdTestFixture holds internal state of the etcd test fixture.
|
||||
type EtcdTestFixture struct {
|
||||
t *testing.T
|
||||
cli *clientv3.Client
|
||||
config *Config
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
// NewTestEtcdInstance creates an embedded etcd instance for testing, listening
|
||||
// on random open ports. Returns the connection config and a cleanup func that
|
||||
// will stop the etcd instance.
|
||||
func NewTestEtcdInstance(t *testing.T, path string) (*Config, func()) {
|
||||
t.Helper()
|
||||
|
||||
config, cleanup, err := NewEmbeddedEtcdInstance(path, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("error while staring embedded etcd instance: %v", err)
|
||||
}
|
||||
|
||||
return config, cleanup
|
||||
}
|
||||
|
||||
// NewTestEtcdTestFixture creates a new etcd-test fixture. This is helper
|
||||
// object to facilitate etcd tests and ensure pre and post conditions.
|
||||
func NewEtcdTestFixture(t *testing.T) *EtcdTestFixture {
|
||||
tmpDir, err := ioutil.TempDir("", "etcd")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
config, etcdCleanup := NewTestEtcdInstance(t, tmpDir)
|
||||
|
||||
cli, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: []string{config.Host},
|
||||
Username: config.User,
|
||||
Password: config.Pass,
|
||||
})
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("unable to create etcd test fixture: %v", err)
|
||||
}
|
||||
|
||||
// Apply the default namespace (since that's what we use in tests).
|
||||
cli.KV = namespace.NewKV(cli.KV, defaultNamespace)
|
||||
cli.Watcher = namespace.NewWatcher(cli.Watcher, defaultNamespace)
|
||||
cli.Lease = namespace.NewLease(cli.Lease, defaultNamespace)
|
||||
|
||||
return &EtcdTestFixture{
|
||||
t: t,
|
||||
cli: cli,
|
||||
config: config,
|
||||
cleanup: func() {
|
||||
etcdCleanup()
|
||||
os.RemoveAll(tmpDir)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Put puts a string key/value into the test etcd database.
|
||||
func (f *EtcdTestFixture) Put(key, value string) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testEtcdTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := f.cli.Put(ctx, key, value)
|
||||
if err != nil {
|
||||
f.t.Fatalf("etcd test fixture failed to put: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get queries a key and returns the stored value from the test etcd database.
|
||||
func (f *EtcdTestFixture) Get(key string) string {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testEtcdTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := f.cli.Get(ctx, key)
|
||||
if err != nil {
|
||||
f.t.Fatalf("etcd test fixture failed to get: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Kvs) > 0 {
|
||||
return string(resp.Kvs[0].Value)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Dump scans and returns all key/values from the test etcd database.
|
||||
func (f *EtcdTestFixture) Dump() map[string]string {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testEtcdTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := f.cli.Get(ctx, "\x00", clientv3.WithFromKey())
|
||||
if err != nil {
|
||||
f.t.Fatalf("etcd test fixture failed to get: %v", err)
|
||||
}
|
||||
|
||||
result := make(map[string]string)
|
||||
for _, kv := range resp.Kvs {
|
||||
result[string(kv.Key)] = string(kv.Value)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// BackendConfig returns the backend config for connecting to theembedded
|
||||
// etcd instance.
|
||||
func (f *EtcdTestFixture) BackendConfig() Config {
|
||||
return *f.config
|
||||
}
|
||||
|
||||
// Cleanup should be called at test fixture teardown to stop the embedded
|
||||
// etcd instance and remove all temp db files form the filesystem.
|
||||
func (f *EtcdTestFixture) Cleanup() {
|
||||
f.cleanup()
|
||||
}
|
356
kvdb/etcd/readwrite_bucket.go
Normal file
356
kvdb/etcd/readwrite_bucket.go
Normal file
@@ -0,0 +1,356 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
)
|
||||
|
||||
// readWriteBucket stores the bucket id and the buckets transaction.
|
||||
type readWriteBucket struct {
|
||||
// id is used to identify the bucket and is created by
|
||||
// hashing the parent id with the bucket key. For each key/value,
|
||||
// sub-bucket or the bucket sequence the bucket id is used with the
|
||||
// appropriate prefix to prefix the key.
|
||||
id []byte
|
||||
|
||||
// tx holds the parent transaction.
|
||||
tx *readWriteTx
|
||||
}
|
||||
|
||||
// newReadWriteBucket creates a new rw bucket with the passed transaction
|
||||
// and bucket id.
|
||||
func newReadWriteBucket(tx *readWriteTx, key, id []byte) *readWriteBucket {
|
||||
return &readWriteBucket{
|
||||
id: id,
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
// NestedReadBucket retrieves a nested read bucket with the given key.
|
||||
// Returns nil if the bucket does not exist.
|
||||
func (b *readWriteBucket) NestedReadBucket(key []byte) walletdb.ReadBucket {
|
||||
return b.NestedReadWriteBucket(key)
|
||||
}
|
||||
|
||||
// ForEach invokes the passed function with every key/value pair in
|
||||
// the bucket. This includes nested buckets, in which case the value
|
||||
// is nil, but it does not include the key/value pairs within those
|
||||
// nested buckets.
|
||||
func (b *readWriteBucket) ForEach(cb func(k, v []byte) error) error {
|
||||
prefix := string(b.id)
|
||||
|
||||
// Get the first matching key that is in the bucket.
|
||||
kv, err := b.tx.stm.First(prefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for kv != nil {
|
||||
key, val := getKeyVal(kv)
|
||||
|
||||
if err := cb(key, val); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step to the next key.
|
||||
kv, err = b.tx.stm.Next(prefix, kv.key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the value for the given key. Returns nil if the key does
|
||||
// not exist in this bucket.
|
||||
func (b *readWriteBucket) Get(key []byte) []byte {
|
||||
// Return nil if the key is empty.
|
||||
if len(key) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fetch the associated value.
|
||||
val, err := b.tx.stm.Get(string(makeValueKey(b.id, key)))
|
||||
if err != nil {
|
||||
// TODO: we should return the error once the
|
||||
// kvdb inteface is extended.
|
||||
return nil
|
||||
}
|
||||
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
func (b *readWriteBucket) ReadCursor() walletdb.ReadCursor {
|
||||
return newReadWriteCursor(b)
|
||||
}
|
||||
|
||||
// NestedReadWriteBucket retrieves a nested bucket with the given key.
|
||||
// Returns nil if the bucket does not exist.
|
||||
func (b *readWriteBucket) NestedReadWriteBucket(key []byte) walletdb.ReadWriteBucket {
|
||||
if len(key) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the bucket id (and return nil if bucket doesn't exist).
|
||||
bucketKey := makeBucketKey(b.id, key)
|
||||
bucketVal, err := b.tx.stm.Get(string(bucketKey))
|
||||
if err != nil {
|
||||
// TODO: we should return the error once the
|
||||
// kvdb inteface is extended.
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isValidBucketID(bucketVal) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the bucket with the fetched bucket id.
|
||||
return newReadWriteBucket(b.tx, bucketKey, bucketVal)
|
||||
}
|
||||
|
||||
// assertNoValue checks if the value for the passed key exists.
|
||||
func (b *readWriteBucket) assertNoValue(key []byte) error {
|
||||
val, err := b.tx.stm.Get(string(makeValueKey(b.id, key)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val != nil {
|
||||
return walletdb.ErrIncompatibleValue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateBucket creates and returns a new nested bucket with the given
|
||||
// key. Returns ErrBucketExists if the bucket already exists,
|
||||
// ErrBucketNameRequired if the key is empty, or ErrIncompatibleValue
|
||||
// if the key value is otherwise invalid for the particular database
|
||||
// implementation. Other errors are possible depending on the
|
||||
// implementation.
|
||||
func (b *readWriteBucket) CreateBucket(key []byte) (
|
||||
walletdb.ReadWriteBucket, error) {
|
||||
|
||||
if len(key) == 0 {
|
||||
return nil, walletdb.ErrBucketNameRequired
|
||||
}
|
||||
|
||||
// Check if the bucket already exists.
|
||||
bucketKey := makeBucketKey(b.id, key)
|
||||
|
||||
bucketVal, err := b.tx.stm.Get(string(bucketKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isValidBucketID(bucketVal) {
|
||||
return nil, walletdb.ErrBucketExists
|
||||
}
|
||||
|
||||
if err := b.assertNoValue(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a deterministic bucket id from the bucket key.
|
||||
newID := makeBucketID(bucketKey)
|
||||
|
||||
// Create the bucket.
|
||||
b.tx.stm.Put(string(bucketKey), string(newID[:]))
|
||||
|
||||
return newReadWriteBucket(b.tx, bucketKey, newID[:]), nil
|
||||
}
|
||||
|
||||
// CreateBucketIfNotExists creates and returns a new nested bucket with
|
||||
// the given key if it does not already exist. Returns
|
||||
// ErrBucketNameRequired if the key is empty or ErrIncompatibleValue
|
||||
// if the key value is otherwise invalid for the particular database
|
||||
// backend. Other errors are possible depending on the implementation.
|
||||
func (b *readWriteBucket) CreateBucketIfNotExists(key []byte) (
|
||||
walletdb.ReadWriteBucket, error) {
|
||||
|
||||
if len(key) == 0 {
|
||||
return nil, walletdb.ErrBucketNameRequired
|
||||
}
|
||||
|
||||
// Check for the bucket and create if it doesn't exist.
|
||||
bucketKey := makeBucketKey(b.id, key)
|
||||
|
||||
bucketVal, err := b.tx.stm.Get(string(bucketKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isValidBucketID(bucketVal) {
|
||||
if err := b.assertNoValue(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newID := makeBucketID(bucketKey)
|
||||
b.tx.stm.Put(string(bucketKey), string(newID[:]))
|
||||
|
||||
return newReadWriteBucket(b.tx, bucketKey, newID[:]), nil
|
||||
}
|
||||
|
||||
// Otherwise return the bucket with the fetched bucket id.
|
||||
return newReadWriteBucket(b.tx, bucketKey, bucketVal), nil
|
||||
}
|
||||
|
||||
// DeleteNestedBucket deletes the nested bucket and its sub-buckets
|
||||
// pointed to by the passed key. All values in the bucket and sub-buckets
|
||||
// will be deleted as well.
|
||||
func (b *readWriteBucket) DeleteNestedBucket(key []byte) error {
|
||||
// TODO shouldn't empty key return ErrBucketNameRequired ?
|
||||
if len(key) == 0 {
|
||||
return walletdb.ErrIncompatibleValue
|
||||
}
|
||||
|
||||
// Get the bucket first.
|
||||
bucketKey := string(makeBucketKey(b.id, key))
|
||||
|
||||
bucketVal, err := b.tx.stm.Get(bucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isValidBucketID(bucketVal) {
|
||||
return walletdb.ErrBucketNotFound
|
||||
}
|
||||
|
||||
// Enqueue the top level bucket id.
|
||||
queue := [][]byte{bucketVal}
|
||||
|
||||
// Traverse the buckets breadth first.
|
||||
for len(queue) != 0 {
|
||||
if !isValidBucketID(queue[0]) {
|
||||
return walletdb.ErrBucketNotFound
|
||||
}
|
||||
|
||||
id := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
kv, err := b.tx.stm.First(string(id))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for kv != nil {
|
||||
b.tx.stm.Del(kv.key)
|
||||
|
||||
if isBucketKey(kv.key) {
|
||||
queue = append(queue, []byte(kv.val))
|
||||
}
|
||||
|
||||
kv, err = b.tx.stm.Next(string(id), kv.key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Finally delete the sequence key for the bucket.
|
||||
b.tx.stm.Del(string(makeSequenceKey(id)))
|
||||
}
|
||||
|
||||
// Delete the top level bucket and sequence key.
|
||||
b.tx.stm.Del(bucketKey)
|
||||
b.tx.stm.Del(string(makeSequenceKey(bucketVal)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put updates the value for the passed key.
|
||||
// Returns ErrKeyRequred if te passed key is empty.
|
||||
func (b *readWriteBucket) Put(key, value []byte) error {
|
||||
if len(key) == 0 {
|
||||
return walletdb.ErrKeyRequired
|
||||
}
|
||||
|
||||
val, err := b.tx.stm.Get(string(makeBucketKey(b.id, key)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val != nil {
|
||||
return walletdb.ErrIncompatibleValue
|
||||
}
|
||||
|
||||
// Update the transaction with the new value.
|
||||
b.tx.stm.Put(string(makeValueKey(b.id, key)), string(value))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes the key/value pointed to by the passed key.
|
||||
// Returns ErrKeyRequred if the passed key is empty.
|
||||
func (b *readWriteBucket) Delete(key []byte) error {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
if len(key) == 0 {
|
||||
return walletdb.ErrKeyRequired
|
||||
}
|
||||
|
||||
// Update the transaction to delete the key/value.
|
||||
b.tx.stm.Del(string(makeValueKey(b.id, key)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadWriteCursor returns a new read-write cursor for this bucket.
|
||||
func (b *readWriteBucket) ReadWriteCursor() walletdb.ReadWriteCursor {
|
||||
return newReadWriteCursor(b)
|
||||
}
|
||||
|
||||
// Tx returns the buckets transaction.
|
||||
func (b *readWriteBucket) Tx() walletdb.ReadWriteTx {
|
||||
return b.tx
|
||||
}
|
||||
|
||||
// NextSequence returns an autoincrementing sequence number for this bucket.
|
||||
// Note that this is not a thread safe function and as such it must not be used
|
||||
// for synchronization.
|
||||
func (b *readWriteBucket) NextSequence() (uint64, error) {
|
||||
seq := b.Sequence() + 1
|
||||
|
||||
return seq, b.SetSequence(seq)
|
||||
}
|
||||
|
||||
// SetSequence updates the sequence number for the bucket.
|
||||
func (b *readWriteBucket) SetSequence(v uint64) error {
|
||||
// Convert the number to string.
|
||||
val := strconv.FormatUint(v, 10)
|
||||
|
||||
// Update the transaction with the new value for the sequence key.
|
||||
b.tx.stm.Put(string(makeSequenceKey(b.id)), val)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sequence returns the current sequence number for this bucket without
|
||||
// incrementing it.
|
||||
func (b *readWriteBucket) Sequence() uint64 {
|
||||
val, err := b.tx.stm.Get(string(makeSequenceKey(b.id)))
|
||||
if err != nil {
|
||||
// TODO: This update kvdb interface such that error
|
||||
// may be returned here.
|
||||
return 0
|
||||
}
|
||||
|
||||
if val == nil {
|
||||
// If the sequence number is not yet
|
||||
// stored, then take the default value.
|
||||
return 0
|
||||
}
|
||||
|
||||
// Otherwise try to parse a 64 bit unsigned integer from the value.
|
||||
num, _ := strconv.ParseUint(string(val), 10, 64)
|
||||
|
||||
return num
|
||||
}
|
524
kvdb/etcd/readwrite_bucket_test.go
Normal file
524
kvdb/etcd/readwrite_bucket_test.go
Normal file
@@ -0,0 +1,524 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBucketCreation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
// empty bucket name
|
||||
b, err := tx.CreateTopLevelBucket(nil)
|
||||
require.Error(t, walletdb.ErrBucketNameRequired, err)
|
||||
require.Nil(t, b)
|
||||
|
||||
// empty bucket name
|
||||
b, err = tx.CreateTopLevelBucket([]byte(""))
|
||||
require.Error(t, walletdb.ErrBucketNameRequired, err)
|
||||
require.Nil(t, b)
|
||||
|
||||
// "apple"
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
// Check bucket tx.
|
||||
require.Equal(t, tx, apple.Tx())
|
||||
|
||||
// "apple" already created
|
||||
b, err = tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, b)
|
||||
|
||||
// "apple/banana"
|
||||
banana, err := apple.CreateBucket([]byte("banana"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, banana)
|
||||
|
||||
banana, err = apple.CreateBucketIfNotExists([]byte("banana"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, banana)
|
||||
|
||||
// Try creating "apple/banana" again
|
||||
b, err = apple.CreateBucket([]byte("banana"))
|
||||
require.Error(t, walletdb.ErrBucketExists, err)
|
||||
require.Nil(t, b)
|
||||
|
||||
// "apple/mango"
|
||||
mango, err := apple.CreateBucket([]byte("mango"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, mango)
|
||||
|
||||
// "apple/banana/pear"
|
||||
pear, err := banana.CreateBucket([]byte("pear"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, pear)
|
||||
|
||||
// empty bucket
|
||||
require.Nil(t, apple.NestedReadWriteBucket(nil))
|
||||
require.Nil(t, apple.NestedReadWriteBucket([]byte("")))
|
||||
|
||||
// "apple/pear" doesn't exist
|
||||
require.Nil(t, apple.NestedReadWriteBucket([]byte("pear")))
|
||||
|
||||
// "apple/banana" exits
|
||||
require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana")))
|
||||
require.NotNil(t, apple.NestedReadBucket([]byte("banana")))
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
bkey("apple", "banana"): bval("apple", "banana"),
|
||||
bkey("apple", "mango"): bval("apple", "mango"),
|
||||
bkey("apple", "banana", "pear"): bval("apple", "banana", "pear"),
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
||||
|
||||
func TestBucketDeletion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
// "apple"
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
// "apple/banana"
|
||||
banana, err := apple.CreateBucket([]byte("banana"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, banana)
|
||||
|
||||
kvs := []KV{{"key1", "val1"}, {"key2", "val2"}, {"key3", "val3"}}
|
||||
|
||||
for _, kv := range kvs {
|
||||
require.NoError(t, banana.Put([]byte(kv.key), []byte(kv.val)))
|
||||
require.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key)))
|
||||
}
|
||||
|
||||
// Delete a k/v from "apple/banana"
|
||||
require.NoError(t, banana.Delete([]byte("key2")))
|
||||
// Try getting/putting/deleting invalid k/v's.
|
||||
require.Nil(t, banana.Get(nil))
|
||||
require.Error(t, walletdb.ErrKeyRequired, banana.Put(nil, []byte("val")))
|
||||
require.Error(t, walletdb.ErrKeyRequired, banana.Delete(nil))
|
||||
|
||||
// Try deleting a k/v that doesn't exist.
|
||||
require.NoError(t, banana.Delete([]byte("nokey")))
|
||||
|
||||
// "apple/pear"
|
||||
pear, err := apple.CreateBucket([]byte("pear"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, pear)
|
||||
|
||||
// Put some values into "apple/pear"
|
||||
for _, kv := range kvs {
|
||||
require.Nil(t, pear.Put([]byte(kv.key), []byte(kv.val)))
|
||||
require.Equal(t, []byte(kv.val), pear.Get([]byte(kv.key)))
|
||||
}
|
||||
|
||||
// Create nested bucket "apple/pear/cherry"
|
||||
cherry, err := pear.CreateBucket([]byte("cherry"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, cherry)
|
||||
|
||||
// Put some values into "apple/pear/cherry"
|
||||
for _, kv := range kvs {
|
||||
require.NoError(t, cherry.Put([]byte(kv.key), []byte(kv.val)))
|
||||
}
|
||||
|
||||
// Read back values in "apple/pear/cherry" trough a read bucket.
|
||||
cherryReadBucket := pear.NestedReadBucket([]byte("cherry"))
|
||||
for _, kv := range kvs {
|
||||
require.Equal(
|
||||
t, []byte(kv.val),
|
||||
cherryReadBucket.Get([]byte(kv.key)),
|
||||
)
|
||||
}
|
||||
|
||||
// Try deleting some invalid buckets.
|
||||
require.Error(t,
|
||||
walletdb.ErrBucketNameRequired, apple.DeleteNestedBucket(nil),
|
||||
)
|
||||
|
||||
// Try deleting a non existing bucket.
|
||||
require.Error(
|
||||
t,
|
||||
walletdb.ErrBucketNotFound,
|
||||
apple.DeleteNestedBucket([]byte("missing")),
|
||||
)
|
||||
|
||||
// Delete "apple/pear"
|
||||
require.Nil(t, apple.DeleteNestedBucket([]byte("pear")))
|
||||
|
||||
// "apple/pear" deleted
|
||||
require.Nil(t, apple.NestedReadWriteBucket([]byte("pear")))
|
||||
|
||||
// "apple/pear/cherry" deleted
|
||||
require.Nil(t, pear.NestedReadWriteBucket([]byte("cherry")))
|
||||
|
||||
// Values deleted too.
|
||||
for _, kv := range kvs {
|
||||
require.Nil(t, pear.Get([]byte(kv.key)))
|
||||
require.Nil(t, cherry.Get([]byte(kv.key)))
|
||||
}
|
||||
|
||||
// "aple/banana" exists
|
||||
require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana")))
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
bkey("apple", "banana"): bval("apple", "banana"),
|
||||
vkey("key1", "apple", "banana"): "val1",
|
||||
vkey("key3", "apple", "banana"): "val3",
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
||||
|
||||
func TestBucketForEach(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
// "apple"
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
// "apple/banana"
|
||||
banana, err := apple.CreateBucket([]byte("banana"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, banana)
|
||||
|
||||
kvs := []KV{{"key1", "val1"}, {"key2", "val2"}, {"key3", "val3"}}
|
||||
|
||||
// put some values into "apple" and "apple/banana" too
|
||||
for _, kv := range kvs {
|
||||
require.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val)))
|
||||
require.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key)))
|
||||
|
||||
require.Nil(t, banana.Put([]byte(kv.key), []byte(kv.val)))
|
||||
require.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key)))
|
||||
}
|
||||
|
||||
got := make(map[string]string)
|
||||
err = apple.ForEach(func(key, val []byte) error {
|
||||
got[string(key)] = string(val)
|
||||
return nil
|
||||
})
|
||||
|
||||
expected := map[string]string{
|
||||
"key1": "val1",
|
||||
"key2": "val2",
|
||||
"key3": "val3",
|
||||
"banana": "",
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, got)
|
||||
|
||||
got = make(map[string]string)
|
||||
err = banana.ForEach(func(key, val []byte) error {
|
||||
got[string(key)] = string(val)
|
||||
return nil
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
// remove the sub-bucket key
|
||||
delete(expected, "banana")
|
||||
require.Equal(t, expected, got)
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
bkey("apple", "banana"): bval("apple", "banana"),
|
||||
vkey("key1", "apple"): "val1",
|
||||
vkey("key2", "apple"): "val2",
|
||||
vkey("key3", "apple"): "val3",
|
||||
vkey("key1", "apple", "banana"): "val1",
|
||||
vkey("key2", "apple", "banana"): "val2",
|
||||
vkey("key3", "apple", "banana"): "val3",
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
||||
|
||||
func TestBucketForEachWithError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
// "apple"
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
// "apple/banana"
|
||||
banana, err := apple.CreateBucket([]byte("banana"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, banana)
|
||||
|
||||
// "apple/pear"
|
||||
pear, err := apple.CreateBucket([]byte("pear"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, pear)
|
||||
|
||||
kvs := []KV{{"key1", "val1"}, {"key2", "val2"}}
|
||||
|
||||
// Put some values into "apple" and "apple/banana" too.
|
||||
for _, kv := range kvs {
|
||||
require.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val)))
|
||||
require.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key)))
|
||||
}
|
||||
|
||||
got := make(map[string]string)
|
||||
i := 0
|
||||
// Error while iterating value keys.
|
||||
err = apple.ForEach(func(key, val []byte) error {
|
||||
if i == 2 {
|
||||
return fmt.Errorf("error")
|
||||
}
|
||||
|
||||
got[string(key)] = string(val)
|
||||
i++
|
||||
return nil
|
||||
})
|
||||
|
||||
expected := map[string]string{
|
||||
"banana": "",
|
||||
"key1": "val1",
|
||||
}
|
||||
|
||||
require.Equal(t, expected, got)
|
||||
require.Error(t, err)
|
||||
|
||||
got = make(map[string]string)
|
||||
i = 0
|
||||
// Erro while iterating buckets.
|
||||
err = apple.ForEach(func(key, val []byte) error {
|
||||
if i == 3 {
|
||||
return fmt.Errorf("error")
|
||||
}
|
||||
|
||||
got[string(key)] = string(val)
|
||||
i++
|
||||
return nil
|
||||
})
|
||||
|
||||
expected = map[string]string{
|
||||
"banana": "",
|
||||
"key1": "val1",
|
||||
"key2": "val2",
|
||||
}
|
||||
|
||||
require.Equal(t, expected, got)
|
||||
require.Error(t, err)
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
bkey("apple", "banana"): bval("apple", "banana"),
|
||||
bkey("apple", "pear"): bval("apple", "pear"),
|
||||
vkey("key1", "apple"): "val1",
|
||||
vkey("key2", "apple"): "val2",
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
||||
|
||||
func TestBucketSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
banana, err := apple.CreateBucket([]byte("banana"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, banana)
|
||||
|
||||
require.Equal(t, uint64(0), apple.Sequence())
|
||||
require.Equal(t, uint64(0), banana.Sequence())
|
||||
|
||||
require.Nil(t, apple.SetSequence(math.MaxUint64))
|
||||
require.Equal(t, uint64(math.MaxUint64), apple.Sequence())
|
||||
|
||||
for i := uint64(0); i < uint64(5); i++ {
|
||||
s, err := apple.NextSequence()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, i, s)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// TestKeyClash tests that one cannot create a bucket if a value with the same
|
||||
// key exists and the same is true in reverse: that a value cannot be put if
|
||||
// a bucket with the same key exists.
|
||||
func TestKeyClash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
// First:
|
||||
// put: /apple/key -> val
|
||||
// create bucket: /apple/banana
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
require.NoError(t, apple.Put([]byte("key"), []byte("val")))
|
||||
|
||||
banana, err := apple.CreateBucket([]byte("banana"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, banana)
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
// Next try to:
|
||||
// put: /apple/banana -> val => will fail (as /apple/banana is a bucket)
|
||||
// create bucket: /apple/key => will fail (as /apple/key is a value)
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
require.Error(t,
|
||||
walletdb.ErrIncompatibleValue,
|
||||
apple.Put([]byte("banana"), []byte("val")),
|
||||
)
|
||||
|
||||
b, err := apple.CreateBucket([]byte("key"))
|
||||
require.Nil(t, b)
|
||||
require.Error(t, walletdb.ErrIncompatibleValue, b)
|
||||
|
||||
b, err = apple.CreateBucketIfNotExists([]byte("key"))
|
||||
require.Nil(t, b)
|
||||
require.Error(t, walletdb.ErrIncompatibleValue, b)
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
// Except that the only existing items in the db are:
|
||||
// bucket: /apple
|
||||
// bucket: /apple/banana
|
||||
// value: /apple/key -> val
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
bkey("apple", "banana"): bval("apple", "banana"),
|
||||
vkey("key", "apple"): "val",
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
|
||||
}
|
||||
|
||||
// TestBucketCreateDelete tests that creating then deleting then creating a
|
||||
// bucket suceeds.
|
||||
func TestBucketCreateDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
banana, err := apple.CreateBucket([]byte("banana"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, banana)
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
apple := tx.ReadWriteBucket([]byte("apple"))
|
||||
require.NotNil(t, apple)
|
||||
require.NoError(t, apple.DeleteNestedBucket([]byte("banana")))
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
apple := tx.ReadWriteBucket([]byte("apple"))
|
||||
require.NotNil(t, apple)
|
||||
require.NoError(t, apple.Put([]byte("banana"), []byte("value")))
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
vkey("banana", "apple"): "value",
|
||||
bkey("apple"): bval("apple"),
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
143
kvdb/etcd/readwrite_cursor.go
Normal file
143
kvdb/etcd/readwrite_cursor.go
Normal file
@@ -0,0 +1,143 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
// readWriteCursor holds a reference to the cursors bucket, the value
|
||||
// prefix and the current key used while iterating.
|
||||
type readWriteCursor struct {
|
||||
// bucket holds the reference to the parent bucket.
|
||||
bucket *readWriteBucket
|
||||
|
||||
// prefix holds the value prefix which is in front of each
|
||||
// value key in the bucket.
|
||||
prefix string
|
||||
|
||||
// currKey holds the current key of the cursor.
|
||||
currKey string
|
||||
}
|
||||
|
||||
func newReadWriteCursor(bucket *readWriteBucket) *readWriteCursor {
|
||||
return &readWriteCursor{
|
||||
bucket: bucket,
|
||||
prefix: string(bucket.id),
|
||||
}
|
||||
}
|
||||
|
||||
// First positions the cursor at the first key/value pair and returns
|
||||
// the pair.
|
||||
func (c *readWriteCursor) First() (key, value []byte) {
|
||||
// Get the first key with the value prefix.
|
||||
kv, err := c.bucket.tx.stm.First(c.prefix)
|
||||
if err != nil {
|
||||
// TODO: revise this once kvdb interface supports errors
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if kv != nil {
|
||||
c.currKey = kv.key
|
||||
return getKeyVal(kv)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Last positions the cursor at the last key/value pair and returns the
|
||||
// pair.
|
||||
func (c *readWriteCursor) Last() (key, value []byte) {
|
||||
kv, err := c.bucket.tx.stm.Last(c.prefix)
|
||||
if err != nil {
|
||||
// TODO: revise this once kvdb interface supports errors
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if kv != nil {
|
||||
c.currKey = kv.key
|
||||
return getKeyVal(kv)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Next moves the cursor one key/value pair forward and returns the new
|
||||
// pair.
|
||||
func (c *readWriteCursor) Next() (key, value []byte) {
|
||||
kv, err := c.bucket.tx.stm.Next(c.prefix, c.currKey)
|
||||
if err != nil {
|
||||
// TODO: revise this once kvdb interface supports errors
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if kv != nil {
|
||||
c.currKey = kv.key
|
||||
return getKeyVal(kv)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Prev moves the cursor one key/value pair backward and returns the new
|
||||
// pair.
|
||||
func (c *readWriteCursor) Prev() (key, value []byte) {
|
||||
kv, err := c.bucket.tx.stm.Prev(c.prefix, c.currKey)
|
||||
if err != nil {
|
||||
// TODO: revise this once kvdb interface supports errors
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if kv != nil {
|
||||
c.currKey = kv.key
|
||||
return getKeyVal(kv)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Seek positions the cursor at the passed seek key. If the key does
|
||||
// not exist, the cursor is moved to the next key after seek. Returns
|
||||
// the new pair.
|
||||
func (c *readWriteCursor) Seek(seek []byte) (key, value []byte) {
|
||||
// Return nil if trying to seek to an empty key.
|
||||
if seek == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Seek to the first key with prefix + seek. If that key is not present
|
||||
// STM will seek to the next matching key with prefix.
|
||||
kv, err := c.bucket.tx.stm.Seek(c.prefix, c.prefix+string(seek))
|
||||
if err != nil {
|
||||
// TODO: revise this once kvdb interface supports errors
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if kv != nil {
|
||||
c.currKey = kv.key
|
||||
return getKeyVal(kv)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Delete removes the current key/value pair the cursor is at without
|
||||
// invalidating the cursor. Returns ErrIncompatibleValue if attempted
|
||||
// when the cursor points to a nested bucket.
|
||||
func (c *readWriteCursor) Delete() error {
|
||||
// Get the next key after the current one. We could do this
|
||||
// after deletion too but it's one step more efficient here.
|
||||
nextKey, err := c.bucket.tx.stm.Next(c.prefix, c.currKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isBucketKey(c.currKey) {
|
||||
c.bucket.DeleteNestedBucket(getKey(c.currKey))
|
||||
} else {
|
||||
c.bucket.Delete(getKey(c.currKey))
|
||||
}
|
||||
|
||||
if nextKey != nil {
|
||||
// Set current key to the next one.
|
||||
c.currKey = nextKey.key
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
369
kvdb/etcd/readwrite_cursor_test.go
Normal file
369
kvdb/etcd/readwrite_cursor_test.go
Normal file
@@ -0,0 +1,369 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReadCursorEmptyInterval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
b, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, b)
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.View(func(tx walletdb.ReadTx) error {
|
||||
b := tx.ReadBucket([]byte("apple"))
|
||||
require.NotNil(t, b)
|
||||
|
||||
cursor := b.ReadCursor()
|
||||
k, v := cursor.First()
|
||||
require.Nil(t, k)
|
||||
require.Nil(t, v)
|
||||
|
||||
k, v = cursor.Next()
|
||||
require.Nil(t, k)
|
||||
require.Nil(t, v)
|
||||
|
||||
k, v = cursor.Last()
|
||||
require.Nil(t, k)
|
||||
require.Nil(t, v)
|
||||
|
||||
k, v = cursor.Prev()
|
||||
require.Nil(t, k)
|
||||
require.Nil(t, v)
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestReadCursorNonEmptyInterval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
testKeyValues := []KV{
|
||||
{"b", "1"},
|
||||
{"c", "2"},
|
||||
{"da", "3"},
|
||||
{"e", "4"},
|
||||
}
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
b, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, b)
|
||||
|
||||
for _, kv := range testKeyValues {
|
||||
require.NoError(t, b.Put([]byte(kv.key), []byte(kv.val)))
|
||||
}
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.View(func(tx walletdb.ReadTx) error {
|
||||
b := tx.ReadBucket([]byte("apple"))
|
||||
require.NotNil(t, b)
|
||||
|
||||
// Iterate from the front.
|
||||
var kvs []KV
|
||||
cursor := b.ReadCursor()
|
||||
k, v := cursor.First()
|
||||
|
||||
for k != nil && v != nil {
|
||||
kvs = append(kvs, KV{string(k), string(v)})
|
||||
k, v = cursor.Next()
|
||||
}
|
||||
require.Equal(t, testKeyValues, kvs)
|
||||
|
||||
// Iterate from the back.
|
||||
kvs = []KV{}
|
||||
k, v = cursor.Last()
|
||||
|
||||
for k != nil && v != nil {
|
||||
kvs = append(kvs, KV{string(k), string(v)})
|
||||
k, v = cursor.Prev()
|
||||
}
|
||||
require.Equal(t, reverseKVs(testKeyValues), kvs)
|
||||
|
||||
// Random access
|
||||
perm := []int{3, 0, 2, 1}
|
||||
for _, i := range perm {
|
||||
k, v := cursor.Seek([]byte(testKeyValues[i].key))
|
||||
require.Equal(t, []byte(testKeyValues[i].key), k)
|
||||
require.Equal(t, []byte(testKeyValues[i].val), v)
|
||||
}
|
||||
|
||||
// Seek to nonexisting key.
|
||||
k, v = cursor.Seek(nil)
|
||||
require.Nil(t, k)
|
||||
require.Nil(t, v)
|
||||
|
||||
k, v = cursor.Seek([]byte("x"))
|
||||
require.Nil(t, k)
|
||||
require.Nil(t, v)
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestReadWriteCursor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
testKeyValues := []KV{
|
||||
{"b", "1"},
|
||||
{"c", "2"},
|
||||
{"da", "3"},
|
||||
{"e", "4"},
|
||||
}
|
||||
|
||||
count := len(testKeyValues)
|
||||
|
||||
// Pre-store the first half of the interval.
|
||||
require.NoError(t, db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
b, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, b)
|
||||
|
||||
for i := 0; i < count/2; i++ {
|
||||
err = b.Put(
|
||||
[]byte(testKeyValues[i].key),
|
||||
[]byte(testKeyValues[i].val),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
return nil
|
||||
}, func() {}))
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
b := tx.ReadWriteBucket([]byte("apple"))
|
||||
require.NotNil(t, b)
|
||||
|
||||
// Store the second half of the interval.
|
||||
for i := count / 2; i < count; i++ {
|
||||
err = b.Put(
|
||||
[]byte(testKeyValues[i].key),
|
||||
[]byte(testKeyValues[i].val),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
cursor := b.ReadWriteCursor()
|
||||
|
||||
// First on valid interval.
|
||||
fk, fv := cursor.First()
|
||||
require.Equal(t, []byte("b"), fk)
|
||||
require.Equal(t, []byte("1"), fv)
|
||||
|
||||
// Prev(First()) = nil
|
||||
k, v := cursor.Prev()
|
||||
require.Nil(t, k)
|
||||
require.Nil(t, v)
|
||||
|
||||
// Last on valid interval.
|
||||
lk, lv := cursor.Last()
|
||||
require.Equal(t, []byte("e"), lk)
|
||||
require.Equal(t, []byte("4"), lv)
|
||||
|
||||
// Next(Last()) = nil
|
||||
k, v = cursor.Next()
|
||||
require.Nil(t, k)
|
||||
require.Nil(t, v)
|
||||
|
||||
// Delete first item, then add an item before the
|
||||
// deleted one. Check that First/Next will "jump"
|
||||
// over the deleted item and return the new first.
|
||||
_, _ = cursor.First()
|
||||
require.NoError(t, cursor.Delete())
|
||||
require.NoError(t, b.Put([]byte("a"), []byte("0")))
|
||||
fk, fv = cursor.First()
|
||||
|
||||
require.Equal(t, []byte("a"), fk)
|
||||
require.Equal(t, []byte("0"), fv)
|
||||
|
||||
k, v = cursor.Next()
|
||||
require.Equal(t, []byte("c"), k)
|
||||
require.Equal(t, []byte("2"), v)
|
||||
|
||||
// Similarly test that a new end is returned if
|
||||
// the old end is deleted first.
|
||||
_, _ = cursor.Last()
|
||||
require.NoError(t, cursor.Delete())
|
||||
require.NoError(t, b.Put([]byte("f"), []byte("5")))
|
||||
|
||||
lk, lv = cursor.Last()
|
||||
require.Equal(t, []byte("f"), lk)
|
||||
require.Equal(t, []byte("5"), lv)
|
||||
|
||||
k, v = cursor.Prev()
|
||||
require.Equal(t, []byte("da"), k)
|
||||
require.Equal(t, []byte("3"), v)
|
||||
|
||||
// Overwrite k/v in the middle of the interval.
|
||||
require.NoError(t, b.Put([]byte("c"), []byte("3")))
|
||||
k, v = cursor.Prev()
|
||||
require.Equal(t, []byte("c"), k)
|
||||
require.Equal(t, []byte("3"), v)
|
||||
|
||||
// Insert new key/values.
|
||||
require.NoError(t, b.Put([]byte("cx"), []byte("x")))
|
||||
require.NoError(t, b.Put([]byte("cy"), []byte("y")))
|
||||
|
||||
k, v = cursor.Next()
|
||||
require.Equal(t, []byte("cx"), k)
|
||||
require.Equal(t, []byte("x"), v)
|
||||
|
||||
k, v = cursor.Next()
|
||||
require.Equal(t, []byte("cy"), k)
|
||||
require.Equal(t, []byte("y"), v)
|
||||
|
||||
expected := []KV{
|
||||
{"a", "0"},
|
||||
{"c", "3"},
|
||||
{"cx", "x"},
|
||||
{"cy", "y"},
|
||||
{"da", "3"},
|
||||
{"f", "5"},
|
||||
}
|
||||
|
||||
// Iterate from the front.
|
||||
var kvs []KV
|
||||
k, v = cursor.First()
|
||||
|
||||
for k != nil && v != nil {
|
||||
kvs = append(kvs, KV{string(k), string(v)})
|
||||
k, v = cursor.Next()
|
||||
}
|
||||
require.Equal(t, expected, kvs)
|
||||
|
||||
// Iterate from the back.
|
||||
kvs = []KV{}
|
||||
k, v = cursor.Last()
|
||||
|
||||
for k != nil && v != nil {
|
||||
kvs = append(kvs, KV{string(k), string(v)})
|
||||
k, v = cursor.Prev()
|
||||
}
|
||||
require.Equal(t, reverseKVs(expected), kvs)
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
vkey("a", "apple"): "0",
|
||||
vkey("c", "apple"): "3",
|
||||
vkey("cx", "apple"): "x",
|
||||
vkey("cy", "apple"): "y",
|
||||
vkey("da", "apple"): "3",
|
||||
vkey("f", "apple"): "5",
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
||||
|
||||
// TestReadWriteCursorWithBucketAndValue tests that cursors are able to iterate
|
||||
// over both bucket and value keys if both are present in the iterated bucket.
|
||||
func TestReadWriteCursorWithBucketAndValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Pre-store the first half of the interval.
|
||||
require.NoError(t, db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
b, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, b)
|
||||
|
||||
require.NoError(t, b.Put([]byte("key"), []byte("val")))
|
||||
|
||||
b1, err := b.CreateBucket([]byte("banana"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, b1)
|
||||
|
||||
b2, err := b.CreateBucket([]byte("pear"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, b2)
|
||||
|
||||
return nil
|
||||
}, func() {}))
|
||||
|
||||
err = db.View(func(tx walletdb.ReadTx) error {
|
||||
b := tx.ReadBucket([]byte("apple"))
|
||||
require.NotNil(t, b)
|
||||
|
||||
cursor := b.ReadCursor()
|
||||
|
||||
// First on valid interval.
|
||||
k, v := cursor.First()
|
||||
require.Equal(t, []byte("banana"), k)
|
||||
require.Nil(t, v)
|
||||
|
||||
k, v = cursor.Next()
|
||||
require.Equal(t, []byte("key"), k)
|
||||
require.Equal(t, []byte("val"), v)
|
||||
|
||||
k, v = cursor.Last()
|
||||
require.Equal(t, []byte("pear"), k)
|
||||
require.Nil(t, v)
|
||||
|
||||
k, v = cursor.Seek([]byte("k"))
|
||||
require.Equal(t, []byte("key"), k)
|
||||
require.Equal(t, []byte("val"), v)
|
||||
|
||||
k, v = cursor.Seek([]byte("banana"))
|
||||
require.Equal(t, []byte("banana"), k)
|
||||
require.Nil(t, v)
|
||||
|
||||
k, v = cursor.Next()
|
||||
require.Equal(t, []byte("key"), k)
|
||||
require.Equal(t, []byte("val"), v)
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
bkey("apple", "banana"): bval("apple", "banana"),
|
||||
bkey("apple", "pear"): bval("apple", "pear"),
|
||||
vkey("key", "apple"): "val",
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
99
kvdb/etcd/readwrite_tx.go
Normal file
99
kvdb/etcd/readwrite_tx.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
)
|
||||
|
||||
// readWriteTx holds a reference to the STM transaction.
|
||||
type readWriteTx struct {
|
||||
// stm is the reference to the parent STM.
|
||||
stm STM
|
||||
|
||||
// rootBucketID holds the sha256 hash of the root bucket id, which is used
|
||||
// for key space spearation.
|
||||
rootBucketID [bucketIDLength]byte
|
||||
|
||||
// active is true if the transaction hasn't been committed yet.
|
||||
active bool
|
||||
}
|
||||
|
||||
// newReadWriteTx creates an rw transaction with the passed STM.
|
||||
func newReadWriteTx(stm STM, prefix string) *readWriteTx {
|
||||
return &readWriteTx{
|
||||
stm: stm,
|
||||
active: true,
|
||||
rootBucketID: makeBucketID([]byte(prefix)),
|
||||
}
|
||||
}
|
||||
|
||||
// rooBucket is a helper function to return the always present
|
||||
// pseudo root bucket.
|
||||
func rootBucket(tx *readWriteTx) *readWriteBucket {
|
||||
return newReadWriteBucket(tx, tx.rootBucketID[:], tx.rootBucketID[:])
|
||||
}
|
||||
|
||||
// ReadBucket opens the root bucket for read only access. If the bucket
|
||||
// described by the key does not exist, nil is returned.
|
||||
func (tx *readWriteTx) ReadBucket(key []byte) walletdb.ReadBucket {
|
||||
return rootBucket(tx).NestedReadWriteBucket(key)
|
||||
}
|
||||
|
||||
// Rollback closes the transaction, discarding changes (if any) if the
|
||||
// database was modified by a write transaction.
|
||||
func (tx *readWriteTx) Rollback() error {
|
||||
// If the transaction has been closed roolback will fail.
|
||||
if !tx.active {
|
||||
return walletdb.ErrTxClosed
|
||||
}
|
||||
|
||||
// Rollback the STM and set the tx to inactive.
|
||||
tx.stm.Rollback()
|
||||
tx.active = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadWriteBucket opens the root bucket for read/write access. If the
|
||||
// bucket described by the key does not exist, nil is returned.
|
||||
func (tx *readWriteTx) ReadWriteBucket(key []byte) walletdb.ReadWriteBucket {
|
||||
return rootBucket(tx).NestedReadWriteBucket(key)
|
||||
}
|
||||
|
||||
// CreateTopLevelBucket creates the top level bucket for a key if it
|
||||
// does not exist. The newly-created bucket it returned.
|
||||
func (tx *readWriteTx) CreateTopLevelBucket(key []byte) (walletdb.ReadWriteBucket, error) {
|
||||
return rootBucket(tx).CreateBucketIfNotExists(key)
|
||||
}
|
||||
|
||||
// DeleteTopLevelBucket deletes the top level bucket for a key. This
|
||||
// errors if the bucket can not be found or the key keys a single value
|
||||
// instead of a bucket.
|
||||
func (tx *readWriteTx) DeleteTopLevelBucket(key []byte) error {
|
||||
return rootBucket(tx).DeleteNestedBucket(key)
|
||||
}
|
||||
|
||||
// Commit commits the transaction if not already committed. Will return
|
||||
// error if the underlying STM fails.
|
||||
func (tx *readWriteTx) Commit() error {
|
||||
// Commit will fail if the transaction is already committed.
|
||||
if !tx.active {
|
||||
return walletdb.ErrTxClosed
|
||||
}
|
||||
|
||||
// Try committing the transaction.
|
||||
if err := tx.stm.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark the transaction as not active after commit.
|
||||
tx.active = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnCommit sets the commit callback (overriding if already set).
|
||||
func (tx *readWriteTx) OnCommit(cb func()) {
|
||||
tx.stm.OnCommit(cb)
|
||||
}
|
157
kvdb/etcd/readwrite_tx_test.go
Normal file
157
kvdb/etcd/readwrite_tx_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTxManualCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
tx, err := db.BeginReadWriteTx()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tx)
|
||||
|
||||
committed := false
|
||||
|
||||
tx.OnCommit(func() {
|
||||
committed = true
|
||||
})
|
||||
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, apple)
|
||||
require.NoError(t, apple.Put([]byte("testKey"), []byte("testVal")))
|
||||
|
||||
banana, err := tx.CreateTopLevelBucket([]byte("banana"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, banana)
|
||||
require.NoError(t, banana.Put([]byte("testKey"), []byte("testVal")))
|
||||
require.NoError(t, tx.DeleteTopLevelBucket([]byte("banana")))
|
||||
|
||||
require.NoError(t, tx.Commit())
|
||||
require.True(t, committed)
|
||||
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
vkey("testKey", "apple"): "testVal",
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
||||
|
||||
func TestTxRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
tx, err := db.BeginReadWriteTx()
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, tx)
|
||||
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
require.NoError(t, apple.Put([]byte("testKey"), []byte("testVal")))
|
||||
|
||||
require.NoError(t, tx.Rollback())
|
||||
require.Error(t, walletdb.ErrTxClosed, tx.Commit())
|
||||
require.Equal(t, map[string]string{}, f.Dump())
|
||||
}
|
||||
|
||||
func TestChangeDuringManualTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
tx, err := db.BeginReadWriteTx()
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, tx)
|
||||
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
require.NoError(t, apple.Put([]byte("testKey"), []byte("testVal")))
|
||||
|
||||
// Try overwriting the bucket key.
|
||||
f.Put(bkey("apple"), "banana")
|
||||
|
||||
// TODO: translate error
|
||||
require.NotNil(t, tx.Commit())
|
||||
require.Equal(t, map[string]string{
|
||||
bkey("apple"): "banana",
|
||||
}, f.Dump())
|
||||
}
|
||||
|
||||
func TestChangeDuringUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(context.TODO(), f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
|
||||
err = db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, apple)
|
||||
|
||||
require.NoError(t, apple.Put([]byte("key"), []byte("value")))
|
||||
|
||||
if count == 0 {
|
||||
f.Put(vkey("key", "apple"), "new_value")
|
||||
f.Put(vkey("key2", "apple"), "value2")
|
||||
}
|
||||
|
||||
cursor := apple.ReadCursor()
|
||||
k, v := cursor.First()
|
||||
require.Equal(t, []byte("key"), k)
|
||||
require.Equal(t, []byte("value"), v)
|
||||
require.Equal(t, v, apple.Get([]byte("key")))
|
||||
|
||||
k, v = cursor.Next()
|
||||
if count == 0 {
|
||||
require.Nil(t, k)
|
||||
require.Nil(t, v)
|
||||
} else {
|
||||
require.Equal(t, []byte("key2"), k)
|
||||
require.Equal(t, []byte("value2"), v)
|
||||
}
|
||||
|
||||
count++
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, count, 2)
|
||||
|
||||
expected := map[string]string{
|
||||
bkey("apple"): bval("apple"),
|
||||
vkey("key", "apple"): "value",
|
||||
vkey("key2", "apple"): "value2",
|
||||
}
|
||||
require.Equal(t, expected, f.Dump())
|
||||
}
|
805
kvdb/etcd/stm.go
Normal file
805
kvdb/etcd/stm.go
Normal file
@@ -0,0 +1,805 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
v3 "go.etcd.io/etcd/clientv3"
|
||||
)
|
||||
|
||||
type CommitStats struct {
|
||||
Rset int
|
||||
Wset int
|
||||
Retries int
|
||||
}
|
||||
|
||||
// KV stores a key/value pair.
|
||||
type KV struct {
|
||||
key string
|
||||
val string
|
||||
}
|
||||
|
||||
// STM is an interface for software transactional memory.
|
||||
// All calls that return error will do so only if STM is manually handled and
|
||||
// abort the apply closure otherwise. In both case the returned error is a
|
||||
// DatabaseError.
|
||||
type STM interface {
|
||||
// Get returns the value for a key and inserts the key in the txn's read
|
||||
// set. Returns nil if there's no matching key, or the key is empty.
|
||||
Get(key string) ([]byte, error)
|
||||
|
||||
// Put adds a value for a key to the txn's write set.
|
||||
Put(key, val string)
|
||||
|
||||
// Del adds a delete operation for the key to the txn's write set.
|
||||
Del(key string)
|
||||
|
||||
// First returns the first k/v that begins with prefix or nil if there's
|
||||
// no such k/v pair. If the key is found it is inserted to the txn's
|
||||
// read set. Returns nil if there's no match.
|
||||
First(prefix string) (*KV, error)
|
||||
|
||||
// Last returns the last k/v that begins with prefix or nil if there's
|
||||
// no such k/v pair. If the key is found it is inserted to the txn's
|
||||
// read set. Returns nil if there's no match.
|
||||
Last(prefix string) (*KV, error)
|
||||
|
||||
// Prev returns the previous k/v before key that begins with prefix or
|
||||
// nil if there's no such k/v. If the key is found it is inserted to the
|
||||
// read set. Returns nil if there's no match.
|
||||
Prev(prefix, key string) (*KV, error)
|
||||
|
||||
// Next returns the next k/v after key that begins with prefix or nil
|
||||
// if there's no such k/v. If the key is found it is inserted to the
|
||||
// txn's read set. Returns nil if there's no match.
|
||||
Next(prefix, key string) (*KV, error)
|
||||
|
||||
// Seek will return k/v at key beginning with prefix. If the key doesn't
|
||||
// exists Seek will return the next k/v after key beginning with prefix.
|
||||
// If a matching k/v is found it is inserted to the txn's read set. Returns
|
||||
// nil if there's no match.
|
||||
Seek(prefix, key string) (*KV, error)
|
||||
|
||||
// OnCommit calls the passed callback func upon commit.
|
||||
OnCommit(func())
|
||||
|
||||
// Commit attempts to apply the txn's changes to the server.
|
||||
// Commit may return CommitError if transaction is outdated and needs retry.
|
||||
Commit() error
|
||||
|
||||
// Rollback emties the read and write sets such that a subsequent commit
|
||||
// won't alter the database.
|
||||
Rollback()
|
||||
}
|
||||
|
||||
// CommitError is used to check if there was an error
|
||||
// due to stale data in the transaction.
|
||||
type CommitError struct{}
|
||||
|
||||
// Error returns a static string for CommitError for
|
||||
// debugging/logging purposes.
|
||||
func (e CommitError) Error() string {
|
||||
return "commit failed"
|
||||
}
|
||||
|
||||
// DatabaseError is used to wrap errors that are not
|
||||
// related to stale data in the transaction.
|
||||
type DatabaseError struct {
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
// Unwrap returns the wrapped error in a DatabaseError.
|
||||
func (e *DatabaseError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// Error simply converts DatabaseError to a string that
|
||||
// includes both the message and the wrapped error.
|
||||
func (e DatabaseError) Error() string {
|
||||
return fmt.Sprintf("etcd error: %v - %v", e.msg, e.err)
|
||||
}
|
||||
|
||||
// stmGet is the result of a read operation,
|
||||
// a value and the mod revision of the key/value.
|
||||
type stmGet struct {
|
||||
val string
|
||||
rev int64
|
||||
}
|
||||
|
||||
// readSet stores all reads done in an STM.
|
||||
type readSet map[string]stmGet
|
||||
|
||||
// stmPut stores a value and an operation (put/delete).
|
||||
type stmPut struct {
|
||||
val string
|
||||
op v3.Op
|
||||
}
|
||||
|
||||
// writeSet stroes all writes done in an STM.
|
||||
type writeSet map[string]stmPut
|
||||
|
||||
// stm implements repeatable-read software transactional memory
|
||||
// over etcd.
|
||||
type stm struct {
|
||||
// client is an etcd client handling all RPC communications
|
||||
// to the etcd instance/cluster.
|
||||
client *v3.Client
|
||||
|
||||
// manual is set to true for manual transactions which don't
|
||||
// execute in the STM run loop.
|
||||
manual bool
|
||||
|
||||
// txQueue is lightweight contention manager, which is used to detect
|
||||
// transaction conflicts and reduce retries.
|
||||
txQueue *commitQueue
|
||||
|
||||
// options stores optional settings passed by the user.
|
||||
options *STMOptions
|
||||
|
||||
// prefetch hold prefetched key values and revisions.
|
||||
prefetch readSet
|
||||
|
||||
// rset holds read key values and revisions.
|
||||
rset readSet
|
||||
|
||||
// wset holds overwritten keys and their values.
|
||||
wset writeSet
|
||||
|
||||
// getOpts are the opts used for gets.
|
||||
getOpts []v3.OpOption
|
||||
|
||||
// revision stores the snapshot revision after first read.
|
||||
revision int64
|
||||
|
||||
// onCommit gets called upon commit.
|
||||
onCommit func()
|
||||
}
|
||||
|
||||
// STMOptions can be used to pass optional settings
|
||||
// when an STM is created.
|
||||
type STMOptions struct {
|
||||
// ctx holds an externally provided abort context.
|
||||
ctx context.Context
|
||||
commitStatsCallback func(bool, CommitStats)
|
||||
}
|
||||
|
||||
// STMOptionFunc is a function that updates the passed STMOptions.
|
||||
type STMOptionFunc func(*STMOptions)
|
||||
|
||||
// WithAbortContext specifies the context for permanently
|
||||
// aborting the transaction.
|
||||
func WithAbortContext(ctx context.Context) STMOptionFunc {
|
||||
return func(so *STMOptions) {
|
||||
so.ctx = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func WithCommitStatsCallback(cb func(bool, CommitStats)) STMOptionFunc {
|
||||
return func(so *STMOptions) {
|
||||
so.commitStatsCallback = cb
|
||||
}
|
||||
}
|
||||
|
||||
// RunSTM runs the apply function by creating an STM using serializable snapshot
|
||||
// isolation, passing it to the apply and handling commit errors and retries.
|
||||
func RunSTM(cli *v3.Client, apply func(STM) error, txQueue *commitQueue,
|
||||
so ...STMOptionFunc) error {
|
||||
|
||||
return runSTM(makeSTM(cli, false, txQueue, so...), apply)
|
||||
}
|
||||
|
||||
// NewSTM creates a new STM instance, using serializable snapshot isolation.
|
||||
func NewSTM(cli *v3.Client, txQueue *commitQueue, so ...STMOptionFunc) STM {
|
||||
return makeSTM(cli, true, txQueue, so...)
|
||||
}
|
||||
|
||||
// makeSTM is the actual constructor of the stm. It first apply all passed
|
||||
// options then creates the stm object and resets it before returning.
|
||||
func makeSTM(cli *v3.Client, manual bool, txQueue *commitQueue,
|
||||
so ...STMOptionFunc) *stm {
|
||||
|
||||
opts := &STMOptions{
|
||||
ctx: cli.Ctx(),
|
||||
}
|
||||
|
||||
// Apply all functional options.
|
||||
for _, fo := range so {
|
||||
fo(opts)
|
||||
}
|
||||
|
||||
s := &stm{
|
||||
client: cli,
|
||||
manual: manual,
|
||||
txQueue: txQueue,
|
||||
options: opts,
|
||||
prefetch: make(map[string]stmGet),
|
||||
}
|
||||
|
||||
// Reset read and write set.
|
||||
s.Rollback()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// runSTM implements the run loop of the STM, running the apply func, catching
|
||||
// errors and handling commit. The loop will quit on every error except
|
||||
// CommitError which is used to indicate a necessary retry.
|
||||
func runSTM(s *stm, apply func(STM) error) error {
|
||||
var (
|
||||
retries int
|
||||
stats CommitStats
|
||||
executeErr error
|
||||
)
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
execute := func() {
|
||||
defer close(done)
|
||||
|
||||
for {
|
||||
select {
|
||||
// Check if the STM is aborted and break the retry loop
|
||||
// if it is.
|
||||
case <-s.options.ctx.Done():
|
||||
executeErr = fmt.Errorf("aborted")
|
||||
return
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
stats, executeErr = s.commit()
|
||||
|
||||
// Re-apply only upon commit error (meaning the
|
||||
// keys were changed).
|
||||
if _, ok := executeErr.(CommitError); !ok {
|
||||
// Anything that's not a CommitError
|
||||
// aborts the transaction.
|
||||
return
|
||||
}
|
||||
|
||||
// Rollback before trying to re-apply.
|
||||
s.Rollback()
|
||||
retries++
|
||||
|
||||
// Re-apply the transaction closure.
|
||||
if executeErr = apply(s); executeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run the tx closure to construct the read and write sets.
|
||||
// Also we expect that if there are no conflicting transactions
|
||||
// in the queue, then we only run apply once.
|
||||
if preApplyErr := apply(s); preApplyErr != nil {
|
||||
return preApplyErr
|
||||
}
|
||||
|
||||
// Queue up the transaction for execution.
|
||||
s.txQueue.Add(execute, s.rset, s.wset)
|
||||
|
||||
// Wait for the transaction to execute, or break if aborted.
|
||||
select {
|
||||
case <-done:
|
||||
case <-s.options.ctx.Done():
|
||||
}
|
||||
|
||||
s.txQueue.Done(s.rset, s.wset)
|
||||
|
||||
if s.options.commitStatsCallback != nil {
|
||||
stats.Retries = retries
|
||||
s.options.commitStatsCallback(executeErr == nil, stats)
|
||||
}
|
||||
|
||||
return executeErr
|
||||
}
|
||||
|
||||
// add inserts a txn response to the read set. This is useful when the txn
|
||||
// fails due to conflict where the txn response can be used to prefetch
|
||||
// key/values.
|
||||
func (rs readSet) add(txnResp *v3.TxnResponse) {
|
||||
for _, resp := range txnResp.Responses {
|
||||
getResp := (*v3.GetResponse)(resp.GetResponseRange())
|
||||
for _, kv := range getResp.Kvs {
|
||||
rs[string(kv.Key)] = stmGet{
|
||||
val: string(kv.Value),
|
||||
rev: kv.ModRevision,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gets is a helper to create an op slice for transaction
|
||||
// construction.
|
||||
func (rs readSet) gets() []v3.Op {
|
||||
ops := make([]v3.Op, 0, len(rs))
|
||||
|
||||
for k := range rs {
|
||||
ops = append(ops, v3.OpGet(k))
|
||||
}
|
||||
|
||||
return ops
|
||||
}
|
||||
|
||||
// cmps returns a compare list which will serve as a precondition testing that
|
||||
// the values in the read set didn't change.
|
||||
func (rs readSet) cmps() []v3.Cmp {
|
||||
cmps := make([]v3.Cmp, 0, len(rs))
|
||||
for key, getValue := range rs {
|
||||
cmps = append(cmps, v3.Compare(
|
||||
v3.ModRevision(key), "=", getValue.rev,
|
||||
))
|
||||
}
|
||||
|
||||
return cmps
|
||||
}
|
||||
|
||||
// cmps returns a cmp list testing no writes have happened past rev.
|
||||
func (ws writeSet) cmps(rev int64) []v3.Cmp {
|
||||
cmps := make([]v3.Cmp, 0, len(ws))
|
||||
for key := range ws {
|
||||
cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
|
||||
}
|
||||
|
||||
return cmps
|
||||
}
|
||||
|
||||
// puts is the list of ops for all pending writes.
|
||||
func (ws writeSet) puts() []v3.Op {
|
||||
puts := make([]v3.Op, 0, len(ws))
|
||||
for _, v := range ws {
|
||||
puts = append(puts, v.op)
|
||||
}
|
||||
|
||||
return puts
|
||||
}
|
||||
|
||||
// fetch is a helper to fetch key/value given options. If a value is returned
|
||||
// then fetch will try to fix the STM's snapshot revision (if not already set).
|
||||
// We'll also cache the returned key/value in the read set.
|
||||
func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) {
|
||||
resp, err := s.client.Get(
|
||||
s.options.ctx, key, append(opts, s.getOpts...)...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, DatabaseError{
|
||||
msg: "stm.fetch() failed",
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Set revison and serializable options upon first fetch
|
||||
// for any subsequent fetches.
|
||||
if s.getOpts == nil {
|
||||
s.revision = resp.Header.Revision
|
||||
s.getOpts = []v3.OpOption{
|
||||
v3.WithRev(s.revision),
|
||||
v3.WithSerializable(),
|
||||
}
|
||||
}
|
||||
|
||||
if len(resp.Kvs) == 0 {
|
||||
// Add assertion to the read set which will extend our commit
|
||||
// constraint such that the commit will fail if the key is
|
||||
// present in the database.
|
||||
s.rset[key] = stmGet{
|
||||
rev: 0,
|
||||
}
|
||||
}
|
||||
|
||||
var result []KV
|
||||
|
||||
// Fill the read set with key/values returned.
|
||||
for _, kv := range resp.Kvs {
|
||||
// Remove from prefetch.
|
||||
key := string(kv.Key)
|
||||
val := string(kv.Value)
|
||||
|
||||
delete(s.prefetch, key)
|
||||
|
||||
// Add to read set.
|
||||
s.rset[key] = stmGet{
|
||||
val: val,
|
||||
rev: kv.ModRevision,
|
||||
}
|
||||
|
||||
result = append(result, KV{key, val})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Get returns the value for key. If there's no such
|
||||
// key/value in the database or the passed key is empty
|
||||
// Get will return nil.
|
||||
func (s *stm) Get(key string) ([]byte, error) {
|
||||
if key == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Return freshly written value if present.
|
||||
if put, ok := s.wset[key]; ok {
|
||||
if put.op.IsDelete() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []byte(put.val), nil
|
||||
}
|
||||
|
||||
// Populate read set if key is present in
|
||||
// the prefetch set.
|
||||
if getValue, ok := s.prefetch[key]; ok {
|
||||
delete(s.prefetch, key)
|
||||
|
||||
// Use the prefetched value only if it is for
|
||||
// an existing key.
|
||||
if getValue.rev != 0 {
|
||||
s.rset[key] = getValue
|
||||
}
|
||||
}
|
||||
|
||||
// Return value if alread in read set.
|
||||
if getValue, ok := s.rset[key]; ok {
|
||||
// Return the value if the rset contains an existing key.
|
||||
if getValue.rev != 0 {
|
||||
return []byte(getValue.val), nil
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch and return value.
|
||||
kvs, err := s.fetch(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(kvs) > 0 {
|
||||
return []byte(kvs[0].val), nil
|
||||
}
|
||||
|
||||
// Return empty result if key not in DB.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// First returns the first key/value matching prefix. If there's no key starting
|
||||
// with prefix, Last will return nil.
|
||||
func (s *stm) First(prefix string) (*KV, error) {
|
||||
return s.next(prefix, prefix, true)
|
||||
}
|
||||
|
||||
// Last returns the last key/value with prefix. If there's no key starting with
|
||||
// prefix, Last will return nil.
|
||||
func (s *stm) Last(prefix string) (*KV, error) {
|
||||
// As we don't know the full range, fetch the last
|
||||
// key/value with this prefix first.
|
||||
resp, err := s.fetch(prefix, v3.WithLastKey()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
kv KV
|
||||
found bool
|
||||
)
|
||||
|
||||
if len(resp) > 0 {
|
||||
kv = resp[0]
|
||||
found = true
|
||||
}
|
||||
|
||||
// Now make sure there's nothing in the write set
|
||||
// that is a better match, meaning it has the same
|
||||
// prefix but is greater or equal than the current
|
||||
// best candidate. Note that this is not efficient
|
||||
// when the write set is large!
|
||||
for k, put := range s.wset {
|
||||
if put.op.IsDelete() {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(k, prefix) && k >= kv.key {
|
||||
kv.key = k
|
||||
kv.val = put.val
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
return &kv, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Prev returns the prior key/value before key (with prefix). If there's no such
|
||||
// key Next will return nil.
|
||||
func (s *stm) Prev(prefix, startKey string) (*KV, error) {
|
||||
var result KV
|
||||
|
||||
fetchKey := startKey
|
||||
matchFound := false
|
||||
|
||||
for {
|
||||
// Ask etcd to retrieve one key that is a
|
||||
// match in descending order from the passed key.
|
||||
opts := []v3.OpOption{
|
||||
v3.WithRange(fetchKey),
|
||||
v3.WithSort(v3.SortByKey, v3.SortDescend),
|
||||
v3.WithLimit(1),
|
||||
}
|
||||
|
||||
kvs, err := s.fetch(prefix, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(kvs) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
kv := &kvs[0]
|
||||
|
||||
// WithRange and WithPrefix can't be used
|
||||
// together, so check prefix here. If the
|
||||
// returned key no longer has the prefix,
|
||||
// then break out.
|
||||
if !strings.HasPrefix(kv.key, prefix) {
|
||||
break
|
||||
}
|
||||
|
||||
// Fetch the prior key if this is deleted.
|
||||
if put, ok := s.wset[kv.key]; ok && put.op.IsDelete() {
|
||||
fetchKey = kv.key
|
||||
continue
|
||||
}
|
||||
|
||||
result = *kv
|
||||
matchFound = true
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// Closre holding all checks to find a possibly
|
||||
// better match.
|
||||
matches := func(key string) bool {
|
||||
if !strings.HasPrefix(key, prefix) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !matchFound {
|
||||
return key < startKey
|
||||
}
|
||||
|
||||
// matchFound == true
|
||||
return result.key <= key && key < startKey
|
||||
}
|
||||
|
||||
// Now go trough the write set and check
|
||||
// if there's an even better match.
|
||||
for k, put := range s.wset {
|
||||
if !put.op.IsDelete() && matches(k) {
|
||||
result.key = k
|
||||
result.val = put.val
|
||||
matchFound = true
|
||||
}
|
||||
}
|
||||
|
||||
if !matchFound {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Next returns the next key/value after key (with prefix). If there's no such
|
||||
// key Next will return nil.
|
||||
func (s *stm) Next(prefix string, key string) (*KV, error) {
|
||||
return s.next(prefix, key, false)
|
||||
}
|
||||
|
||||
// Seek "seeks" to the key (with prefix). If the key doesn't exists it'll get
|
||||
// the next key with the same prefix. If no key fills this criteria, Seek will
|
||||
// return nil.
|
||||
func (s *stm) Seek(prefix, key string) (*KV, error) {
|
||||
return s.next(prefix, key, true)
|
||||
}
|
||||
|
||||
// next will try to retrieve the next match that has prefix and starts with the
|
||||
// passed startKey. If includeStartKey is set to true, it'll return the value
|
||||
// of startKey (essentially implementing seek).
|
||||
func (s *stm) next(prefix, startKey string, includeStartKey bool) (*KV, error) {
|
||||
var result KV
|
||||
|
||||
fetchKey := startKey
|
||||
firstFetch := true
|
||||
matchFound := false
|
||||
|
||||
for {
|
||||
// Ask etcd to retrieve one key that is a
|
||||
// match in ascending order from the passed key.
|
||||
opts := []v3.OpOption{
|
||||
v3.WithFromKey(),
|
||||
v3.WithSort(v3.SortByKey, v3.SortAscend),
|
||||
v3.WithLimit(1),
|
||||
}
|
||||
|
||||
// By default we include the start key too
|
||||
// if it is a full match.
|
||||
if includeStartKey && firstFetch {
|
||||
firstFetch = false
|
||||
} else {
|
||||
// If we'd like to retrieve the first key
|
||||
// after the start key.
|
||||
fetchKey += "\x00"
|
||||
}
|
||||
|
||||
kvs, err := s.fetch(fetchKey, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(kvs) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
kv := &kvs[0]
|
||||
// WithRange and WithPrefix can't be used
|
||||
// together, so check prefix here. If the
|
||||
// returned key no longer has the prefix,
|
||||
// then break the fetch loop.
|
||||
if !strings.HasPrefix(kv.key, prefix) {
|
||||
break
|
||||
}
|
||||
|
||||
// Move on to fetch starting with the next
|
||||
// key if this one is marked deleted.
|
||||
if put, ok := s.wset[kv.key]; ok && put.op.IsDelete() {
|
||||
fetchKey = kv.key
|
||||
continue
|
||||
}
|
||||
|
||||
result = *kv
|
||||
matchFound = true
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// Closure holding all checks to find a possibly
|
||||
// better match.
|
||||
matches := func(k string) bool {
|
||||
if !strings.HasPrefix(k, prefix) {
|
||||
return false
|
||||
}
|
||||
|
||||
if includeStartKey && !matchFound {
|
||||
return startKey <= k
|
||||
}
|
||||
|
||||
if !includeStartKey && !matchFound {
|
||||
return startKey < k
|
||||
}
|
||||
|
||||
if includeStartKey && matchFound {
|
||||
return startKey <= k && k <= result.key
|
||||
}
|
||||
|
||||
// !includeStartKey && matchFound.
|
||||
return startKey < k && k <= result.key
|
||||
}
|
||||
|
||||
// Now go trough the write set and check
|
||||
// if there's an even better match.
|
||||
for k, put := range s.wset {
|
||||
if !put.op.IsDelete() && matches(k) {
|
||||
result.key = k
|
||||
result.val = put.val
|
||||
matchFound = true
|
||||
}
|
||||
}
|
||||
|
||||
if !matchFound {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Put sets the value of the passed key. The actual put will happen upon commit.
|
||||
func (s *stm) Put(key, val string) {
|
||||
s.wset[key] = stmPut{
|
||||
val: val,
|
||||
op: v3.OpPut(key, val),
|
||||
}
|
||||
}
|
||||
|
||||
// Del marks a key as deleted. The actual delete will happen upon commit.
|
||||
func (s *stm) Del(key string) {
|
||||
s.wset[key] = stmPut{
|
||||
val: "",
|
||||
op: v3.OpDelete(key),
|
||||
}
|
||||
}
|
||||
|
||||
// OnCommit sets the callback that is called upon committing the STM
|
||||
// transaction.
|
||||
func (s *stm) OnCommit(cb func()) {
|
||||
s.onCommit = cb
|
||||
}
|
||||
|
||||
// commit builds the final transaction and tries to execute it. If commit fails
|
||||
// because the keys have changed return a CommitError, otherwise return a
|
||||
// DatabaseError.
|
||||
func (s *stm) commit() (CommitStats, error) {
|
||||
rset := s.rset.cmps()
|
||||
wset := s.wset.cmps(s.revision + 1)
|
||||
|
||||
stats := CommitStats{
|
||||
Rset: len(rset),
|
||||
Wset: len(wset),
|
||||
}
|
||||
|
||||
// Create the compare set.
|
||||
cmps := append(rset, wset...)
|
||||
// Create a transaction with the optional abort context.
|
||||
txn := s.client.Txn(s.options.ctx)
|
||||
|
||||
// If the compare set holds, try executing the puts.
|
||||
txn = txn.If(cmps...)
|
||||
txn = txn.Then(s.wset.puts()...)
|
||||
|
||||
// Prefetch keys in case of conflict to save
|
||||
// a round trip to etcd.
|
||||
txn = txn.Else(s.rset.gets()...)
|
||||
|
||||
txnresp, err := txn.Commit()
|
||||
if err != nil {
|
||||
return stats, DatabaseError{
|
||||
msg: "stm.Commit() failed",
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Call the commit callback if the transaction
|
||||
// was successful.
|
||||
if txnresp.Succeeded {
|
||||
if s.onCommit != nil {
|
||||
s.onCommit()
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// Load prefetch before if commit failed.
|
||||
s.rset.add(txnresp)
|
||||
s.prefetch = s.rset
|
||||
|
||||
// Return CommitError indicating that the transaction
|
||||
// can be retried.
|
||||
return stats, CommitError{}
|
||||
}
|
||||
|
||||
// Commit simply calls commit and the commit stats callback if set.
|
||||
func (s *stm) Commit() error {
|
||||
stats, err := s.commit()
|
||||
|
||||
if s.options.commitStatsCallback != nil {
|
||||
s.options.commitStatsCallback(err == nil, stats)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Rollback resets the STM. This is useful for uncommitted transaction rollback
|
||||
// and also used in the STM main loop to reset state if commit fails.
|
||||
func (s *stm) Rollback() {
|
||||
s.rset = make(map[string]stmGet)
|
||||
s.wset = make(map[string]stmPut)
|
||||
s.getOpts = nil
|
||||
s.revision = math.MaxInt64 - 1
|
||||
}
|
380
kvdb/etcd/stm_test.go
Normal file
380
kvdb/etcd/stm_test.go
Normal file
@@ -0,0 +1,380 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func reverseKVs(a []KV) []KV {
|
||||
for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 {
|
||||
a[i], a[j] = a[j], a[i]
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func TestPutToEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
txQueue := NewCommitQueue(ctx)
|
||||
defer func() {
|
||||
cancel()
|
||||
f.Cleanup()
|
||||
txQueue.Wait()
|
||||
}()
|
||||
|
||||
db, err := newEtcdBackend(ctx, f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
apply := func(stm STM) error {
|
||||
stm.Put("123", "abc")
|
||||
return nil
|
||||
}
|
||||
|
||||
err = RunSTM(db.cli, apply, txQueue)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "abc", f.Get("123"))
|
||||
}
|
||||
|
||||
func TestGetPutDel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
txQueue := NewCommitQueue(ctx)
|
||||
defer func() {
|
||||
cancel()
|
||||
f.Cleanup()
|
||||
txQueue.Wait()
|
||||
}()
|
||||
|
||||
testKeyValues := []KV{
|
||||
{"a", "1"},
|
||||
{"b", "2"},
|
||||
{"c", "3"},
|
||||
{"d", "4"},
|
||||
{"e", "5"},
|
||||
}
|
||||
|
||||
for _, kv := range testKeyValues {
|
||||
f.Put(kv.key, kv.val)
|
||||
}
|
||||
|
||||
db, err := newEtcdBackend(ctx, f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
apply := func(stm STM) error {
|
||||
// Get some non existing keys.
|
||||
v, err := stm.Get("")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, v)
|
||||
|
||||
v, err = stm.Get("x")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, v)
|
||||
|
||||
// Get all existing keys.
|
||||
for _, kv := range testKeyValues {
|
||||
v, err = stm.Get(kv.key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte(kv.val), v)
|
||||
}
|
||||
|
||||
// Overwrite, then delete an existing key.
|
||||
stm.Put("c", "6")
|
||||
|
||||
v, err = stm.Get("c")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("6"), v)
|
||||
|
||||
stm.Del("c")
|
||||
|
||||
v, err = stm.Get("c")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, v)
|
||||
|
||||
// Re-add the deleted key.
|
||||
stm.Put("c", "7")
|
||||
|
||||
v, err = stm.Get("c")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("7"), v)
|
||||
|
||||
// Add a new key.
|
||||
stm.Put("x", "x")
|
||||
|
||||
v, err = stm.Get("x")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("x"), v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err = RunSTM(db.cli, apply, txQueue)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "1", f.Get("a"))
|
||||
require.Equal(t, "2", f.Get("b"))
|
||||
require.Equal(t, "7", f.Get("c"))
|
||||
require.Equal(t, "4", f.Get("d"))
|
||||
require.Equal(t, "5", f.Get("e"))
|
||||
require.Equal(t, "x", f.Get("x"))
|
||||
}
|
||||
|
||||
func TestFirstLastNextPrev(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
txQueue := NewCommitQueue(ctx)
|
||||
defer func() {
|
||||
cancel()
|
||||
f.Cleanup()
|
||||
txQueue.Wait()
|
||||
}()
|
||||
|
||||
testKeyValues := []KV{
|
||||
{"kb", "1"},
|
||||
{"kc", "2"},
|
||||
{"kda", "3"},
|
||||
{"ke", "4"},
|
||||
{"w", "w"},
|
||||
}
|
||||
for _, kv := range testKeyValues {
|
||||
f.Put(kv.key, kv.val)
|
||||
}
|
||||
|
||||
db, err := newEtcdBackend(ctx, f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
apply := func(stm STM) error {
|
||||
// First/Last on valid multi item interval.
|
||||
kv, err := stm.First("k")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"kb", "1"}, kv)
|
||||
|
||||
kv, err = stm.Last("k")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"ke", "4"}, kv)
|
||||
|
||||
// First/Last on single item interval.
|
||||
kv, err = stm.First("w")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"w", "w"}, kv)
|
||||
|
||||
kv, err = stm.Last("w")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"w", "w"}, kv)
|
||||
|
||||
// Next/Prev on start/end.
|
||||
kv, err = stm.Next("k", "ke")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, kv)
|
||||
|
||||
kv, err = stm.Prev("k", "kb")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, kv)
|
||||
|
||||
// Next/Prev in the middle.
|
||||
kv, err = stm.Next("k", "kc")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"kda", "3"}, kv)
|
||||
|
||||
kv, err = stm.Prev("k", "ke")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"kda", "3"}, kv)
|
||||
|
||||
// Delete first item, then add an item before the
|
||||
// deleted one. Check that First/Next will "jump"
|
||||
// over the deleted item and return the new first.
|
||||
stm.Del("kb")
|
||||
stm.Put("ka", "0")
|
||||
|
||||
kv, err = stm.First("k")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"ka", "0"}, kv)
|
||||
|
||||
kv, err = stm.Prev("k", "kc")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"ka", "0"}, kv)
|
||||
|
||||
// Similarly test that a new end is returned if
|
||||
// the old end is deleted first.
|
||||
stm.Del("ke")
|
||||
stm.Put("kf", "5")
|
||||
|
||||
kv, err = stm.Last("k")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"kf", "5"}, kv)
|
||||
|
||||
kv, err = stm.Next("k", "kda")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"kf", "5"}, kv)
|
||||
|
||||
// Overwrite one in the middle.
|
||||
stm.Put("kda", "6")
|
||||
|
||||
kv, err = stm.Next("k", "kc")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &KV{"kda", "6"}, kv)
|
||||
|
||||
// Add three in the middle, then delete one.
|
||||
stm.Put("kdb", "7")
|
||||
stm.Put("kdc", "8")
|
||||
stm.Put("kdd", "9")
|
||||
stm.Del("kdc")
|
||||
|
||||
// Check that stepping from first to last returns
|
||||
// the expected sequence.
|
||||
var kvs []KV
|
||||
|
||||
curr, err := stm.First("k")
|
||||
require.NoError(t, err)
|
||||
|
||||
for curr != nil {
|
||||
kvs = append(kvs, *curr)
|
||||
curr, err = stm.Next("k", curr.key)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
expected := []KV{
|
||||
{"ka", "0"},
|
||||
{"kc", "2"},
|
||||
{"kda", "6"},
|
||||
{"kdb", "7"},
|
||||
{"kdd", "9"},
|
||||
{"kf", "5"},
|
||||
}
|
||||
require.Equal(t, expected, kvs)
|
||||
|
||||
// Similarly check that stepping from last to first
|
||||
// returns the expected sequence.
|
||||
kvs = []KV{}
|
||||
|
||||
curr, err = stm.Last("k")
|
||||
require.NoError(t, err)
|
||||
|
||||
for curr != nil {
|
||||
kvs = append(kvs, *curr)
|
||||
curr, err = stm.Prev("k", curr.key)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
expected = reverseKVs(expected)
|
||||
require.Equal(t, expected, kvs)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err = RunSTM(db.cli, apply, txQueue)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "0", f.Get("ka"))
|
||||
require.Equal(t, "2", f.Get("kc"))
|
||||
require.Equal(t, "6", f.Get("kda"))
|
||||
require.Equal(t, "7", f.Get("kdb"))
|
||||
require.Equal(t, "9", f.Get("kdd"))
|
||||
require.Equal(t, "5", f.Get("kf"))
|
||||
require.Equal(t, "w", f.Get("w"))
|
||||
}
|
||||
|
||||
func TestCommitError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
txQueue := NewCommitQueue(ctx)
|
||||
defer func() {
|
||||
cancel()
|
||||
f.Cleanup()
|
||||
txQueue.Wait()
|
||||
}()
|
||||
|
||||
db, err := newEtcdBackend(ctx, f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Preset DB state.
|
||||
f.Put("123", "xyz")
|
||||
|
||||
// Count the number of applies.
|
||||
cnt := 0
|
||||
|
||||
apply := func(stm STM) error {
|
||||
// STM must have the key/value.
|
||||
val, err := stm.Get("123")
|
||||
require.NoError(t, err)
|
||||
|
||||
if cnt == 0 {
|
||||
require.Equal(t, []byte("xyz"), val)
|
||||
|
||||
// Put a conflicting key/value during the first apply.
|
||||
f.Put("123", "def")
|
||||
}
|
||||
|
||||
// We'd expect to
|
||||
stm.Put("123", "abc")
|
||||
|
||||
cnt++
|
||||
return nil
|
||||
}
|
||||
|
||||
err = RunSTM(db.cli, apply, txQueue)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, cnt)
|
||||
|
||||
require.Equal(t, "abc", f.Get("123"))
|
||||
}
|
||||
|
||||
func TestManualTxError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
txQueue := NewCommitQueue(ctx)
|
||||
defer func() {
|
||||
cancel()
|
||||
f.Cleanup()
|
||||
txQueue.Wait()
|
||||
}()
|
||||
|
||||
db, err := newEtcdBackend(ctx, f.BackendConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Preset DB state.
|
||||
f.Put("123", "xyz")
|
||||
|
||||
stm := NewSTM(db.cli, txQueue)
|
||||
|
||||
val, err := stm.Get("123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("xyz"), val)
|
||||
|
||||
// Put a conflicting key/value.
|
||||
f.Put("123", "def")
|
||||
|
||||
// Should still get the original version.
|
||||
val, err = stm.Get("123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("xyz"), val)
|
||||
|
||||
// Commit will fail with CommitError.
|
||||
err = stm.Commit()
|
||||
var e CommitError
|
||||
require.True(t, errors.As(err, &e))
|
||||
|
||||
// We expect that the transacton indeed did not commit.
|
||||
require.Equal(t, "def", f.Get("123"))
|
||||
}
|
19
kvdb/etcd/walletdb_interface_test.go
Normal file
19
kvdb/etcd/walletdb_interface_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// +build kvdb_etcd
|
||||
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb/walletdbtest"
|
||||
)
|
||||
|
||||
// TestWalletDBInterface performs the WalletDB interface test suite for the
|
||||
// etcd database driver.
|
||||
func TestWalletDBInterface(t *testing.T) {
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
cfg := f.BackendConfig()
|
||||
walletdbtest.TestInterface(t, dbType, context.TODO(), &cfg)
|
||||
}
|
Reference in New Issue
Block a user