From 3a3b5413b99d258030095425291f83874bf8ba90 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 7 May 2020 18:28:52 +0800 Subject: [PATCH 001/218] lncfg: allow no auth on private addresses --- lncfg/address.go | 39 +++++++++++++++++++++++++++--- lncfg/address_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 3 deletions(-) diff --git a/lncfg/address.go b/lncfg/address.go index 7fe2d71d5..2c1770e4f 100644 --- a/lncfg/address.go +++ b/lncfg/address.go @@ -54,10 +54,10 @@ func NormalizeAddresses(addrs []string, defaultPort string, // interface. func EnforceSafeAuthentication(addrs []net.Addr, macaroonsActive bool) error { // We'll now examine all addresses that this RPC server is listening - // on. If it's a localhost address, we'll skip it, otherwise, we'll - // return an error if macaroons are inactive. + // on. If it's a localhost address or a private address, we'll skip it, + // otherwise, we'll return an error if macaroons are inactive. for _, addr := range addrs { - if IsLoopback(addr.String()) || IsUnix(addr) { + if IsLoopback(addr.String()) || IsUnix(addr) || IsPrivate(addr) { continue } @@ -117,6 +117,39 @@ func IsUnix(addr net.Addr) bool { return strings.HasPrefix(addr.Network(), "unix") } +// IsPrivate returns true if the address is private. The definitions are, +// https://en.wikipedia.org/wiki/Link-local_address +// https://en.wikipedia.org/wiki/Multicast_address +// Local IPv4 addresses, https://tools.ietf.org/html/rfc1918 +// Local IPv6 addresses, https://tools.ietf.org/html/rfc4193 +func IsPrivate(addr net.Addr) bool { + switch addr := addr.(type) { + case *net.TCPAddr: + // Check 169.254.0.0/16 and fe80::/10. + if addr.IP.IsLinkLocalUnicast() { + return true + } + + // Check 224.0.0.0/4 and ff00::/8. + if addr.IP.IsLinkLocalMulticast() { + return true + } + + // Check 10.0.0.0/8, 172.16.0.0/12 and 192.168.0.0/16. + if ip4 := addr.IP.To4(); ip4 != nil { + return ip4[0] == 10 || + (ip4[0] == 172 && ip4[1]&0xf0 == 16) || + (ip4[0] == 192 && ip4[1] == 168) + } + + // Check fc00::/7. + return len(addr.IP) == net.IPv6len && addr.IP[0]&0xfe == 0xfc + + default: + return false + } +} + // ParseAddressString converts an address in string format to a net.Addr that is // compatible with lnd. UDP is not supported because lnd needs reliable // connections. We accept a custom function to resolve any TCP addresses so diff --git a/lncfg/address_test.go b/lncfg/address_test.go index c35d7199d..208b0407e 100644 --- a/lncfg/address_test.go +++ b/lncfg/address_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcec" + "github.com/stretchr/testify/require" ) // addressTest defines a test vector for an address that contains the non- @@ -265,3 +266,57 @@ func validateAddr(t *testing.T, addr net.Addr, test addressTest) { ) } } + +func TestIsPrivate(t *testing.T) { + nonPrivateIPList := []net.IP{ + net.IPv4(169, 255, 0, 0), + {0xfe, 0x79, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + net.IPv4(225, 0, 0, 0), + {0xff, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + net.IPv4(11, 0, 0, 0), + net.IPv4(172, 15, 0, 0), + net.IPv4(192, 169, 0, 0), + {0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + net.IPv4(8, 8, 8, 8), + {2, 0, 0, 1, 4, 8, 6, 0, 4, 8, 6, 0, 8, 8, 8, 8}, + } + privateIPList := []net.IP{ + net.IPv4(169, 254, 0, 0), + {0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + net.IPv4(224, 0, 0, 0), + {0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + net.IPv4(10, 0, 0, 1), + net.IPv4(172, 16, 0, 1), + net.IPv4(172, 31, 255, 255), + net.IPv4(192, 168, 0, 1), + {0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + } + + testParams := []struct { + name string + ipList []net.IP + private bool + }{ + { + "Non-private addresses should return false", + nonPrivateIPList, false, + }, + { + "Private addresses should return true", + privateIPList, true, + }, + } + + for _, tt := range testParams { + test := tt + t.Run(test.name, func(t *testing.T) { + for _, ip := range test.ipList { + addr := &net.TCPAddr{IP: ip} + require.Equal( + t, test.private, IsPrivate(addr), + "expected IP: %s to be %v", ip, test.private, + ) + } + }) + } +} From 86d5facaa2d55594dee64e24d038ad1925461e32 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 21 Jul 2020 18:12:13 +0800 Subject: [PATCH 002/218] docs: update no-macaroons option in macaroon --- config.go | 2 +- docs/macaroons.md | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 368661bad..77a0ea230 100644 --- a/config.go +++ b/config.go @@ -146,7 +146,7 @@ type Config struct { TLSExtraDomains []string `long:"tlsextradomain" description:"Adds an extra domain to the generated certificate"` TLSAutoRefresh bool `long:"tlsautorefresh" description:"Re-generate TLS certificate and key if the IPs or domains are changed"` - NoMacaroons bool `long:"no-macaroons" description:"Disable macaroon authentication"` + NoMacaroons bool `long:"no-macaroons" description:"Disable macaroon authentication, can only be used if server is not listening on a public interface."` AdminMacPath string `long:"adminmacaroonpath" description:"Path to write the admin macaroon for lnd's RPC and REST services if it doesn't exist"` ReadMacPath string `long:"readonlymacaroonpath" description:"Path to write the read-only macaroon for lnd's RPC and REST services if it doesn't exist"` InvoiceMacPath string `long:"invoicemacaroonpath" description:"Path to the invoice-only macaroon for lnd's RPC and REST services if it doesn't exist"` diff --git a/docs/macaroons.md b/docs/macaroons.md index aae12d1d9..b1ed988cd 100644 --- a/docs/macaroons.md +++ b/docs/macaroons.md @@ -81,7 +81,14 @@ methods. This means a few important things: You can also run `lnd` with the `--no-macaroons` option, which skips the creation of the macaroon files and all macaroon checks within the RPC server. This means you can still pass a macaroon to the RPC server with a client, but -it won't be checked for validity. +it won't be checked for validity. Note that disabling authentication of a server +that's listening on a public interface is not allowed. This means the +`--no-macaroons` option is only permitted when the RPC server is in a private +network. In CIDR notation, the following IPs are considered private, +- [`169.254.0.0/16` and `fe80::/10`](https://en.wikipedia.org/wiki/Link-local_address). +- [`224.0.0.0/4` and `ff00::/8`](https://en.wikipedia.org/wiki/Multicast_address). +- [`10.0.0.0/8`, `172.16.0.0/12` and `192.168.0.0/16`](https://tools.ietf.org/html/rfc1918). +- [`fc00::/7`](https://tools.ietf.org/html/rfc4193). Since `lnd` requires macaroons by default in order to call RPC methods, `lncli` now reads a macaroon and provides it in the RPC call. Unless the path is From 70a69ce99001171334ebbdce09f944ea7fd72e30 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Mon, 29 Jun 2020 14:53:38 +0200 Subject: [PATCH 003/218] kvdb: make etcd tests use testify require instead of assert --- channeldb/kvdb/etcd/db_test.go | 26 +-- channeldb/kvdb/etcd/driver_test.go | 18 +- channeldb/kvdb/etcd/readwrite_bucket_test.go | 200 +++++++++---------- channeldb/kvdb/etcd/readwrite_cursor_test.go | 140 ++++++------- channeldb/kvdb/etcd/readwrite_tx_test.go | 90 ++++----- channeldb/kvdb/etcd/stm_test.go | 160 +++++++-------- 6 files changed, 317 insertions(+), 317 deletions(-) diff --git a/channeldb/kvdb/etcd/db_test.go b/channeldb/kvdb/etcd/db_test.go index 155d912ec..c4332db8a 100644 --- a/channeldb/kvdb/etcd/db_test.go +++ b/channeldb/kvdb/etcd/db_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCopy(t *testing.T) { @@ -18,30 +18,30 @@ func TestCopy(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, apple) + require.NoError(t, err) + require.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("key"), []byte("val"))) + require.NoError(t, apple.Put([]byte("key"), []byte("val"))) return nil }) // Expect non-zero copy. var buf bytes.Buffer - assert.NoError(t, db.Copy(&buf)) - assert.Greater(t, buf.Len(), 0) - assert.Nil(t, err) + require.NoError(t, db.Copy(&buf)) + require.Greater(t, buf.Len(), 0) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), vkey("key", "apple"): "val", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestAbortContext(t *testing.T) { @@ -57,19 +57,19 @@ func TestAbortContext(t *testing.T) { // Pass abort context and abort right away. db, err := newEtcdBackend(config) - assert.NoError(t, err) + require.NoError(t, err) cancel() // Expect that the update will fail. err = db.Update(func(tx walletdb.ReadWriteTx) error { _, err := tx.CreateTopLevelBucket([]byte("bucket")) - assert.NoError(t, err) + require.NoError(t, err) return nil }) - assert.Error(t, err, "context canceled") + require.Error(t, err, "context canceled") // No changes in the DB. - assert.Equal(t, map[string]string{}, f.Dump()) + require.Equal(t, map[string]string{}, f.Dump()) } diff --git a/channeldb/kvdb/etcd/driver_test.go b/channeldb/kvdb/etcd/driver_test.go index 365eda7a0..ea4196eff 100644 --- a/channeldb/kvdb/etcd/driver_test.go +++ b/channeldb/kvdb/etcd/driver_test.go @@ -6,25 +6,25 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestOpenCreateFailure(t *testing.T) { t.Parallel() db, err := walletdb.Open(dbType) - assert.Error(t, err) - assert.Nil(t, db) + require.Error(t, err) + require.Nil(t, db) db, err = walletdb.Open(dbType, "wrong") - assert.Error(t, err) - assert.Nil(t, db) + require.Error(t, err) + require.Nil(t, db) db, err = walletdb.Create(dbType) - assert.Error(t, err) - assert.Nil(t, db) + require.Error(t, err) + require.Nil(t, db) db, err = walletdb.Create(dbType, "wrong") - assert.Error(t, err) - assert.Nil(t, db) + require.Error(t, err) + require.Nil(t, db) } diff --git a/channeldb/kvdb/etcd/readwrite_bucket_test.go b/channeldb/kvdb/etcd/readwrite_bucket_test.go index a3a5d6208..f5de23b5e 100644 --- a/channeldb/kvdb/etcd/readwrite_bucket_test.go +++ b/channeldb/kvdb/etcd/readwrite_bucket_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBucketCreation(t *testing.T) { @@ -18,70 +18,70 @@ func TestBucketCreation(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // empty bucket name b, err := tx.CreateTopLevelBucket(nil) - assert.Error(t, walletdb.ErrBucketNameRequired, err) - assert.Nil(t, b) + require.Error(t, walletdb.ErrBucketNameRequired, err) + require.Nil(t, b) // empty bucket name b, err = tx.CreateTopLevelBucket([]byte("")) - assert.Error(t, walletdb.ErrBucketNameRequired, err) - assert.Nil(t, b) + require.Error(t, walletdb.ErrBucketNameRequired, err) + require.Nil(t, b) // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, apple) + require.NoError(t, err) + require.NotNil(t, apple) // Check bucket tx. - assert.Equal(t, tx, apple.Tx()) + require.Equal(t, tx, apple.Tx()) // "apple" already created b, err = tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, b) + require.NoError(t, err) + require.NotNil(t, b) // "apple/banana" banana, err := apple.CreateBucket([]byte("banana")) - assert.NoError(t, err) - assert.NotNil(t, banana) + require.NoError(t, err) + require.NotNil(t, banana) banana, err = apple.CreateBucketIfNotExists([]byte("banana")) - assert.NoError(t, err) - assert.NotNil(t, banana) + require.NoError(t, err) + require.NotNil(t, banana) // Try creating "apple/banana" again b, err = apple.CreateBucket([]byte("banana")) - assert.Error(t, walletdb.ErrBucketExists, err) - assert.Nil(t, b) + require.Error(t, walletdb.ErrBucketExists, err) + require.Nil(t, b) // "apple/mango" mango, err := apple.CreateBucket([]byte("mango")) - assert.Nil(t, err) - assert.NotNil(t, mango) + require.Nil(t, err) + require.NotNil(t, mango) // "apple/banana/pear" pear, err := banana.CreateBucket([]byte("pear")) - assert.Nil(t, err) - assert.NotNil(t, pear) + require.Nil(t, err) + require.NotNil(t, pear) // empty bucket - assert.Nil(t, apple.NestedReadWriteBucket(nil)) - assert.Nil(t, apple.NestedReadWriteBucket([]byte(""))) + require.Nil(t, apple.NestedReadWriteBucket(nil)) + require.Nil(t, apple.NestedReadWriteBucket([]byte(""))) // "apple/pear" doesn't exist - assert.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) + require.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) // "apple/banana" exits - assert.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) - assert.NotNil(t, apple.NestedReadBucket([]byte("banana"))) + require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) + require.NotNil(t, apple.NestedReadBucket([]byte("banana"))) return nil }) - assert.Nil(t, err) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -89,7 +89,7 @@ func TestBucketCreation(t *testing.T) { bkey("apple", "mango"): bval("apple", "mango"), bkey("apple", "banana", "pear"): bval("apple", "banana", "pear"), } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestBucketDeletion(t *testing.T) { @@ -99,99 +99,99 @@ func TestBucketDeletion(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) // "apple/banana" banana, err := apple.CreateBucket([]byte("banana")) - assert.Nil(t, err) - assert.NotNil(t, banana) + require.Nil(t, err) + require.NotNil(t, banana) kvs := []KV{{"key1", "val1"}, {"key2", "val2"}, {"key3", "val3"}} for _, kv := range kvs { - assert.NoError(t, banana.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) + require.NoError(t, banana.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) } // Delete a k/v from "apple/banana" - assert.NoError(t, banana.Delete([]byte("key2"))) + require.NoError(t, banana.Delete([]byte("key2"))) // Try getting/putting/deleting invalid k/v's. - assert.Nil(t, banana.Get(nil)) - assert.Error(t, walletdb.ErrKeyRequired, banana.Put(nil, []byte("val"))) - assert.Error(t, walletdb.ErrKeyRequired, banana.Delete(nil)) + require.Nil(t, banana.Get(nil)) + require.Error(t, walletdb.ErrKeyRequired, banana.Put(nil, []byte("val"))) + require.Error(t, walletdb.ErrKeyRequired, banana.Delete(nil)) // Try deleting a k/v that doesn't exist. - assert.NoError(t, banana.Delete([]byte("nokey"))) + require.NoError(t, banana.Delete([]byte("nokey"))) // "apple/pear" pear, err := apple.CreateBucket([]byte("pear")) - assert.Nil(t, err) - assert.NotNil(t, pear) + require.Nil(t, err) + require.NotNil(t, pear) // Put some values into "apple/pear" for _, kv := range kvs { - assert.Nil(t, pear.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), pear.Get([]byte(kv.key))) + require.Nil(t, pear.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), pear.Get([]byte(kv.key))) } // Create nested bucket "apple/pear/cherry" cherry, err := pear.CreateBucket([]byte("cherry")) - assert.Nil(t, err) - assert.NotNil(t, cherry) + require.Nil(t, err) + require.NotNil(t, cherry) // Put some values into "apple/pear/cherry" for _, kv := range kvs { - assert.NoError(t, cherry.Put([]byte(kv.key), []byte(kv.val))) + require.NoError(t, cherry.Put([]byte(kv.key), []byte(kv.val))) } // Read back values in "apple/pear/cherry" trough a read bucket. cherryReadBucket := pear.NestedReadBucket([]byte("cherry")) for _, kv := range kvs { - assert.Equal( + require.Equal( t, []byte(kv.val), cherryReadBucket.Get([]byte(kv.key)), ) } // Try deleting some invalid buckets. - assert.Error(t, + require.Error(t, walletdb.ErrBucketNameRequired, apple.DeleteNestedBucket(nil), ) // Try deleting a non existing bucket. - assert.Error( + require.Error( t, walletdb.ErrBucketNotFound, apple.DeleteNestedBucket([]byte("missing")), ) // Delete "apple/pear" - assert.Nil(t, apple.DeleteNestedBucket([]byte("pear"))) + require.Nil(t, apple.DeleteNestedBucket([]byte("pear"))) // "apple/pear" deleted - assert.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) + require.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) // "apple/pear/cherry" deleted - assert.Nil(t, pear.NestedReadWriteBucket([]byte("cherry"))) + require.Nil(t, pear.NestedReadWriteBucket([]byte("cherry"))) // Values deleted too. for _, kv := range kvs { - assert.Nil(t, pear.Get([]byte(kv.key))) - assert.Nil(t, cherry.Get([]byte(kv.key))) + require.Nil(t, pear.Get([]byte(kv.key))) + require.Nil(t, cherry.Get([]byte(kv.key))) } // "aple/banana" exists - assert.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) + require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) return nil }) - assert.Nil(t, err) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -199,7 +199,7 @@ func TestBucketDeletion(t *testing.T) { vkey("key1", "apple", "banana"): "val1", vkey("key3", "apple", "banana"): "val3", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestBucketForEach(t *testing.T) { @@ -209,28 +209,28 @@ func TestBucketForEach(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) // "apple/banana" banana, err := apple.CreateBucket([]byte("banana")) - assert.Nil(t, err) - assert.NotNil(t, banana) + require.Nil(t, err) + require.NotNil(t, banana) kvs := []KV{{"key1", "val1"}, {"key2", "val2"}, {"key3", "val3"}} // put some values into "apple" and "apple/banana" too for _, kv := range kvs { - assert.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) + require.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) - assert.Nil(t, banana.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) + require.Nil(t, banana.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) } got := make(map[string]string) @@ -246,8 +246,8 @@ func TestBucketForEach(t *testing.T) { "banana": "", } - assert.NoError(t, err) - assert.Equal(t, expected, got) + require.NoError(t, err) + require.Equal(t, expected, got) got = make(map[string]string) err = banana.ForEach(func(key, val []byte) error { @@ -255,15 +255,15 @@ func TestBucketForEach(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) // remove the sub-bucket key delete(expected, "banana") - assert.Equal(t, expected, got) + require.Equal(t, expected, got) return nil }) - assert.Nil(t, err) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -275,7 +275,7 @@ func TestBucketForEach(t *testing.T) { vkey("key2", "apple", "banana"): "val2", vkey("key3", "apple", "banana"): "val3", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestBucketForEachWithError(t *testing.T) { @@ -285,30 +285,30 @@ func TestBucketForEachWithError(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) // "apple/banana" banana, err := apple.CreateBucket([]byte("banana")) - assert.Nil(t, err) - assert.NotNil(t, banana) + require.Nil(t, err) + require.NotNil(t, banana) // "apple/pear" pear, err := apple.CreateBucket([]byte("pear")) - assert.Nil(t, err) - assert.NotNil(t, pear) + require.Nil(t, err) + require.NotNil(t, pear) kvs := []KV{{"key1", "val1"}, {"key2", "val2"}} // Put some values into "apple" and "apple/banana" too. for _, kv := range kvs { - assert.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) + require.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) } got := make(map[string]string) @@ -328,8 +328,8 @@ func TestBucketForEachWithError(t *testing.T) { "key1": "val1", } - assert.Equal(t, expected, got) - assert.Error(t, err) + require.Equal(t, expected, got) + require.Error(t, err) got = make(map[string]string) i = 0 @@ -350,12 +350,12 @@ func TestBucketForEachWithError(t *testing.T) { "banana": "", } - assert.Equal(t, expected, got) - assert.Error(t, err) + require.Equal(t, expected, got) + require.Error(t, err) return nil }) - assert.Nil(t, err) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -364,7 +364,7 @@ func TestBucketForEachWithError(t *testing.T) { vkey("key1", "apple"): "val1", vkey("key2", "apple"): "val2", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestBucketSequence(t *testing.T) { @@ -374,31 +374,31 @@ func TestBucketSequence(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) banana, err := apple.CreateBucket([]byte("banana")) - assert.Nil(t, err) - assert.NotNil(t, banana) + require.Nil(t, err) + require.NotNil(t, banana) - assert.Equal(t, uint64(0), apple.Sequence()) - assert.Equal(t, uint64(0), banana.Sequence()) + require.Equal(t, uint64(0), apple.Sequence()) + require.Equal(t, uint64(0), banana.Sequence()) - assert.Nil(t, apple.SetSequence(math.MaxUint64)) - assert.Equal(t, uint64(math.MaxUint64), apple.Sequence()) + require.Nil(t, apple.SetSequence(math.MaxUint64)) + require.Equal(t, uint64(math.MaxUint64), apple.Sequence()) for i := uint64(0); i < uint64(5); i++ { s, err := apple.NextSequence() - assert.Nil(t, err) - assert.Equal(t, i, s) + require.Nil(t, err) + require.Equal(t, i, s) } return nil }) - assert.Nil(t, err) + require.Nil(t, err) } diff --git a/channeldb/kvdb/etcd/readwrite_cursor_test.go b/channeldb/kvdb/etcd/readwrite_cursor_test.go index c14de7aa8..16dcc9dfd 100644 --- a/channeldb/kvdb/etcd/readwrite_cursor_test.go +++ b/channeldb/kvdb/etcd/readwrite_cursor_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestReadCursorEmptyInterval(t *testing.T) { @@ -16,41 +16,41 @@ func TestReadCursorEmptyInterval(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { b, err := tx.CreateTopLevelBucket([]byte("alma")) - assert.NoError(t, err) - assert.NotNil(t, b) + require.NoError(t, err) + require.NotNil(t, b) return nil }) - assert.NoError(t, err) + require.NoError(t, err) err = db.View(func(tx walletdb.ReadTx) error { b := tx.ReadBucket([]byte("alma")) - assert.NotNil(t, b) + require.NotNil(t, b) cursor := b.ReadCursor() k, v := cursor.First() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) k, v = cursor.Next() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) k, v = cursor.Last() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) k, v = cursor.Prev() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) return nil }) - assert.NoError(t, err) + require.NoError(t, err) } func TestReadCursorNonEmptyInterval(t *testing.T) { @@ -60,7 +60,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) testKeyValues := []KV{ {"b", "1"}, @@ -71,20 +71,20 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { err = db.Update(func(tx walletdb.ReadWriteTx) error { b, err := tx.CreateTopLevelBucket([]byte("alma")) - assert.NoError(t, err) - assert.NotNil(t, b) + require.NoError(t, err) + require.NotNil(t, b) for _, kv := range testKeyValues { - assert.NoError(t, b.Put([]byte(kv.key), []byte(kv.val))) + require.NoError(t, b.Put([]byte(kv.key), []byte(kv.val))) } return nil }) - assert.NoError(t, err) + require.NoError(t, err) err = db.View(func(tx walletdb.ReadTx) error { b := tx.ReadBucket([]byte("alma")) - assert.NotNil(t, b) + require.NotNil(t, b) // Iterate from the front. var kvs []KV @@ -95,7 +95,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { kvs = append(kvs, KV{string(k), string(v)}) k, v = cursor.Next() } - assert.Equal(t, testKeyValues, kvs) + require.Equal(t, testKeyValues, kvs) // Iterate from the back. kvs = []KV{} @@ -105,29 +105,29 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { kvs = append(kvs, KV{string(k), string(v)}) k, v = cursor.Prev() } - assert.Equal(t, reverseKVs(testKeyValues), kvs) + require.Equal(t, reverseKVs(testKeyValues), kvs) // Random access perm := []int{3, 0, 2, 1} for _, i := range perm { k, v := cursor.Seek([]byte(testKeyValues[i].key)) - assert.Equal(t, []byte(testKeyValues[i].key), k) - assert.Equal(t, []byte(testKeyValues[i].val), v) + require.Equal(t, []byte(testKeyValues[i].key), k) + require.Equal(t, []byte(testKeyValues[i].val), v) } // Seek to nonexisting key. k, v = cursor.Seek(nil) - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) k, v = cursor.Seek([]byte("x")) - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) return nil }) - assert.NoError(t, err) + require.NoError(t, err) } func TestReadWriteCursor(t *testing.T) { @@ -137,7 +137,7 @@ func TestReadWriteCursor(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) testKeyValues := []KV{ {"b", "1"}, @@ -149,24 +149,24 @@ func TestReadWriteCursor(t *testing.T) { count := len(testKeyValues) // Pre-store the first half of the interval. - assert.NoError(t, db.Update(func(tx walletdb.ReadWriteTx) error { + require.NoError(t, db.Update(func(tx walletdb.ReadWriteTx) error { b, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, b) + require.NoError(t, err) + require.NotNil(t, b) for i := 0; i < count/2; i++ { err = b.Put( []byte(testKeyValues[i].key), []byte(testKeyValues[i].val), ) - assert.NoError(t, err) + require.NoError(t, err) } return nil })) err = db.Update(func(tx walletdb.ReadWriteTx) error { b := tx.ReadWriteBucket([]byte("apple")) - assert.NotNil(t, b) + require.NotNil(t, b) // Store the second half of the interval. for i := count / 2; i < count; i++ { @@ -174,77 +174,77 @@ func TestReadWriteCursor(t *testing.T) { []byte(testKeyValues[i].key), []byte(testKeyValues[i].val), ) - assert.NoError(t, err) + require.NoError(t, err) } cursor := b.ReadWriteCursor() // First on valid interval. fk, fv := cursor.First() - assert.Equal(t, []byte("b"), fk) - assert.Equal(t, []byte("1"), fv) + require.Equal(t, []byte("b"), fk) + require.Equal(t, []byte("1"), fv) // Prev(First()) = nil k, v := cursor.Prev() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) // Last on valid interval. lk, lv := cursor.Last() - assert.Equal(t, []byte("e"), lk) - assert.Equal(t, []byte("4"), lv) + require.Equal(t, []byte("e"), lk) + require.Equal(t, []byte("4"), lv) // Next(Last()) = nil k, v = cursor.Next() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) // Delete first item, then add an item before the // deleted one. Check that First/Next will "jump" // over the deleted item and return the new first. _, _ = cursor.First() - assert.NoError(t, cursor.Delete()) - assert.NoError(t, b.Put([]byte("a"), []byte("0"))) + require.NoError(t, cursor.Delete()) + require.NoError(t, b.Put([]byte("a"), []byte("0"))) fk, fv = cursor.First() - assert.Equal(t, []byte("a"), fk) - assert.Equal(t, []byte("0"), fv) + require.Equal(t, []byte("a"), fk) + require.Equal(t, []byte("0"), fv) k, v = cursor.Next() - assert.Equal(t, []byte("c"), k) - assert.Equal(t, []byte("2"), v) + require.Equal(t, []byte("c"), k) + require.Equal(t, []byte("2"), v) // Similarly test that a new end is returned if // the old end is deleted first. _, _ = cursor.Last() - assert.NoError(t, cursor.Delete()) - assert.NoError(t, b.Put([]byte("f"), []byte("5"))) + require.NoError(t, cursor.Delete()) + require.NoError(t, b.Put([]byte("f"), []byte("5"))) lk, lv = cursor.Last() - assert.Equal(t, []byte("f"), lk) - assert.Equal(t, []byte("5"), lv) + require.Equal(t, []byte("f"), lk) + require.Equal(t, []byte("5"), lv) k, v = cursor.Prev() - assert.Equal(t, []byte("da"), k) - assert.Equal(t, []byte("3"), v) + require.Equal(t, []byte("da"), k) + require.Equal(t, []byte("3"), v) // Overwrite k/v in the middle of the interval. - assert.NoError(t, b.Put([]byte("c"), []byte("3"))) + require.NoError(t, b.Put([]byte("c"), []byte("3"))) k, v = cursor.Prev() - assert.Equal(t, []byte("c"), k) - assert.Equal(t, []byte("3"), v) + require.Equal(t, []byte("c"), k) + require.Equal(t, []byte("3"), v) // Insert new key/values. - assert.NoError(t, b.Put([]byte("cx"), []byte("x"))) - assert.NoError(t, b.Put([]byte("cy"), []byte("y"))) + require.NoError(t, b.Put([]byte("cx"), []byte("x"))) + require.NoError(t, b.Put([]byte("cy"), []byte("y"))) k, v = cursor.Next() - assert.Equal(t, []byte("cx"), k) - assert.Equal(t, []byte("x"), v) + require.Equal(t, []byte("cx"), k) + require.Equal(t, []byte("x"), v) k, v = cursor.Next() - assert.Equal(t, []byte("cy"), k) - assert.Equal(t, []byte("y"), v) + require.Equal(t, []byte("cy"), k) + require.Equal(t, []byte("y"), v) expected := []KV{ {"a", "0"}, @@ -263,7 +263,7 @@ func TestReadWriteCursor(t *testing.T) { kvs = append(kvs, KV{string(k), string(v)}) k, v = cursor.Next() } - assert.Equal(t, expected, kvs) + require.Equal(t, expected, kvs) // Iterate from the back. kvs = []KV{} @@ -273,12 +273,12 @@ func TestReadWriteCursor(t *testing.T) { kvs = append(kvs, KV{string(k), string(v)}) k, v = cursor.Prev() } - assert.Equal(t, reverseKVs(expected), kvs) + require.Equal(t, reverseKVs(expected), kvs) return nil }) - assert.NoError(t, err) + require.NoError(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -289,5 +289,5 @@ func TestReadWriteCursor(t *testing.T) { vkey("da", "apple"): "3", vkey("f", "apple"): "5", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } diff --git a/channeldb/kvdb/etcd/readwrite_tx_test.go b/channeldb/kvdb/etcd/readwrite_tx_test.go index f65faa545..bab6967f8 100644 --- a/channeldb/kvdb/etcd/readwrite_tx_test.go +++ b/channeldb/kvdb/etcd/readwrite_tx_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTxManualCommit(t *testing.T) { @@ -16,11 +16,11 @@ func TestTxManualCommit(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) tx, err := db.BeginReadWriteTx() - assert.NoError(t, err) - assert.NotNil(t, tx) + require.NoError(t, err) + require.NotNil(t, tx) committed := false @@ -29,24 +29,24 @@ func TestTxManualCommit(t *testing.T) { }) apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) + require.NoError(t, err) + require.NotNil(t, apple) + require.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) banana, err := tx.CreateTopLevelBucket([]byte("banana")) - assert.NoError(t, err) - assert.NotNil(t, banana) - assert.NoError(t, banana.Put([]byte("testKey"), []byte("testVal"))) - assert.NoError(t, tx.DeleteTopLevelBucket([]byte("banana"))) + require.NoError(t, err) + require.NotNil(t, banana) + require.NoError(t, banana.Put([]byte("testKey"), []byte("testVal"))) + require.NoError(t, tx.DeleteTopLevelBucket([]byte("banana"))) - assert.NoError(t, tx.Commit()) - assert.True(t, committed) + require.NoError(t, tx.Commit()) + require.True(t, committed) expected := map[string]string{ bkey("apple"): bval("apple"), vkey("testKey", "apple"): "testVal", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestTxRollback(t *testing.T) { @@ -56,21 +56,21 @@ func TestTxRollback(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) tx, err := db.BeginReadWriteTx() - assert.Nil(t, err) - assert.NotNil(t, tx) + require.Nil(t, err) + require.NotNil(t, tx) apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) + require.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) - assert.NoError(t, tx.Rollback()) - assert.Error(t, walletdb.ErrTxClosed, tx.Commit()) - assert.Equal(t, map[string]string{}, f.Dump()) + require.NoError(t, tx.Rollback()) + require.Error(t, walletdb.ErrTxClosed, tx.Commit()) + require.Equal(t, map[string]string{}, f.Dump()) } func TestChangeDuringManualTx(t *testing.T) { @@ -80,24 +80,24 @@ func TestChangeDuringManualTx(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) tx, err := db.BeginReadWriteTx() - assert.Nil(t, err) - assert.NotNil(t, tx) + require.Nil(t, err) + require.NotNil(t, tx) apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) + require.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) // Try overwriting the bucket key. f.Put(bkey("apple"), "banana") // TODO: translate error - assert.NotNil(t, tx.Commit()) - assert.Equal(t, map[string]string{ + require.NotNil(t, tx.Commit()) + require.Equal(t, map[string]string{ bkey("apple"): "banana", }, f.Dump()) } @@ -109,16 +109,16 @@ func TestChangeDuringUpdate(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) count := 0 err = db.Update(func(tx walletdb.ReadWriteTx) error { apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, apple) + require.NoError(t, err) + require.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("key"), []byte("value"))) + require.NoError(t, apple.Put([]byte("key"), []byte("value"))) if count == 0 { f.Put(vkey("key", "apple"), "new_value") @@ -127,30 +127,30 @@ func TestChangeDuringUpdate(t *testing.T) { cursor := apple.ReadCursor() k, v := cursor.First() - assert.Equal(t, []byte("key"), k) - assert.Equal(t, []byte("value"), v) - assert.Equal(t, v, apple.Get([]byte("key"))) + require.Equal(t, []byte("key"), k) + require.Equal(t, []byte("value"), v) + require.Equal(t, v, apple.Get([]byte("key"))) k, v = cursor.Next() if count == 0 { - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) } else { - assert.Equal(t, []byte("key2"), k) - assert.Equal(t, []byte("value2"), v) + require.Equal(t, []byte("key2"), k) + require.Equal(t, []byte("value2"), v) } count++ return nil }) - assert.Nil(t, err) - assert.Equal(t, count, 2) + require.Nil(t, err) + require.Equal(t, count, 2) expected := map[string]string{ bkey("apple"): bval("apple"), vkey("key", "apple"): "value", vkey("key2", "apple"): "value2", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } diff --git a/channeldb/kvdb/etcd/stm_test.go b/channeldb/kvdb/etcd/stm_test.go index 767963d4f..6beffc284 100644 --- a/channeldb/kvdb/etcd/stm_test.go +++ b/channeldb/kvdb/etcd/stm_test.go @@ -6,7 +6,7 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func reverseKVs(a []KV) []KV { @@ -24,7 +24,7 @@ func TestPutToEmpty(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) apply := func(stm STM) error { stm.Put("123", "abc") @@ -32,9 +32,9 @@ func TestPutToEmpty(t *testing.T) { } err = RunSTM(db.cli, apply) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "abc", f.Get("123")) + require.Equal(t, "abc", f.Get("123")) } func TestGetPutDel(t *testing.T) { @@ -56,64 +56,64 @@ func TestGetPutDel(t *testing.T) { } db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) apply := func(stm STM) error { // Get some non existing keys. v, err := stm.Get("") - assert.NoError(t, err) - assert.Nil(t, v) + require.NoError(t, err) + require.Nil(t, v) v, err = stm.Get("x") - assert.NoError(t, err) - assert.Nil(t, v) + require.NoError(t, err) + require.Nil(t, v) // Get all existing keys. for _, kv := range testKeyValues { v, err = stm.Get(kv.key) - assert.NoError(t, err) - assert.Equal(t, []byte(kv.val), v) + require.NoError(t, err) + require.Equal(t, []byte(kv.val), v) } // Overwrite, then delete an existing key. stm.Put("c", "6") v, err = stm.Get("c") - assert.NoError(t, err) - assert.Equal(t, []byte("6"), v) + require.NoError(t, err) + require.Equal(t, []byte("6"), v) stm.Del("c") v, err = stm.Get("c") - assert.NoError(t, err) - assert.Nil(t, v) + require.NoError(t, err) + require.Nil(t, v) // Re-add the deleted key. stm.Put("c", "7") v, err = stm.Get("c") - assert.NoError(t, err) - assert.Equal(t, []byte("7"), v) + require.NoError(t, err) + require.Equal(t, []byte("7"), v) // Add a new key. stm.Put("x", "x") v, err = stm.Get("x") - assert.NoError(t, err) - assert.Equal(t, []byte("x"), v) + require.NoError(t, err) + require.Equal(t, []byte("x"), v) return nil } err = RunSTM(db.cli, apply) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "1", f.Get("a")) - assert.Equal(t, "2", f.Get("b")) - assert.Equal(t, "7", f.Get("c")) - assert.Equal(t, "4", f.Get("d")) - assert.Equal(t, "5", f.Get("e")) - assert.Equal(t, "x", f.Get("x")) + require.Equal(t, "1", f.Get("a")) + require.Equal(t, "2", f.Get("b")) + require.Equal(t, "7", f.Get("c")) + require.Equal(t, "4", f.Get("d")) + require.Equal(t, "5", f.Get("e")) + require.Equal(t, "x", f.Get("x")) } func TestFirstLastNextPrev(t *testing.T) { @@ -134,44 +134,44 @@ func TestFirstLastNextPrev(t *testing.T) { } db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) apply := func(stm STM) error { // First/Last on valid multi item interval. kv, err := stm.First("k") - assert.NoError(t, err) - assert.Equal(t, &KV{"kb", "1"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kb", "1"}, kv) kv, err = stm.Last("k") - assert.NoError(t, err) - assert.Equal(t, &KV{"ke", "4"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"ke", "4"}, kv) // First/Last on single item interval. kv, err = stm.First("w") - assert.NoError(t, err) - assert.Equal(t, &KV{"w", "w"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"w", "w"}, kv) kv, err = stm.Last("w") - assert.NoError(t, err) - assert.Equal(t, &KV{"w", "w"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"w", "w"}, kv) // Next/Prev on start/end. kv, err = stm.Next("k", "ke") - assert.NoError(t, err) - assert.Nil(t, kv) + require.NoError(t, err) + require.Nil(t, kv) kv, err = stm.Prev("k", "kb") - assert.NoError(t, err) - assert.Nil(t, kv) + require.NoError(t, err) + require.Nil(t, kv) // Next/Prev in the middle. kv, err = stm.Next("k", "kc") - assert.NoError(t, err) - assert.Equal(t, &KV{"kda", "3"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kda", "3"}, kv) kv, err = stm.Prev("k", "ke") - assert.NoError(t, err) - assert.Equal(t, &KV{"kda", "3"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kda", "3"}, kv) // Delete first item, then add an item before the // deleted one. Check that First/Next will "jump" @@ -180,12 +180,12 @@ func TestFirstLastNextPrev(t *testing.T) { stm.Put("ka", "0") kv, err = stm.First("k") - assert.NoError(t, err) - assert.Equal(t, &KV{"ka", "0"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"ka", "0"}, kv) kv, err = stm.Prev("k", "kc") - assert.NoError(t, err) - assert.Equal(t, &KV{"ka", "0"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"ka", "0"}, kv) // Similarly test that a new end is returned if // the old end is deleted first. @@ -193,19 +193,19 @@ func TestFirstLastNextPrev(t *testing.T) { stm.Put("kf", "5") kv, err = stm.Last("k") - assert.NoError(t, err) - assert.Equal(t, &KV{"kf", "5"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kf", "5"}, kv) kv, err = stm.Next("k", "kda") - assert.NoError(t, err) - assert.Equal(t, &KV{"kf", "5"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kf", "5"}, kv) // Overwrite one in the middle. stm.Put("kda", "6") kv, err = stm.Next("k", "kc") - assert.NoError(t, err) - assert.Equal(t, &KV{"kda", "6"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kda", "6"}, kv) // Add three in the middle, then delete one. stm.Put("kdb", "7") @@ -218,12 +218,12 @@ func TestFirstLastNextPrev(t *testing.T) { var kvs []KV curr, err := stm.First("k") - assert.NoError(t, err) + require.NoError(t, err) for curr != nil { kvs = append(kvs, *curr) curr, err = stm.Next("k", curr.key) - assert.NoError(t, err) + require.NoError(t, err) } expected := []KV{ @@ -234,37 +234,37 @@ func TestFirstLastNextPrev(t *testing.T) { {"kdd", "9"}, {"kf", "5"}, } - assert.Equal(t, expected, kvs) + require.Equal(t, expected, kvs) // Similarly check that stepping from last to first // returns the expected sequence. kvs = []KV{} curr, err = stm.Last("k") - assert.NoError(t, err) + require.NoError(t, err) for curr != nil { kvs = append(kvs, *curr) curr, err = stm.Prev("k", curr.key) - assert.NoError(t, err) + require.NoError(t, err) } expected = reverseKVs(expected) - assert.Equal(t, expected, kvs) + require.Equal(t, expected, kvs) return nil } err = RunSTM(db.cli, apply) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "0", f.Get("ka")) - assert.Equal(t, "2", f.Get("kc")) - assert.Equal(t, "6", f.Get("kda")) - assert.Equal(t, "7", f.Get("kdb")) - assert.Equal(t, "9", f.Get("kdd")) - assert.Equal(t, "5", f.Get("kf")) - assert.Equal(t, "w", f.Get("w")) + require.Equal(t, "0", f.Get("ka")) + require.Equal(t, "2", f.Get("kc")) + require.Equal(t, "6", f.Get("kda")) + require.Equal(t, "7", f.Get("kdb")) + require.Equal(t, "9", f.Get("kdd")) + require.Equal(t, "5", f.Get("kf")) + require.Equal(t, "w", f.Get("w")) } func TestCommitError(t *testing.T) { @@ -274,7 +274,7 @@ func TestCommitError(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) // Preset DB state. f.Put("123", "xyz") @@ -285,10 +285,10 @@ func TestCommitError(t *testing.T) { apply := func(stm STM) error { // STM must have the key/value. val, err := stm.Get("123") - assert.NoError(t, err) + require.NoError(t, err) if cnt == 0 { - assert.Equal(t, []byte("xyz"), val) + require.Equal(t, []byte("xyz"), val) // Put a conflicting key/value during the first apply. f.Put("123", "def") @@ -302,10 +302,10 @@ func TestCommitError(t *testing.T) { } err = RunSTM(db.cli, apply) - assert.NoError(t, err) - assert.Equal(t, 2, cnt) + require.NoError(t, err) + require.Equal(t, 2, cnt) - assert.Equal(t, "abc", f.Get("123")) + require.Equal(t, "abc", f.Get("123")) } func TestManualTxError(t *testing.T) { @@ -315,7 +315,7 @@ func TestManualTxError(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) // Preset DB state. f.Put("123", "xyz") @@ -323,22 +323,22 @@ func TestManualTxError(t *testing.T) { stm := NewSTM(db.cli) val, err := stm.Get("123") - assert.NoError(t, err) - assert.Equal(t, []byte("xyz"), val) + require.NoError(t, err) + require.Equal(t, []byte("xyz"), val) // Put a conflicting key/value. f.Put("123", "def") // Should still get the original version. val, err = stm.Get("123") - assert.NoError(t, err) - assert.Equal(t, []byte("xyz"), val) + require.NoError(t, err) + require.Equal(t, []byte("xyz"), val) // Commit will fail with CommitError. err = stm.Commit() var e CommitError - assert.True(t, errors.As(err, &e)) + require.True(t, errors.As(err, &e)) // We expect that the transacton indeed did not commit. - assert.Equal(t, "def", f.Get("123")) + require.Equal(t, "def", f.Get("123")) } From 9173958f9c93837b250d2f1a0adea00b7438610a Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Mon, 29 Jun 2020 15:09:35 +0200 Subject: [PATCH 004/218] kvdb: s/hu/en/g --- channeldb/kvdb/etcd/readwrite_cursor_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/channeldb/kvdb/etcd/readwrite_cursor_test.go b/channeldb/kvdb/etcd/readwrite_cursor_test.go index 16dcc9dfd..bc457e7ed 100644 --- a/channeldb/kvdb/etcd/readwrite_cursor_test.go +++ b/channeldb/kvdb/etcd/readwrite_cursor_test.go @@ -19,7 +19,7 @@ func TestReadCursorEmptyInterval(t *testing.T) { require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { - b, err := tx.CreateTopLevelBucket([]byte("alma")) + b, err := tx.CreateTopLevelBucket([]byte("apple")) require.NoError(t, err) require.NotNil(t, b) @@ -28,7 +28,7 @@ func TestReadCursorEmptyInterval(t *testing.T) { require.NoError(t, err) err = db.View(func(tx walletdb.ReadTx) error { - b := tx.ReadBucket([]byte("alma")) + b := tx.ReadBucket([]byte("apple")) require.NotNil(t, b) cursor := b.ReadCursor() @@ -70,7 +70,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { } err = db.Update(func(tx walletdb.ReadWriteTx) error { - b, err := tx.CreateTopLevelBucket([]byte("alma")) + b, err := tx.CreateTopLevelBucket([]byte("apple")) require.NoError(t, err) require.NotNil(t, b) @@ -83,7 +83,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { require.NoError(t, err) err = db.View(func(tx walletdb.ReadTx) error { - b := tx.ReadBucket([]byte("alma")) + b := tx.ReadBucket([]byte("apple")) require.NotNil(t, b) // Iterate from the front. From 63e9d6102fe87511762ff6fe97a03f22700bcee6 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Fri, 26 Jun 2020 20:08:08 +0200 Subject: [PATCH 005/218] kvdb+etcd: change flattened bucket key derivation algorithm This commit changes the key derivation algo we use to emulate buckets similar to bbolt. The issue with prefixing keys with either a bucket or a value prefix is that the cursor couldn't effectively iterate trough all keys in a bucket, as it skipped the bucket keys. While there are multiple ways to fix that issue (eg. two pointers, iterating value keys then bucket keys, etc), the cleanest is to instead of prefixes in keys we use a postfix indicating whether a key is a bucket or a value. This also simplifies all operations where we (recursively) iterate a bucket and is equivalent with the prefixing key derivation with the addition that bucket and value keys are now continous. --- channeldb/kvdb/etcd/bucket.go | 75 ++++++++++++-------- channeldb/kvdb/etcd/readwrite_bucket.go | 68 +++++------------- channeldb/kvdb/etcd/readwrite_bucket_test.go | 7 +- channeldb/kvdb/etcd/readwrite_cursor.go | 26 ++++--- channeldb/kvdb/etcd/readwrite_cursor_test.go | 75 ++++++++++++++++++++ 5 files changed, 151 insertions(+), 100 deletions(-) diff --git a/channeldb/kvdb/etcd/bucket.go b/channeldb/kvdb/etcd/bucket.go index 3bc087dbf..8a1ff071e 100644 --- a/channeldb/kvdb/etcd/bucket.go +++ b/channeldb/kvdb/etcd/bucket.go @@ -11,9 +11,9 @@ const ( ) var ( - bucketPrefix = []byte("b") - valuePrefix = []byte("v") - sequencePrefix = []byte("$") + valuePostfix = []byte{0x00} + bucketPostfix = []byte{0xFF} + sequencePrefix = []byte("$seq$") ) // makeBucketID returns a deterministic key for the passed byte slice. @@ -28,52 +28,65 @@ func isValidBucketID(s []byte) bool { return len(s) == bucketIDLength } -// makeKey concatenates prefix, parent and key into one byte slice. -// The prefix indicates the use of this key (whether bucket, value or sequence), -// while parentID refers to the parent bucket. -func makeKey(prefix, parent, key []byte) []byte { - keyBuf := make([]byte, len(prefix)+len(parent)+len(key)) - copy(keyBuf, prefix) - copy(keyBuf[len(prefix):], parent) - copy(keyBuf[len(prefix)+len(parent):], key) +// makeKey concatenates parent, key and postfix into one byte slice. +// The postfix indicates the use of this key (whether bucket or value), while +// parent refers to the parent bucket. +func makeKey(parent, key, postfix []byte) []byte { + keyBuf := make([]byte, len(parent)+len(key)+len(postfix)) + copy(keyBuf, parent) + copy(keyBuf[len(parent):], key) + copy(keyBuf[len(parent)+len(key):], postfix) return keyBuf } -// makePrefix concatenates prefix with parent into one byte slice. -func makePrefix(prefix []byte, parent []byte) []byte { - prefixBuf := make([]byte, len(prefix)+len(parent)) - copy(prefixBuf, prefix) - copy(prefixBuf[len(prefix):], parent) - - return prefixBuf -} - // makeBucketKey returns a bucket key from the passed parent bucket id and // the key. func makeBucketKey(parent []byte, key []byte) []byte { - return makeKey(bucketPrefix, parent, key) + return makeKey(parent, key, bucketPostfix) } // makeValueKey returns a value key from the passed parent bucket id and // the key. func makeValueKey(parent []byte, key []byte) []byte { - return makeKey(valuePrefix, parent, key) + return makeKey(parent, key, valuePostfix) } // makeSequenceKey returns a sequence key of the passed parent bucket id. func makeSequenceKey(parent []byte) []byte { - return makeKey(sequencePrefix, parent, nil) + keyBuf := make([]byte, len(sequencePrefix)+len(parent)) + copy(keyBuf, sequencePrefix) + copy(keyBuf[len(sequencePrefix):], parent) + return keyBuf } -// makeBucketPrefix returns the bucket prefix of the passed parent bucket id. -// This prefix is used for all sub buckets. -func makeBucketPrefix(parent []byte) []byte { - return makePrefix(bucketPrefix, parent) +// isBucketKey returns true if the passed key is a bucket key, meaning it +// keys a bucket name. +func isBucketKey(key string) bool { + if len(key) < bucketIDLength+1 { + return false + } + + return key[len(key)-1] == bucketPostfix[0] } -// makeValuePrefix returns the value prefix of the passed parent bucket id. -// This prefix is used for all key/values in the bucket. -func makeValuePrefix(parent []byte) []byte { - return makePrefix(valuePrefix, parent) +// getKey chops out the key from the raw key (by removing the bucket id +// prefixing the key and the postfix indicating whether it is a bucket or +// a value key) +func getKey(rawKey string) []byte { + return []byte(rawKey[bucketIDLength : len(rawKey)-1]) +} + +// getKeyVal chops out the key from the raw key (by removing the bucket id +// prefixing the key and the postfix indicating whether it is a bucket or +// a value key) and also returns the appropriate value for the key, which is +// nil in case of buckets (or the set value otherwise). +func getKeyVal(kv *KV) ([]byte, []byte) { + var val []byte + + if !isBucketKey(kv.key) { + val = []byte(kv.val) + } + + return getKey(kv.key), val } diff --git a/channeldb/kvdb/etcd/readwrite_bucket.go b/channeldb/kvdb/etcd/readwrite_bucket.go index e60d2cec3..20af7d929 100644 --- a/channeldb/kvdb/etcd/readwrite_bucket.go +++ b/channeldb/kvdb/etcd/readwrite_bucket.go @@ -46,44 +46,23 @@ func (b *readWriteBucket) NestedReadBucket(key []byte) walletdb.ReadBucket { // is nil, but it does not include the key/value pairs within those // nested buckets. func (b *readWriteBucket) ForEach(cb func(k, v []byte) error) error { - prefix := makeValuePrefix(b.id) - prefixLen := len(prefix) + prefix := string(b.id) // Get the first matching key that is in the bucket. - kv, err := b.tx.stm.First(string(prefix)) + kv, err := b.tx.stm.First(prefix) if err != nil { return err } for kv != nil { - if err := cb([]byte(kv.key[prefixLen:]), []byte(kv.val)); err != nil { + key, val := getKeyVal(kv) + + if err := cb(key, val); err != nil { return err } // Step to the next key. - kv, err = b.tx.stm.Next(string(prefix), kv.key) - if err != nil { - return err - } - } - - // Make a bucket prefix. This prefixes all sub buckets. - prefix = makeBucketPrefix(b.id) - prefixLen = len(prefix) - - // Get the first bucket. - kv, err = b.tx.stm.First(string(prefix)) - if err != nil { - return err - } - - for kv != nil { - if err := cb([]byte(kv.key[prefixLen:]), nil); err != nil { - return err - } - - // Step to the next bucket. - kv, err = b.tx.stm.Next(string(prefix), kv.key) + kv, err = b.tx.stm.Next(prefix, kv.key) if err != nil { return err } @@ -241,10 +220,7 @@ func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { id := queue[0] queue = queue[1:] - // Delete values in the current bucket - valuePrefix := string(makeValuePrefix(id)) - - kv, err := b.tx.stm.First(valuePrefix) + kv, err := b.tx.stm.First(string(id)) if err != nil { return err } @@ -252,35 +228,23 @@ func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { for kv != nil { b.tx.del(kv.key) - kv, err = b.tx.stm.Next(valuePrefix, kv.key) + if isBucketKey(kv.key) { + queue = append(queue, []byte(kv.val)) + } + + kv, err = b.tx.stm.Next(string(id), kv.key) if err != nil { return err } } - // Iterate sub buckets - bucketPrefix := string(makeBucketPrefix(id)) - - kv, err = b.tx.stm.First(bucketPrefix) - if err != nil { - return err - } - - for kv != nil { - // Delete sub bucket key. - b.tx.del(kv.key) - // Queue it for traversal. - queue = append(queue, []byte(kv.val)) - - kv, err = b.tx.stm.Next(bucketPrefix, kv.key) - if err != nil { - return err - } - } + // Finally delete the sequence key for the bucket. + b.tx.del(string(makeSequenceKey(id))) } - // Delete the top level bucket. + // Delete the top level bucket and sequence key. b.tx.del(bucketKey) + b.tx.del(string(makeSequenceKey(bucketVal))) return nil } diff --git a/channeldb/kvdb/etcd/readwrite_bucket_test.go b/channeldb/kvdb/etcd/readwrite_bucket_test.go index f5de23b5e..6fb321367 100644 --- a/channeldb/kvdb/etcd/readwrite_bucket_test.go +++ b/channeldb/kvdb/etcd/readwrite_bucket_test.go @@ -315,7 +315,7 @@ func TestBucketForEachWithError(t *testing.T) { i := 0 // Error while iterating value keys. err = apple.ForEach(func(key, val []byte) error { - if i == 1 { + if i == 2 { return fmt.Errorf("error") } @@ -325,7 +325,8 @@ func TestBucketForEachWithError(t *testing.T) { }) expected := map[string]string{ - "key1": "val1", + "banana": "", + "key1": "val1", } require.Equal(t, expected, got) @@ -345,9 +346,9 @@ func TestBucketForEachWithError(t *testing.T) { }) expected = map[string]string{ + "banana": "", "key1": "val1", "key2": "val2", - "banana": "", } require.Equal(t, expected, got) diff --git a/channeldb/kvdb/etcd/readwrite_cursor.go b/channeldb/kvdb/etcd/readwrite_cursor.go index 989656933..75c0456d7 100644 --- a/channeldb/kvdb/etcd/readwrite_cursor.go +++ b/channeldb/kvdb/etcd/readwrite_cursor.go @@ -19,7 +19,7 @@ type readWriteCursor struct { func newReadWriteCursor(bucket *readWriteBucket) *readWriteCursor { return &readWriteCursor{ bucket: bucket, - prefix: string(makeValuePrefix(bucket.id)), + prefix: string(bucket.id), } } @@ -35,8 +35,7 @@ func (c *readWriteCursor) First() (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -53,8 +52,7 @@ func (c *readWriteCursor) Last() (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -71,8 +69,7 @@ func (c *readWriteCursor) Next() (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -89,8 +86,7 @@ func (c *readWriteCursor) Prev() (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -115,8 +111,7 @@ func (c *readWriteCursor) Seek(seek []byte) (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -133,11 +128,14 @@ func (c *readWriteCursor) Delete() error { return err } - // Delete the current key. - c.bucket.tx.stm.Del(c.currKey) + if isBucketKey(c.currKey) { + c.bucket.DeleteNestedBucket(getKey(c.currKey)) + } else { + c.bucket.Delete(getKey(c.currKey)) + } - // Set current key to the next one if possible. if nextKey != nil { + // Set current key to the next one. c.currKey = nextKey.key } diff --git a/channeldb/kvdb/etcd/readwrite_cursor_test.go b/channeldb/kvdb/etcd/readwrite_cursor_test.go index bc457e7ed..216b47c43 100644 --- a/channeldb/kvdb/etcd/readwrite_cursor_test.go +++ b/channeldb/kvdb/etcd/readwrite_cursor_test.go @@ -291,3 +291,78 @@ func TestReadWriteCursor(t *testing.T) { } require.Equal(t, expected, f.Dump()) } + +// TestReadWriteCursorWithBucketAndValue tests that cursors are able to iterate +// over both bucket and value keys if both are present in the iterated bucket. +func TestReadWriteCursorWithBucketAndValue(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + require.NoError(t, err) + + // Pre-store the first half of the interval. + require.NoError(t, db.Update(func(tx walletdb.ReadWriteTx) error { + b, err := tx.CreateTopLevelBucket([]byte("apple")) + require.NoError(t, err) + require.NotNil(t, b) + + require.NoError(t, b.Put([]byte("key"), []byte("val"))) + + b1, err := b.CreateBucket([]byte("banana")) + require.NoError(t, err) + require.NotNil(t, b1) + + b2, err := b.CreateBucket([]byte("pear")) + require.NoError(t, err) + require.NotNil(t, b2) + + return nil + })) + + err = db.View(func(tx walletdb.ReadTx) error { + b := tx.ReadBucket([]byte("apple")) + require.NotNil(t, b) + + cursor := b.ReadCursor() + + // First on valid interval. + k, v := cursor.First() + require.Equal(t, []byte("banana"), k) + require.Nil(t, v) + + k, v = cursor.Next() + require.Equal(t, []byte("key"), k) + require.Equal(t, []byte("val"), v) + + k, v = cursor.Last() + require.Equal(t, []byte("pear"), k) + require.Nil(t, v) + + k, v = cursor.Seek([]byte("k")) + require.Equal(t, []byte("key"), k) + require.Equal(t, []byte("val"), v) + + k, v = cursor.Seek([]byte("banana")) + require.Equal(t, []byte("banana"), k) + require.Nil(t, v) + + k, v = cursor.Next() + require.Equal(t, []byte("key"), k) + require.Equal(t, []byte("val"), v) + + return nil + }) + + require.NoError(t, err) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + bkey("apple", "banana"): bval("apple", "banana"), + bkey("apple", "pear"): bval("apple", "pear"), + vkey("key", "apple"): "val", + } + require.Equal(t, expected, f.Dump()) +} From cbce8e8872171279b3651c4a6ec461e9eb341047 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Wed, 24 Jun 2020 12:50:11 +0200 Subject: [PATCH 006/218] channeldb: move makeTestDB out of test to make it available for other tests This commit moves makeTestDB to db.go and exports it so that we'll be able to use this function in other unit tests to make them testable with etcd if needed. --- channeldb/channel_test.go | 56 ++++++------------------------- channeldb/db.go | 35 +++++++++++++++++++ channeldb/db_test.go | 14 ++++---- channeldb/forwarding_log_test.go | 6 ++-- channeldb/graph_test.go | 54 ++++++++++++++--------------- channeldb/invoice_test.go | 20 +++++------ channeldb/meta_test.go | 4 +-- channeldb/nodes_test.go | 4 +-- channeldb/payment_control_test.go | 14 ++++---- channeldb/payments_test.go | 4 +-- channeldb/reports_test.go | 6 ++-- channeldb/waitingproof_test.go | 2 +- channeldb/witness_cache_test.go | 8 ++--- 13 files changed, 113 insertions(+), 114 deletions(-) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index e0fb3e897..e13f6df37 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -2,10 +2,8 @@ package channeldb import ( "bytes" - "io/ioutil" "math/rand" "net" - "os" "reflect" "runtime" "testing" @@ -86,40 +84,6 @@ var ( } ) -// makeTestDB creates a new instance of the ChannelDB for testing purposes. A -// callback which cleans up the created temporary directories is also returned -// and intended to be executed after the test completes. -func makeTestDB() (*DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. - backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cdb") - if err != nil { - backendCleanup() - return nil, nil, err - } - - cdb, err := CreateWithBackend(backend, OptionClock(testClock)) - if err != nil { - backendCleanup() - os.RemoveAll(tempDirName) - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - backendCleanup() - os.RemoveAll(tempDirName) - } - - return cdb, cleanUp, nil -} - // testChannelParams is a struct which details the specifics of how a channel // should be created. type testChannelParams struct { @@ -403,7 +367,7 @@ func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { func TestOpenChannelPutGetDelete(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -552,7 +516,7 @@ func TestOptionalShutdown(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -609,7 +573,7 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { func TestChannelStateTransition(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -914,7 +878,7 @@ func TestChannelStateTransition(t *testing.T) { func TestFetchPendingChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -993,7 +957,7 @@ func TestFetchPendingChannels(t *testing.T) { func TestFetchClosedChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -1084,7 +1048,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // We'll start by creating two channels within our test database. One of // them will have their funding transaction confirmed on-chain, while // the other one will remain unconfirmed. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -1199,7 +1163,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) { func TestRefreshShortChanID(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -1347,7 +1311,7 @@ func TestCloseInitiator(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1392,7 +1356,7 @@ func TestCloseInitiator(t *testing.T) { // TestCloseChannelStatus tests setting of a channel status on the historical // channel on channel close. func TestCloseChannelStatus(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1538,7 +1502,7 @@ func TestBalanceAtHeight(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) diff --git a/channeldb/db.go b/channeldb/db.go index 06d905606..a4c5c5167 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "io/ioutil" "net" "os" @@ -1260,3 +1261,37 @@ func (db *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, err return channel, nil } + +// MakeTestDB creates a new instance of the ChannelDB for testing purposes. +// A callback which cleans up the created temporary directories is also +// returned and intended to be executed after the test completes. +func MakeTestDB(modifiers ...OptionModifier) (*DB, func(), error) { + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + return nil, nil, err + } + + // Next, create channeldb for the first time. + backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cdb") + if err != nil { + backendCleanup() + return nil, nil, err + } + + cdb, err := CreateWithBackend(backend, modifiers...) + if err != nil { + backendCleanup() + os.RemoveAll(tempDirName) + return nil, nil, err + } + + cleanUp := func() { + cdb.Close() + backendCleanup() + os.RemoveAll(tempDirName) + } + + return cdb, cleanUp, nil +} diff --git a/channeldb/db_test.go b/channeldb/db_test.go index e5c57c1de..0984ed1e6 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -115,7 +115,7 @@ func TestFetchClosedChannelForID(t *testing.T) { const numChans = 101 - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -186,7 +186,7 @@ func TestFetchClosedChannelForID(t *testing.T) { func TestAddrsForNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -247,7 +247,7 @@ func TestAddrsForNode(t *testing.T) { func TestFetchChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -351,7 +351,7 @@ func genRandomChannelShell() (*ChannelShell, error) { func TestRestoreChannelShells(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -445,7 +445,7 @@ func TestRestoreChannelShells(t *testing.T) { func TestAbandonChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -618,7 +618,7 @@ func TestFetchChannels(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test "+ "database: %v", err) @@ -687,7 +687,7 @@ func TestFetchChannels(t *testing.T) { // TestFetchHistoricalChannel tests lookup of historical channels. func TestFetchHistoricalChannel(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } diff --git a/channeldb/forwarding_log_test.go b/channeldb/forwarding_log_test.go index cc06e8867..07dfc902c 100644 --- a/channeldb/forwarding_log_test.go +++ b/channeldb/forwarding_log_test.go @@ -19,7 +19,7 @@ func TestForwardingLogBasicStorageAndQuery(t *testing.T) { // First, we'll set up a test database, and use that to instantiate the // forwarding event log that we'll be using for the duration of the // test. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -91,7 +91,7 @@ func TestForwardingLogQueryOptions(t *testing.T) { // First, we'll set up a test database, and use that to instantiate the // forwarding event log that we'll be using for the duration of the // test. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -196,7 +196,7 @@ func TestForwardingLogQueryLimit(t *testing.T) { // First, we'll set up a test database, and use that to instantiate the // forwarding event log that we'll be using for the duration of the // test. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index a6c1fb0d8..71edc8f88 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -73,7 +73,7 @@ func createTestVertex(db *DB) (*LightningNode, error) { func TestNodeInsertionAndDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -139,7 +139,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { func TestPartialNode(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -201,7 +201,7 @@ func TestPartialNode(t *testing.T) { func TestAliasLookup(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -255,7 +255,7 @@ func TestAliasLookup(t *testing.T) { func TestSourceNode(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -296,7 +296,7 @@ func TestSourceNode(t *testing.T) { func TestEdgeInsertionDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -431,7 +431,7 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, func TestDisconnectBlockAtHeight(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -718,7 +718,7 @@ func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo, func TestEdgeInfoUpdates(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -848,7 +848,7 @@ func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, func TestGraphTraversal(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1109,7 +1109,7 @@ func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoi func TestGraphPruning(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1317,7 +1317,7 @@ func TestGraphPruning(t *testing.T) { func TestHighestChanID(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1394,7 +1394,7 @@ func TestHighestChanID(t *testing.T) { func TestChanUpdatesInHorizon(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1570,7 +1570,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { func TestNodeUpdatesInHorizon(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1693,7 +1693,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { func TestFilterKnownChanIDs(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1810,7 +1810,7 @@ func TestFilterKnownChanIDs(t *testing.T) { func TestFilterChannelRange(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1929,7 +1929,7 @@ func TestFilterChannelRange(t *testing.T) { func TestFetchChanInfos(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2057,7 +2057,7 @@ func TestFetchChanInfos(t *testing.T) { func TestIncompleteChannelPolicies(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2172,7 +2172,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2327,7 +2327,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { func TestPruneGraphNodes(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2411,7 +2411,7 @@ func TestPruneGraphNodes(t *testing.T) { func TestAddChannelEdgeShellNodes(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2465,7 +2465,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { func TestNodePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2535,7 +2535,7 @@ func TestNodeIsPublic(t *testing.T) { // We'll need to create a separate database and channel graph for each // participant to replicate real-world scenarios (private edges being in // some graphs but not others, etc.). - aliceDB, cleanUp, err := makeTestDB() + aliceDB, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2549,7 +2549,7 @@ func TestNodeIsPublic(t *testing.T) { t.Fatalf("unable to set source node: %v", err) } - bobDB, cleanUp, err := makeTestDB() + bobDB, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2563,7 +2563,7 @@ func TestNodeIsPublic(t *testing.T) { t.Fatalf("unable to set source node: %v", err) } - carolDB, cleanUp, err := makeTestDB() + carolDB, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2684,7 +2684,7 @@ func TestNodeIsPublic(t *testing.T) { func TestDisabledChannelIDs(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -2782,7 +2782,7 @@ func TestDisabledChannelIDs(t *testing.T) { func TestEdgePolicyMissingMaxHtcl(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2962,7 +2962,7 @@ func TestGraphZombieIndex(t *testing.T) { t.Parallel() // We'll start by creating our test graph along with a test edge. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to create test database: %v", err) @@ -3151,7 +3151,7 @@ func TestLightningNodeSigVerification(t *testing.T) { } // Create a LightningNode from the same private key. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 10148917b..64e2dbe62 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -136,7 +136,7 @@ func TestInvoiceWorkflow(t *testing.T) { } func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -290,7 +290,7 @@ func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { // TestAddDuplicatePayAddr asserts that the payment addresses of inserted // invoices are unique. func TestAddDuplicatePayAddr(t *testing.T) { - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() require.NoError(t, err) @@ -317,7 +317,7 @@ func TestAddDuplicatePayAddr(t *testing.T) { // addresses to be inserted if they are blank to support JIT legacy keysend // invoices. func TestAddDuplicateKeysendPayAddr(t *testing.T) { - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() require.NoError(t, err) @@ -358,7 +358,7 @@ func TestAddDuplicateKeysendPayAddr(t *testing.T) { // TestInvRefEquivocation asserts that retrieving or updating an invoice using // an equivocating InvoiceRef results in ErrInvRefEquivocation. func TestInvRefEquivocation(t *testing.T) { - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() require.NoError(t, err) @@ -398,7 +398,7 @@ func TestInvRefEquivocation(t *testing.T) { func TestInvoiceCancelSingleHtlc(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -472,7 +472,7 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) { func TestInvoiceAddTimeSeries(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB(OptionClock(testClock)) defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -627,7 +627,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -732,7 +732,7 @@ func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { func TestDuplicateSettleInvoice(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB(OptionClock(testClock)) defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -797,7 +797,7 @@ func TestDuplicateSettleInvoice(t *testing.T) { func TestQueryInvoices(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB(OptionClock(testClock)) defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -1112,7 +1112,7 @@ func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { func TestCustomRecords(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) diff --git a/channeldb/meta_test.go b/channeldb/meta_test.go index 956ffb5d6..98e9c88a0 100644 --- a/channeldb/meta_test.go +++ b/channeldb/meta_test.go @@ -15,7 +15,7 @@ import ( func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), migrationFunc migration, shouldFail bool, dryRun bool) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatal(err) @@ -86,7 +86,7 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), func TestVersionFetchPut(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatal(err) diff --git a/channeldb/nodes_test.go b/channeldb/nodes_test.go index 755177aa7..0d649d431 100644 --- a/channeldb/nodes_test.go +++ b/channeldb/nodes_test.go @@ -13,7 +13,7 @@ import ( func TestLinkNodeEncodeDecode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -110,7 +110,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) { func TestDeleteLinkNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 147e54525..4f9014621 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -56,7 +56,7 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo, func TestPaymentControlSwitchFail(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { t.Fatalf("unable to init db: %v", err) @@ -203,7 +203,7 @@ func TestPaymentControlSwitchFail(t *testing.T) { func TestPaymentControlSwitchDoubleSend(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -286,7 +286,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -319,7 +319,7 @@ func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { func TestPaymentControlFailsWithoutInFlight(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -347,7 +347,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -530,7 +530,7 @@ func TestPaymentControlMultiShard(t *testing.T) { } runSubTest := func(t *testing.T, test testCase) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -780,7 +780,7 @@ func TestPaymentControlMultiShard(t *testing.T) { func TestPaymentControlMPPRecordValidation(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 9e790c3e3..0dc059561 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -399,7 +399,7 @@ func TestQueryPayments(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() if err != nil { t.Fatalf("unable to init db: %v", err) } @@ -512,7 +512,7 @@ func TestQueryPayments(t *testing.T) { // case where a specific duplicate is not found and the duplicates bucket is not // present when we expect it to be. func TestFetchPaymentWithSequenceNumber(t *testing.T) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() require.NoError(t, err) defer cleanup() diff --git a/channeldb/reports_test.go b/channeldb/reports_test.go index 398d0e6db..a63fe42b0 100644 --- a/channeldb/reports_test.go +++ b/channeldb/reports_test.go @@ -48,7 +48,7 @@ func TestPersistReport(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() require.NoError(t, err) defer cleanup() @@ -85,7 +85,7 @@ func TestPersistReport(t *testing.T) { // channel, testing that the appropriate error is returned based on the state // of the existing bucket. func TestFetchChannelReadBucket(t *testing.T) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() require.NoError(t, err) defer cleanup() @@ -197,7 +197,7 @@ func TestFetchChannelWriteBucket(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() require.NoError(t, err) defer cleanup() diff --git a/channeldb/waitingproof_test.go b/channeldb/waitingproof_test.go index fff52b921..12679b69f 100644 --- a/channeldb/waitingproof_test.go +++ b/channeldb/waitingproof_test.go @@ -14,7 +14,7 @@ import ( func TestWaitingProofStore(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() if err != nil { t.Fatalf("failed to make test database: %s", err) } diff --git a/channeldb/witness_cache_test.go b/channeldb/witness_cache_test.go index 8ba1e8355..fb6c9683a 100644 --- a/channeldb/witness_cache_test.go +++ b/channeldb/witness_cache_test.go @@ -12,7 +12,7 @@ import ( func TestWitnessCacheSha256Retrieval(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -57,7 +57,7 @@ func TestWitnessCacheSha256Retrieval(t *testing.T) { func TestWitnessCacheSha256Deletion(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -108,7 +108,7 @@ func TestWitnessCacheSha256Deletion(t *testing.T) { func TestWitnessCacheUnknownWitness(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -127,7 +127,7 @@ func TestWitnessCacheUnknownWitness(t *testing.T) { // TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves // identically to the insertion via the generalized interface. func TestAddSha256Witnesses(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } From f2a08e420e182d497f47e2bc2f5891adcedbc37c Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Wed, 24 Jun 2020 12:55:05 +0200 Subject: [PATCH 007/218] lnd: use channeldb.MakeTestDB in nursery store tests --- nursery_store_test.go | 33 +++------------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/nursery_store_test.go b/nursery_store_test.go index af5e0a067..dee927011 100644 --- a/nursery_store_test.go +++ b/nursery_store_test.go @@ -3,8 +3,6 @@ package lnd import ( - "io/ioutil" - "os" "reflect" "testing" @@ -12,31 +10,6 @@ import ( "github.com/lightningnetwork/lnd/channeldb" ) -// makeTestDB creates a new instance of the ChannelDB for testing purposes. A -// callback which cleans up the created temporary directories is also returned -// and intended to be executed after the test completes. -func makeTestDB() (*channeldb.DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) - if err != nil { - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) - } - - return cdb, cleanUp, nil -} - type incubateTest struct { nOutputs int chanPoint *wire.OutPoint @@ -75,7 +48,7 @@ func initIncubateTests() { // TestNurseryStoreInit verifies basic properties of the nursery store before // any modifying calls are made. func TestNurseryStoreInit(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channel db: %v", err) } @@ -95,7 +68,7 @@ func TestNurseryStoreInit(t *testing.T) { // outputs through the nursery store, verifying the properties of the // intermediate states. func TestNurseryStoreIncubate(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channel db: %v", err) } @@ -336,7 +309,7 @@ func TestNurseryStoreIncubate(t *testing.T) { // populated entries from the height index as it is purged, and that the last // purged height is set appropriately. func TestNurseryStoreGraduate(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channel db: %v", err) } From 905990eb540c77f4743d9b439114c9550bf4b671 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Wed, 24 Jun 2020 12:57:23 +0200 Subject: [PATCH 008/218] sweep: use channeldb.MakeTestDB --- sweep/store_test.go | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/sweep/store_test.go b/sweep/store_test.go index 3738f6c91..b27efb31e 100644 --- a/sweep/store_test.go +++ b/sweep/store_test.go @@ -1,8 +1,6 @@ package sweep import ( - "io/ioutil" - "os" "testing" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -10,38 +8,13 @@ import ( "github.com/lightningnetwork/lnd/channeldb" ) -// makeTestDB creates a new instance of the ChannelDB for testing purposes. A -// callback which cleans up the created temporary directories is also returned -// and intended to be executed after the test completes. -func makeTestDB() (*channeldb.DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) - if err != nil { - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) - } - - return cdb, cleanUp, nil -} - // TestStore asserts that the store persists the presented data to disk and is // able to retrieve it again. func TestStore(t *testing.T) { t.Run("bolt", func(t *testing.T) { // Create new store. - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channel db: %v", err) } From 3cdbb341da98c15c2195cabe25036efc656f625a Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Wed, 24 Jun 2020 15:41:41 +0200 Subject: [PATCH 009/218] lnd: utxo nursery test to use channeldb.MakeTestDB --- utxonursery_test.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/utxonursery_test.go b/utxonursery_test.go index 1a9323db5..4f9fdab93 100644 --- a/utxonursery_test.go +++ b/utxonursery_test.go @@ -5,7 +5,6 @@ package lnd import ( "bytes" "fmt" - "io/ioutil" "math" "os" "reflect" @@ -407,6 +406,7 @@ type nurseryTestContext struct { sweeper *mockSweeper timeoutChan chan chan time.Time t *testing.T + dbCleanup func() } func createNurseryTestContext(t *testing.T, @@ -416,12 +416,7 @@ func createNurseryTestContext(t *testing.T, // alternative, mocking nurseryStore, is not chosen because there is // still considerable logic in the store. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - t.Fatalf("unable to create temp dir: %v", err) - } - - cdb, err := channeldb.Open(tempDirName) + cdb, cleanup, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channeldb: %v", err) } @@ -484,6 +479,7 @@ func createNurseryTestContext(t *testing.T, sweeper: sweeper, timeoutChan: timeoutChan, t: t, + dbCleanup: cleanup, } ctx.receiveTx = func() wire.MsgTx { @@ -531,6 +527,8 @@ func (ctx *nurseryTestContext) notifyEpoch(height int32) { } func (ctx *nurseryTestContext) finish() { + defer ctx.dbCleanup() + // Add a final restart point in this state ctx.restart() From 12a341ba5983b2f3a8fce13177de6eb53e1f73a7 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Fri, 10 Jul 2020 14:28:18 +0200 Subject: [PATCH 010/218] etcd: remove the lock set concept This commit removes the lock set which was used to only add bucket keys to the tx predicate while also bumping their mod version. This was useful to reduce the size of the compare set but wasn't useful to reduce contention as top level buckets were always in the lock set. --- channeldb/kvdb/etcd/readwrite_bucket.go | 6 --- channeldb/kvdb/etcd/readwrite_tx.go | 43 -------------------- channeldb/kvdb/etcd/stm.go | 52 +++++++------------------ 3 files changed, 13 insertions(+), 88 deletions(-) diff --git a/channeldb/kvdb/etcd/readwrite_bucket.go b/channeldb/kvdb/etcd/readwrite_bucket.go index 20af7d929..f97268d92 100644 --- a/channeldb/kvdb/etcd/readwrite_bucket.go +++ b/channeldb/kvdb/etcd/readwrite_bucket.go @@ -3,7 +3,6 @@ package etcd import ( - "bytes" "strconv" "github.com/btcsuite/btcwallet/walletdb" @@ -24,11 +23,6 @@ type readWriteBucket struct { // newReadWriteBucket creates a new rw bucket with the passed transaction // and bucket id. func newReadWriteBucket(tx *readWriteTx, key, id []byte) *readWriteBucket { - if !bytes.Equal(id, tx.rootBucketID[:]) { - // Add the bucket key/value to the lock set. - tx.lock(string(key), string(id)) - } - return &readWriteBucket{ id: id, tx: tx, diff --git a/channeldb/kvdb/etcd/readwrite_tx.go b/channeldb/kvdb/etcd/readwrite_tx.go index 22d0ce421..aed00a17e 100644 --- a/channeldb/kvdb/etcd/readwrite_tx.go +++ b/channeldb/kvdb/etcd/readwrite_tx.go @@ -17,14 +17,6 @@ type readWriteTx struct { // active is true if the transaction hasn't been committed yet. active bool - - // dirty is true if we intent to update a value in this transaction. - dirty bool - - // lset holds key/value set that we want to lock on. If upon commit the - // transaction is dirty and the lset is not empty, we'll bump the mod - // version of these key/values. - lset map[string]string } // newReadWriteTx creates an rw transaction with the passed STM. @@ -33,7 +25,6 @@ func newReadWriteTx(stm STM, prefix string) *readWriteTx { stm: stm, active: true, rootBucketID: makeBucketID([]byte(prefix)), - lset: make(map[string]string), } } @@ -43,48 +34,14 @@ func rootBucket(tx *readWriteTx) *readWriteBucket { return newReadWriteBucket(tx, tx.rootBucketID[:], tx.rootBucketID[:]) } -// lock adds a key value to the lock set. -func (tx *readWriteTx) lock(key, val string) { - tx.stm.Lock(key) - if !tx.dirty { - tx.lset[key] = val - } else { - // Bump the mod version of the key, - // leaving the value intact. - tx.stm.Put(key, val) - } -} - // put updates the passed key/value. func (tx *readWriteTx) put(key, val string) { tx.stm.Put(key, val) - tx.setDirty() } // del marks the passed key deleted. func (tx *readWriteTx) del(key string) { tx.stm.Del(key) - tx.setDirty() -} - -// setDirty marks the transaction dirty and bumps -// mod version for the existing lock set if it is -// not empty. -func (tx *readWriteTx) setDirty() { - // Bump the lock set. - if !tx.dirty && len(tx.lset) > 0 { - for key, val := range tx.lset { - // Bump the mod version of the key, - // leaving the value intact. - tx.stm.Put(key, val) - } - - // Clear the lock set. - tx.lset = make(map[string]string) - } - - // Set dirty. - tx.dirty = true } // ReadBucket opens the root bucket for read only access. If the bucket diff --git a/channeldb/kvdb/etcd/stm.go b/channeldb/kvdb/etcd/stm.go index 7a2f33b51..c13e8f966 100644 --- a/channeldb/kvdb/etcd/stm.go +++ b/channeldb/kvdb/etcd/stm.go @@ -32,11 +32,6 @@ type STM interface { // set. Returns nil if there's no matching key, or the key is empty. Get(key string) ([]byte, error) - // Lock adds a key to the lock set. If the lock set is not empty, we'll - // only check for conflicts in the lock set and the write set, instead - // of all read keys plus the write set. - Lock(key string) - // Put adds a value for a key to the txn's write set. Put(key, val string) @@ -151,9 +146,6 @@ type stm struct { // wset holds overwritten keys and their values. wset writeSet - // lset holds keys we intent to lock on. - lset map[string]interface{} - // getOpts are the opts used for gets. getOpts []v3.OpOption @@ -247,19 +239,19 @@ loop: default: } - - // Apply the transaction closure and abort the STM if there was an - // application error. + // Apply the transaction closure and abort the STM if there was + // an application error. if err = apply(s); err != nil { break loop } stats, err = s.commit() - // Re-apply only upon commit error (meaning the database was changed). + // Retry the apply closure only upon commit error (meaning the + // database was changed). if _, ok := err.(CommitError); !ok { - // Anything that's not a CommitError - // aborts the STM run loop. + // Anything that's not a CommitError aborts the STM + // run loop. break loop } @@ -303,24 +295,14 @@ func (rs readSet) gets() []v3.Op { return ops } -// cmps returns a cmp list testing values in read set didn't change. -func (rs readSet) cmps(lset map[string]interface{}) []v3.Cmp { - if len(lset) > 0 { - cmps := make([]v3.Cmp, 0, len(lset)) - for key := range lset { - if getValue, ok := rs[key]; ok { - cmps = append( - cmps, - v3.Compare(v3.ModRevision(key), "=", getValue.rev), - ) - } - } - return cmps - } - +// cmps returns a compare list which will serve as a precondition testing that +// the values in the read set didn't change. +func (rs readSet) cmps() []v3.Cmp { cmps := make([]v3.Cmp, 0, len(rs)) for key, getValue := range rs { - cmps = append(cmps, v3.Compare(v3.ModRevision(key), "=", getValue.rev)) + cmps = append(cmps, v3.Compare( + v3.ModRevision(key), "=", getValue.rev, + )) } return cmps @@ -435,13 +417,6 @@ func (s *stm) Get(key string) ([]byte, error) { return nil, nil } -// Lock adds a key to the lock set. If the lock set is -// not empty, we'll only check conflicts for the keys -// in the lock set. -func (s *stm) Lock(key string) { - s.lset[key] = nil -} - // First returns the first key/value matching prefix. If there's no key starting // with prefix, Last will return nil. func (s *stm) First(prefix string) (*KV, error) { @@ -711,7 +686,7 @@ func (s *stm) OnCommit(cb func()) { // because the keys have changed return a CommitError, otherwise return a // DatabaseError. func (s *stm) commit() (CommitStats, error) { - rset := s.rset.cmps(s.lset) + rset := s.rset.cmps() wset := s.wset.cmps(s.revision + 1) stats := CommitStats{ @@ -775,7 +750,6 @@ func (s *stm) Commit() error { func (s *stm) Rollback() { s.rset = make(map[string]stmGet) s.wset = make(map[string]stmPut) - s.lset = make(map[string]interface{}) s.getOpts = nil s.revision = math.MaxInt64 - 1 } From 678e5c5736d4cb334186881b94c6647effe304c3 Mon Sep 17 00:00:00 2001 From: Jake Sylvestre Date: Tue, 28 Jul 2020 19:46:52 -0400 Subject: [PATCH 011/218] chore: Update minimum golang version to 1.13 --- docs/INSTALL.md | 4 ++-- go.mod | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 8214bc5f7..496cbbc92 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -95,7 +95,7 @@ * **Go modules:** This project uses [Go modules](https://github.com/golang/go/wiki/Modules) to manage dependencies as well as to provide *reproducible builds*. - Usage of Go modules (with Go 1.12) means that you no longer need to clone + Usage of Go modules (with Go 1.13) means that you no longer need to clone `lnd` into your `$GOPATH` for development purposes. Instead, your `lnd` repo can now live anywhere! @@ -124,7 +124,7 @@ make install **NOTE**: Our instructions still use the `$GOPATH` directory from prior -versions of Go, but with Go 1.12, it's now possible for `lnd` to live +versions of Go, but with Go 1.13, it's now possible for `lnd` to live _anywhere_ on your file system. For Windows WSL users, make will need to be referenced directly via diff --git a/go.mod b/go.mod index ab4d59dd1..cf266e53d 100644 --- a/go.mod +++ b/go.mod @@ -85,4 +85,4 @@ replace github.com/lightningnetwork/lnd/clock => ./clock replace git.schwanenlied.me/yawning/bsaes.git => github.com/Yawning/bsaes v0.0.0-20180720073208-c0276d75487e -go 1.12 +go 1.13 From 77fc1ac68fab813821a42d405dd923f211df9266 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 29 Jul 2020 09:50:37 +0200 Subject: [PATCH 012/218] travis: add sanity check stage --- .travis.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/.travis.yml b/.travis.yml index 9456d2085..5d9e7064a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,6 +28,25 @@ sudo: required jobs: include: + - stage: Sanity Check + name: Lint and compile + before_script: + # Install the RPC tools as a before step so Travis collapses the output + # after it's done. + - ./scripts/install_travis_proto.sh + + script: + # Step 1: Make sure no diff is produced when compiling with the correct + # version. + - make rpc-check + + # Step 2: Make sure the unit tests compile, but don't run them. They run + # in a GitHub Workflow. + - make unit pkg=... case=_NONE_ + + # Step 3: Lint go code. Limit to 1 worker to reduce memory usage. + - make lint workers=1 + - stage: Integration Test name: Btcd Integration script: From 2a614cc596273df729399adad603cd6825cf2865 Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 29 Jul 2020 09:27:22 +0200 Subject: [PATCH 013/218] multi: add labels to lnd native transactions Follow up labelling of external transactions with labels for the transaction types we create within lnd. Since these labels will live a life of string matching, a version number and rigid format is added so that string matching is less painful. We start out with channel ID, where available, and a transaction "type". External labels, added in a previous PR, are not updated to this new versioned label because they are not lnd-initiated transactions. Label matching can check this case, then check for a version number. --- breacharbiter.go | 4 +- contractcourt/chain_arbitrator.go | 6 ++- contractcourt/channel_arbitrator.go | 7 ++- contractcourt/htlc_success_resolver.go | 11 ++++- fundingmanager.go | 41 +++++++++++++++++- fundingmanager_test.go | 6 +++ labels/labels.go | 59 +++++++++++++++++++++++++ lntest/harness.go | 25 ++++++++--- lntest/itest/lnd_test.go | 60 +++++++++++++++++++++++++- lnwallet/chancloser/chancloser.go | 10 ++++- server.go | 7 ++- utxonursery.go | 5 ++- watchtower/lookout/punisher.go | 4 +- 13 files changed, 225 insertions(+), 20 deletions(-) diff --git a/breacharbiter.go b/breacharbiter.go index 6924682f2..865e67ce0 100644 --- a/breacharbiter.go +++ b/breacharbiter.go @@ -20,6 +20,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) @@ -566,7 +567,8 @@ justiceTxBroadcast: // We'll now attempt to broadcast the transaction which finalized the // channel's retribution against the cheating counter party. - err = b.cfg.PublishTransaction(finalTx, "") + label := labels.MakeLabel(labels.LabelTypeJusticeTransaction, nil) + err = b.cfg.PublishTransaction(finalTx, label) if err != nil { brarLog.Errorf("Unable to broadcast justice tx: %v", err) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index dacbc1f59..dceee6c88 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" @@ -715,7 +716,10 @@ func (c *ChainArbitrator) rebroadcast(channel *channeldb.OpenChannel, log.Infof("Re-publishing %s close tx(%v) for channel %v", kind, closeTx.TxHash(), chanPoint) - err = c.cfg.PublishTx(closeTx, "") + label := labels.MakeLabel( + labels.LabelTypeChannelClose, &channel.ShortChannelID, + ) + err = c.cfg.PublishTx(closeTx, label) if err != nil && err != lnwallet.ErrDoubleSpend { log.Warnf("Unable to broadcast %s close tx(%v): %v", kind, closeTx.TxHash(), err) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index b67d2c7cd..2d5009e8a 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -874,7 +875,11 @@ func (c *ChannelArbitrator) stateStep( // At this point, we'll now broadcast the commitment // transaction itself. - if err := c.cfg.PublishTx(closeTx, ""); err != nil { + label := labels.MakeLabel( + labels.LabelTypeChannelClose, &c.cfg.ShortChanID, + ) + + if err := c.cfg.PublishTx(closeTx, label); err != nil { log.Errorf("ChannelArbitrator(%v): unable to broadcast "+ "close tx: %v", c.cfg.ChanPoint, err) if err != lnwallet.ErrDoubleSpend { diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 7b2620915..1a99cc3b6 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -10,6 +10,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/sweep" ) @@ -157,7 +158,10 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { // Regardless of whether an existing transaction was found or newly // constructed, we'll broadcast the sweep transaction to the // network. - err := h.PublishTx(h.sweepTx, "") + label := labels.MakeLabel( + labels.LabelTypeChannelClose, &h.ShortChanID, + ) + err := h.PublishTx(h.sweepTx, label) if err != nil { log.Infof("%T(%x): unable to publish tx: %v", h, h.htlc.RHash[:], err) @@ -206,7 +210,10 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { // the claiming process. // // TODO(roasbeef): after changing sighashes send to tx bundler - err := h.PublishTx(h.htlcResolution.SignedSuccessTx, "") + label := labels.MakeLabel( + labels.LabelTypeChannelClose, &h.ShortChanID, + ) + err := h.PublishTx(h.htlcResolution.SignedSuccessTx, label) if err != nil { return nil, err } diff --git a/fundingmanager.go b/fundingmanager.go index aa3887ed1..826df8abf 100644 --- a/fundingmanager.go +++ b/fundingmanager.go @@ -22,6 +22,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwallet" @@ -243,6 +244,10 @@ type fundingConfig struct { // transaction to the network. PublishTransaction func(*wire.MsgTx, string) error + // UpdateLabel updates the label that a transaction has in our wallet, + // overwriting any existing labels. + UpdateLabel func(chainhash.Hash, string) error + // FeeEstimator calculates appropriate fee rates based on historical // transaction information. FeeEstimator chainfee.Estimator @@ -576,8 +581,15 @@ func (f *fundingManager) start() error { channel.FundingOutpoint, fundingTxBuf.Bytes()) + // Set a nil short channel ID at this stage + // because we do not know it until our funding + // tx confirms. + label := labels.MakeLabel( + labels.LabelTypeChannelOpen, nil, + ) + err = f.cfg.PublishTransaction( - channel.FundingTxn, "", + channel.FundingTxn, label, ) if err != nil { fndgLog.Errorf("Unable to rebroadcast "+ @@ -2032,7 +2044,13 @@ func (f *fundingManager) handleFundingSigned(fmsg *fundingSignedMsg) { fndgLog.Infof("Broadcasting funding tx for ChannelPoint(%v): %x", completeChan.FundingOutpoint, fundingTxBuf.Bytes()) - err = f.cfg.PublishTransaction(fundingTx, "") + // Set a nil short channel ID at this stage because we do not + // know it until our funding tx confirms. + label := labels.MakeLabel( + labels.LabelTypeChannelOpen, nil, + ) + + err = f.cfg.PublishTransaction(fundingTx, label) if err != nil { fndgLog.Errorf("Unable to broadcast funding tx %x for "+ "ChannelPoint(%v): %v", fundingTxBuf.Bytes(), @@ -2372,6 +2390,25 @@ func (f *fundingManager) handleFundingConfirmation( fndgLog.Errorf("unable to report short chan id: %v", err) } + // If we opened the channel, and lnd's wallet published our funding tx + // (which is not the case for some channels) then we update our + // transaction label with our short channel ID, which is known now that + // our funding transaction has confirmed. We do not label transactions + // we did not publish, because our wallet has no knowledge of them. + if completeChan.IsInitiator && completeChan.ChanType.HasFundingTx() { + shortChanID := completeChan.ShortChanID() + label := labels.MakeLabel( + labels.LabelTypeChannelOpen, &shortChanID, + ) + + err = f.cfg.UpdateLabel( + completeChan.FundingOutpoint.Hash, label, + ) + if err != nil { + fndgLog.Errorf("unable to update label: %v", err) + } + } + // Close the discoverySignal channel, indicating to a separate // goroutine that the channel now is marked as open in the database // and that it is acceptable to process funding locked messages diff --git a/fundingmanager_test.go b/fundingmanager_test.go index 9896563dc..c05399592 100644 --- a/fundingmanager_test.go +++ b/fundingmanager_test.go @@ -421,6 +421,9 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, publTxChan <- txn return nil }, + UpdateLabel: func(chainhash.Hash, string) error { + return nil + }, ZombieSweeperInterval: 1 * time.Hour, ReservationTimeout: 1 * time.Nanosecond, MaxPendingChannels: lncfg.DefaultMaxPendingChannels, @@ -524,6 +527,9 @@ func recreateAliceFundingManager(t *testing.T, alice *testNode) { publishChan <- txn return nil }, + UpdateLabel: func(chainhash.Hash, string) error { + return nil + }, ZombieSweeperInterval: oldCfg.ZombieSweeperInterval, ReservationTimeout: oldCfg.ReservationTimeout, OpenChannelPredicate: chainedAcceptor, diff --git a/labels/labels.go b/labels/labels.go index 3bbe5feaa..c35faaf76 100644 --- a/labels/labels.go +++ b/labels/labels.go @@ -1,12 +1,24 @@ // Package labels contains labels used to label transactions broadcast by lnd. // These labels are used across packages, so they are declared in a separate // package to avoid dependency issues. +// +// Labels for transactions broadcast by lnd have two set fields followed by an +// optional set labelled data values, all separated by colons. +// - Label version: an integer that indicates the version lnd used +// - Label type: the type of transaction we are labelling +// - {field name}-{value}: a named field followed by its value, these items are +// optional, and there may be more than field present. +// +// For version 0 we have the following optional data fields defined: +// - shortchanid: the short channel ID that a transaction is associated with, +// with its value set to the uint64 short channel id. package labels import ( "fmt" "github.com/btcsuite/btcwallet/wtxmgr" + "github.com/lightningnetwork/lnd/lnwire" ) // External labels a transaction as user initiated via the api. This @@ -31,3 +43,50 @@ func ValidateAPI(label string) (string, error) { return label, nil } + +// LabelVersion versions our labels so they can be easily update to contain +// new data while still easily string matched. +type LabelVersion uint8 + +// LabelVersionZero is the label version for labels that contain label type and +// channel ID (where available). +const LabelVersionZero LabelVersion = iota + +// LabelType indicates the type of label we are creating. It is a string rather +// than an int for easy string matching and human-readability. +type LabelType string + +const ( + // LabelTypeChannelOpen is used to label channel opens. + LabelTypeChannelOpen LabelType = "openchannel" + + // LabelTypeChannelClose is used to label channel closes. + LabelTypeChannelClose LabelType = "closechannel" + + // LabelTypeJusticeTransaction is used to label justice transactions. + LabelTypeJusticeTransaction LabelType = "justicetx" + + // LabelTypeSweepTransaction is used to label sweeps. + LabelTypeSweepTransaction LabelType = "sweep" +) + +// LabelField is used to tag a value within a label. +type LabelField string + +const ( + // ShortChanID is used to tag short channel id values in our labels. + ShortChanID LabelField = "shortchanid" +) + +// MakeLabel creates a label with the provided type and short channel id. If +// our short channel ID is not known, we simply return version:label_type. If +// we do have a short channel ID set, the label will also contain its value: +// shortchanid-{int64 chan ID}. +func MakeLabel(labelType LabelType, channelID *lnwire.ShortChannelID) string { + if channelID == nil { + return fmt.Sprintf("%v:%v", LabelVersionZero, labelType) + } + + return fmt.Sprintf("%v:%v:%v-%v", LabelVersionZero, labelType, + ShortChanID, channelID.ToUint64()) +} diff --git a/lntest/harness.go b/lntest/harness.go index aa20c3b9c..f999f897c 100644 --- a/lntest/harness.go +++ b/lntest/harness.go @@ -1204,9 +1204,14 @@ func (n *NetworkHarness) WaitForChannelClose(ctx context.Context, } // AssertChannelExists asserts that an active channel identified by the -// specified channel point exists from the point-of-view of the node. +// specified channel point exists from the point-of-view of the node. It takes +// an optional set of check functions which can be used to make further +// assertions using channel's values. These functions are responsible for +// failing the test themselves if they do not pass. +// nolint: interfacer func (n *NetworkHarness) AssertChannelExists(ctx context.Context, - node *HarnessNode, chanPoint *wire.OutPoint) error { + node *HarnessNode, chanPoint *wire.OutPoint, + checks ...func(*lnrpc.Channel)) error { req := &lnrpc.ListChannelsRequest{} @@ -1218,12 +1223,20 @@ func (n *NetworkHarness) AssertChannelExists(ctx context.Context, for _, channel := range resp.Channels { if channel.ChannelPoint == chanPoint.String() { - if channel.Active { - return nil + // First check whether our channel is active, + // failing early if it is not. + if !channel.Active { + return fmt.Errorf("channel %s inactive", + chanPoint) } - return fmt.Errorf("channel %s inactive", - chanPoint) + // Apply any additional checks that we would + // like to verify. + for _, check := range checks { + check(channel) + } + + return nil } } diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index ee51f47b5..023bb3fcb 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -35,6 +35,7 @@ import ( "github.com/lightningnetwork/lnd/chanbackup" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" @@ -2905,6 +2906,7 @@ func testChannelFundingPersistence(net *lntest.NetworkHarness, t *harnessTest) { t.Fatalf("unable to convert funding txid into chainhash.Hash:"+ " %v", err) } + fundingTxStr := fundingTxID.String() // Mine a block, then wait for Alice's node to notify us that the // channel has been opened. The funding transaction should be found @@ -2912,6 +2914,10 @@ func testChannelFundingPersistence(net *lntest.NetworkHarness, t *harnessTest) { block := mineBlocks(t, net, 1, 1)[0] assertTxInBlock(t, block, fundingTxID) + // Get the height that our transaction confirmed at. + _, height, err := net.Miner.Node.GetBestBlock() + require.NoError(t.t, err, "could not get best block") + // Restart both nodes to test that the appropriate state has been // persisted and that both nodes recover gracefully. if err := net.RestartNode(net.Alice, nil); err != nil { @@ -2934,6 +2940,16 @@ func testChannelFundingPersistence(net *lntest.NetworkHarness, t *harnessTest) { t.Fatalf("unable to mine blocks: %v", err) } + // Assert that our wallet has our opening transaction with a label + // that does not have a channel ID set yet, because we have not + // reached our required confirmations. + tx := findTxAtHeight(ctxt, t, height, fundingTxStr, net, net.Alice) + + // At this stage, we expect the transaction to be labelled, but not with + // our channel ID because our transaction has not yet confirmed. + label := labels.MakeLabel(labels.LabelTypeChannelOpen, nil) + require.Equal(t.t, label, tx.Label, "open channel label wrong") + // Both nodes should still show a single channel as pending. time.Sleep(time.Second * 1) ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) @@ -2957,9 +2973,27 @@ func testChannelFundingPersistence(net *lntest.NetworkHarness, t *harnessTest) { Index: pendingUpdate.OutputIndex, } + // Re-lookup our transaction in the block that it confirmed in. + tx = findTxAtHeight(ctxt, t, height, fundingTxStr, net, net.Alice) + + // Create an additional check for our channel assertion that will + // check that our label is as expected. + check := func(channel *lnrpc.Channel) { + shortChanID := lnwire.NewShortChanIDFromInt( + channel.ChanId, + ) + + label := labels.MakeLabel( + labels.LabelTypeChannelOpen, &shortChanID, + ) + require.Equal(t.t, label, tx.Label, + "open channel label not updated") + } + // Check both nodes to ensure that the channel is ready for operation. ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) - if err := net.AssertChannelExists(ctxt, net.Alice, &outPoint); err != nil { + err = net.AssertChannelExists(ctxt, net.Alice, &outPoint, check) + if err != nil { t.Fatalf("unable to assert channel existence: %v", err) } ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) @@ -2980,6 +3014,30 @@ func testChannelFundingPersistence(net *lntest.NetworkHarness, t *harnessTest) { closeChannelAndAssert(ctxt, t, net, net.Alice, chanPoint, false) } +// findTxAtHeight gets all of the transactions that a node's wallet has a record +// of at the target height, and finds and returns the tx with the target txid, +// failing if it is not found. +func findTxAtHeight(ctx context.Context, t *harnessTest, height int32, + target string, net *lntest.NetworkHarness, + node *lntest.HarnessNode) *lnrpc.Transaction { + + txns, err := node.LightningClient.GetTransactions( + ctx, &lnrpc.GetTransactionsRequest{ + StartHeight: height, + EndHeight: height, + }, + ) + require.NoError(t.t, err, "could not get transactions") + + for _, tx := range txns.Transactions { + if tx.TxHash == target { + return tx + } + } + + return nil +} + // testChannelBalance creates a new channel between Alice and Bob, then // checks channel balance to be equal amount specified while creation of channel. func testChannelBalance(net *lntest.NetworkHarness, t *harnessTest) { diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go index f1380ef7d..0b073bf72 100644 --- a/lnwallet/chancloser/chancloser.go +++ b/lnwallet/chancloser/chancloser.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" @@ -551,7 +552,14 @@ func (c *ChanCloser) ProcessCloseMsg(msg lnwire.Message) ([]lnwire.Message, return spew.Sdump(closeTx) }), ) - if err := c.cfg.BroadcastTx(closeTx, ""); err != nil { + + // Create a close channel label. + chanID := c.cfg.Channel.ShortChanID() + closeLabel := labels.MakeLabel( + labels.LabelTypeChannelClose, &chanID, + ) + + if err := c.cfg.BroadcastTx(closeTx, closeLabel); err != nil { return nil, false, err } diff --git a/server.go b/server.go index fe711cec0..07e7e3e84 100644 --- a/server.go +++ b/server.go @@ -971,8 +971,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, chanDB *channeldb.DB, IDKey: nodeKeyECDH.PubKey(), Wallet: cc.wallet, PublishTransaction: cc.wallet.PublishTransaction, - Notifier: cc.chainNotifier, - FeeEstimator: cc.feeEstimator, + UpdateLabel: func(hash chainhash.Hash, label string) error { + return cc.wallet.LabelTransaction(hash, label, true) + }, + Notifier: cc.chainNotifier, + FeeEstimator: cc.feeEstimator, SignMessage: func(pubKey *btcec.PublicKey, msg []byte) (input.Signature, error) { diff --git a/utxonursery.go b/utxonursery.go index a1a207944..217aa2da4 100644 --- a/utxonursery.go +++ b/utxonursery.go @@ -11,10 +11,10 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/sweep" ) @@ -867,7 +867,8 @@ func (u *utxoNursery) sweepCribOutput(classHeight uint32, baby *babyOutput) erro // We'll now broadcast the HTLC transaction, then wait for it to be // confirmed before transitioning it to kindergarten. - err := u.cfg.PublishTransaction(baby.timeoutTx, "") + label := labels.MakeLabel(labels.LabelTypeSweepTransaction, nil) + err := u.cfg.PublishTransaction(baby.timeoutTx, label) if err != nil && err != lnwallet.ErrDoubleSpend { utxnLog.Errorf("Unable to broadcast baby tx: "+ "%v, %v", err, spew.Sdump(baby.timeoutTx)) diff --git a/watchtower/lookout/punisher.go b/watchtower/lookout/punisher.go index 98d2b6c62..0c8f8bb89 100644 --- a/watchtower/lookout/punisher.go +++ b/watchtower/lookout/punisher.go @@ -2,6 +2,7 @@ package lookout import ( "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/labels" ) // PunisherConfig houses the resources required by the Punisher. @@ -42,7 +43,8 @@ func (p *BreachPunisher) Punish(desc *JusticeDescriptor, quit <-chan struct{}) e log.Infof("Publishing justice transaction for client=%s with txid=%s", desc.SessionInfo.ID, justiceTxn.TxHash()) - err = p.cfg.PublishTx(justiceTxn, "") + label := labels.MakeLabel(labels.LabelTypeJusticeTransaction, nil) + err = p.cfg.PublishTx(justiceTxn, label) if err != nil { log.Errorf("Unable to publish justice txn for client=%s"+ "with breach-txid=%s: %v", From b17ed28deaa525219a62a632470508fe67bb9ff4 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 16 Jun 2020 14:41:34 +0200 Subject: [PATCH 014/218] build: add all subservers to dev build --- make/testing_flags.mk | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/make/testing_flags.mk b/make/testing_flags.mk index aab09bb5e..80637a230 100644 --- a/make/testing_flags.mk +++ b/make/testing_flags.mk @@ -1,8 +1,14 @@ DEV_TAGS = dev +RPC_TAGS = autopilotrpc chainrpc invoicesrpc routerrpc signrpc verrpc walletrpc watchtowerrpc wtclientrpc LOG_TAGS = TEST_FLAGS = COVER_PKG = $$(go list -deps ./... | grep '$(PKG)' | grep -v lnrpc) +# If rpc option is set also add all extra RPC tags to DEV_TAGS +ifneq ($(with-rpc),) +DEV_TAGS += $(RPC_TAGS) +endif + # If specific package is being unit tested, construct the full name of the # subpackage. ifneq ($(pkg),) @@ -61,6 +67,6 @@ backend = btcd endif # Construct the integration test command with the added build flags. -ITEST_TAGS := $(DEV_TAGS) rpctest chainrpc walletrpc signrpc invoicesrpc autopilotrpc watchtowerrpc $(backend) +ITEST_TAGS := $(DEV_TAGS) $(RPC_TAGS) rpctest $(backend) ITEST := rm lntest/itest/*.log; date; $(GOTEST) -v ./lntest/itest -tags="$(ITEST_TAGS)" $(TEST_FLAGS) -logoutput -goroutinedump From e4188ba9c2a76924903f4c31eab689492989f395 Mon Sep 17 00:00:00 2001 From: nsa Date: Thu, 2 Jul 2020 02:16:04 -0400 Subject: [PATCH 015/218] channeldb+lnwallet: store updates the peer should sign under new key This fixes a long-standing force close bug. When we receive a revocation, store the updates that the remote should sign next under a new database key. Previously, these were not persisted which would lead to force closure. --- channeldb/channel.go | 93 ++++++++++++++- channeldb/channel_test.go | 4 +- channeldb/db_test.go | 2 +- lnwallet/channel.go | 236 +++++++++++++++++++++++++++++++++++++- 4 files changed, 326 insertions(+), 9 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index 5bec7a475..5d943a11d 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -79,6 +79,11 @@ var ( // for in one of our remote commits. unsignedAckedUpdatesKey = []byte("unsigned-acked-updates-key") + // remoteUnsignedLocalUpdatesKey is an entry in the channel bucket that + // contains the local updates that the remote party has acked, but + // has not yet signed for in one of their local commits. + remoteUnsignedLocalUpdatesKey = []byte("remote-unsigned-local-updates-key") + // revocationStateKey stores their current revocation hash, our // preimage producer and their preimage store. revocationStateKey = []byte("revocation-state-key") @@ -1448,6 +1453,39 @@ func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment, "updates: %v", err) } + // Persist the remote unsigned local updates that are not included + // in our new commitment. + updateBytes := chanBucket.Get(remoteUnsignedLocalUpdatesKey) + if updateBytes == nil { + return nil + } + + r := bytes.NewReader(updateBytes) + updates, err := deserializeLogUpdates(r) + if err != nil { + return err + } + + var validUpdates []LogUpdate + for _, upd := range updates { + // Filter for updates that are not on our local + // commitment. + if upd.LogIndex >= newCommitment.LocalLogIndex { + validUpdates = append(validUpdates, upd) + } + } + + var b2 bytes.Buffer + err = serializeLogUpdates(&b2, validUpdates) + if err != nil { + return fmt.Errorf("unable to serialize log updates: %v", err) + } + + err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b2.Bytes()) + if err != nil { + return fmt.Errorf("unable to restore chanbucket: %v", err) + } + return nil }) if err != nil { @@ -2065,6 +2103,39 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { return updates, nil } +// RemoteUnsignedLocalUpdates retrieves the persisted, unsigned local log +// updates that the remote still needs to sign for. +func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) { + var updates []LogUpdate + err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch err { + case nil: + break + case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: + return nil + default: + return err + } + + updateBytes := chanBucket.Get(remoteUnsignedLocalUpdatesKey) + if updateBytes == nil { + return nil + } + + r := bytes.NewReader(updateBytes) + updates, err = deserializeLogUpdates(r) + return err + }) + if err != nil { + return nil, err + } + + return updates, nil +} + // InsertNextRevocation inserts the _next_ commitment point (revocation) into // the database, and also modifies the internal RemoteNextRevocation attribute // to point to the passed key. This method is to be using during final channel @@ -2101,8 +2172,12 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { // this log can be consulted in order to reconstruct the state needed to // rectify the situation. This method will add the current commitment for the // remote party to the revocation log, and promote the current pending -// commitment to the current remote commitment. -func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg) error { +// commitment to the current remote commitment. The updates parameter is the +// set of local updates that the peer still needs to send us a signature for. +// We store this set of updates in case we go down. +func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg, + updates []LogUpdate) error { + c.Lock() defer c.Unlock() @@ -2226,6 +2301,20 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg) error { return fmt.Errorf("unable to store under unsignedAckedUpdatesKey: %v", err) } + // Persist the local updates the peer hasn't yet signed so they + // can be restored after restart. + var b2 bytes.Buffer + err = serializeLogUpdates(&b2, updates) + if err != nil { + return err + } + + err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b2.Bytes()) + if err != nil { + return fmt.Errorf("unable to restore remote unsigned "+ + "local updates: %v", err) + } + newRemoteCommit = &newCommit.Commitment return nil diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index e0fb3e897..6bd73d1f0 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -797,7 +797,7 @@ func TestChannelStateTransition(t *testing.T) { fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, diskCommitDiff.LogUpdates, nil) - err = channel.AdvanceCommitChainTail(fwdPkg) + err = channel.AdvanceCommitChainTail(fwdPkg, nil) if err != nil { t.Fatalf("unable to append to revocation log: %v", err) } @@ -845,7 +845,7 @@ func TestChannelStateTransition(t *testing.T) { fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil) - err = channel.AdvanceCommitChainTail(fwdPkg) + err = channel.AdvanceCommitChainTail(fwdPkg, nil) if err != nil { t.Fatalf("unable to append to revocation log: %v", err) } diff --git a/channeldb/db_test.go b/channeldb/db_test.go index e5c57c1de..b05ac1152 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -398,7 +398,7 @@ func TestRestoreChannelShells(t *testing.T) { if err != ErrNoRestoredChannelMutation { t.Fatalf("able to mutate restored channel") } - err = channel.AdvanceCommitChainTail(nil) + err = channel.AdvanceCommitChainTail(nil, nil) if err != ErrNoRestoredChannelMutation { t.Fatalf("able to mutate restored channel") } diff --git a/lnwallet/channel.go b/lnwallet/channel.go index a0ce102ae..563fe8d16 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1527,6 +1527,87 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, return pd, nil } +// localLogUpdateToPayDesc converts a LogUpdate into a matching PaymentDescriptor +// entry that can be re-inserted into the local update log. This method is used +// when we sent an update+sig, receive a revocation, but drop right before the +// counterparty can sign for the update we just sent. In this case, we need to +// re-insert the original entries back into the update log so we'll be expecting +// the peer to sign them. The height of the remote commitment is expected to be +// provided and we restore all log update entries with this height, even though +// the real height may be lower. In the way these fields are used elsewhere, this +// doesn't change anything. +func (lc *LightningChannel) localLogUpdateToPayDesc(logUpdate *channeldb.LogUpdate, + remoteUpdateLog *updateLog, commitHeight uint64) (*PaymentDescriptor, + error) { + + // Since Add updates aren't saved to disk under this key, the update will + // never be an Add. + switch wireMsg := logUpdate.UpdateMsg.(type) { + + // For HTLCs that we settled, we'll fetch the original offered HTLC from + // the remote update log so we can retrieve the same PaymentDescriptor that + // ReceiveHTLCSettle would produce. + case *lnwire.UpdateFulfillHTLC: + ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) + + return &PaymentDescriptor{ + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + RPreimage: wireMsg.PaymentPreimage, + LogIndex: logUpdate.LogIndex, + ParentIndex: ogHTLC.HtlcIndex, + EntryType: Settle, + removeCommitHeightRemote: commitHeight, + }, nil + + // If we sent a failure for a prior incoming HTLC, then we'll consult the + // remote update log so we can retrieve the information of the original + // HTLC we're failing. + case *lnwire.UpdateFailHTLC: + ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) + + return &PaymentDescriptor{ + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + removeCommitHeightRemote: commitHeight, + }, nil + + // HTLC fails due to malformed onion blocks are treated the exact same + // way as regular HTLC fails. + case *lnwire.UpdateFailMalformedHTLC: + ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) + + return &PaymentDescriptor{ + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: MalformedFail, + FailCode: wireMsg.FailureCode, + ShaOnionBlob: wireMsg.ShaOnionBlob, + removeCommitHeightRemote: commitHeight, + }, nil + + case *lnwire.UpdateFee: + return &PaymentDescriptor{ + LogIndex: logUpdate.LogIndex, + Amount: lnwire.NewMSatFromSatoshis( + btcutil.Amount(wireMsg.FeePerKw), + ), + EntryType: FeeUpdate, + addCommitHeightRemote: commitHeight, + removeCommitHeightRemote: commitHeight, + }, nil + + default: + return nil, fmt.Errorf("unknown message type: %T", wireMsg) + } +} + // remoteLogUpdateToPayDesc converts a LogUpdate into a matching // PaymentDescriptor entry that can be re-inserted into the update log. This // method is used when we revoked a local commitment, but the connection was @@ -1736,13 +1817,19 @@ func (lc *LightningChannel) restoreCommitState( return err } + // Fetch the local updates the peer still needs to sign for. + remoteUnsignedLocalUpdates, err := lc.channelState.RemoteUnsignedLocalUpdates() + if err != nil { + return err + } + // Finally, with the commitment states restored, we'll now restore the // state logs based on the current local+remote commit, and any pending // remote commit that exists. err = lc.restoreStateLogs( localCommit, remoteCommit, pendingRemoteCommit, pendingRemoteCommitDiff, pendingRemoteKeyChain, - unsignedAckedUpdates, + unsignedAckedUpdates, remoteUnsignedLocalUpdates, ) if err != nil { return err @@ -1759,7 +1846,8 @@ func (lc *LightningChannel) restoreStateLogs( localCommitment, remoteCommitment, pendingRemoteCommit *commitment, pendingRemoteCommitDiff *channeldb.CommitDiff, pendingRemoteKeys *CommitmentKeyRing, - unsignedAckedUpdates []channeldb.LogUpdate) error { + unsignedAckedUpdates, + remoteUnsignedLocalUpdates []channeldb.LogUpdate) error { // We make a map of incoming HTLCs to the height of the remote // commitment they were first added, and outgoing HTLCs to the height @@ -1817,6 +1905,34 @@ func (lc *LightningChannel) restoreStateLogs( outgoingLocalAddHeights[htlcIdx] = localCommitment.height } + // If there are local updates that the peer needs to sign for, then the + // corresponding add is no longer on the remote commitment, but is still on + // our local commitment. + // ----fail---> + // ----sig----> + // <---rev----- + // To ensure proper channel operation, we restore the add's addCommitHeightRemote + // field to the height of the remote commitment. + for _, logUpdate := range remoteUnsignedLocalUpdates { + + var htlcIdx uint64 + switch wireMsg := logUpdate.UpdateMsg.(type) { + case *lnwire.UpdateFulfillHTLC: + htlcIdx = wireMsg.ID + case *lnwire.UpdateFailHTLC: + htlcIdx = wireMsg.ID + case *lnwire.UpdateFailMalformedHTLC: + htlcIdx = wireMsg.ID + default: + continue + } + + // The htlcIdx is stored in the map with the remote commitment + // height so the related add's addCommitHeightRemote field can be + // restored. + incomingRemoteAddHeights[htlcIdx] = remoteCommitment.height + } + // For each incoming HTLC within the local commitment, we add it to the // remote update log. Since HTLCs are added first to the receiver's // commitment, we don't have to restore outgoing HTLCs, as they will be @@ -1873,7 +1989,11 @@ func (lc *LightningChannel) restoreStateLogs( return err } - return nil + // Restore unsigned acked local log updates so we expect the peer to + // sign for them. + return lc.restorePeerLocalUpdates( + remoteUnsignedLocalUpdates, remoteCommitment.height, + ) } // restorePendingRemoteUpdates restores the acked remote log updates that we @@ -1956,6 +2076,38 @@ func (lc *LightningChannel) restorePendingRemoteUpdates( return nil } +// restorePeerLocalUpdates restores the acked local log updates the peer still +// needs to sign for. +func (lc *LightningChannel) restorePeerLocalUpdates(updates []channeldb.LogUpdate, + remoteCommitmentHeight uint64) error { + + lc.log.Debugf("Restoring %v local updates that the peer should sign", + len(updates)) + + for _, logUpdate := range updates { + logUpdate := logUpdate + + payDesc, err := lc.localLogUpdateToPayDesc( + &logUpdate, lc.remoteUpdateLog, remoteCommitmentHeight, + ) + if err != nil { + return err + } + + lc.localUpdateLog.restoreUpdate(payDesc) + + // Since Add updates are not stored and FeeUpdates don't have a + // corresponding entry in the remote update log, we only need to + // mark the htlc as modified if the update was Settle, Fail, or + // MalformedFail. + if payDesc.EntryType != FeeUpdate { + lc.remoteUpdateLog.markHtlcModified(payDesc.ParentIndex) + } + } + + return nil +} + // restorePendingLocalUpdates restores the local log updates leading up to the // given pending remote commitment. func (lc *LightningChannel) restorePendingLocalUpdates( @@ -4625,6 +4777,15 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( } } + // We use the remote commitment chain's tip as it will soon become the tail + // once advanceTail is called. + remoteMessageIndex := lc.remoteCommitChain.tip().ourMessageIndex + localMessageIndex := lc.localCommitChain.tail().ourMessageIndex + + localPeerUpdates := lc.unsignedLocalUpdates( + remoteMessageIndex, localMessageIndex, chanID, + ) + // Now that we have gathered the set of HTLCs to forward, separated by // type, construct a forwarding package using the height that the remote // commitment chain will be extended after persisting the revocation. @@ -4637,7 +4798,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( // sync now to ensure the revocation producer state is consistent with // the current commitment height and also to advance the on-disk // commitment chain. - err = lc.channelState.AdvanceCommitChainTail(fwdPkg) + err = lc.channelState.AdvanceCommitChainTail(fwdPkg, localPeerUpdates) if err != nil { return nil, nil, nil, nil, err } @@ -6750,3 +6911,70 @@ func (lc *LightningChannel) NextLocalHtlcIndex() (uint64, error) { func (lc *LightningChannel) FwdMinHtlc() lnwire.MilliSatoshi { return lc.channelState.LocalChanCfg.MinHTLC } + +// unsignedLocalUpdates retrieves the unsigned local updates that we should +// store upon receiving a revocation. This function is called from +// ReceiveRevocation. remoteMessageIndex is the height into the local update +// log that the remote commitment chain tip includes. localMessageIndex +// is the height into the local update log that the local commitment tail +// includes. Our local updates that are unsigned by the remote should +// have height greater than or equal to localMessageIndex (not on our commit), +// and height less than remoteMessageIndex (on the remote commit). +// +// NOTE: remoteMessageIndex is the height on the tip because this is called +// before the tail is advanced to the tip during ReceiveRevocation. +func (lc *LightningChannel) unsignedLocalUpdates(remoteMessageIndex, + localMessageIndex uint64, chanID lnwire.ChannelID) []channeldb.LogUpdate { + + var localPeerUpdates []channeldb.LogUpdate + for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() { + pd := e.Value.(*PaymentDescriptor) + + // We don't save add updates as they are restored from the + // remote commitment in restoreStateLogs. + if pd.EntryType == Add { + continue + } + + // This is a settle/fail that is on the remote commitment, but + // not on the local commitment. We expect this update to be + // covered in the next commitment signature that the remote + // sends. + if pd.LogIndex < remoteMessageIndex && pd.LogIndex >= localMessageIndex { + logUpdate := channeldb.LogUpdate{ + LogIndex: pd.LogIndex, + } + + switch pd.EntryType { + case FeeUpdate: + logUpdate.UpdateMsg = &lnwire.UpdateFee{ + ChanID: chanID, + FeePerKw: uint32(pd.Amount.ToSatoshis()), + } + case Settle: + logUpdate.UpdateMsg = &lnwire.UpdateFulfillHTLC{ + ChanID: chanID, + ID: pd.ParentIndex, + PaymentPreimage: pd.RPreimage, + } + case Fail: + logUpdate.UpdateMsg = &lnwire.UpdateFailHTLC{ + ChanID: chanID, + ID: pd.ParentIndex, + Reason: pd.FailReason, + } + case MalformedFail: + logUpdate.UpdateMsg = &lnwire.UpdateFailMalformedHTLC{ + ChanID: chanID, + ID: pd.ParentIndex, + ShaOnionBlob: pd.ShaOnionBlob, + FailureCode: pd.FailCode, + } + } + + localPeerUpdates = append(localPeerUpdates, logUpdate) + } + } + + return localPeerUpdates +} From c36840c2a50c52cf5311c75b10b1d38d29f52027 Mon Sep 17 00:00:00 2001 From: nsa Date: Tue, 28 Jul 2020 11:44:48 -0400 Subject: [PATCH 016/218] lnwallet: add regression test TestChannelLocalUnsignedUpdatesFailure This commit includes a regression test that checks that we remember to restore updates that we sent to the peer but they haven't sent us a signature for yet. --- lnwallet/channel_test.go | 85 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 18b2d760b..f076ae3a0 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9212,3 +9212,88 @@ func TestChannelUnsignedAckedFailure(t *testing.T) { err = newAliceChannel.ReceiveNewCommitment(bobSig, bobHtlcSigs) require.NoError(t, err) } + +// TestChannelLocalUnsignedUpdatesFailure checks that updates from the local +// log are restored if the remote hasn't sent us a signature covering them. +// +// The full state transition is: +// +// Alice Bob +// <----add----- +// <----sig----- +// -----rev----> +// -----sig----> +// <----rev----- +// ----fail----> +// -----sig----> +// <----rev----- +// *reconnect* +// <----sig----- +// +// Alice should reject the last signature since the settle is not restored +// into the local update log and thus calculates Bob's signature as invalid. +func TestChannelLocalUnsignedUpdatesFailure(t *testing.T) { + t.Parallel() + + // Create a test channel so that we can test the buggy behavior. + aliceChannel, bobChannel, cleanUp, err := CreateTestChannels( + channeldb.SingleFunderTweaklessBit, + ) + require.NoError(t, err) + defer cleanUp() + + // First we create an htlc that Bob sends to Alice. + htlc, _ := createHTLC(0, lnwire.MilliSatoshi(500000)) + + // <----add----- + _, err = bobChannel.AddHTLC(htlc, nil) + require.NoError(t, err) + _, err = aliceChannel.ReceiveHTLC(htlc) + require.NoError(t, err) + + // Force a state transition to lock in this add on both commitments. + // <----sig----- + // -----rev----> + // -----sig----> + // <----rev----- + err = ForceStateTransition(bobChannel, aliceChannel) + require.NoError(t, err) + + // Now Alice should fail the htlc back to Bob. + // -----fail---> + err = aliceChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + require.NoError(t, err) + err = bobChannel.ReceiveFailHTLC(0, []byte("bad")) + require.NoError(t, err) + + // Alice should send a commitment signature to Bob. + // -----sig----> + aliceSig, aliceHtlcSigs, _, err := aliceChannel.SignNextCommitment() + require.NoError(t, err) + err = bobChannel.ReceiveNewCommitment(aliceSig, aliceHtlcSigs) + require.NoError(t, err) + + // Bob should reply with a revocation and Alice should save the fail as + // an unsigned local update. + // <----rev----- + bobRevocation, _, err := bobChannel.RevokeCurrentCommitment() + require.NoError(t, err) + _, _, _, _, err = aliceChannel.ReceiveRevocation(bobRevocation) + require.NoError(t, err) + + // Restart Alice and assert that she can receive Bob's next commitment + // signature. + // *reconnect* + newAliceChannel, err := NewLightningChannel( + aliceChannel.Signer, aliceChannel.channelState, + aliceChannel.sigPool, + ) + require.NoError(t, err) + + // Bob sends the final signature and Alice should not reject it. + // <----sig----- + bobSig, bobHtlcSigs, _, err := bobChannel.SignNextCommitment() + require.NoError(t, err) + err = newAliceChannel.ReceiveNewCommitment(bobSig, bobHtlcSigs) + require.NoError(t, err) +} From 608617975a76789f97b354f52a845197783a0aff Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 29 Jul 2020 20:13:24 -0700 Subject: [PATCH 017/218] build/version: bump to 0.11.0-beta-rc1 --- build/version.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build/version.go b/build/version.go index 444673ecf..8298e9550 100644 --- a/build/version.go +++ b/build/version.go @@ -41,14 +41,14 @@ const ( AppMajor uint = 0 // AppMinor defines the minor version of this binary. - AppMinor uint = 10 + AppMinor uint = 11 // AppPatch defines the application patch for this binary. - AppPatch uint = 99 + AppPatch uint = 0 // AppPreRelease MUST only contain characters from semanticAlphabet // per the semantic versioning spec. - AppPreRelease = "beta" + AppPreRelease = "beta-rc1" ) func init() { From 247b7530caf08a555ffd56f81019031bc1af6565 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 29 Jul 2020 20:56:09 -0700 Subject: [PATCH 018/218] build/version: use dot syntax for rc --- build/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/version.go b/build/version.go index 8298e9550..7954a15b3 100644 --- a/build/version.go +++ b/build/version.go @@ -48,7 +48,7 @@ const ( // AppPreRelease MUST only contain characters from semanticAlphabet // per the semantic versioning spec. - AppPreRelease = "beta-rc1" + AppPreRelease = "beta.rc1" ) func init() { From 5346ed8a5c1aba8bdec44f5dc37544487a215982 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 21 Jul 2020 18:26:13 +0200 Subject: [PATCH 019/218] kvdb+etcd: assert on bucket/value key when putting value/bucket This commit extends compatibility with the bbolt kvdb implementation, which returns ErrIncompatibleValue in case of a bucket/value key collision. Furthermore the commit also adds an extra precondition to the transaction when a key doesn't exist. This is needed as we fix reads to a snapshot revision and other writers may commit the key otherwise. --- channeldb/kvdb/etcd/readwrite_bucket.go | 49 ++++++-- channeldb/kvdb/etcd/readwrite_bucket_test.go | 118 +++++++++++++++++++ channeldb/kvdb/etcd/readwrite_tx.go | 10 -- channeldb/kvdb/etcd/stm.go | 25 +++- 4 files changed, 180 insertions(+), 22 deletions(-) diff --git a/channeldb/kvdb/etcd/readwrite_bucket.go b/channeldb/kvdb/etcd/readwrite_bucket.go index f97268d92..dafab5ff4 100644 --- a/channeldb/kvdb/etcd/readwrite_bucket.go +++ b/channeldb/kvdb/etcd/readwrite_bucket.go @@ -116,6 +116,20 @@ func (b *readWriteBucket) NestedReadWriteBucket(key []byte) walletdb.ReadWriteBu return newReadWriteBucket(b.tx, bucketKey, bucketVal) } +// assertNoValue checks if the value for the passed key exists. +func (b *readWriteBucket) assertNoValue(key []byte) error { + val, err := b.tx.stm.Get(string(makeValueKey(b.id, key))) + if err != nil { + return err + } + + if val != nil { + return walletdb.ErrIncompatibleValue + } + + return nil +} + // CreateBucket creates and returns a new nested bucket with the given // key. Returns ErrBucketExists if the bucket already exists, // ErrBucketNameRequired if the key is empty, or ErrIncompatibleValue @@ -141,11 +155,15 @@ func (b *readWriteBucket) CreateBucket(key []byte) ( return nil, walletdb.ErrBucketExists } + if err := b.assertNoValue(key); err != nil { + return nil, err + } + // Create a deterministic bucket id from the bucket key. newID := makeBucketID(bucketKey) // Create the bucket. - b.tx.put(string(bucketKey), string(newID[:])) + b.tx.stm.Put(string(bucketKey), string(newID[:])) return newReadWriteBucket(b.tx, bucketKey, newID[:]), nil } @@ -171,8 +189,12 @@ func (b *readWriteBucket) CreateBucketIfNotExists(key []byte) ( } if !isValidBucketID(bucketVal) { + if err := b.assertNoValue(key); err != nil { + return nil, err + } + newID := makeBucketID(bucketKey) - b.tx.put(string(bucketKey), string(newID[:])) + b.tx.stm.Put(string(bucketKey), string(newID[:])) return newReadWriteBucket(b.tx, bucketKey, newID[:]), nil } @@ -220,7 +242,7 @@ func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { } for kv != nil { - b.tx.del(kv.key) + b.tx.stm.Del(kv.key) if isBucketKey(kv.key) { queue = append(queue, []byte(kv.val)) @@ -233,12 +255,12 @@ func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { } // Finally delete the sequence key for the bucket. - b.tx.del(string(makeSequenceKey(id))) + b.tx.stm.Del(string(makeSequenceKey(id))) } // Delete the top level bucket and sequence key. - b.tx.del(bucketKey) - b.tx.del(string(makeSequenceKey(bucketVal))) + b.tx.stm.Del(bucketKey) + b.tx.stm.Del(string(makeSequenceKey(bucketVal))) return nil } @@ -250,8 +272,17 @@ func (b *readWriteBucket) Put(key, value []byte) error { return walletdb.ErrKeyRequired } + val, err := b.tx.stm.Get(string(makeBucketKey(b.id, key))) + if err != nil { + return err + } + + if val != nil { + return walletdb.ErrIncompatibleValue + } + // Update the transaction with the new value. - b.tx.put(string(makeValueKey(b.id, key)), string(value)) + b.tx.stm.Put(string(makeValueKey(b.id, key)), string(value)) return nil } @@ -264,7 +295,7 @@ func (b *readWriteBucket) Delete(key []byte) error { } // Update the transaction to delete the key/value. - b.tx.del(string(makeValueKey(b.id, key))) + b.tx.stm.Del(string(makeValueKey(b.id, key))) return nil } @@ -294,7 +325,7 @@ func (b *readWriteBucket) SetSequence(v uint64) error { val := strconv.FormatUint(v, 10) // Update the transaction with the new value for the sequence key. - b.tx.put(string(makeSequenceKey(b.id)), val) + b.tx.stm.Put(string(makeSequenceKey(b.id)), val) return nil } diff --git a/channeldb/kvdb/etcd/readwrite_bucket_test.go b/channeldb/kvdb/etcd/readwrite_bucket_test.go index 6fb321367..2795dce34 100644 --- a/channeldb/kvdb/etcd/readwrite_bucket_test.go +++ b/channeldb/kvdb/etcd/readwrite_bucket_test.go @@ -403,3 +403,121 @@ func TestBucketSequence(t *testing.T) { require.Nil(t, err) } + +// TestKeyClash tests that one cannot create a bucket if a value with the same +// key exists and the same is true in reverse: that a value cannot be put if +// a bucket with the same key exists. +func TestKeyClash(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + require.NoError(t, err) + + // First: + // put: /apple/key -> val + // create bucket: /apple/banana + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + require.Nil(t, err) + require.NotNil(t, apple) + + require.NoError(t, apple.Put([]byte("key"), []byte("val"))) + + banana, err := apple.CreateBucket([]byte("banana")) + require.Nil(t, err) + require.NotNil(t, banana) + + return nil + }) + + require.Nil(t, err) + + // Next try to: + // put: /apple/banana -> val => will fail (as /apple/banana is a bucket) + // create bucket: /apple/key => will fail (as /apple/key is a value) + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + require.Nil(t, err) + require.NotNil(t, apple) + + require.Error(t, + walletdb.ErrIncompatibleValue, + apple.Put([]byte("banana"), []byte("val")), + ) + + b, err := apple.CreateBucket([]byte("key")) + require.Nil(t, b) + require.Error(t, walletdb.ErrIncompatibleValue, b) + + b, err = apple.CreateBucketIfNotExists([]byte("key")) + require.Nil(t, b) + require.Error(t, walletdb.ErrIncompatibleValue, b) + + return nil + }) + + require.Nil(t, err) + + // Except that the only existing items in the db are: + // bucket: /apple + // bucket: /apple/banana + // value: /apple/key -> val + expected := map[string]string{ + bkey("apple"): bval("apple"), + bkey("apple", "banana"): bval("apple", "banana"), + vkey("key", "apple"): "val", + } + require.Equal(t, expected, f.Dump()) + +} + +// TestBucketCreateDelete tests that creating then deleting then creating a +// bucket suceeds. +func TestBucketCreateDelete(t *testing.T) { + t.Parallel() + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + require.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + require.NoError(t, err) + require.NotNil(t, apple) + + banana, err := apple.CreateBucket([]byte("banana")) + require.NoError(t, err) + require.NotNil(t, banana) + + return nil + }) + require.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple := tx.ReadWriteBucket([]byte("apple")) + require.NotNil(t, apple) + require.NoError(t, apple.DeleteNestedBucket([]byte("banana"))) + + return nil + }) + require.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple := tx.ReadWriteBucket([]byte("apple")) + require.NotNil(t, apple) + require.NoError(t, apple.Put([]byte("banana"), []byte("value"))) + + return nil + }) + require.NoError(t, err) + + expected := map[string]string{ + vkey("banana", "apple"): "value", + bkey("apple"): bval("apple"), + } + require.Equal(t, expected, f.Dump()) +} diff --git a/channeldb/kvdb/etcd/readwrite_tx.go b/channeldb/kvdb/etcd/readwrite_tx.go index aed00a17e..81c27323f 100644 --- a/channeldb/kvdb/etcd/readwrite_tx.go +++ b/channeldb/kvdb/etcd/readwrite_tx.go @@ -34,16 +34,6 @@ func rootBucket(tx *readWriteTx) *readWriteBucket { return newReadWriteBucket(tx, tx.rootBucketID[:], tx.rootBucketID[:]) } -// put updates the passed key/value. -func (tx *readWriteTx) put(key, val string) { - tx.stm.Put(key, val) -} - -// del marks the passed key deleted. -func (tx *readWriteTx) del(key string) { - tx.stm.Del(key) -} - // ReadBucket opens the root bucket for read only access. If the bucket // described by the key does not exist, nil is returned. func (tx *readWriteTx) ReadBucket(key []byte) walletdb.ReadBucket { diff --git a/channeldb/kvdb/etcd/stm.go b/channeldb/kvdb/etcd/stm.go index c13e8f966..14bb9ca92 100644 --- a/channeldb/kvdb/etcd/stm.go +++ b/channeldb/kvdb/etcd/stm.go @@ -352,6 +352,15 @@ func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) { } } + if len(resp.Kvs) == 0 { + // Add assertion to the read set which will extend our commit + // constraint such that the commit will fail if the key is + // present in the database. + s.rset[key] = stmGet{ + rev: 0, + } + } + var result []KV // Fill the read set with key/values returned. @@ -395,12 +404,22 @@ func (s *stm) Get(key string) ([]byte, error) { // the prefetch set. if getValue, ok := s.prefetch[key]; ok { delete(s.prefetch, key) - s.rset[key] = getValue + + // Use the prefetched value only if it is for + // an existing key. + if getValue.rev != 0 { + s.rset[key] = getValue + } } // Return value if alread in read set. - if getVal, ok := s.rset[key]; ok { - return []byte(getVal.val), nil + if getValue, ok := s.rset[key]; ok { + // Return the value if the rset contains an existing key. + if getValue.rev != 0 { + return []byte(getValue.val), nil + } else { + return nil, nil + } } // Fetch and return value. From 09b8bee8658ab8565279bf87b3522dfc489916a9 Mon Sep 17 00:00:00 2001 From: "Johan T. Halseth" Date: Tue, 4 Aug 2020 14:43:08 +0200 Subject: [PATCH 020/218] mobile: remember walletunlocker.proto --- mobile/gen_bindings.sh | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/mobile/gen_bindings.sh b/mobile/gen_bindings.sh index 2e54beaac..c325ff3fd 100755 --- a/mobile/gen_bindings.sh +++ b/mobile/gen_bindings.sh @@ -39,14 +39,21 @@ listeners="lightning=lightningLis walletunlocker=walletUnlockerLis" # one proto file is being parsed, it should only be done once. mem_rpc=1 +PROTOS="rpc.proto walletunlocker.proto" + opts="package_name=$pkg,target_package=$target_pkg,listeners=$listeners,mem_rpc=$mem_rpc" -protoc -I/usr/local/include -I. \ - -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ - --plugin=protoc-gen-custom=$falafel\ - --custom_out=./build \ - --custom_opt="$opts" \ - --proto_path=../lnrpc \ - rpc.proto + +for file in $PROTOS; do + echo "Generating mobile protos from ${file}" + + protoc -I/usr/local/include -I. \ + -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ + --plugin=protoc-gen-custom=$falafel\ + --custom_out=./build \ + --custom_opt="$opts" \ + --proto_path=../lnrpc \ + "${file}" +done # If prefix=1 is specified, prefix the generated methods with subserver name. # This must be enabled to support subservers with name conflicts. From 675c1b95c91f3da51cb997948a16d42d4d8b3f2e Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 4 Aug 2020 15:33:21 -0700 Subject: [PATCH 021/218] lnd: don't set freelist value when creating channeldb This value actually isn't read anywhere, since it's no longer used. Instead, `cfg.Db.Bolt.NoSyncFreeList` is what's evaluated when we go to open the DB. --- lnd.go | 1 - 1 file changed, 1 deletion(-) diff --git a/lnd.go b/lnd.go index cccf30e61..ddce659af 100644 --- a/lnd.go +++ b/lnd.go @@ -269,7 +269,6 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error { chanDbBackend, channeldb.OptionSetRejectCacheSize(cfg.Caches.RejectCacheSize), channeldb.OptionSetChannelCacheSize(cfg.Caches.ChannelCacheSize), - channeldb.OptionSetSyncFreelist(cfg.SyncFreelist), channeldb.OptionDryRunMigration(cfg.DryRunMigration), ) switch { From e616903d4f62ba3ff026e09a82d85f47bfca991d Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 4 Aug 2020 15:34:29 -0700 Subject: [PATCH 022/218] config: unify old and new config values for sync-freelist In this commit, unify the old and new values for `sync-freelist`, and also ensure that we don't break behavior for any users that're using the _old_ value. To do this, we first rename what was `--db.bolt.no-sync-freelist`, to `--db.bolt.sync-freelist`. This gets rid of the negation on the config level, and lets us override that value if the user is specifying the legacy config option. In the future, we'll deprecate the old config option, in favor of the new DB scoped option. --- channeldb/kvdb/config.go | 2 +- config.go | 8 ++++++++ lncfg/db.go | 22 +++++++++++----------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/channeldb/kvdb/config.go b/channeldb/kvdb/config.go index a4ed68bab..179fde78a 100644 --- a/channeldb/kvdb/config.go +++ b/channeldb/kvdb/config.go @@ -12,7 +12,7 @@ const EtcdBackendName = "etcd" // BoltConfig holds bolt configuration. type BoltConfig struct { - NoFreeListSync bool `long:"nofreelistsync" description:"If true, prevents the database from syncing its freelist to disk"` + SyncFreelist bool `long:"nofreelistsync" description:"Whether the databases used within lnd should sync their freelist to disk. This is disabled by default resulting in improved memory performance during operation, but with an increase in startup time."` } // EtcdConfig holds etcd configuration. diff --git a/config.go b/config.go index 3a52ca964..7913f04ea 100644 --- a/config.go +++ b/config.go @@ -1097,6 +1097,14 @@ func ValidateConfig(cfg Config, usageMessage string) (*Config, error) { "minbackoff") } + // Newer versions of lnd added a new sub-config for bolt-specific + // parameters. However we want to also allow existing users to use the + // value on the top-level config. If the outer config value is set, + // then we'll use that directly. + if cfg.SyncFreelist { + cfg.DB.Bolt.SyncFreelist = cfg.SyncFreelist + } + // Validate the subconfigs for workers, caches, and the tower client. err = lncfg.Validate( cfg.Workers, diff --git a/lncfg/db.go b/lncfg/db.go index d63da8caf..6bc1e9a3b 100644 --- a/lncfg/db.go +++ b/lncfg/db.go @@ -9,8 +9,8 @@ import ( const ( dbName = "channel.db" - boltBackend = "bolt" - etcdBackend = "etcd" + BoltBackend = "bolt" + EtcdBackend = "etcd" ) // DB holds database configuration for LND. @@ -25,26 +25,24 @@ type DB struct { // NewDB creates and returns a new default DB config. func DefaultDB() *DB { return &DB{ - Backend: boltBackend, - Bolt: &kvdb.BoltConfig{ - NoFreeListSync: true, - }, + Backend: BoltBackend, + Bolt: &kvdb.BoltConfig{}, } } // Validate validates the DB config. func (db *DB) Validate() error { switch db.Backend { - case boltBackend: + case BoltBackend: - case etcdBackend: + case EtcdBackend: if db.Etcd.Host == "" { return fmt.Errorf("etcd host must be set") } default: return fmt.Errorf("unknown backend, must be either \"%v\" or \"%v\"", - boltBackend, etcdBackend) + BoltBackend, EtcdBackend) } return nil @@ -54,12 +52,14 @@ func (db *DB) Validate() error { func (db *DB) GetBackend(ctx context.Context, dbPath string, networkName string) (kvdb.Backend, error) { - if db.Backend == etcdBackend { + if db.Backend == EtcdBackend { // Prefix will separate key/values in the db. return kvdb.GetEtcdBackend(ctx, networkName, db.Etcd) } - return kvdb.GetBoltBackend(dbPath, dbName, db.Bolt.NoFreeListSync) + // The implementation by walletdb accepts "noFreelistSync" as the + // second parameter, so we negate here. + return kvdb.GetBoltBackend(dbPath, dbName, !db.Bolt.SyncFreelist) } // Compile-time constraint to ensure Workers implements the Validator interface. From 19f68d2538dc59f502524295cbfeb0ec607da2c0 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 4 Aug 2020 16:17:03 -0700 Subject: [PATCH 023/218] lnd: log bbolt freelist sync config value on start up --- lnd.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lnd.go b/lnd.go index ddce659af..7cfae4c3d 100644 --- a/lnd.go +++ b/lnd.go @@ -255,6 +255,11 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error { ctx, cancel := context.WithCancel(ctx) defer cancel() + if cfg.DB.Backend == lncfg.BoltBackend { + ltndLog.Infof("Opening bbolt database, sync_freelist=%v", + cfg.DB.Bolt.SyncFreelist) + } + chanDbBackend, err := cfg.DB.GetBackend(ctx, cfg.localDatabaseDir(), cfg.networkName(), ) From 2f1f8561aebcf14575c2e55321d7978829805540 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 31 Jul 2020 10:32:29 +0200 Subject: [PATCH 024/218] mod: update to latest btcd version --- go.mod | 2 +- go.sum | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index cf266e53d..59e21c2e0 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ require ( github.com/NebulousLabs/fastrand v0.0.0-20181203155948-6fb6489aac4e // indirect github.com/NebulousLabs/go-upnp v0.0.0-20180202185039-29b680b06c82 github.com/Yawning/aez v0.0.0-20180114000226-4dad034d9db2 - github.com/btcsuite/btcd v0.20.1-beta.0.20200515232429-9f0179fd2c46 + github.com/btcsuite/btcd v0.20.1-beta.0.20200730232343-1db1b6f8217f github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f github.com/btcsuite/btcutil v1.0.2 github.com/btcsuite/btcutil/psbt v1.0.2 diff --git a/go.sum b/go.sum index f75d6f0da..46a292615 100644 --- a/go.sum +++ b/go.sum @@ -27,8 +27,8 @@ github.com/btcsuite/btcd v0.0.0-20190824003749-130ea5bddde3/go.mod h1:3J08xEfcug github.com/btcsuite/btcd v0.20.1-beta h1:Ik4hyJqN8Jfyv3S4AGBOmyouMsYE3EdYODkMbQjwPGw= github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= github.com/btcsuite/btcd v0.20.1-beta.0.20200513120220-b470eee47728/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= -github.com/btcsuite/btcd v0.20.1-beta.0.20200515232429-9f0179fd2c46 h1:QyTpiR5nQe94vza2qkvf7Ns8XX2Rjh/vdIhO3RzGj4o= -github.com/btcsuite/btcd v0.20.1-beta.0.20200515232429-9f0179fd2c46/go.mod h1:Yktc19YNjh/Iz2//CX0vfRTS4IJKM/RKO5YZ9Fn+Pgo= +github.com/btcsuite/btcd v0.20.1-beta.0.20200730232343-1db1b6f8217f h1:m/GhMTvDQLbID616c4TYdHyt0MZ9lH5B/nf9Lu3okCY= +github.com/btcsuite/btcd v0.20.1-beta.0.20200730232343-1db1b6f8217f/go.mod h1:ZSWyehm27aAuS9bvkATT+Xte3hjHZ+MRgMY/8NJ7K94= github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f h1:bAs4lUbRJpnnkd9VhRV3jjAVU7DJVjMaK+IsvSeZvFo= github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d h1:yJzD/yFppdVCf6ApMkVy8cUxV0XrxdP9rVf6D87/Mng= @@ -86,6 +86,8 @@ github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/lru v1.0.0 h1:Kbsb1SFDsIlaupWPwsPp+dkxiBY1frcS07PCPgotKz8= +github.com/decred/dcrd/lru v1.0.0/go.mod h1:mxKOwFd7lFjN2GZYsiz/ecgqR6kkYAl+0pz0tEMk218= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= From 6115a7b12bdbc8adf0b762e41f344ccb9b06da65 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 31 Jul 2020 10:32:30 +0200 Subject: [PATCH 025/218] make+itest: make itest Windows compatible --- Makefile | 7 +++++++ lntest/btcd.go | 1 + lntest/itest/lnd_test.go | 32 ++++++++++++++++++++++++++------ 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index e474fbe5a..6d729bf07 100644 --- a/Makefile +++ b/Makefile @@ -132,6 +132,11 @@ build-itest: $(GOBUILD) -tags="$(ITEST_TAGS)" -o lnd-itest $(ITEST_LDFLAGS) $(PKG)/cmd/lnd $(GOBUILD) -tags="$(ITEST_TAGS)" -o lncli-itest $(ITEST_LDFLAGS) $(PKG)/cmd/lncli +build-itest-windows: + @$(call print, "Building itest lnd and lncli.") + $(GOBUILD) -tags="$(ITEST_TAGS)" -o lnd-itest.exe $(ITEST_LDFLAGS) $(PKG)/cmd/lnd + $(GOBUILD) -tags="$(ITEST_TAGS)" -o lncli-itest.exe $(ITEST_LDFLAGS) $(PKG)/cmd/lncli + install: @$(call print, "Installing lnd and lncli.") $(GOINSTALL) -tags="${tags}" $(LDFLAGS) $(PKG)/cmd/lnd @@ -158,6 +163,8 @@ itest-only: itest: btcd build-itest itest-only +itest-windows: btcd build-itest-windows itest-only + unit: btcd @$(call print, "Running unit tests.") $(UNIT) diff --git a/lntest/btcd.go b/lntest/btcd.go index 3c50e551e..2322d8eed 100644 --- a/lntest/btcd.go +++ b/lntest/btcd.go @@ -81,6 +81,7 @@ func NewBackend(miner string, netParams *chaincfg.Params) ( "--debuglevel=debug", "--logdir=" + logDir, "--connect=" + miner, + "--nowinservice", } chainBackend, err := rpctest.New(netParams, nil, args) if err != nil { diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index ee51f47b5..9c2bdd235 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -15,6 +15,7 @@ import ( "os" "path/filepath" "reflect" + "runtime" "strings" "sync" "sync/atomic" @@ -2462,9 +2463,14 @@ func testOpenChannelAfterReorg(net *lntest.NetworkHarness, t *harnessTest) { ) // Set up a new miner that we can use to cause a reorg. - args := []string{"--rejectnonstd", "--txindex"} - tempMiner, err := rpctest.New(harnessNetParams, - &rpcclient.NotificationHandlers{}, args) + args := []string{ + "--rejectnonstd", + "--txindex", + "--nowinservice", + } + tempMiner, err := rpctest.New( + harnessNetParams, &rpcclient.NotificationHandlers{}, args, + ) if err != nil { t.Fatalf("unable to create mining node: %v", err) } @@ -15284,6 +15290,7 @@ func TestLightningNetworkDaemon(t *testing.T) { "--debuglevel=debug", "--logdir=" + minerLogDir, "--trickleinterval=100ms", + "--nowinservice", } handlers := &rpcclient.NotificationHandlers{ OnTxAccepted: func(hash *chainhash.Hash, amt btcutil.Amount) { @@ -15329,11 +15336,24 @@ func TestLightningNetworkDaemon(t *testing.T) { ht.Fatalf("unable to request transaction notifications: %v", err) } + binary := itestLndBinary + if runtime.GOOS == "windows" { + // Windows (even in a bash like environment like git bash as on + // Travis) doesn't seem to like relative paths to exe files... + currentDir, err := os.Getwd() + if err != nil { + ht.Fatalf("unable to get working directory: %v", err) + } + targetPath := filepath.Join(currentDir, "../../lnd-itest.exe") + binary, err = filepath.Abs(targetPath) + if err != nil { + ht.Fatalf("unable to get absolute path: %v", err) + } + } + // Now we can set up our test harness (LND instance), with the chain // backend we just created. - lndHarness, err = lntest.NewNetworkHarness( - miner, chainBackend, itestLndBinary, - ) + lndHarness, err = lntest.NewNetworkHarness(miner, chainBackend, binary) if err != nil { ht.Fatalf("unable to create lightning network harness: %v", err) } From a6a7aca8af10745c3bc975bdf85c8906179c6c1b Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 31 Jul 2020 10:32:31 +0200 Subject: [PATCH 026/218] travis: add itest on Windows This commit adds an integration test that runs on a Windows virtual machine on Travis. The tests run inside of a "Git Bash" environment which supports the same command line syntax as a proper bash but doesn't have all the tooling installed. Some tools also behave differently on Windows. Therefore we also have to simplify the command to upload the logs to termbin and remove the upload to file.io on Windows because both the find and tar command don't work as expected. --- .travis.yml | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 5d9e7064a..d5d9204a8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -68,7 +68,32 @@ jobs: - GOARM=7 GOARCH=arm GOOS=linux CGO_ENABLED=0 make itest-only arch: arm64 + - name: Btcd Integration Windows + script: + - make itest-windows + os: windows + before_install: + - choco upgrade --no-progress -y make netcat curl findutils + - export MAKE=mingw32-make + after_script: + - |- + case $TRAVIS_OS_NAME in + windows) + echo "Uploading to termbin.com..." + for f in ./lntest/itest/*.log; do cat $f | nc termbin.com 9999 | xargs -r0 printf "$f"' uploaded to %s'; done + ;; + esac + after_script: - - LOG_FILES=./lntest/itest/*.log - - echo "Uploading to termbin.com..." && find $LOG_FILES | xargs -I{} sh -c "cat {} | nc termbin.com 9999 | xargs -r0 printf '{} uploaded to %s'" - - echo "Uploading to file.io..." && tar -zcvO $LOG_FILES | curl -s -F 'file=@-;filename=logs.tar.gz' https://file.io | xargs -r0 printf 'logs.tar.gz uploaded to %s\n' + - |- + case $TRAVIS_OS_NAME in + windows) + # Needs other commands, see after_script of the Windows build + ;; + + *) + LOG_FILES=./lntest/itest/*.log + echo "Uploading to termbin.com..." && find $LOG_FILES | xargs -I{} sh -c "cat {} | nc termbin.com 9999 | xargs -r0 printf '{} uploaded to %s'" + echo "Uploading to file.io..." && tar -zcvO $LOG_FILES | curl -s -F 'file=@-;filename=logs.tar.gz' https://file.io | xargs -r0 printf 'logs.tar.gz uploaded to %s\n' + ;; + esac From b21b2ebd6f2e81ca990fb7fe7062c4ce30bd8cf6 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 31 Jul 2020 10:32:33 +0200 Subject: [PATCH 027/218] lntest: improve fee calculation in multi-hop test --- lntest/itest/lnd_multi-hop-payments.go | 32 +++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/lntest/itest/lnd_multi-hop-payments.go b/lntest/itest/lnd_multi-hop-payments.go index a7dfee5a5..cba968834 100644 --- a/lntest/itest/lnd_multi-hop-payments.go +++ b/lntest/itest/lnd_multi-hop-payments.go @@ -171,13 +171,18 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { // channel edges to relatively large non default values. This makes it // possible to pick up more subtle fee calculation errors. maxHtlc := uint64(calculateMaxHtlc(chanAmt)) + const aliceBaseFeeSat = 1 + const aliceFeeRatePPM = 100000 updateChannelPolicy( - t, net.Alice, chanPointAlice, 1000, 100000, - lnd.DefaultBitcoinTimeLockDelta, maxHtlc, carol, + t, net.Alice, chanPointAlice, aliceBaseFeeSat*1000, + aliceFeeRatePPM, lnd.DefaultBitcoinTimeLockDelta, maxHtlc, + carol, ) + const daveBaseFeeSat = 5 + const daveFeeRatePPM = 150000 updateChannelPolicy( - t, dave, chanPointDave, 5000, 150000, + t, dave, chanPointDave, daveBaseFeeSat*1000, daveFeeRatePPM, lnd.DefaultBitcoinTimeLockDelta, maxHtlc, carol, ) @@ -224,11 +229,6 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { t.Fatalf("unable to send payments: %v", err) } - // When asserting the amount of satoshis moved, we'll factor in the - // default base fee, as we didn't modify the fee structure when - // creating the seed nodes in the network. - const baseFee = 1 - // At this point all the channels within our proto network should be // shifted by 5k satoshis in the direction of Bob, the sink within the // payment flow generated above. The order of asserts corresponds to @@ -237,7 +237,7 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { // Alice, David, Carol. // The final node bob expects to get paid five times 1000 sat. - expectedAmountPaidAtoB := int64(5 * 1000) + expectedAmountPaidAtoB := int64(numPayments * paymentAmt) assertAmountPaid(t, "Alice(local) => Bob(remote)", net.Bob, aliceFundPoint, int64(0), expectedAmountPaidAtoB) @@ -246,7 +246,9 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { // To forward a payment of 1000 sat, Alice is charging a fee of // 1 sat + 10% = 101 sat. - const expectedFeeAlice = 5 * 101 + const aliceFeePerPayment = aliceBaseFeeSat + + (paymentAmt * aliceFeeRatePPM / 1_000_000) + const expectedFeeAlice = numPayments * aliceFeePerPayment // Dave needs to pay what Alice pays plus Alice's fee. expectedAmountPaidDtoA := expectedAmountPaidAtoB + expectedFeeAlice @@ -258,7 +260,10 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { // To forward a payment of 1101 sat, Dave is charging a fee of // 5 sat + 15% = 170.15 sat. This is rounded down in rpcserver to 170. - const expectedFeeDave = 5 * 170 + const davePaymentAmt = paymentAmt + aliceFeePerPayment + const daveFeePerPayment = daveBaseFeeSat + + (davePaymentAmt * daveFeeRatePPM / 1_000_000) + const expectedFeeDave = numPayments * daveFeePerPayment // Carol needs to pay what Dave pays plus Dave's fee. expectedAmountPaidCtoD := expectedAmountPaidDtoA + expectedFeeDave @@ -303,9 +308,10 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { if err != nil { t.Fatalf("unable to query for fee report: %v", err) } - if len(fwdingHistory.ForwardingEvents) != 5 { + if len(fwdingHistory.ForwardingEvents) != numPayments { t.Fatalf("wrong number of forwarding event: expected %v, "+ - "got %v", 5, len(fwdingHistory.ForwardingEvents)) + "got %v", numPayments, + len(fwdingHistory.ForwardingEvents)) } expectedForwardingFee := uint64(expectedFeeDave / numPayments) for _, event := range fwdingHistory.ForwardingEvents { From 97c73706b55fb5520d2f3001dfa13dc1da986972 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 31 Jul 2020 10:32:34 +0200 Subject: [PATCH 028/218] channeldb: fix for Windows clock resolution We use the event timestamp of a forwarding event as its primary storage key. On systems with a bad clock resolution this can lead to collisions of the events if some of the timestamps are identical. We fix this problem by shifting the timestamps on the nanosecond level until only unique values remain. --- channeldb/forwarding_log.go | 105 ++++++++++++++++++++----- channeldb/forwarding_log_test.go | 128 +++++++++++++++++++++++++++++-- 2 files changed, 208 insertions(+), 25 deletions(-) diff --git a/channeldb/forwarding_log.go b/channeldb/forwarding_log.go index a52848dd4..d1216dc46 100644 --- a/channeldb/forwarding_log.go +++ b/channeldb/forwarding_log.go @@ -6,6 +6,7 @@ import ( "sort" "time" + "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lnwire" ) @@ -104,10 +105,9 @@ func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error { func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error { // Before we create the database transaction, we'll ensure that the set // of forwarding events are properly sorted according to their - // timestamp. - sort.Slice(events, func(i, j int) bool { - return events[i].Timestamp.Before(events[j].Timestamp) - }) + // timestamp and that no duplicate timestamps exist to avoid collisions + // in the key we are going to store the events under. + makeUniqueTimestamps(events) var timestamp [8]byte @@ -124,22 +124,7 @@ func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error { // With the bucket obtained, we can now begin to write out the // series of events. for _, event := range events { - var eventBytes [forwardingEventSize]byte - eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize]) - - // First, we'll serialize this timestamp into our - // timestamp buffer. - byteOrder.PutUint64( - timestamp[:], uint64(event.Timestamp.UnixNano()), - ) - - // With the key encoded, we'll then encode the event - // into our buffer, then write it out to disk. - err := encodeForwardingEvent(eventBuf, &event) - if err != nil { - return err - } - err = logBucket.Put(timestamp[:], eventBuf.Bytes()) + err := storeEvent(logBucket, event, timestamp[:]) if err != nil { return err } @@ -149,6 +134,55 @@ func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error { }) } +// storeEvent tries to store a forwarding event into the given bucket by trying +// to avoid collisions. If a key for the event timestamp already exists in the +// database, the timestamp is incremented in nanosecond intervals until a "free" +// slot is found. +func storeEvent(bucket walletdb.ReadWriteBucket, event ForwardingEvent, + timestampScratchSpace []byte) error { + + // First, we'll serialize this timestamp into our + // timestamp buffer. + byteOrder.PutUint64( + timestampScratchSpace, uint64(event.Timestamp.UnixNano()), + ) + + // Next we'll loop until we find a "free" slot in the bucket to store + // the event under. This should almost never happen unless we're running + // on a system that has a very bad system clock that doesn't properly + // resolve to nanosecond scale. We try up to 100 times (which would come + // to a maximum shift of 0.1 microsecond which is acceptable for most + // use cases). If we don't find a free slot, we just give up and let + // the collision happen. Something must be wrong with the data in that + // case, even on a very fast machine forwarding payments _will_ take a + // few microseconds at least so we should find a nanosecond slot + // somewhere. + const maxTries = 100 + tries := 0 + for tries < maxTries { + val := bucket.Get(timestampScratchSpace) + if val == nil { + break + } + + // Collision, try the next nanosecond timestamp. + nextNano := event.Timestamp.UnixNano() + 1 + event.Timestamp = time.Unix(0, nextNano) + byteOrder.PutUint64(timestampScratchSpace, uint64(nextNano)) + tries++ + } + + // With the key encoded, we'll then encode the event + // into our buffer, then write it out to disk. + var eventBytes [forwardingEventSize]byte + eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize]) + err := encodeForwardingEvent(eventBuf, &event) + if err != nil { + return err + } + return bucket.Put(timestampScratchSpace, eventBuf.Bytes()) +} + // ForwardingEventQuery represents a query to the forwarding log payment // circuit time series database. The query allows a caller to retrieve all // records for a particular time slice, offset in that time slice, limiting the @@ -272,3 +306,34 @@ func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, e return resp, nil } + +// makeUniqueTimestamps takes a slice of forwarding events, sorts it by the +// event timestamps and then makes sure there are no duplicates in the +// timestamps. If duplicates are found, some of the timestamps are increased on +// the nanosecond scale until only unique values remain. This is a fix to +// address the problem that in some environments (looking at you, Windows) the +// system clock has such a bad resolution that two serial invocations of +// time.Now() might return the same timestamp, even if some time has elapsed +// between the calls. +func makeUniqueTimestamps(events []ForwardingEvent) { + sort.Slice(events, func(i, j int) bool { + return events[i].Timestamp.Before(events[j].Timestamp) + }) + + // Now that we know the events are sorted by timestamp, we can go + // through the list and fix all duplicates until only unique values + // remain. + for outer := 0; outer < len(events)-1; outer++ { + current := events[outer].Timestamp.UnixNano() + next := events[outer+1].Timestamp.UnixNano() + + // We initially sorted the slice. So if the current is now + // greater or equal to the next one, it's either because it's a + // duplicate or because we increased the current in the last + // iteration. + if current >= next { + next = current + 1 + events[outer+1].Timestamp = time.Unix(0, next) + } + } +} diff --git a/channeldb/forwarding_log_test.go b/channeldb/forwarding_log_test.go index 07dfc902c..cd21f12e2 100644 --- a/channeldb/forwarding_log_test.go +++ b/channeldb/forwarding_log_test.go @@ -4,11 +4,11 @@ import ( "math/rand" "reflect" "testing" + "time" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lnwire" - - "time" + "github.com/stretchr/testify/assert" ) // TestForwardingLogBasicStorageAndQuery tests that we're able to store and @@ -20,10 +20,11 @@ func TestForwardingLogBasicStorageAndQuery(t *testing.T) { // forwarding event log that we'll be using for the duration of the // test. db, cleanUp, err := MakeTestDB() - defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) } + defer cleanUp() + log := ForwardingLog{ db: db, } @@ -92,10 +93,11 @@ func TestForwardingLogQueryOptions(t *testing.T) { // forwarding event log that we'll be using for the duration of the // test. db, cleanUp, err := MakeTestDB() - defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) } + defer cleanUp() + log := ForwardingLog{ db: db, } @@ -197,10 +199,11 @@ func TestForwardingLogQueryLimit(t *testing.T) { // forwarding event log that we'll be using for the duration of the // test. db, cleanUp, err := MakeTestDB() - defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) } + defer cleanUp() + log := ForwardingLog{ db: db, } @@ -263,3 +266,118 @@ func TestForwardingLogQueryLimit(t *testing.T) { timeSlice.LastIndexOffset) } } + +// TestForwardingLogMakeUniqueTimestamps makes sure the function that creates +// unique timestamps does it job correctly. +func TestForwardingLogMakeUniqueTimestamps(t *testing.T) { + t.Parallel() + + // Create a list of events where some of the timestamps collide. We + // expect no existing timestamp to be overwritten, instead the "gaps" + // between them should be filled. + inputSlice := []ForwardingEvent{ + {Timestamp: time.Unix(0, 1001)}, + {Timestamp: time.Unix(0, 2001)}, + {Timestamp: time.Unix(0, 1001)}, + {Timestamp: time.Unix(0, 1002)}, + {Timestamp: time.Unix(0, 1004)}, + {Timestamp: time.Unix(0, 1004)}, + {Timestamp: time.Unix(0, 1007)}, + {Timestamp: time.Unix(0, 1001)}, + } + expectedSlice := []ForwardingEvent{ + {Timestamp: time.Unix(0, 1001)}, + {Timestamp: time.Unix(0, 1002)}, + {Timestamp: time.Unix(0, 1003)}, + {Timestamp: time.Unix(0, 1004)}, + {Timestamp: time.Unix(0, 1005)}, + {Timestamp: time.Unix(0, 1006)}, + {Timestamp: time.Unix(0, 1007)}, + {Timestamp: time.Unix(0, 2001)}, + } + + makeUniqueTimestamps(inputSlice) + + for idx, in := range inputSlice { + expect := expectedSlice[idx] + assert.Equal( + t, expect.Timestamp.UnixNano(), in.Timestamp.UnixNano(), + ) + } +} + +// TestForwardingLogStoreEvent makes sure forwarding events are stored without +// colliding on duplicate timestamps. +func TestForwardingLogStoreEvent(t *testing.T) { + t.Parallel() + + // First, we'll set up a test database, and use that to instantiate the + // forwarding event log that we'll be using for the duration of the + // test. + db, cleanUp, err := MakeTestDB() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + defer cleanUp() + + log := ForwardingLog{ + db: db, + } + + // We'll create 20 random events, with each event having a timestamp + // with just one nanosecond apart. + numEvents := 20 + events := make([]ForwardingEvent, numEvents) + ts := time.Now().UnixNano() + for i := 0; i < numEvents; i++ { + events[i] = ForwardingEvent{ + Timestamp: time.Unix(0, ts+int64(i)), + IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + AmtIn: lnwire.MilliSatoshi(rand.Int63()), + AmtOut: lnwire.MilliSatoshi(rand.Int63()), + } + } + + // Now that all of our events are constructed, we'll add them to the + // database in a batched manner. + if err := log.AddForwardingEvents(events); err != nil { + t.Fatalf("unable to add events: %v", err) + } + + // Because timestamps are de-duplicated when adding them in a single + // batch before they even hit the DB, we add the same events again but + // in a new batch. They now have to be de-duplicated on the DB level. + if err := log.AddForwardingEvents(events); err != nil { + t.Fatalf("unable to add second batch of events: %v", err) + } + + // With all of our events added, we should be able to query for all + // events with a range of just 40 nanoseconds (2 times 20 events, all + // spaced one nanosecond apart). + eventQuery := ForwardingEventQuery{ + StartTime: time.Unix(0, ts), + EndTime: time.Unix(0, ts+int64(numEvents*2)), + IndexOffset: 0, + NumMaxEvents: uint32(numEvents * 3), + } + timeSlice, err := log.Query(eventQuery) + if err != nil { + t.Fatalf("unable to query for events: %v", err) + } + + // We should get exactly 40 events back. + if len(timeSlice.ForwardingEvents) != numEvents*2 { + t.Fatalf("wrong number of events: expected %v, got %v", + numEvents*2, len(timeSlice.ForwardingEvents)) + } + + // The timestamps should be spaced out evenly and in order. + for i := 0; i < numEvents*2; i++ { + eventTs := timeSlice.ForwardingEvents[i].Timestamp.UnixNano() + if eventTs != ts+int64(i) { + t.Fatalf("unexpected timestamp of event %d: expected "+ + "%d, got %d", i, ts+int64(i), eventTs) + } + } +} From adfd0dc01588da0c74b6792c8c889db5feb0a9d8 Mon Sep 17 00:00:00 2001 From: Candle <50766841+CandleHater@users.noreply.github.com> Date: Wed, 5 Aug 2020 09:58:05 +0000 Subject: [PATCH 029/218] rpc: add missing space to error message This corrects the output of the chain notifier RPC error. It has been displayed as: "chain notifier RPC *isstill* in the process of starting" --- lnrpc/chainrpc/chainnotifier_server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lnrpc/chainrpc/chainnotifier_server.go b/lnrpc/chainrpc/chainnotifier_server.go index afa53c013..08c5330a2 100644 --- a/lnrpc/chainrpc/chainnotifier_server.go +++ b/lnrpc/chainrpc/chainnotifier_server.go @@ -67,7 +67,7 @@ var ( // ErrChainNotifierServerNotActive indicates that the chain notifier hasn't // finished the startup process. - ErrChainNotifierServerNotActive = errors.New("chain notifier RPC is" + + ErrChainNotifierServerNotActive = errors.New("chain notifier RPC is " + "still in the process of starting") ) From ba3c65bfd6c05cf0733aae7f96b75dc93bcdb2fe Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Fri, 17 Jul 2020 14:24:54 +0200 Subject: [PATCH 030/218] invoices: re-format overreaching code lines --- invoices/invoice_expiry_watcher.go | 25 +++++++++++++++---------- invoices/invoice_expiry_watcher_test.go | 7 +++++-- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index f0db08d11..9df6ca745 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -48,8 +48,8 @@ type InvoiceExpiryWatcher struct { // invoice to expire. expiryQueue queue.PriorityQueue - // newInvoices channel is used to wake up the main loop when a new invoices - // is added. + // newInvoices channel is used to wake up the main loop when a new + // invoices is added. newInvoices chan []*invoiceExpiry wg sync.WaitGroup @@ -109,7 +109,8 @@ func (ew *InvoiceExpiryWatcher) prepareInvoice( paymentHash lntypes.Hash, invoice *channeldb.Invoice) *invoiceExpiry { if invoice.State != channeldb.ContractOpen { - log.Debugf("Invoice not added to expiry watcher: %v", paymentHash) + log.Debugf("Invoice not added to expiry watcher: %v", + paymentHash) return nil } @@ -133,10 +134,13 @@ func (ew *InvoiceExpiryWatcher) AddInvoices( invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices)) for _, invoiceWithPaymentHash := range invoices { newInvoiceExpiry := ew.prepareInvoice( - invoiceWithPaymentHash.PaymentHash, &invoiceWithPaymentHash.Invoice, + invoiceWithPaymentHash.PaymentHash, + &invoiceWithPaymentHash.Invoice, ) if newInvoiceExpiry != nil { - invoicesWithExpiry = append(invoicesWithExpiry, newInvoiceExpiry) + invoicesWithExpiry = append( + invoicesWithExpiry, newInvoiceExpiry, + ) } } @@ -160,8 +164,8 @@ func (ew *InvoiceExpiryWatcher) AddInvoice( newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice) if newInvoiceExpiry != nil { - log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v", - paymentHash, newInvoiceExpiry.Expiry) + log.Debugf("Adding invoice '%v' to expiry watcher,"+ + "expiration: %v", paymentHash, newInvoiceExpiry.Expiry) select { case ew.newInvoices <- []*invoiceExpiry{newInvoiceExpiry}: @@ -202,7 +206,8 @@ func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() { if err != nil && err != channeldb.ErrInvoiceAlreadySettled && err != channeldb.ErrInvoiceAlreadyCanceled { - log.Errorf("Unable to cancel invoice: %v", top.PaymentHash) + log.Errorf("Unable to cancel invoice: %v", + top.PaymentHash) } ew.expiryQueue.Pop() @@ -236,8 +241,8 @@ func (ew *InvoiceExpiryWatcher) mainLoop() { continue case invoicesWithExpiry := <-ew.newInvoices: - for _, invoiceWithExpiry := range invoicesWithExpiry { - ew.expiryQueue.Push(invoiceWithExpiry) + for _, invoice := range invoicesWithExpiry { + ew.expiryQueue.Push(invoice) } case <-ew.quit: diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 2aa0f87ba..58d6e2d8d 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -37,7 +37,9 @@ func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time, err := test.watcher.Start(func(paymentHash lntypes.Hash, force bool) error { - test.canceledInvoices = append(test.canceledInvoices, paymentHash) + test.canceledInvoices = append( + test.canceledInvoices, paymentHash, + ) test.wg.Done() return nil }) @@ -70,7 +72,8 @@ func (t *invoiceExpiryWatcherTest) checkExpectations() { // that expired. if len(t.canceledInvoices) != len(t.testData.expiredInvoices) { t.t.Fatalf("expected %v cancellations, got %v", - len(t.testData.expiredInvoices), len(t.canceledInvoices)) + len(t.testData.expiredInvoices), + len(t.canceledInvoices)) } for i := range t.canceledInvoices { From 92f3b0a30c5fba4e9a4e5ee32131f95e20d0684a Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 28 Jul 2020 21:22:23 +0200 Subject: [PATCH 031/218] channeldb+invoices: add ScanInvoices and integrate with InvoiceRegistry This commit adds channeldb.ScanInvoices to scan through all invoices in the database. The new call will also replace the already existing channeldb.FetchAllInvoicesWithPaymentHash call in preparation to collect invoices we'd like to delete and watch for expiry in one scan in later commits. --- channeldb/invoice_test.go | 115 ++++++++---------------- channeldb/invoices.go | 53 ++++------- invoices/invoice_expiry_watcher.go | 9 +- invoices/invoice_expiry_watcher_test.go | 16 +--- invoices/invoiceregistry.go | 45 +++++++--- 5 files changed, 90 insertions(+), 148 deletions(-) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 64e2dbe62..9d5aba364 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -622,9 +622,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } -// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their -// corresponding payment hashes. -func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { +// TestScanInvoices tests that ScanInvoices scans trough all stored invoices +// correctly. +func TestScanInvoices(t *testing.T) { t.Parallel() db, cleanup, err := MakeTestDB() @@ -633,97 +633,54 @@ func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { t.Fatalf("unable to make test db: %v", err) } - // With an empty DB we expect to return no error and an empty list. - empty, err := db.FetchAllInvoicesWithPaymentHash(false) - if err != nil { - t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v", - err) + var invoices map[lntypes.Hash]*Invoice + callCount := 0 + resetCount := 0 + + // reset is used to reset/initialize results and is called once + // upon calling ScanInvoices and when the underlying transaction is + // retried. + reset := func() { + invoices = make(map[lntypes.Hash]*Invoice) + callCount = 0 + resetCount++ + } - if len(empty) != 0 { - t.Fatalf("expected empty list as a result, got: %v", empty) + scanFunc := func(paymentHash lntypes.Hash, invoice *Invoice) error { + invoices[paymentHash] = invoice + callCount++ + + return nil } - states := []ContractState{ - ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, - } + // With an empty DB we expect to not scan any invoices. + require.NoError(t, db.ScanInvoices(scanFunc, reset)) + require.Equal(t, 0, len(invoices)) + require.Equal(t, 0, callCount) + require.Equal(t, 1, resetCount) - numInvoices := len(states) * 2 - testPendingInvoices := make(map[lntypes.Hash]*Invoice) - testAllInvoices := make(map[lntypes.Hash]*Invoice) + numInvoices := 5 + testInvoices := make(map[lntypes.Hash]*Invoice) // Now populate the DB and check if we can get all invoices with their // payment hashes as expected. for i := 1; i <= numInvoices; i++ { invoice, err := randInvoice(lnwire.MilliSatoshi(i)) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } + require.NoError(t, err) - // Set the contract state of the next invoice such that there's an equal - // number for all possbile states. - invoice.State = states[i%len(states)] paymentHash := invoice.Terms.PaymentPreimage.Hash() + testInvoices[paymentHash] = invoice - if invoice.IsPending() { - testPendingInvoices[paymentHash] = invoice - } - - testAllInvoices[paymentHash] = invoice - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice: %v", err) - } - } - - pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true) - if err != nil { - t.Fatalf("can't fetch invoices with payment hash: %v", err) - } - - if len(testPendingInvoices) != len(pendingInvoices) { - t.Fatalf("expected %v pending invoices, got: %v", - len(testPendingInvoices), len(pendingInvoices)) - } - - allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false) - if err != nil { - t.Fatalf("can't fetch invoices with payment hash: %v", err) - } - - if len(testAllInvoices) != len(allInvoices) { - t.Fatalf("expected %v invoices, got: %v", - len(testAllInvoices), len(allInvoices)) - } - - for i := range pendingInvoices { - expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash] - if !ok { - t.Fatalf("coulnd't find invoice with hash: %v", - pendingInvoices[i].PaymentHash) - } - - // Zero out add index to not confuse require.Equal. - pendingInvoices[i].Invoice.AddIndex = 0 - expected.AddIndex = 0 - - require.Equal(t, *expected, pendingInvoices[i].Invoice) - } - - for i := range allInvoices { - expected, ok := testAllInvoices[allInvoices[i].PaymentHash] - if !ok { - t.Fatalf("coulnd't find invoice with hash: %v", - allInvoices[i].PaymentHash) - } - - // Zero out add index to not confuse require.Equal. - allInvoices[i].Invoice.AddIndex = 0 - expected.AddIndex = 0 - - require.Equal(t, *expected, allInvoices[i].Invoice) + _, err = db.AddInvoice(invoice, paymentHash) + require.NoError(t, err) } + resetCount = 0 + require.NoError(t, db.ScanInvoices(scanFunc, reset)) + require.Equal(t, numInvoices, callCount) + require.Equal(t, testInvoices, invoices) + require.Equal(t, 1, resetCount) } // TestDuplicateSettleInvoice tests that if we add a new invoice and settle it diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 436f194e1..a7ece3c30 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -723,28 +723,21 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.RBucket, } } -// InvoiceWithPaymentHash is used to store an invoice and its corresponding -// payment hash. This struct is only used to store results of -// ChannelDB.FetchAllInvoicesWithPaymentHash() call. -type InvoiceWithPaymentHash struct { - // Invoice holds the invoice as selected from the invoices bucket. - Invoice Invoice +// ScanInvoices scans trough all invoices and calls the passed scanFunc for +// for each invoice with its respective payment hash. Additionally a reset() +// closure is passed which is used to reset/initialize partial results and also +// to signal if the kvdb.View transaction has been retried. +func (d *DB) ScanInvoices( + scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error { - // PaymentHash is the payment hash for the Invoice. - PaymentHash lntypes.Hash -} + return kvdb.View(d, func(tx kvdb.RTx) error { + // Reset partial results. As transaction commit success is not + // guaranteed when using etcd, we need to be prepared to redo + // the whole view transaction. In order to be able to do that + // we need a way to reset existing results. This is also done + // upon first run for initialization. + reset() -// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes -// currently stored within the database. If the pendingOnly param is true, then -// only open or accepted invoices and their payment hashes will be returned, -// skipping all invoices that are fully settled or canceled. Note that the -// returned array is not ordered by add index. -func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( - []InvoiceWithPaymentHash, error) { - - var result []InvoiceWithPaymentHash - - err := kvdb.View(d, func(tx kvdb.RTx) error { invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { return ErrNoInvoicesCreated @@ -775,26 +768,12 @@ func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( return err } - if pendingOnly && !invoice.IsPending() { - return nil - } + var paymentHash lntypes.Hash + copy(paymentHash[:], k) - invoiceWithPaymentHash := InvoiceWithPaymentHash{ - Invoice: invoice, - } - - copy(invoiceWithPaymentHash.PaymentHash[:], k) - result = append(result, invoiceWithPaymentHash) - - return nil + return scanFunc(paymentHash, &invoice) }) }) - - if err != nil { - return nil, err - } - - return result, nil } // InvoiceQuery represents a query to the invoice database. The query allows a diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index 9df6ca745..a46f27f5a 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -129,14 +129,11 @@ func (ew *InvoiceExpiryWatcher) prepareInvoice( // AddInvoices adds multiple invoices to the InvoiceExpiryWatcher. func (ew *InvoiceExpiryWatcher) AddInvoices( - invoices []channeldb.InvoiceWithPaymentHash) { + invoices map[lntypes.Hash]*channeldb.Invoice) { invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices)) - for _, invoiceWithPaymentHash := range invoices { - newInvoiceExpiry := ew.prepareInvoice( - invoiceWithPaymentHash.PaymentHash, - &invoiceWithPaymentHash.Invoice, - ) + for paymentHash, invoice := range invoices { + newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice) if newInvoiceExpiry != nil { invoicesWithExpiry = append( invoicesWithExpiry, newInvoiceExpiry, diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 58d6e2d8d..67ea25256 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -158,24 +158,14 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) - var invoices []channeldb.InvoiceWithPaymentHash + invoices := make(map[lntypes.Hash]*channeldb.Invoice) for hash, invoice := range test.testData.expiredInvoices { - invoices = append(invoices, - channeldb.InvoiceWithPaymentHash{ - Invoice: *invoice, - PaymentHash: hash, - }, - ) + invoices[hash] = invoice } for hash, invoice := range test.testData.pendingInvoices { - invoices = append(invoices, - channeldb.InvoiceWithPaymentHash{ - Invoice: *invoice, - PaymentHash: hash, - }, - ) + invoices[hash] = invoice } test.watcher.AddInvoices(invoices) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 84d646178..66043ff08 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -147,21 +147,39 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, } } -// populateExpiryWatcher fetches all active invoices and their corresponding -// payment hashes from ChannelDB and adds them to the expiry watcher. -func (i *InvoiceRegistry) populateExpiryWatcher() error { - pendingOnly := true - pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly) - if err != nil && err != channeldb.ErrNoInvoicesCreated { - log.Errorf( - "Error while prefetching active invoices from the database: %v", err, - ) +// scanInvoicesOnStart will scan all invoices on start and add active invoices +// to the invoice expiry watcher. +func (i *InvoiceRegistry) scanInvoicesOnStart() error { + var pending map[lntypes.Hash]*channeldb.Invoice + + reset := func() { + // Zero out our results on start and if the scan is ever run + // more than once. This latter case can happen if the kvdb + // layer needs to retry the View transaction underneath (eg. + // using the etcd driver, where all transactions are allowed + // to retry for serializability). + pending = make(map[lntypes.Hash]*channeldb.Invoice) + } + + scanFunc := func( + paymentHash lntypes.Hash, invoice *channeldb.Invoice) error { + + if invoice.IsPending() { + pending[paymentHash] = invoice + } + + return nil + } + + err := i.cdb.ScanInvoices(scanFunc, reset) + if err != nil { return err } log.Debugf("Adding %d pending invoices to the expiry watcher", - len(pendingInvoices)) - i.expiryWatcher.AddInvoices(pendingInvoices) + len(pending)) + i.expiryWatcher.AddInvoices(pending) + return nil } @@ -178,8 +196,9 @@ func (i *InvoiceRegistry) Start() error { i.wg.Add(1) go i.invoiceEventLoop() - // Now prefetch all pending invoices to the expiry watcher. - err = i.populateExpiryWatcher() + // Now scan all pending and removable invoices to the expiry watcher or + // delete them. + err = i.scanInvoicesOnStart() if err != nil { i.Stop() return err From a5778c4673fb7152a9d373d0c3ef6afd8e0519c4 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 28 Jul 2020 23:14:24 +0200 Subject: [PATCH 032/218] channeldb: add DeleteInvoices call This commit extends channeldb with the DeleteInvoices call which is when passed a slice of delete references will attempt to delete the invoices pointed to by the references and also clean up all our invoice indexes. --- channeldb/invoice_test.go | 93 +++++++++++++++++++++++++++ channeldb/invoices.go | 131 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 9d5aba364..bb118f715 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1151,3 +1151,96 @@ func TestInvoiceRef(t *testing.T) { require.Equal(t, payHash, refByHashAndAddr.PayHash()) require.Equal(t, &payAddr, refByHashAndAddr.PayAddr()) } + +// TestDeleteInvoices tests that deleting a list of invoices will succeed +// if all delete references are valid, or will fail otherwise. +func TestDeleteInvoices(t *testing.T) { + t.Parallel() + + db, cleanup, err := MakeTestDB() + defer cleanup() + require.NoError(t, err, "unable to make test db") + + // Add some invoices to the test db. + numInvoices := 3 + invoicesToDelete := make([]InvoiceDeleteRef, numInvoices) + + for i := 0; i < numInvoices; i++ { + invoice, err := randInvoice(lnwire.MilliSatoshi(i + 1)) + require.NoError(t, err) + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + addIndex, err := db.AddInvoice(invoice, paymentHash) + require.NoError(t, err) + + // Settle the second invoice. + if i == 1 { + invoice, err = db.UpdateInvoice( + InvoiceRefByHash(paymentHash), + getUpdateInvoice(invoice.Terms.Value), + ) + require.NoError(t, err, "unable to settle invoice") + } + + // store the delete ref for later. + invoicesToDelete[i] = InvoiceDeleteRef{ + PayHash: paymentHash, + PayAddr: &invoice.Terms.PaymentAddr, + AddIndex: addIndex, + SettleIndex: invoice.SettleIndex, + } + } + + // assertInvoiceCount asserts that the number of invoices equals + // to the passed count. + assertInvoiceCount := func(count int) { + // Query to collect all invoices. + query := InvoiceQuery{ + IndexOffset: 0, + NumMaxInvoices: math.MaxUint64, + } + + // Check that we really have 3 invoices. + response, err := db.QueryInvoices(query) + require.NoError(t, err) + require.Equal(t, count, len(response.Invoices)) + } + + // XOR one byte of one of the references' hash and attempt to delete. + invoicesToDelete[0].PayHash[2] ^= 3 + require.Error(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(3) + + // Restore the hash. + invoicesToDelete[0].PayHash[2] ^= 3 + + // XOR one byte of one of the references' payment address and attempt + // to delete. + invoicesToDelete[1].PayAddr[5] ^= 7 + require.Error(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(3) + + // Restore the payment address. + invoicesToDelete[1].PayAddr[5] ^= 7 + + // XOR the second invoice's payment settle index as it is settled, and + // attempt to delete. + invoicesToDelete[1].SettleIndex ^= 11 + require.Error(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(3) + + // Restore the settle index. + invoicesToDelete[1].SettleIndex ^= 11 + + // XOR the add index for one of the references and attempt to delete. + invoicesToDelete[2].AddIndex ^= 13 + require.Error(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(3) + + // Restore the add index. + invoicesToDelete[2].AddIndex ^= 13 + + // Delete should succeed with all the valid references. + require.NoError(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(0) +} diff --git a/channeldb/invoices.go b/channeldb/invoices.go index a7ece3c30..5f7b64623 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -1740,3 +1740,134 @@ func setSettleMetaFields(settleIndex kvdb.RwBucket, invoiceNum []byte, return nil } + +// InvoiceDeleteRef holds a refererence to an invoice to be deleted. +type InvoiceDeleteRef struct { + // PayHash is the payment hash of the target invoice. All invoices are + // currently indexed by payment hash. + PayHash lntypes.Hash + + // PayAddr is the payment addr of the target invoice. Newer invoices + // (0.11 and up) are indexed by payment address in addition to payment + // hash, but pre 0.8 invoices do not have one at all. + PayAddr *[32]byte + + // AddIndex is the add index of the invoice. + AddIndex uint64 + + // SettleIndex is the settle index of the invoice. + SettleIndex uint64 +} + +// DeleteInvoice attempts to delete the passed invoices from the database in +// one transaction. The passed delete references hold all keys required to +// delete the invoices without also needing to deserialze them. +func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error { + err := kvdb.Update(d, func(tx kvdb.RwTx) error { + invoices := tx.ReadWriteBucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + + invoiceIndex := invoices.NestedReadWriteBucket( + invoiceIndexBucket, + ) + if invoiceIndex == nil { + return ErrNoInvoicesCreated + } + + invoiceAddIndex := invoices.NestedReadWriteBucket( + addIndexBucket, + ) + if invoiceAddIndex == nil { + return ErrNoInvoicesCreated + } + // settleIndex can be nil, as the bucket is created lazily + // when the first invoice is settled. + settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket) + + payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket) + + for _, ref := range invoicesToDelete { + // Fetch the invoice key for using it to check for + // consistency and also to delete from the invoice index. + invoiceKey := invoiceIndex.Get(ref.PayHash[:]) + if invoiceKey == nil { + return ErrInvoiceNotFound + } + + err := invoiceIndex.Delete(ref.PayHash[:]) + if err != nil { + return err + } + + // Delete payment address index reference if there's a + // valid payment address passed. + if ref.PayAddr != nil { + // To ensure consistency check that the already + // fetched invoice key matches the one in the + // payment address index. + key := payAddrIndex.Get(ref.PayAddr[:]) + if !bytes.Equal(key, invoiceKey) { + return fmt.Errorf("unknown invoice") + } + + // Delete from the payment address index. + err := payAddrIndex.Delete(ref.PayAddr[:]) + if err != nil { + return err + } + } + + var addIndexKey [8]byte + byteOrder.PutUint64(addIndexKey[:], ref.AddIndex) + + // To ensure consistency check that the key stored in + // the add index also matches the previously fetched + // invoice key. + key := invoiceAddIndex.Get(addIndexKey[:]) + if !bytes.Equal(key, invoiceKey) { + return fmt.Errorf("unknown invoice") + } + + // Remove from the add index. + err = invoiceAddIndex.Delete(addIndexKey[:]) + if err != nil { + return err + } + + // Remove from the settle index if available and + // if the invoice is settled. + if settleIndex != nil && ref.SettleIndex > 0 { + var settleIndexKey [8]byte + byteOrder.PutUint64( + settleIndexKey[:], ref.SettleIndex, + ) + + // To ensure consistency check that the already + // fetched invoice key matches the one in the + // settle index + key := settleIndex.Get(settleIndexKey[:]) + if !bytes.Equal(key, invoiceKey) { + return fmt.Errorf("unknown invoice") + } + + err = settleIndex.Delete(settleIndexKey[:]) + if err != nil { + return err + } + } + + // Finally remove the serialized invoice from the + // invoice bucket. + err = invoices.Delete(invoiceKey) + if err != nil { + return err + } + } + + return nil + }) + + return err +} From 0ea763d83c1e50a3583cf5f5382ef5b931fd8be5 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 28 Jul 2020 23:39:11 +0200 Subject: [PATCH 033/218] invoices: attempt to delete old invoices upon start --- invoices/invoiceregistry.go | 30 +++++++++++-- invoices/invoiceregistry_test.go | 75 ++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 66043ff08..e53d7da8f 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -148,9 +148,13 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, } // scanInvoicesOnStart will scan all invoices on start and add active invoices -// to the invoice expiry watcher. +// to the invoice expirt watcher while also attempting to delete all canceled +// invoices. func (i *InvoiceRegistry) scanInvoicesOnStart() error { - var pending map[lntypes.Hash]*channeldb.Invoice + var ( + pending map[lntypes.Hash]*channeldb.Invoice + removable []channeldb.InvoiceDeleteRef + ) reset := func() { // Zero out our results on start and if the scan is ever run @@ -159,6 +163,7 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { // using the etcd driver, where all transactions are allowed // to retry for serializability). pending = make(map[lntypes.Hash]*channeldb.Invoice) + removable = make([]channeldb.InvoiceDeleteRef, 0) } scanFunc := func( @@ -166,8 +171,23 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { if invoice.IsPending() { pending[paymentHash] = invoice - } + } else if invoice.State == channeldb.ContractCanceled { + // Consider invoice for removal if it is already + // canceled. Invoices that are expired but not yet + // canceled, will be queued up for cancellation after + // startup and will be deleted afterwards. + ref := channeldb.InvoiceDeleteRef{ + PayHash: paymentHash, + AddIndex: invoice.AddIndex, + SettleIndex: invoice.SettleIndex, + } + if invoice.Terms.PaymentAddr != channeldb.BlankPayAddr { + ref.PayAddr = &invoice.Terms.PaymentAddr + } + + removable = append(removable, ref) + } return nil } @@ -180,6 +200,10 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { len(pending)) i.expiryWatcher.AddInvoices(pending) + if err := i.cdb.DeleteInvoice(removable); err != nil { + log.Warnf("Deleting old invoices failed: %v", err) + } + return nil } diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index c77b38ed5..0da260a25 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -1,6 +1,7 @@ package invoices import ( + "math" "testing" "time" @@ -1077,3 +1078,77 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { } } } + +// TestOldInvoiceRemovalOnStart tests that we'll attempt to remove old canceled +// invoices upon start while keeping all settled ones. +func TestOldInvoiceRemovalOnStart(t *testing.T) { + t.Parallel() + + testClock := clock.NewTestClock(testTime) + cdb, cleanup, err := newTestChannelDB(testClock) + defer cleanup() + + require.NoError(t, err) + + cfg := RegistryConfig{ + FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: testClock, + } + + expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock) + registry := NewRegistry(cdb, expiryWatcher, &cfg) + + // First prefill the Channel DB with some pre-existing expired invoices. + const numExpired = 5 + const numPending = 0 + existingInvoices := generateInvoiceExpiryTestData( + t, testTime, 0, numExpired, numPending, + ) + + i := 0 + for paymentHash, invoice := range existingInvoices.expiredInvoices { + // Mark half of the invoices as settled, the other hald as + // canceled. + if i%2 == 0 { + invoice.State = channeldb.ContractSettled + } else { + invoice.State = channeldb.ContractCanceled + } + + _, err := cdb.AddInvoice(invoice, paymentHash) + require.NoError(t, err) + i++ + } + + // Collect all settled invoices for our expectation set. + var expected []channeldb.Invoice + + // Perform a scan query to collect all invoices. + query := channeldb.InvoiceQuery{ + IndexOffset: 0, + NumMaxInvoices: math.MaxUint64, + } + + response, err := cdb.QueryInvoices(query) + require.NoError(t, err) + + // Save all settled invoices for our expectation set. + for _, invoice := range response.Invoices { + if invoice.State == channeldb.ContractSettled { + expected = append(expected, invoice) + } + } + + // Start the registry which should collect and delete all canceled + // invoices upon start. + err = registry.Start() + require.NoError(t, err, "cannot start the registry") + + // Perform a scan query to collect all invoices. + response, err = cdb.QueryInvoices(query) + require.NoError(t, err) + + // Check that we really only kept the settled invoices after the + // registry start. + require.Equal(t, expected, response.Invoices) +} From a0d7877d9a327dbd74ac9426551f2ab30ae41753 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 28 Jul 2020 23:57:51 +0200 Subject: [PATCH 034/218] multi: make canceled invoice garbage collection configurable This commit extends the application config with a flag to control canceled invoice garbage collection upon startup. --- config.go | 2 ++ invoices/invoiceregistry.go | 8 +++++++- invoices/invoiceregistry_test.go | 5 +++-- server.go | 11 ++++++----- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/config.go b/config.go index 3a52ca964..89f3fa89f 100644 --- a/config.go +++ b/config.go @@ -242,6 +242,8 @@ type Config struct { KeysendHoldTime time.Duration `long:"keysend-hold-time" description:"If non-zero, keysend payments are accepted but not immediately settled. If the payment isn't settled manually after the specified time, it is canceled automatically. [experimental]"` + GcCanceledInvoicesOnStartup bool `long:"gc-canceled-invoices-on-startup" description:"If true, we'll attempt to garbage collect canceled invoices upon start."` + Routing *routing.Conf `group:"routing" namespace:"routing"` Workers *lncfg.Workers `group:"workers" namespace:"workers"` diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index e53d7da8f..bc40168a8 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -57,6 +57,10 @@ type RegistryConfig struct { // send payments. AcceptKeySend bool + // GcCanceledInvoicesOnStartup if set, we'll attempt to garbage collect + // all canceled invoices upon start. + GcCanceledInvoicesOnStartup bool + // KeysendHoldTime indicates for how long we want to accept and hold // spontaneous keysend payments. KeysendHoldTime time.Duration @@ -171,7 +175,9 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { if invoice.IsPending() { pending[paymentHash] = invoice - } else if invoice.State == channeldb.ContractCanceled { + } else if i.cfg.GcCanceledInvoicesOnStartup && + invoice.State == channeldb.ContractCanceled { + // Consider invoice for removal if it is already // canceled. Invoices that are expired but not yet // canceled, will be queued up for cancellation after diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 0da260a25..6e5ca2212 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -1091,8 +1091,9 @@ func TestOldInvoiceRemovalOnStart(t *testing.T) { require.NoError(t, err) cfg := RegistryConfig{ - FinalCltvRejectDelta: testFinalCltvRejectDelta, - Clock: testClock, + FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: testClock, + GcCanceledInvoicesOnStartup: true, } expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock) diff --git a/server.go b/server.go index fe711cec0..d3de88ef9 100644 --- a/server.go +++ b/server.go @@ -396,11 +396,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr, chanDB *channeldb.DB, } registryConfig := invoices.RegistryConfig{ - FinalCltvRejectDelta: lncfg.DefaultFinalCltvRejectDelta, - HtlcHoldDuration: invoices.DefaultHtlcHoldDuration, - Clock: clock.NewDefaultClock(), - AcceptKeySend: cfg.AcceptKeySend, - KeysendHoldTime: cfg.KeysendHoldTime, + FinalCltvRejectDelta: lncfg.DefaultFinalCltvRejectDelta, + HtlcHoldDuration: invoices.DefaultHtlcHoldDuration, + Clock: clock.NewDefaultClock(), + AcceptKeySend: cfg.AcceptKeySend, + GcCanceledInvoicesOnStartup: cfg.GcCanceledInvoicesOnStartup, + KeysendHoldTime: cfg.KeysendHoldTime, } s := &server{ From 2aa680ede2ac2552ec466a00fadc28e4df87a3ab Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 4 Aug 2020 18:28:15 +0200 Subject: [PATCH 035/218] invoices: optionally garbage collect invoices on the fly This commit extends invoice garbage collection to also remove invoices which are canceled when LND is already up and running. When the option GcCanceledInvoicesOnTheFly is false (default) then invoices are kept and the behavior is unchanged. --- config.go | 2 ++ invoices/invoiceregistry.go | 30 ++++++++++++++++++ invoices/invoiceregistry_test.go | 53 +++++++++++++++++++++++++++----- server.go | 1 + 4 files changed, 78 insertions(+), 8 deletions(-) diff --git a/config.go b/config.go index 89f3fa89f..511e012a1 100644 --- a/config.go +++ b/config.go @@ -244,6 +244,8 @@ type Config struct { GcCanceledInvoicesOnStartup bool `long:"gc-canceled-invoices-on-startup" description:"If true, we'll attempt to garbage collect canceled invoices upon start."` + GcCanceledInvoicesOnTheFly bool `long:"gc-canceled-invoices-on-the-fly" description:"If true, we'll delete newly canceled invoices on the fly."` + Routing *routing.Conf `group:"routing" namespace:"routing"` Workers *lncfg.Workers `group:"workers" namespace:"workers"` diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index bc40168a8..c827f1445 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -61,6 +61,10 @@ type RegistryConfig struct { // all canceled invoices upon start. GcCanceledInvoicesOnStartup bool + // GcCanceledInvoicesOnTheFly if set, we'll garbage collect all newly + // canceled invoices on the fly. + GcCanceledInvoicesOnTheFly bool + // KeysendHoldTime indicates for how long we want to accept and hold // spontaneous keysend payments. KeysendHoldTime time.Duration @@ -1124,6 +1128,32 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, } i.notifyClients(payHash, invoice, channeldb.ContractCanceled) + // Attempt to also delete the invoice if requested through the registry + // config. + if i.cfg.GcCanceledInvoicesOnTheFly { + // Assemble the delete reference and attempt to delete through + // the invocice from the DB. + deleteRef := channeldb.InvoiceDeleteRef{ + PayHash: payHash, + AddIndex: invoice.AddIndex, + SettleIndex: invoice.SettleIndex, + } + if invoice.Terms.PaymentAddr != channeldb.BlankPayAddr { + deleteRef.PayAddr = &invoice.Terms.PaymentAddr + } + + err = i.cdb.DeleteInvoice( + []channeldb.InvoiceDeleteRef{deleteRef}, + ) + // If by any chance deletion failed, then log it instead of + // returning the error, as the invoice itsels has already been + // canceled. + if err != nil { + log.Warnf("Invoice%v could not be deleted: %v", + ref, err) + } + } + return nil } diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 6e5ca2212..cb916aeab 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -220,11 +220,14 @@ func TestSettleInvoice(t *testing.T) { } } -// TestCancelInvoice tests cancelation of an invoice and related notifications. -func TestCancelInvoice(t *testing.T) { +func testCancelInvoice(t *testing.T, gc bool) { ctx := newTestContext(t) defer ctx.cleanup() + // If set to true, then also delete the invoice from the DB after + // cancellation. + ctx.registry.cfg.GcCanceledInvoicesOnTheFly = gc + allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) assert.Nil(t, err) defer allSubscriptions.Cancel() @@ -299,13 +302,26 @@ func TestCancelInvoice(t *testing.T) { t.Fatal("no update received") } + if gc { + // Check that the invoice has been deleted from the db. + _, err = ctx.cdb.LookupInvoice( + channeldb.InvoiceRefByHash(testInvoicePaymentHash), + ) + require.Error(t, err) + } + // We expect no cancel notification to be sent to all invoice // subscribers (backwards compatibility). - // Try to cancel again. + // Try to cancel again. Expect that we report ErrInvoiceNotFound if the + // invoice has been garbage collected (since the invoice has been + // deleted when it was canceled), and no error otherwise. err = ctx.registry.CancelInvoice(testInvoicePaymentHash) - if err != nil { - t.Fatal("expected cancelation of a canceled invoice to succeed") + + if gc { + require.Error(t, err, channeldb.ErrInvoiceNotFound) + } else { + require.NoError(t, err) } // Notify arrival of a new htlc paying to this invoice. This should @@ -327,12 +343,33 @@ func TestCancelInvoice(t *testing.T) { t.Fatalf("expected acceptHeight %v, but got %v", testCurrentHeight, failResolution.AcceptHeight) } - if failResolution.Outcome != ResultInvoiceAlreadyCanceled { - t.Fatalf("expected expiry too soon, got: %v", - failResolution.Outcome) + + // If the invoice has been deleted (or not present) then we expect the + // outcome to be ResultInvoiceNotFound instead of when the invoice is + // in our database in which case we expect ResultInvoiceAlreadyCanceled. + if gc { + require.Equal(t, failResolution.Outcome, ResultInvoiceNotFound) + } else { + require.Equal(t, + failResolution.Outcome, + ResultInvoiceAlreadyCanceled, + ) } } +// TestCancelInvoice tests cancelation of an invoice and related notifications. +func TestCancelInvoice(t *testing.T) { + // Test cancellation both with garbage collection (meaning that canceled + // invoice will be deleted) and without (meain it'll be kept). + t.Run("garbage collect", func(t *testing.T) { + testCancelInvoice(t, true) + }) + + t.Run("no garbage collect", func(t *testing.T) { + testCancelInvoice(t, false) + }) +} + // TestSettleHoldInvoice tests settling of a hold invoice and related // notifications. func TestSettleHoldInvoice(t *testing.T) { diff --git a/server.go b/server.go index d3de88ef9..0f2b59d0a 100644 --- a/server.go +++ b/server.go @@ -401,6 +401,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, chanDB *channeldb.DB, Clock: clock.NewDefaultClock(), AcceptKeySend: cfg.AcceptKeySend, GcCanceledInvoicesOnStartup: cfg.GcCanceledInvoicesOnStartup, + GcCanceledInvoicesOnTheFly: cfg.GcCanceledInvoicesOnTheFly, KeysendHoldTime: cfg.KeysendHoldTime, } From c4f739ac8af06663e6ed3322e3142cce7142dde8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Habov=C5=A1tiak?= Date: Wed, 5 Aug 2020 07:29:15 +0200 Subject: [PATCH 036/218] config: Updated deprecation message of noseedbackup According to the recent discussion `noseedbackup` is not deprecated. This change clarifies the message about deprecation. Also fixes a typo. Closes #4499 --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index 3a52ca964..38472c9e1 100644 --- a/config.go +++ b/config.go @@ -205,7 +205,7 @@ type Config struct { NoNetBootstrap bool `long:"nobootstrap" description:"If true, then automatic network bootstrapping will not be attempted."` - NoSeedBackup bool `long:"noseedbackup" description:"If true, NO SEED WILL BE EXPOSED AND THE WALLET WILL BE ENCRYPTED USING THE DEFAULT PASSPHRASE -- EVER. THIS FLAG IS ONLY FOR TESTING AND IS BEING DEPRECATED."` + NoSeedBackup bool `long:"noseedbackup" description:"If true, NO SEED WILL BE EXPOSED -- EVER, AND THE WALLET WILL BE ENCRYPTED USING THE DEFAULT PASSPHRASE. THIS FLAG IS ONLY FOR TESTING AND SHOULD NEVER BE USED ON MAINNET."` PaymentsExpirationGracePeriod time.Duration `long:"payments-expiration-grace-period" description:"A period to wait before force closing channels with outgoing htlcs that have timed-out and are a result of this node initiated payments."` TrickleDelay int `long:"trickledelay" description:"Time in milliseconds between each release of announcements to the network"` From c5c28564e9ae97a41262c0f3a6df959691687ed0 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 6 Aug 2020 12:07:07 +0200 Subject: [PATCH 037/218] lnrpc: add macaroon workaround for WebSockets in browsers For security reasons, browsers are limited in the header fields they can send when opening a WebSocket connection. Specifically, the macaroon cannot be sent in the Grpc-Metadata-Macaroon header field as that would be possible for normal REST requests. Instead we only have the special field "Sec-Websocket-Protocol" that can be used to transport custom data. We allow the macaroon to be sent there and transform it into a proper header field for the target request. --- lnrpc/websocket_proxy.go | 63 +++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/lnrpc/websocket_proxy.go b/lnrpc/websocket_proxy.go index 921b21710..3cb701be0 100644 --- a/lnrpc/websocket_proxy.go +++ b/lnrpc/websocket_proxy.go @@ -21,6 +21,16 @@ const ( // This is necessary because the WebSocket API specifies that a // handshake request must always be done through a GET request. MethodOverrideParam = "method" + + // HeaderWebSocketProtocol is the name of the WebSocket protocol + // exchange header field that we use to transport additional header + // fields. + HeaderWebSocketProtocol = "Sec-Websocket-Protocol" + + // WebSocketProtocolDelimiter is the delimiter we use between the + // additional header field and its value. We use the plus symbol because + // the default delimiters aren't allowed in the protocol names. + WebSocketProtocolDelimiter = "+" ) var ( @@ -32,6 +42,13 @@ var ( "Referer": true, "Grpc-Metadata-Macaroon": true, } + + // defaultProtocolsToAllow are additional header fields that we allow + // to be transported inside of the Sec-Websocket-Protocol field to be + // forwarded to the backend. + defaultProtocolsToAllow = map[string]bool{ + "Grpc-Metadata-Macaroon": true, + } ) // NewWebSocketProxy attempts to expose the underlying handler as a response- @@ -101,13 +118,13 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, p.logger.Errorf("WS: error preparing request:", err) return } - for header := range r.Header { - headerName := textproto.CanonicalMIMEHeaderKey(header) - forward, ok := defaultHeadersToForward[headerName] - if ok && forward { - request.Header.Set(headerName, r.Header.Get(header)) - } - } + + // Allow certain headers to be forwarded, either from source headers + // or the special Sec-Websocket-Protocol header field. + forwardHeaders(r.Header, request.Header) + + // Also allow the target request method to be overwritten, as all + // WebSocket establishment calls MUST be GET requests. if m := r.URL.Query().Get(MethodOverrideParam); m != "" { request.Method = m } @@ -182,6 +199,38 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, } } +// forwardHeaders forwards certain allowed header fields from the source request +// to the target request. Because browsers are limited in what header fields +// they can send on the WebSocket setup call, we also allow additional fields to +// be transported in the special Sec-Websocket-Protocol field. +func forwardHeaders(source, target http.Header) { + // Forward allowed header fields directly. + for header := range source { + headerName := textproto.CanonicalMIMEHeaderKey(header) + forward, ok := defaultHeadersToForward[headerName] + if ok && forward { + target.Set(headerName, source.Get(header)) + } + } + + // Browser aren't allowed to set custom header fields on WebSocket + // requests. We need to allow them to submit the macaroon as a WS + // protocol, which is the only allowed header. Set any "protocols" we + // declare valid as header fields on the forwarded request. + protocol := source.Get(HeaderWebSocketProtocol) + for key := range defaultProtocolsToAllow { + if strings.HasPrefix(protocol, key) { + // The format is "+". We know the + // protocol string starts with the name so we only need + // to set the value. + values := strings.Split( + protocol, WebSocketProtocolDelimiter, + ) + target.Set(key, values[1]) + } + } +} + // newRequestForwardingReader creates a new request forwarding pipe. func newRequestForwardingReader() *requestForwardingReader { r, w := io.Pipe() From c7cb2c0a78cfbe2188af109cbb6cb8c322122352 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 6 Aug 2020 12:09:15 +0200 Subject: [PATCH 038/218] docs: describe how to use WebSockets with the REST API We add a new document that shows two examples of how to use the WebSocket REST API with JavaScript. --- docs/rest/websockets.md | 99 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 docs/rest/websockets.md diff --git a/docs/rest/websockets.md b/docs/rest/websockets.md new file mode 100644 index 000000000..705a4c731 --- /dev/null +++ b/docs/rest/websockets.md @@ -0,0 +1,99 @@ +# WebSockets with `lnd`'s REST API + +This document describes how streaming response REST calls can be used correctly +by making use of the WebSocket API. + +As an example, we are going to write a simple JavaScript program that subscribes +to `lnd`'s +[block notification RPC](https://api.lightning.community/#v2-chainnotifier-register-blocks). + +The WebSocket will be kept open as long as `lnd` runs and JavaScript program +isn't stopped. + +## Browser environment + +When using WebSockets in a browser, there are certain security limitations of +what header fields are allowed to be sent. Therefore, the macaroon cannot just +be added as a `Grpc-Metadata-Macaroon` header field as it would work with normal +REST calls. The browser will just ignore that header field and not send it. + +Instead we have added a workaround in `lnd`'s WebSocket proxy that allows +sending the macaroon as a WebSocket "protocol": + +```javascript +const host = 'localhost:8080'; // The default REST port of lnd, can be overwritten with --restlisten=ip:port +const macaroon = '0201036c6e6402eb01030a10625e7e60fd00f5a6f9cd53f33fc82a...'; // The hex encoded macaroon to send +const initialRequest = { // The initial request to send (see API docs for each RPC). + hash: "xlkMdV382uNPskw6eEjDGFMQHxHNnZZgL47aVDSwiRQ=", // Just some example to show that all `byte` fields always have to be base64 encoded in the REST API. + height: 144, +} + +// The protocol is our workaround for sending the macaroon because custom header +// fields aren't allowed to be sent by the browser when opening a WebSocket. +const protocolString = 'Grpc-Metadata-Macaroon+' + macaroon; + +// Let's now connect the web socket. Notice that all WebSocket open calls are +// always GET requests. If the RPC expects a call to be POST or DELETE (see API +// docs to find out), the query parameter "method" can be set to overwrite. +const wsUrl = 'wss://' + host + '/v2/chainnotifier/register/blocks?method=POST'; +let ws = new WebSocket(wsUrl, protocolString); +ws.onopen = function (event) { + // After the WS connection is establishes, lnd expects the client to send the + // initial message. If an RPC doesn't have any request parameters, an empty + // JSON object has to be sent as a string, for example: ws.send('{}') + ws.send(JSON.stringify(initialRequest)); +} +ws.onmessage = function (event) { + // We received a new message. + console.log(event); + + // The data we're really interested in is in data and is always a string + // that needs to be parsed as JSON and always contains a "result" field: + console.log("Payload: "); + console.log(JSON.parse(event.data).result); +} +ws.onerror = function (event) { + // An error occured, let's log it to the console. + console.log(event); +} +``` + +## Node.js environment + +With Node.js it is a bit easier to use the streaming response APIs because we +can set the macaroon header field directly. This is the example from the API +docs: + +```javascript +// -------------------------- +// Example with websockets: +// -------------------------- +const WebSocket = require('ws'); +const fs = require('fs'); +const macaroon = fs.readFileSync('LND_DIR/data/chain/bitcoin/simnet/admin.macaroon').toString('hex'); +let ws = new WebSocket('wss://localhost:8080/v2/chainnotifier/register/blocks?method=POST', { + // Work-around for self-signed certificates. + rejectUnauthorized: false, + headers: { + 'Grpc-Metadata-Macaroon': macaroon, + }, +}); +let requestBody = { + hash: "", + height: "", +} +ws.on('open', function() { + ws.send(JSON.stringify(requestBody)); +}); +ws.on('error', function(err) { + console.log('Error: ' + err); +}); +ws.on('message', function(body) { + console.log(body); +}); +// Console output (repeated for every message in the stream): +// { +// "hash": , +// "height": , +// } +``` From af8ffc9764fd94a89f78aa4c8cd66e947adde471 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 6 Aug 2020 12:11:20 +0200 Subject: [PATCH 039/218] lntest: add WS test case with custom header macaroon --- lntest/itest/rest_api.go | 122 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 117 insertions(+), 5 deletions(-) diff --git a/lntest/itest/rest_api.go b/lntest/itest/rest_api.go index 5eee3d85a..c296e3a0a 100644 --- a/lntest/itest/rest_api.go +++ b/lntest/itest/rest_api.go @@ -201,7 +201,116 @@ func testRestApi(net *lntest.NetworkHarness, ht *harnessTest) { Height: uint32(height), } url := "/v2/chainnotifier/register/blocks" - c, err := openWebSocket(a, url, "POST", req) + c, err := openWebSocket(a, url, "POST", req, nil) + require.Nil(t, err, "websocket") + defer func() { + _ = c.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage( + websocket.CloseNormalClosure, + "done", + ), + ) + _ = c.Close() + }() + + msgChan := make(chan *chainrpc.BlockEpoch) + errChan := make(chan error) + timeout := time.After(defaultTimeout) + + // We want to read exactly one message. + go func() { + defer close(msgChan) + + _, msg, err := c.ReadMessage() + if err != nil { + errChan <- err + return + } + + // The chunked/streamed responses come wrapped + // in either a {"result":{}} or {"error":{}} + // wrapper which we'll get rid of here. + msgStr := string(msg) + if !strings.Contains(msgStr, "\"result\":") { + errChan <- fmt.Errorf("invalid msg: %s", + msgStr) + return + } + msgStr = resultPattern.ReplaceAllString( + msgStr, "${1}", + ) + + // Make sure we can parse the unwrapped message + // into the expected proto message. + protoMsg := &chainrpc.BlockEpoch{} + err = jsonpb.UnmarshalString( + msgStr, protoMsg, + ) + if err != nil { + errChan <- err + return + } + + select { + case msgChan <- protoMsg: + case <-timeout: + } + }() + + // Mine a block and make sure we get a message for it. + blockHashes, err := net.Miner.Node.Generate(1) + require.Nil(t, err, "generate blocks") + assert.Equal(t, 1, len(blockHashes), "num blocks") + select { + case msg := <-msgChan: + assert.Equal( + t, blockHashes[0].CloneBytes(), + msg.Hash, "block hash", + ) + + case err := <-errChan: + t.Fatalf("Received error from WS: %v", err) + + case <-timeout: + t.Fatalf("Timeout before message was received") + } + }, + }, { + name: "websocket subscription with macaroon in protocol", + run: func(t *testing.T, a, b *lntest.HarnessNode) { + // Find out the current best block so we can subscribe + // to the next one. + hash, height, err := net.Miner.Node.GetBestBlock() + require.Nil(t, err, "get best block") + + // Create a new subscription to get block epoch events. + req := &chainrpc.BlockEpoch{ + Hash: hash.CloneBytes(), + Height: uint32(height), + } + url := "/v2/chainnotifier/register/blocks" + + // This time we send the macaroon in the special header + // Sec-Websocket-Protocol which is the only header field + // available to browsers when opening a WebSocket. + mac, err := a.ReadMacaroon( + a.AdminMacPath(), defaultTimeout, + ) + require.NoError(t, err, "read admin mac") + macBytes, err := mac.MarshalBinary() + require.NoError(t, err, "marshal admin mac") + + customHeader := make(http.Header) + customHeader.Set( + lnrpc.HeaderWebSocketProtocol, fmt.Sprintf( + "Grpc-Metadata-Macaroon+%s", + hex.EncodeToString(macBytes), + ), + ) + c, err := openWebSocket( + a, url, "POST", req, customHeader, + ) require.Nil(t, err, "websocket") defer func() { _ = c.WriteMessage( @@ -364,14 +473,17 @@ func makeRequest(node *lntest.HarnessNode, url, method string, // openWebSocket opens a new WebSocket connection to the given URL with the // appropriate macaroon headers and sends the request message over the socket. func openWebSocket(node *lntest.HarnessNode, url, method string, - req proto.Message) (*websocket.Conn, error) { + req proto.Message, customHeader http.Header) (*websocket.Conn, error) { // Prepare our macaroon headers and assemble the full URL from the // node's listening address. WebSockets always work over GET so we need // to append the target request method as a query parameter. - header := make(http.Header) - if err := addAdminMacaroon(node, header); err != nil { - return nil, err + header := customHeader + if header == nil { + header = make(http.Header) + if err := addAdminMacaroon(node, header); err != nil { + return nil, err + } } fullURL := fmt.Sprintf( "wss://%s%s?method=%s", node.Cfg.RESTAddr(), url, method, From de74798c12adb7a8aeb2508f3e93bd7499a67bd8 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 29 Jul 2020 20:09:55 -0700 Subject: [PATCH 040/218] lntest/itest: add whitelist entry for block hash fetch --- lntest/itest/log_error_whitelist.txt | 1 + lntest/itest/log_substitutions.txt | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lntest/itest/log_error_whitelist.txt b/lntest/itest/log_error_whitelist.txt index deb6abc9c..af71ee9bd 100644 --- a/lntest/itest/log_error_whitelist.txt +++ b/lntest/itest/log_error_whitelist.txt @@ -202,3 +202,4 @@