mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-26 20:26:34 +02:00
channeldb+kvdb: an extended STM on top of etcd clientv3
This commit adds an extended STM, similar to what available in etcd's clientv3 module. This incarnation of said STM supports additional features, like positioning in key intervals while taking into account deletes and writes as well. This is a preliminary work to support all features of the kvdb interface.
This commit is contained in:
55
channeldb/kvdb/etcd/db.go
Normal file
55
channeldb/kvdb/etcd/db.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
)
|
||||
|
||||
const (
|
||||
// etcdConnectionTimeout is the timeout until successful connection to the
|
||||
// etcd instance.
|
||||
etcdConnectionTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// db holds a reference to the etcd client connection.
|
||||
type db struct {
|
||||
cli *clientv3.Client
|
||||
}
|
||||
|
||||
// BackendConfig holds and etcd backend config and connection parameters.
|
||||
type BackendConfig struct {
|
||||
// Host holds the peer url of the etcd instance.
|
||||
Host string
|
||||
|
||||
// User is the username for the etcd peer.
|
||||
User string
|
||||
|
||||
// Pass is the password for the etcd peer.
|
||||
Pass string
|
||||
}
|
||||
|
||||
// newEtcdBackend returns a db object initialized with the passed backend
|
||||
// config. If etcd connection cannot be estabished, then returns error.
|
||||
func newEtcdBackend(config BackendConfig) (*db, error) {
|
||||
cli, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: []string{config.Host},
|
||||
DialTimeout: etcdConnectionTimeout,
|
||||
Username: config.User,
|
||||
Password: config.Pass,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backend := &db{
|
||||
cli: cli,
|
||||
}
|
||||
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// Close closes the db, but closing the underlying etcd client connection.
|
||||
func (db *db) Close() error {
|
||||
return db.cli.Close()
|
||||
}
|
72
channeldb/kvdb/etcd/embed.go
Normal file
72
channeldb/kvdb/etcd/embed.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/etcd/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
// readyTimeout is the time until the embedded etcd instance should start.
|
||||
readyTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// getFreePort returns a random open TCP port.
|
||||
func getFreePort() int {
|
||||
ln, err := net.Listen("tcp", "[::]:0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
port := ln.Addr().(*net.TCPAddr).Port
|
||||
|
||||
err = ln.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
// NewEmbeddedEtcdInstance creates an embedded etcd instance for testing,
|
||||
// listening on random open ports. Returns the backend config and a cleanup
|
||||
// func that will stop the etcd instance.
|
||||
func NewEmbeddedEtcdInstance(path string) (*BackendConfig, func(), error) {
|
||||
cfg := embed.NewConfig()
|
||||
cfg.Dir = path
|
||||
|
||||
// To ensure that we can submit large transactions.
|
||||
cfg.MaxTxnOps = 1000
|
||||
|
||||
// Listen on random free ports.
|
||||
clientURL := fmt.Sprintf("127.0.0.1:%d", getFreePort())
|
||||
peerURL := fmt.Sprintf("127.0.0.1:%d", getFreePort())
|
||||
cfg.LCUrls = []url.URL{{Host: clientURL}}
|
||||
cfg.LPUrls = []url.URL{{Host: peerURL}}
|
||||
|
||||
etcd, err := embed.StartEtcd(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-etcd.Server.ReadyNotify():
|
||||
case <-time.After(readyTimeout):
|
||||
etcd.Close()
|
||||
return nil, nil,
|
||||
fmt.Errorf("etcd failed to start after: %v", readyTimeout)
|
||||
}
|
||||
|
||||
connConfig := &BackendConfig{
|
||||
Host: "http://" + peerURL,
|
||||
User: "user",
|
||||
Pass: "pass",
|
||||
}
|
||||
|
||||
return connConfig, func() {
|
||||
etcd.Close()
|
||||
}, nil
|
||||
}
|
127
channeldb/kvdb/etcd/fixture_test.go
Normal file
127
channeldb/kvdb/etcd/fixture_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
)
|
||||
|
||||
const (
|
||||
// testEtcdTimeout is used for all RPC calls initiated by the test fixture.
|
||||
testEtcdTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// EtcdTestFixture holds internal state of the etcd test fixture.
|
||||
type EtcdTestFixture struct {
|
||||
t *testing.T
|
||||
cli *clientv3.Client
|
||||
config *BackendConfig
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
// NewTestEtcdInstance creates an embedded etcd instance for testing, listening
|
||||
// on random open ports. Returns the connection config and a cleanup func that
|
||||
// will stop the etcd instance.
|
||||
func NewTestEtcdInstance(t *testing.T, path string) (*BackendConfig, func()) {
|
||||
t.Helper()
|
||||
|
||||
config, cleanup, err := NewEmbeddedEtcdInstance(path)
|
||||
if err != nil {
|
||||
t.Fatalf("error while staring embedded etcd instance: %v", err)
|
||||
}
|
||||
|
||||
return config, cleanup
|
||||
}
|
||||
|
||||
// NewTestEtcdTestFixture creates a new etcd-test fixture. This is helper
|
||||
// object to facilitate etcd tests and ensure pre and post conditions.
|
||||
func NewEtcdTestFixture(t *testing.T) *EtcdTestFixture {
|
||||
tmpDir, err := ioutil.TempDir("", "etcd")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
config, etcdCleanup := NewTestEtcdInstance(t, tmpDir)
|
||||
|
||||
cli, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: []string{config.Host},
|
||||
Username: config.User,
|
||||
Password: config.Pass,
|
||||
})
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("unable to create etcd test fixture: %v", err)
|
||||
}
|
||||
|
||||
return &EtcdTestFixture{
|
||||
t: t,
|
||||
cli: cli,
|
||||
config: config,
|
||||
cleanup: func() {
|
||||
etcdCleanup()
|
||||
os.RemoveAll(tmpDir)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Put puts a string key/value into the test etcd database.
|
||||
func (f *EtcdTestFixture) Put(key, value string) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testEtcdTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := f.cli.Put(ctx, key, value)
|
||||
if err != nil {
|
||||
f.t.Fatalf("etcd test fixture failed to put: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get queries a key and returns the stored value from the test etcd database.
|
||||
func (f *EtcdTestFixture) Get(key string) string {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testEtcdTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := f.cli.Get(ctx, key)
|
||||
if err != nil {
|
||||
f.t.Fatalf("etcd test fixture failed to put: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Kvs) > 0 {
|
||||
return string(resp.Kvs[0].Value)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Dump scans and returns all key/values from the test etcd database.
|
||||
func (f *EtcdTestFixture) Dump() map[string]string {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testEtcdTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := f.cli.Get(ctx, "", clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
f.t.Fatalf("etcd test fixture failed to put: %v", err)
|
||||
}
|
||||
|
||||
result := make(map[string]string)
|
||||
for _, kv := range resp.Kvs {
|
||||
result[string(kv.Key)] = string(kv.Value)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// BackendConfig returns the backend config for connecting to theembedded
|
||||
// etcd instance.
|
||||
func (f *EtcdTestFixture) BackendConfig() BackendConfig {
|
||||
return *f.config
|
||||
}
|
||||
|
||||
// Cleanup should be called at test fixture teardown to stop the embedded
|
||||
// etcd instance and remove all temp db files form the filesystem.
|
||||
func (f *EtcdTestFixture) Cleanup() {
|
||||
f.cleanup()
|
||||
}
|
728
channeldb/kvdb/etcd/stm.go
Normal file
728
channeldb/kvdb/etcd/stm.go
Normal file
@@ -0,0 +1,728 @@
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
v3 "github.com/coreos/etcd/clientv3"
|
||||
)
|
||||
|
||||
// KV stores a key/value pair.
|
||||
type KV struct {
|
||||
key string
|
||||
val string
|
||||
}
|
||||
|
||||
// STM is an interface for software transactional memory.
|
||||
// All calls that return error will do so only if STM is manually handled and
|
||||
// abort the apply closure otherwise. In both case the returned error is a
|
||||
// DatabaseError.
|
||||
type STM interface {
|
||||
// Get returns the value for a key and inserts the key in the txn's read
|
||||
// set. Returns nil if there's no matching key, or the key is empty.
|
||||
Get(key string) ([]byte, error)
|
||||
|
||||
// Put adds a value for a key to the txn's write set.
|
||||
Put(key, val string)
|
||||
|
||||
// Del adds a delete operation for the key to the txn's write set.
|
||||
Del(key string)
|
||||
|
||||
// First returns the first k/v that begins with prefix or nil if there's
|
||||
// no such k/v pair. If the key is found it is inserted to the txn's
|
||||
// read set. Returns nil if there's no match.
|
||||
First(prefix string) (*KV, error)
|
||||
|
||||
// Last returns the last k/v that begins with prefix or nil if there's
|
||||
// no such k/v pair. If the key is found it is inserted to the txn's
|
||||
// read set. Returns nil if there's no match.
|
||||
Last(prefix string) (*KV, error)
|
||||
|
||||
// Prev returns the previous k/v before key that begins with prefix or
|
||||
// nil if there's no such k/v. If the key is found it is inserted to the
|
||||
// read set. Returns nil if there's no match.
|
||||
Prev(prefix, key string) (*KV, error)
|
||||
|
||||
// Next returns the next k/v after key that begins with prefix or nil
|
||||
// if there's no such k/v. If the key is found it is inserted to the
|
||||
// txn's read set. Returns nil if there's no match.
|
||||
Next(prefix, key string) (*KV, error)
|
||||
|
||||
// Seek will return k/v at key beginning with prefix. If the key doesn't
|
||||
// exists Seek will return the next k/v after key beginning with prefix.
|
||||
// If a matching k/v is found it is inserted to the txn's read set. Returns
|
||||
// nil if there's no match.
|
||||
Seek(prefix, key string) (*KV, error)
|
||||
|
||||
// OnCommit calls the passed callback func upon commit.
|
||||
OnCommit(func())
|
||||
|
||||
// Commit attempts to apply the txn's changes to the server.
|
||||
// Commit may return CommitError if transaction is outdated and needs retry.
|
||||
Commit() error
|
||||
|
||||
// Rollback emties the read and write sets such that a subsequent commit
|
||||
// won't alter the database.
|
||||
Rollback()
|
||||
}
|
||||
|
||||
// CommitError is used to check if there was an error
|
||||
// due to stale data in the transaction.
|
||||
type CommitError struct{}
|
||||
|
||||
// Error returns a static string for CommitError for
|
||||
// debugging/logging purposes.
|
||||
func (e CommitError) Error() string {
|
||||
return "commit failed"
|
||||
}
|
||||
|
||||
// DatabaseError is used to wrap errors that are not
|
||||
// related to stale data in the transaction.
|
||||
type DatabaseError struct {
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
// Unwrap returns the wrapped error in a DatabaseError.
|
||||
func (e *DatabaseError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// Error simply converts DatabaseError to a string that
|
||||
// includes both the message and the wrapped error.
|
||||
func (e DatabaseError) Error() string {
|
||||
return fmt.Sprintf("etcd error: %v - %v", e.msg, e.err)
|
||||
}
|
||||
|
||||
// stmGet is the result of a read operation,
|
||||
// a value and the mod revision of the key/value.
|
||||
type stmGet struct {
|
||||
val string
|
||||
rev int64
|
||||
}
|
||||
|
||||
// readSet stores all reads done in an STM.
|
||||
type readSet map[string]stmGet
|
||||
|
||||
// stmPut stores a value and an operation (put/delete).
|
||||
type stmPut struct {
|
||||
val string
|
||||
op v3.Op
|
||||
}
|
||||
|
||||
// writeSet stroes all writes done in an STM.
|
||||
type writeSet map[string]stmPut
|
||||
|
||||
// stm implements repeatable-read software transactional memory
|
||||
// over etcd.
|
||||
type stm struct {
|
||||
// client is an etcd client handling all RPC communications
|
||||
// to the etcd instance/cluster.
|
||||
client *v3.Client
|
||||
|
||||
// manual is set to true for manual transactions which don't
|
||||
// execute in the STM run loop.
|
||||
manual bool
|
||||
|
||||
// options stores optional settings passed by the user.
|
||||
options *STMOptions
|
||||
|
||||
// prefetch hold prefetched key values and revisions.
|
||||
prefetch readSet
|
||||
|
||||
// rset holds read key values and revisions.
|
||||
rset readSet
|
||||
|
||||
// wset holds overwritten keys and their values.
|
||||
wset writeSet
|
||||
|
||||
// getOpts are the opts used for gets.
|
||||
getOpts []v3.OpOption
|
||||
|
||||
// revision stores the snapshot revision after first read.
|
||||
revision int64
|
||||
|
||||
// onCommit gets called upon commit.
|
||||
onCommit func()
|
||||
}
|
||||
|
||||
// STMOptions can be used to pass optional settings
|
||||
// when an STM is created.
|
||||
type STMOptions struct {
|
||||
// ctx holds an externally provided abort context.
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// STMOptionFunc is a function that updates the passed STMOptions.
|
||||
type STMOptionFunc func(*STMOptions)
|
||||
|
||||
// WithAbortContext specifies the context for permanently
|
||||
// aborting the transaction.
|
||||
func WithAbortContext(ctx context.Context) STMOptionFunc {
|
||||
return func(so *STMOptions) {
|
||||
so.ctx = ctx
|
||||
}
|
||||
}
|
||||
|
||||
// RunSTM runs the apply function by creating an STM using serializable snapshot
|
||||
// isolation, passing it to the apply and handling commit errors and retries.
|
||||
func RunSTM(cli *v3.Client, apply func(STM) error, so ...STMOptionFunc) error {
|
||||
return runSTM(makeSTM(cli, false, so...), apply)
|
||||
}
|
||||
|
||||
// NewSTM creates a new STM instance, using serializable snapshot isolation.
|
||||
func NewSTM(cli *v3.Client, so ...STMOptionFunc) STM {
|
||||
return makeSTM(cli, true, so...)
|
||||
}
|
||||
|
||||
// makeSTM is the actual constructor of the stm. It first apply all passed
|
||||
// options then creates the stm object and resets it before returning.
|
||||
func makeSTM(cli *v3.Client, manual bool, so ...STMOptionFunc) *stm {
|
||||
opts := &STMOptions{
|
||||
ctx: cli.Ctx(),
|
||||
}
|
||||
|
||||
// Apply all functional options.
|
||||
for _, fo := range so {
|
||||
fo(opts)
|
||||
}
|
||||
|
||||
s := &stm{
|
||||
client: cli,
|
||||
manual: manual,
|
||||
options: opts,
|
||||
prefetch: make(map[string]stmGet),
|
||||
}
|
||||
|
||||
// Reset read and write set.
|
||||
s.Rollback()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// runSTM implements the run loop of the STM, running the apply func, catching
|
||||
// errors and handling commit. The loop will quit on every error except
|
||||
// CommitError which is used to indicate a necessary retry.
|
||||
func runSTM(s *stm, apply func(STM) error) error {
|
||||
out := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
// Recover DatabaseError panics so
|
||||
// we can return them.
|
||||
if r := recover(); r != nil {
|
||||
e, ok := r.(DatabaseError)
|
||||
if !ok {
|
||||
// Unknown panic.
|
||||
panic(r)
|
||||
}
|
||||
|
||||
// Return the error.
|
||||
out <- e.Unwrap()
|
||||
}
|
||||
}()
|
||||
|
||||
var err error
|
||||
|
||||
// In a loop try to apply and commit and roll back
|
||||
// if the database has changed (CommitError).
|
||||
for {
|
||||
// Abort STM if there was an application error.
|
||||
if err = apply(s); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
err = s.Commit()
|
||||
|
||||
// Re-apply only upon commit error
|
||||
// (meaning the database was changed).
|
||||
if _, ok := err.(CommitError); !ok {
|
||||
// Anything that's not a CommitError
|
||||
// aborts the STM run loop.
|
||||
break
|
||||
}
|
||||
|
||||
// Rollback before trying to re-apply.
|
||||
s.Rollback()
|
||||
}
|
||||
|
||||
// Return the error to the caller.
|
||||
out <- err
|
||||
}()
|
||||
|
||||
return <-out
|
||||
}
|
||||
|
||||
// add inserts a txn response to the read set. This is useful when the txn
|
||||
// fails due to conflict where the txn response can be used to prefetch
|
||||
// key/values.
|
||||
func (rs readSet) add(txnResp *v3.TxnResponse) {
|
||||
for _, resp := range txnResp.Responses {
|
||||
getResp := (*v3.GetResponse)(resp.GetResponseRange())
|
||||
for _, kv := range getResp.Kvs {
|
||||
rs[string(kv.Key)] = stmGet{
|
||||
val: string(kv.Value),
|
||||
rev: kv.ModRevision,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gets is a helper to create an op slice for transaction
|
||||
// construction.
|
||||
func (rs readSet) gets() []v3.Op {
|
||||
ops := make([]v3.Op, 0, len(rs))
|
||||
|
||||
for k := range rs {
|
||||
ops = append(ops, v3.OpGet(k))
|
||||
}
|
||||
|
||||
return ops
|
||||
}
|
||||
|
||||
// cmps returns a cmp list testing values in read set didn't change.
|
||||
func (rs readSet) cmps() []v3.Cmp {
|
||||
cmps := make([]v3.Cmp, 0, len(rs))
|
||||
for key, getValue := range rs {
|
||||
cmps = append(cmps, v3.Compare(v3.ModRevision(key), "=", getValue.rev))
|
||||
}
|
||||
|
||||
return cmps
|
||||
}
|
||||
|
||||
// cmps returns a cmp list testing no writes have happened past rev.
|
||||
func (ws writeSet) cmps(rev int64) []v3.Cmp {
|
||||
cmps := make([]v3.Cmp, 0, len(ws))
|
||||
for key := range ws {
|
||||
cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
|
||||
}
|
||||
|
||||
return cmps
|
||||
}
|
||||
|
||||
// puts is the list of ops for all pending writes.
|
||||
func (ws writeSet) puts() []v3.Op {
|
||||
puts := make([]v3.Op, 0, len(ws))
|
||||
for _, v := range ws {
|
||||
puts = append(puts, v.op)
|
||||
}
|
||||
|
||||
return puts
|
||||
}
|
||||
|
||||
// fetch is a helper to fetch key/value given options. If a value is returned
|
||||
// then fetch will try to fix the STM's snapshot revision (if not already set).
|
||||
// We'll also cache the returned key/value in the read set.
|
||||
func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) {
|
||||
resp, err := s.client.Get(
|
||||
s.options.ctx, key, append(opts, s.getOpts...)...,
|
||||
)
|
||||
if err != nil {
|
||||
dbErr := DatabaseError{
|
||||
msg: "stm.fetch() failed",
|
||||
err: err,
|
||||
}
|
||||
|
||||
// Do not panic when executing a manual transaction.
|
||||
if s.manual {
|
||||
return nil, dbErr
|
||||
}
|
||||
|
||||
// Panic when executing inside the STM runloop.
|
||||
panic(dbErr)
|
||||
}
|
||||
|
||||
// Set revison and serializable options upon first fetch
|
||||
// for any subsequent fetches.
|
||||
if s.getOpts == nil {
|
||||
s.revision = resp.Header.Revision
|
||||
s.getOpts = []v3.OpOption{
|
||||
v3.WithRev(s.revision),
|
||||
v3.WithSerializable(),
|
||||
}
|
||||
}
|
||||
|
||||
var result []KV
|
||||
|
||||
// Fill the read set with key/values returned.
|
||||
for _, kv := range resp.Kvs {
|
||||
// Remove from prefetch.
|
||||
key := string(kv.Key)
|
||||
val := string(kv.Value)
|
||||
|
||||
delete(s.prefetch, key)
|
||||
|
||||
// Add to read set.
|
||||
s.rset[key] = stmGet{
|
||||
val: val,
|
||||
rev: kv.ModRevision,
|
||||
}
|
||||
|
||||
result = append(result, KV{key, val})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Get returns the value for key. If there's no such
|
||||
// key/value in the database or the passed key is empty
|
||||
// Get will return nil.
|
||||
func (s *stm) Get(key string) ([]byte, error) {
|
||||
if key == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Return freshly written value if present.
|
||||
if put, ok := s.wset[key]; ok {
|
||||
if put.op.IsDelete() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []byte(put.val), nil
|
||||
}
|
||||
|
||||
// Populate read set if key is present in
|
||||
// the prefetch set.
|
||||
if getValue, ok := s.prefetch[key]; ok {
|
||||
delete(s.prefetch, key)
|
||||
s.rset[key] = getValue
|
||||
}
|
||||
|
||||
// Return value if alread in read set.
|
||||
if getVal, ok := s.rset[key]; ok {
|
||||
return []byte(getVal.val), nil
|
||||
}
|
||||
|
||||
// Fetch and return value.
|
||||
kvs, err := s.fetch(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(kvs) > 0 {
|
||||
return []byte(kvs[0].val), nil
|
||||
}
|
||||
|
||||
// Return empty result if key not in DB.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// First returns the first key/value matching prefix. If there's no key starting
|
||||
// with prefix, Last will return nil.
|
||||
func (s *stm) First(prefix string) (*KV, error) {
|
||||
return s.next(prefix, prefix, true)
|
||||
}
|
||||
|
||||
// Last returns the last key/value with prefix. If there's no key starting with
|
||||
// prefix, Last will return nil.
|
||||
func (s *stm) Last(prefix string) (*KV, error) {
|
||||
// As we don't know the full range, fetch the last
|
||||
// key/value with this prefix first.
|
||||
resp, err := s.fetch(prefix, v3.WithLastKey()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
kv KV
|
||||
found bool
|
||||
)
|
||||
|
||||
if len(resp) > 0 {
|
||||
kv = resp[0]
|
||||
found = true
|
||||
}
|
||||
|
||||
// Now make sure there's nothing in the write set
|
||||
// that is a better match, meaning it has the same
|
||||
// prefix but is greater or equal than the current
|
||||
// best candidate. Note that this is not efficient
|
||||
// when the write set is large!
|
||||
for k, put := range s.wset {
|
||||
if put.op.IsDelete() {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(k, prefix) && k >= kv.key {
|
||||
kv.key = k
|
||||
kv.val = put.val
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
return &kv, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Prev returns the prior key/value before key (with prefix). If there's no such
|
||||
// key Next will return nil.
|
||||
func (s *stm) Prev(prefix, startKey string) (*KV, error) {
|
||||
var result KV
|
||||
|
||||
fetchKey := startKey
|
||||
matchFound := false
|
||||
|
||||
for {
|
||||
// Ask etcd to retrieve one key that is a
|
||||
// match in descending order from the passed key.
|
||||
opts := []v3.OpOption{
|
||||
v3.WithRange(fetchKey),
|
||||
v3.WithSort(v3.SortByKey, v3.SortDescend),
|
||||
v3.WithLimit(1),
|
||||
}
|
||||
|
||||
kvs, err := s.fetch(prefix, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(kvs) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
kv := &kvs[0]
|
||||
|
||||
// WithRange and WithPrefix can't be used
|
||||
// together, so check prefix here. If the
|
||||
// returned key no longer has the prefix,
|
||||
// then break out.
|
||||
if !strings.HasPrefix(kv.key, prefix) {
|
||||
break
|
||||
}
|
||||
|
||||
// Fetch the prior key if this is deleted.
|
||||
if put, ok := s.wset[kv.key]; ok && put.op.IsDelete() {
|
||||
fetchKey = kv.key
|
||||
continue
|
||||
}
|
||||
|
||||
result = *kv
|
||||
matchFound = true
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// Closre holding all checks to find a possibly
|
||||
// better match.
|
||||
matches := func(key string) bool {
|
||||
if !strings.HasPrefix(key, prefix) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !matchFound {
|
||||
return key < startKey
|
||||
}
|
||||
|
||||
// matchFound == true
|
||||
return result.key <= key && key < startKey
|
||||
}
|
||||
|
||||
// Now go trough the write set and check
|
||||
// if there's an even better match.
|
||||
for k, put := range s.wset {
|
||||
if !put.op.IsDelete() && matches(k) {
|
||||
result.key = k
|
||||
result.val = put.val
|
||||
matchFound = true
|
||||
}
|
||||
}
|
||||
|
||||
if !matchFound {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Next returns the next key/value after key (with prefix). If there's no such
|
||||
// key Next will return nil.
|
||||
func (s *stm) Next(prefix string, key string) (*KV, error) {
|
||||
return s.next(prefix, key, false)
|
||||
}
|
||||
|
||||
// Seek "seeks" to the key (with prefix). If the key doesn't exists it'll get
|
||||
// the next key with the same prefix. If no key fills this criteria, Seek will
|
||||
// return nil.
|
||||
func (s *stm) Seek(prefix, key string) (*KV, error) {
|
||||
return s.next(prefix, key, true)
|
||||
}
|
||||
|
||||
// next will try to retrieve the next match that has prefix and starts with the
|
||||
// passed startKey. If includeStartKey is set to true, it'll return the value
|
||||
// of startKey (essentially implementing seek).
|
||||
func (s *stm) next(prefix, startKey string, includeStartKey bool) (*KV, error) {
|
||||
var result KV
|
||||
|
||||
fetchKey := startKey
|
||||
firstFetch := true
|
||||
matchFound := false
|
||||
|
||||
for {
|
||||
// Ask etcd to retrieve one key that is a
|
||||
// match in ascending order from the passed key.
|
||||
opts := []v3.OpOption{
|
||||
v3.WithFromKey(),
|
||||
v3.WithSort(v3.SortByKey, v3.SortAscend),
|
||||
v3.WithLimit(1),
|
||||
}
|
||||
|
||||
// By default we include the start key too
|
||||
// if it is a full match.
|
||||
if includeStartKey && firstFetch {
|
||||
firstFetch = false
|
||||
} else {
|
||||
// If we'd like to retrieve the first key
|
||||
// after the start key.
|
||||
fetchKey += "\x00"
|
||||
}
|
||||
|
||||
kvs, err := s.fetch(fetchKey, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(kvs) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
kv := &kvs[0]
|
||||
// WithRange and WithPrefix can't be used
|
||||
// together, so check prefix here. If the
|
||||
// returned key no longer has the prefix,
|
||||
// then break the fetch loop.
|
||||
if !strings.HasPrefix(kv.key, prefix) {
|
||||
break
|
||||
}
|
||||
|
||||
// Move on to fetch starting with the next
|
||||
// key if this one is marked deleted.
|
||||
if put, ok := s.wset[kv.key]; ok && put.op.IsDelete() {
|
||||
fetchKey = kv.key
|
||||
continue
|
||||
}
|
||||
|
||||
result = *kv
|
||||
matchFound = true
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// Closure holding all checks to find a possibly
|
||||
// better match.
|
||||
matches := func(k string) bool {
|
||||
if !strings.HasPrefix(k, prefix) {
|
||||
return false
|
||||
}
|
||||
|
||||
if includeStartKey && !matchFound {
|
||||
return startKey <= k
|
||||
}
|
||||
|
||||
if !includeStartKey && !matchFound {
|
||||
return startKey < k
|
||||
}
|
||||
|
||||
if includeStartKey && matchFound {
|
||||
return startKey <= k && k <= result.key
|
||||
}
|
||||
|
||||
// !includeStartKey && matchFound.
|
||||
return startKey < k && k <= result.key
|
||||
}
|
||||
|
||||
// Now go trough the write set and check
|
||||
// if there's an even better match.
|
||||
for k, put := range s.wset {
|
||||
if !put.op.IsDelete() && matches(k) {
|
||||
result.key = k
|
||||
result.val = put.val
|
||||
matchFound = true
|
||||
}
|
||||
}
|
||||
|
||||
if !matchFound {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Put sets the value of the passed key. The actual put will happen upon commit.
|
||||
func (s *stm) Put(key, val string) {
|
||||
s.wset[key] = stmPut{
|
||||
val: val,
|
||||
op: v3.OpPut(key, val),
|
||||
}
|
||||
}
|
||||
|
||||
// Del marks a key as deleted. The actual delete will happen upon commit.
|
||||
func (s *stm) Del(key string) {
|
||||
s.wset[key] = stmPut{
|
||||
val: "",
|
||||
op: v3.OpDelete(key),
|
||||
}
|
||||
}
|
||||
|
||||
// OnCommit sets the callback that is called upon committing the STM
|
||||
// transaction.
|
||||
func (s *stm) OnCommit(cb func()) {
|
||||
s.onCommit = cb
|
||||
}
|
||||
|
||||
// Commit builds the final transaction and tries to execute it. If commit fails
|
||||
// because the keys have changed return a CommitError, otherwise return a
|
||||
// DatabaseError.
|
||||
func (s *stm) Commit() error {
|
||||
// Create the compare set.
|
||||
cmps := append(s.rset.cmps(), s.wset.cmps(s.revision+1)...)
|
||||
// Create a transaction with the optional abort context.
|
||||
txn := s.client.Txn(s.options.ctx)
|
||||
|
||||
// If the compare set holds, try executing the puts.
|
||||
txn = txn.If(cmps...)
|
||||
txn = txn.Then(s.wset.puts()...)
|
||||
|
||||
// Prefetch keys in case of conflict to save
|
||||
// a round trip to etcd.
|
||||
txn = txn.Else(s.rset.gets()...)
|
||||
|
||||
txnresp, err := txn.Commit()
|
||||
if err != nil {
|
||||
return DatabaseError{
|
||||
msg: "stm.Commit() failed",
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Call the commit callback if the transaction
|
||||
// was successful.
|
||||
if txnresp.Succeeded {
|
||||
if s.onCommit != nil {
|
||||
s.onCommit()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load prefetch before if commit failed.
|
||||
s.rset.add(txnresp)
|
||||
s.prefetch = s.rset
|
||||
|
||||
// Return CommitError indicating that the transaction
|
||||
// can be retried.
|
||||
return CommitError{}
|
||||
}
|
||||
|
||||
// Rollback resets the STM. This is useful for uncommitted transaction rollback
|
||||
// and also used in the STM main loop to reset state if commit fails.
|
||||
func (s *stm) Rollback() {
|
||||
s.rset = make(map[string]stmGet)
|
||||
s.wset = make(map[string]stmPut)
|
||||
s.getOpts = nil
|
||||
s.revision = math.MaxInt64 - 1
|
||||
}
|
342
channeldb/kvdb/etcd/stm_test.go
Normal file
342
channeldb/kvdb/etcd/stm_test.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func reverseKVs(a []KV) []KV {
|
||||
for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 {
|
||||
a[i], a[j] = a[j], a[i]
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func TestPutToEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(f.BackendConfig())
|
||||
assert.NoError(t, err)
|
||||
|
||||
apply := func(stm STM) error {
|
||||
stm.Put("123", "abc")
|
||||
return nil
|
||||
}
|
||||
|
||||
err = RunSTM(db.cli, apply)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "abc", f.Get("123"))
|
||||
}
|
||||
|
||||
func TestGetPutDel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.cleanup()
|
||||
|
||||
testKeyValues := []KV{
|
||||
{"a", "1"},
|
||||
{"b", "2"},
|
||||
{"c", "3"},
|
||||
{"d", "4"},
|
||||
{"e", "5"},
|
||||
}
|
||||
|
||||
for _, kv := range testKeyValues {
|
||||
f.Put(kv.key, kv.val)
|
||||
}
|
||||
|
||||
db, err := newEtcdBackend(f.BackendConfig())
|
||||
assert.NoError(t, err)
|
||||
|
||||
apply := func(stm STM) error {
|
||||
// Get some non existing keys.
|
||||
v, err := stm.Get("")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, v)
|
||||
|
||||
v, err = stm.Get("x")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, v)
|
||||
|
||||
// Get all existing keys.
|
||||
for _, kv := range testKeyValues {
|
||||
v, err = stm.Get(kv.key)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte(kv.val), v)
|
||||
}
|
||||
|
||||
// Overwrite, then delete an existing key.
|
||||
stm.Put("c", "6")
|
||||
|
||||
v, err = stm.Get("c")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("6"), v)
|
||||
|
||||
stm.Del("c")
|
||||
|
||||
v, err = stm.Get("c")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, v)
|
||||
|
||||
// Re-add the deleted key.
|
||||
stm.Put("c", "7")
|
||||
|
||||
v, err = stm.Get("c")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("7"), v)
|
||||
|
||||
// Add a new key.
|
||||
stm.Put("x", "x")
|
||||
|
||||
v, err = stm.Get("x")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("x"), v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err = RunSTM(db.cli, apply)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "1", f.Get("a"))
|
||||
assert.Equal(t, "2", f.Get("b"))
|
||||
assert.Equal(t, "7", f.Get("c"))
|
||||
assert.Equal(t, "4", f.Get("d"))
|
||||
assert.Equal(t, "5", f.Get("e"))
|
||||
assert.Equal(t, "x", f.Get("x"))
|
||||
}
|
||||
|
||||
func TestFirstLastNextPrev(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
testKeyValues := []KV{
|
||||
{"kb", "1"},
|
||||
{"kc", "2"},
|
||||
{"kda", "3"},
|
||||
{"ke", "4"},
|
||||
{"w", "w"},
|
||||
}
|
||||
for _, kv := range testKeyValues {
|
||||
f.Put(kv.key, kv.val)
|
||||
}
|
||||
|
||||
db, err := newEtcdBackend(f.BackendConfig())
|
||||
assert.NoError(t, err)
|
||||
|
||||
apply := func(stm STM) error {
|
||||
// First/Last on valid multi item interval.
|
||||
kv, err := stm.First("k")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"kb", "1"}, kv)
|
||||
|
||||
kv, err = stm.Last("k")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"ke", "4"}, kv)
|
||||
|
||||
// First/Last on single item interval.
|
||||
kv, err = stm.First("w")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"w", "w"}, kv)
|
||||
|
||||
kv, err = stm.Last("w")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"w", "w"}, kv)
|
||||
|
||||
// Next/Prev on start/end.
|
||||
kv, err = stm.Next("k", "ke")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, kv)
|
||||
|
||||
kv, err = stm.Prev("k", "kb")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, kv)
|
||||
|
||||
// Next/Prev in the middle.
|
||||
kv, err = stm.Next("k", "kc")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"kda", "3"}, kv)
|
||||
|
||||
kv, err = stm.Prev("k", "ke")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"kda", "3"}, kv)
|
||||
|
||||
// Delete first item, then add an item before the
|
||||
// deleted one. Check that First/Next will "jump"
|
||||
// over the deleted item and return the new first.
|
||||
stm.Del("kb")
|
||||
stm.Put("ka", "0")
|
||||
|
||||
kv, err = stm.First("k")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"ka", "0"}, kv)
|
||||
|
||||
kv, err = stm.Prev("k", "kc")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"ka", "0"}, kv)
|
||||
|
||||
// Similarly test that a new end is returned if
|
||||
// the old end is deleted first.
|
||||
stm.Del("ke")
|
||||
stm.Put("kf", "5")
|
||||
|
||||
kv, err = stm.Last("k")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"kf", "5"}, kv)
|
||||
|
||||
kv, err = stm.Next("k", "kda")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"kf", "5"}, kv)
|
||||
|
||||
// Overwrite one in the middle.
|
||||
stm.Put("kda", "6")
|
||||
|
||||
kv, err = stm.Next("k", "kc")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &KV{"kda", "6"}, kv)
|
||||
|
||||
// Add three in the middle, then delete one.
|
||||
stm.Put("kdb", "7")
|
||||
stm.Put("kdc", "8")
|
||||
stm.Put("kdd", "9")
|
||||
stm.Del("kdc")
|
||||
|
||||
// Check that stepping from first to last returns
|
||||
// the expected sequence.
|
||||
var kvs []KV
|
||||
|
||||
curr, err := stm.First("k")
|
||||
assert.NoError(t, err)
|
||||
|
||||
for curr != nil {
|
||||
kvs = append(kvs, *curr)
|
||||
curr, err = stm.Next("k", curr.key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
expected := []KV{
|
||||
{"ka", "0"},
|
||||
{"kc", "2"},
|
||||
{"kda", "6"},
|
||||
{"kdb", "7"},
|
||||
{"kdd", "9"},
|
||||
{"kf", "5"},
|
||||
}
|
||||
assert.Equal(t, expected, kvs)
|
||||
|
||||
// Similarly check that stepping from last to first
|
||||
// returns the expected sequence.
|
||||
kvs = []KV{}
|
||||
|
||||
curr, err = stm.Last("k")
|
||||
assert.NoError(t, err)
|
||||
|
||||
for curr != nil {
|
||||
kvs = append(kvs, *curr)
|
||||
curr, err = stm.Prev("k", curr.key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
expected = reverseKVs(expected)
|
||||
assert.Equal(t, expected, kvs)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err = RunSTM(db.cli, apply)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "0", f.Get("ka"))
|
||||
assert.Equal(t, "2", f.Get("kc"))
|
||||
assert.Equal(t, "6", f.Get("kda"))
|
||||
assert.Equal(t, "7", f.Get("kdb"))
|
||||
assert.Equal(t, "9", f.Get("kdd"))
|
||||
assert.Equal(t, "5", f.Get("kf"))
|
||||
assert.Equal(t, "w", f.Get("w"))
|
||||
}
|
||||
|
||||
func TestCommitError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(f.BackendConfig())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Preset DB state.
|
||||
f.Put("123", "xyz")
|
||||
|
||||
// Count the number of applies.
|
||||
cnt := 0
|
||||
|
||||
apply := func(stm STM) error {
|
||||
// STM must have the key/value.
|
||||
val, err := stm.Get("123")
|
||||
assert.NoError(t, err)
|
||||
|
||||
if cnt == 0 {
|
||||
assert.Equal(t, []byte("xyz"), val)
|
||||
|
||||
// Put a conflicting key/value during the first apply.
|
||||
f.Put("123", "def")
|
||||
}
|
||||
|
||||
// We'd expect to
|
||||
stm.Put("123", "abc")
|
||||
|
||||
cnt++
|
||||
return nil
|
||||
}
|
||||
|
||||
err = RunSTM(db.cli, apply)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, cnt)
|
||||
|
||||
assert.Equal(t, "abc", f.Get("123"))
|
||||
}
|
||||
|
||||
func TestManualTxError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := NewEtcdTestFixture(t)
|
||||
defer f.Cleanup()
|
||||
|
||||
db, err := newEtcdBackend(f.BackendConfig())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Preset DB state.
|
||||
f.Put("123", "xyz")
|
||||
|
||||
stm := NewSTM(db.cli)
|
||||
|
||||
val, err := stm.Get("123")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("xyz"), val)
|
||||
|
||||
// Put a conflicting key/value.
|
||||
f.Put("123", "def")
|
||||
|
||||
// Should still get the original version.
|
||||
val, err = stm.Get("123")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("xyz"), val)
|
||||
|
||||
// Commit will fail with CommitError.
|
||||
err = stm.Commit()
|
||||
var e CommitError
|
||||
assert.True(t, errors.As(err, &e))
|
||||
|
||||
// We expect that the transacton indeed did not commit.
|
||||
assert.Equal(t, "def", f.Get("123"))
|
||||
}
|
Reference in New Issue
Block a user