diff --git a/channeldb/kvdb/etcd/readwrite_bucket.go b/channeldb/kvdb/etcd/readwrite_bucket.go index 94fa53213..4adfa3545 100644 --- a/channeldb/kvdb/etcd/readwrite_bucket.go +++ b/channeldb/kvdb/etcd/readwrite_bucket.go @@ -1,6 +1,7 @@ package etcd import ( + "bytes" "strconv" "github.com/btcsuite/btcwallet/walletdb" @@ -20,7 +21,12 @@ type readWriteBucket struct { // newReadWriteBucket creates a new rw bucket with the passed transaction // and bucket id. -func newReadWriteBucket(tx *readWriteTx, id []byte) *readWriteBucket { +func newReadWriteBucket(tx *readWriteTx, key, id []byte) *readWriteBucket { + if !bytes.Equal(id, rootBucketID()) { + // Add the bucket key/value to the lock set. + tx.lock(string(key), string(id)) + } + return &readWriteBucket{ id: id, tx: tx, @@ -119,7 +125,8 @@ func (b *readWriteBucket) NestedReadWriteBucket(key []byte) walletdb.ReadWriteBu } // Get the bucket id (and return nil if bucket doesn't exist). - bucketVal, err := b.tx.stm.Get(string(makeBucketKey(b.id, key))) + bucketKey := makeBucketKey(b.id, key) + bucketVal, err := b.tx.stm.Get(string(bucketKey)) if err != nil { // TODO: we should return the error once the // kvdb inteface is extended. @@ -131,7 +138,7 @@ func (b *readWriteBucket) NestedReadWriteBucket(key []byte) walletdb.ReadWriteBu } // Return the bucket with the fetched bucket id. - return newReadWriteBucket(b.tx, bucketVal) + return newReadWriteBucket(b.tx, bucketKey, bucketVal) } // CreateBucket creates and returns a new nested bucket with the given @@ -163,9 +170,9 @@ func (b *readWriteBucket) CreateBucket(key []byte) ( newID := makeBucketID(bucketKey) // Create the bucket. - b.tx.stm.Put(string(bucketKey), string(newID[:])) + b.tx.put(string(bucketKey), string(newID[:])) - return newReadWriteBucket(b.tx, newID[:]), nil + return newReadWriteBucket(b.tx, bucketKey, newID[:]), nil } // CreateBucketIfNotExists creates and returns a new nested bucket with @@ -181,22 +188,22 @@ func (b *readWriteBucket) CreateBucketIfNotExists(key []byte) ( } // Check for the bucket and create if it doesn't exist. - bucketKey := string(makeBucketKey(b.id, key)) + bucketKey := makeBucketKey(b.id, key) - bucketVal, err := b.tx.stm.Get(bucketKey) + bucketVal, err := b.tx.stm.Get(string(bucketKey)) if err != nil { return nil, err } if !isValidBucketID(bucketVal) { - newID := makeBucketID([]byte(bucketKey)) - b.tx.stm.Put(bucketKey, string(newID[:])) + newID := makeBucketID(bucketKey) + b.tx.put(string(bucketKey), string(newID[:])) - return newReadWriteBucket(b.tx, newID[:]), nil + return newReadWriteBucket(b.tx, bucketKey, newID[:]), nil } // Otherwise return the bucket with the fetched bucket id. - return newReadWriteBucket(b.tx, bucketVal), nil + return newReadWriteBucket(b.tx, bucketKey, bucketVal), nil } // DeleteNestedBucket deletes the nested bucket and its sub-buckets @@ -241,7 +248,7 @@ func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { } for kv != nil { - b.tx.stm.Del(kv.key) + b.tx.del(kv.key) kv, err = b.tx.stm.Next(valuePrefix, kv.key) if err != nil { @@ -259,7 +266,7 @@ func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { for kv != nil { // Delete sub bucket key. - b.tx.stm.Del(kv.key) + b.tx.del(kv.key) // Queue it for traversal. queue = append(queue, []byte(kv.val)) @@ -271,7 +278,7 @@ func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { } // Delete the top level bucket. - b.tx.stm.Del(bucketKey) + b.tx.del(bucketKey) return nil } @@ -284,7 +291,7 @@ func (b *readWriteBucket) Put(key, value []byte) error { } // Update the transaction with the new value. - b.tx.stm.Put(string(makeValueKey(b.id, key)), string(value)) + b.tx.put(string(makeValueKey(b.id, key)), string(value)) return nil } @@ -297,7 +304,7 @@ func (b *readWriteBucket) Delete(key []byte) error { } // Update the transaction to delete the key/value. - b.tx.stm.Del(string(makeValueKey(b.id, key))) + b.tx.del(string(makeValueKey(b.id, key))) return nil } @@ -327,7 +334,7 @@ func (b *readWriteBucket) SetSequence(v uint64) error { val := strconv.FormatUint(v, 10) // Update the transaction with the new value for the sequence key. - b.tx.stm.Put(string(makeSequenceKey(b.id)), val) + b.tx.put(string(makeSequenceKey(b.id)), val) return nil } diff --git a/channeldb/kvdb/etcd/readwrite_tx.go b/channeldb/kvdb/etcd/readwrite_tx.go index 591ff55de..88e0a273e 100644 --- a/channeldb/kvdb/etcd/readwrite_tx.go +++ b/channeldb/kvdb/etcd/readwrite_tx.go @@ -9,9 +9,16 @@ type readWriteTx struct { // stm is the reference to the parent STM. stm STM - // active is true if the transaction hasn't been - // committed yet. + // active is true if the transaction hasn't been committed yet. active bool + + // dirty is true if we intent to update a value in this transaction. + dirty bool + + // lset holds key/value set that we want to lock on. If upon commit the + // transaction is dirty and the lset is not empty, we'll bump the mod + // version of these key/values. + lset map[string]string } // newReadWriteTx creates an rw transaction with the passed STM. @@ -19,13 +26,58 @@ func newReadWriteTx(stm STM) *readWriteTx { return &readWriteTx{ stm: stm, active: true, + lset: make(map[string]string), } } // rooBucket is a helper function to return the always present // root bucket. func rootBucket(tx *readWriteTx) *readWriteBucket { - return newReadWriteBucket(tx, rootBucketID()) + return newReadWriteBucket(tx, rootBucketID(), rootBucketID()) +} + +// lock adds a key value to the lock set. +func (tx *readWriteTx) lock(key, val string) { + tx.stm.Lock(key) + if !tx.dirty { + tx.lset[key] = val + } else { + // Bump the mod version of the key, + // leaving the value intact. + tx.stm.Put(key, val) + } +} + +// put updates the passed key/value. +func (tx *readWriteTx) put(key, val string) { + tx.stm.Put(key, val) + tx.setDirty() +} + +// del marks the passed key deleted. +func (tx *readWriteTx) del(key string) { + tx.stm.Del(key) + tx.setDirty() +} + +// setDirty marks the transaction dirty and bumps +// mod version for the existing lock set if it is +// not empty. +func (tx *readWriteTx) setDirty() { + // Bump the lock set. + if !tx.dirty && len(tx.lset) > 0 { + for key, val := range tx.lset { + // Bump the mod version of the key, + // leaving the value intact. + tx.stm.Put(key, val) + } + + // Clear the lock set. + tx.lset = make(map[string]string) + } + + // Set dirty. + tx.dirty = true } // ReadBucket opens the root bucket for read only access. If the bucket diff --git a/channeldb/kvdb/etcd/stm.go b/channeldb/kvdb/etcd/stm.go index 124f36164..a79cb9bad 100644 --- a/channeldb/kvdb/etcd/stm.go +++ b/channeldb/kvdb/etcd/stm.go @@ -30,6 +30,11 @@ type STM interface { // set. Returns nil if there's no matching key, or the key is empty. Get(key string) ([]byte, error) + // Lock adds a key to the lock set. If the lock set is not empty, we'll + // only check for conflicts in the lock set and the write set, instead + // of all read keys plus the write set. + Lock(key string) + // Put adds a value for a key to the txn's write set. Put(key, val string) @@ -144,6 +149,9 @@ type stm struct { // wset holds overwritten keys and their values. wset writeSet + // lset holds keys we intent to lock on. + lset map[string]interface{} + // getOpts are the opts used for gets. getOpts []v3.OpOption @@ -307,7 +315,20 @@ func (rs readSet) gets() []v3.Op { } // cmps returns a cmp list testing values in read set didn't change. -func (rs readSet) cmps() []v3.Cmp { +func (rs readSet) cmps(lset map[string]interface{}) []v3.Cmp { + if len(lset) > 0 { + cmps := make([]v3.Cmp, 0, len(lset)) + for key, _ := range lset { + if getValue, ok := rs[key]; ok { + cmps = append( + cmps, + v3.Compare(v3.ModRevision(key), "=", getValue.rev), + ) + } + } + return cmps + } + cmps := make([]v3.Cmp, 0, len(rs)) for key, getValue := range rs { cmps = append(cmps, v3.Compare(v3.ModRevision(key), "=", getValue.rev)) @@ -433,6 +454,13 @@ func (s *stm) Get(key string) ([]byte, error) { return nil, nil } +// Lock adds a key to the lock set. If the lock set is +// not empty, we'll only check conflicts for the keys +// in the lock set. +func (s *stm) Lock(key string) { + s.lset[key] = nil +} + // First returns the first key/value matching prefix. If there's no key starting // with prefix, Last will return nil. func (s *stm) First(prefix string) (*KV, error) { @@ -702,13 +730,16 @@ func (s *stm) OnCommit(cb func()) { // because the keys have changed return a CommitError, otherwise return a // DatabaseError. func (s *stm) commit() (CommitStats, error) { + rset := s.rset.cmps(s.lset) + wset := s.wset.cmps(s.revision + 1) + stats := CommitStats{ - Rset: len(s.rset), - Wset: len(s.wset), + Rset: len(rset), + Wset: len(wset), } // Create the compare set. - cmps := append(s.rset.cmps(), s.wset.cmps(s.revision+1)...) + cmps := append(rset, wset...) // Create a transaction with the optional abort context. txn := s.client.Txn(s.options.ctx) @@ -763,6 +794,7 @@ func (s *stm) Commit() error { func (s *stm) Rollback() { s.rset = make(map[string]stmGet) s.wset = make(map[string]stmPut) + s.lset = make(map[string]interface{}) s.getOpts = nil s.revision = math.MaxInt64 - 1 }