diff --git a/go.mod b/go.mod index a1b1b0c9e..ec7bde36c 100644 --- a/go.mod +++ b/go.mod @@ -46,7 +46,7 @@ require ( github.com/lightningnetwork/lnd/cert v1.1.0 github.com/lightningnetwork/lnd/clock v1.1.0 github.com/lightningnetwork/lnd/healthcheck v1.2.0 - github.com/lightningnetwork/lnd/kvdb v1.2.5 + github.com/lightningnetwork/lnd/kvdb v1.3.0 github.com/lightningnetwork/lnd/queue v1.1.0 github.com/lightningnetwork/lnd/ticker v1.1.0 github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 diff --git a/kvdb/etcd/readwrite_bucket.go b/kvdb/etcd/readwrite_bucket.go index 790074332..0b7782326 100644 --- a/kvdb/etcd/readwrite_bucket.go +++ b/kvdb/etcd/readwrite_bucket.go @@ -406,3 +406,9 @@ func (b *readWriteBucket) Prefetch(paths ...[]string) { b.tx.stm.Prefetch(flattenMap(keys), flattenMap(ranges)) } + +// ForAll is an optimized version of ForEach with the limitation that no +// additional queries can be executed within the callback. +func (b *readWriteBucket) ForAll(cb func(k, v []byte) error) error { + return b.ForEach(cb) +} diff --git a/kvdb/interface.go b/kvdb/interface.go index 6fe944394..892483024 100644 --- a/kvdb/interface.go +++ b/kvdb/interface.go @@ -109,6 +109,12 @@ type ExtendedRBucket interface { // Prefetch will attempt to prefetch all values under a path. Prefetch(paths ...[]string) + + // ForAll is an optimized version of ForEach. + // + // NOTE: ForAll differs from ForEach in that no additional queries can + // be executed within the callback. + ForAll(func(k, v []byte) error) error } // Prefetch will attempt to prefetch all values under a path from the passed @@ -119,6 +125,16 @@ func Prefetch(b RBucket, paths ...[]string) { } } +// ForAll is an optimized version of ForEach with the limitation that no +// additional queries can be executed within the callback. +func ForAll(b RBucket, cb func(k, v []byte) error) error { + if bucket, ok := b.(ExtendedRBucket); ok { + return bucket.ForAll(cb) + } + + return b.ForEach(cb) +} + // RootBucket is a wrapper to ExtendedRTx.RootBucket which does nothing if // the implementation doesn't have ExtendedRTx. func RootBucket(t RTx) RBucket { diff --git a/kvdb/postgres/readwrite_bucket.go b/kvdb/postgres/readwrite_bucket.go index 5e01aa6cd..933769919 100644 --- a/kvdb/postgres/readwrite_bucket.go +++ b/kvdb/postgres/readwrite_bucket.go @@ -427,3 +427,36 @@ func (b *readWriteBucket) Sequence() uint64 { return uint64(seq) } + +// Prefetch will attempt to prefetch all values under a path from the passed +// bucket. +func (b *readWriteBucket) Prefetch(paths ...[]string) {} + +// ForAll is an optimized version of ForEach with the limitation that no +// additional queries can be executed within the callback. +func (b *readWriteBucket) ForAll(cb func(k, v []byte) error) error { + rows, cancel, err := b.tx.Query( + "SELECT key, value FROM " + b.table + " WHERE " + + parentSelector(b.id) + " ORDER BY key", + ) + if err != nil { + return err + } + defer cancel() + + for rows.Next() { + var key, value []byte + + err := rows.Scan(&key, &value) + if err != nil { + return err + } + + err = cb(key, value) + if err != nil { + return err + } + } + + return nil +} diff --git a/kvdb/postgres/readwrite_tx.go b/kvdb/postgres/readwrite_tx.go index 59be148fc..e0c3e2371 100644 --- a/kvdb/postgres/readwrite_tx.go +++ b/kvdb/postgres/readwrite_tx.go @@ -175,6 +175,21 @@ func (tx *readWriteTx) QueryRow(query string, args ...interface{}) (*sql.Row, return tx.tx.QueryRowContext(ctx, query, args...), cancel } +// Query executes a multi-row query call with a timeout context. +func (tx *readWriteTx) Query(query string, args ...interface{}) (*sql.Rows, + func(), error) { + + ctx, cancel := tx.db.getTimeoutCtx() + rows, err := tx.tx.QueryContext(ctx, query, args...) + if err != nil { + cancel() + + return nil, func() {}, err + } + + return rows, cancel, nil +} + // Exec executes a Exec call with a timeout context. func (tx *readWriteTx) Exec(query string, args ...interface{}) (sql.Result, error) { diff --git a/kvdb/postgres_test.go b/kvdb/postgres_test.go index b26bba537..4645e6a30 100644 --- a/kvdb/postgres_test.go +++ b/kvdb/postgres_test.go @@ -84,7 +84,35 @@ func TestPostgres(t *testing.T) { }, { name: "bucket for each", - test: testBucketForEach, + test: func(t *testing.T, db walletdb.DB) { + testBucketIterator(t, db, func(bucket walletdb.ReadWriteBucket, + callback func(key, val []byte) error) error { + + return bucket.ForEach(callback) + }) + }, + expectedDb: m{ + "test_kv": []m{ + {"id": int64(1), "key": "apple", "parent_id": nil, "sequence": nil, "value": nil}, + {"id": int64(2), "key": "banana", "parent_id": int64(1), "sequence": nil, "value": nil}, + {"id": int64(3), "key": "key1", "parent_id": int64(1), "sequence": nil, "value": "val1"}, + {"id": int64(4), "key": "key1", "parent_id": int64(2), "sequence": nil, "value": "val1"}, + {"id": int64(5), "key": "key2", "parent_id": int64(1), "sequence": nil, "value": "val2"}, + {"id": int64(6), "key": "key2", "parent_id": int64(2), "sequence": nil, "value": "val2"}, + {"id": int64(7), "key": "key3", "parent_id": int64(1), "sequence": nil, "value": "val3"}, + {"id": int64(8), "key": "key3", "parent_id": int64(2), "sequence": nil, "value": "val3"}, + }, + }, + }, + { + name: "bucket for all", + test: func(t *testing.T, db walletdb.DB) { + testBucketIterator(t, db, func(bucket walletdb.ReadWriteBucket, + callback func(key, val []byte) error) error { + + return ForAll(bucket, callback) + }) + }, expectedDb: m{ "test_kv": []m{ {"id": int64(1), "key": "apple", "parent_id": nil, "sequence": nil, "value": nil}, diff --git a/kvdb/readwrite_bucket_test.go b/kvdb/readwrite_bucket_test.go index 764d48231..46ac4c12a 100644 --- a/kvdb/readwrite_bucket_test.go +++ b/kvdb/readwrite_bucket_test.go @@ -159,7 +159,20 @@ func testBucketDeletion(t *testing.T, db walletdb.DB) { require.Nil(t, err) } +type bucketIterator = func(walletdb.ReadWriteBucket, + func(key, val []byte) error) error + func testBucketForEach(t *testing.T, db walletdb.DB) { + testBucketIterator(t, db, func(bucket walletdb.ReadWriteBucket, + callback func(key, val []byte) error) error { + + return bucket.ForEach(callback) + }) +} + +func testBucketIterator(t *testing.T, db walletdb.DB, + iterator bucketIterator) { + err := Update(db, func(tx walletdb.ReadWriteTx) error { // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) @@ -199,7 +212,7 @@ func testBucketForEach(t *testing.T, db walletdb.DB) { require.Equal(t, expected, got) got = make(map[string]string) - err = banana.ForEach(func(key, val []byte) error { + err = iterator(banana, func(key, val []byte) error { got[string(key)] = string(val) return nil })