Merge pull request #6111 from bottlepay/cache-loading

kvdb+channeldb: speed up graph cache
This commit is contained in:
Olaoluwa Osuntokun 2022-01-20 17:35:04 -08:00 committed by GitHub
commit d67e6d5414
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 383 additions and 87 deletions

View File

@ -216,15 +216,29 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
startTime := time.Now()
log.Debugf("Populating in-memory channel graph, this might " +
"take a while...")
err := g.ForEachNodeCacheable(
func(tx kvdb.RTx, node GraphCacheNode) error {
return g.graphCache.AddNode(tx, node)
g.graphCache.AddNodeFeatures(node)
return nil
},
)
if err != nil {
return nil, err
}
err = g.ForEachChannel(func(info *ChannelEdgeInfo,
policy1, policy2 *ChannelEdgePolicy) error {
g.graphCache.AddChannel(info, policy1, policy2)
return nil
})
if err != nil {
return nil, err
}
log.Debugf("Finished populating in-memory channel graph (took "+
"%v, %s)", time.Since(startTime), g.graphCache.Stats())
}
@ -232,6 +246,71 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
return g, nil
}
// channelMapKey is the key structure used for storing channel edge policies.
type channelMapKey struct {
nodeKey route.Vertex
chanID [8]byte
}
// getChannelMap loads all channel edge policies from the database and stores
// them in a map.
func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) (
map[channelMapKey]*ChannelEdgePolicy, error) {
// Create a map to store all channel edge policies.
channelMap := make(map[channelMapKey]*ChannelEdgePolicy)
err := kvdb.ForAll(edges, func(k, edgeBytes []byte) error {
// Skip embedded buckets.
if bytes.Equal(k, edgeIndexBucket) ||
bytes.Equal(k, edgeUpdateIndexBucket) ||
bytes.Equal(k, zombieBucket) ||
bytes.Equal(k, disabledEdgePolicyBucket) ||
bytes.Equal(k, channelPointBucket) {
return nil
}
// Validate key length.
if len(k) != 33+8 {
return fmt.Errorf("invalid edge key %x encountered", k)
}
var key channelMapKey
copy(key.nodeKey[:], k[:33])
copy(key.chanID[:], k[33:])
// No need to deserialize unknown policy.
if bytes.Equal(edgeBytes, unknownPolicy) {
return nil
}
edgeReader := bytes.NewReader(edgeBytes)
edge, err := deserializeChanEdgePolicyRaw(
edgeReader,
)
switch {
// If the db policy was missing an expected optional field, we
// return nil as if the policy was unknown.
case err == ErrEdgePolicyOptionalFieldNotFound:
return nil
case err != nil:
return err
}
channelMap[key] = edge
return nil
})
if err != nil {
return nil, err
}
return channelMap, nil
}
var graphTopLevelBuckets = [][]byte{
nodeBucket,
edgeBucket,
@ -332,50 +411,47 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) {
func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo,
*ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
// TODO(roasbeef): ptr map to reduce # of allocs? no duplicates
return kvdb.View(c.db, func(tx kvdb.RTx) error {
// First, grab the node bucket. This will be used to populate
// the Node pointers in each edge read from disk.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
// Next, grab the edge bucket which stores the edges, and also
// the index itself so we can group the directed edges together
// logically.
return c.db.View(func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
// First, load all edges in memory indexed by node and channel
// id.
channelMap, err := c.getChannelMap(edges)
if err != nil {
return err
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
// For each edge pair within the edge index, we fetch each edge
// itself and also the node information in order to fully
// populated the object.
return edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error {
infoReader := bytes.NewReader(edgeInfoBytes)
edgeInfo, err := deserializeChanEdgeInfo(infoReader)
if err != nil {
return err
}
edgeInfo.db = c.db
// Load edge index, recombine each channel with the policies
// loaded above and invoke the callback.
return kvdb.ForAll(edgeIndex, func(k, edgeInfoBytes []byte) error {
var chanID [8]byte
copy(chanID[:], k)
edge1, edge2, err := fetchChanEdgePolicies(
edgeIndex, edges, nodes, chanID, c.db,
)
edgeInfoReader := bytes.NewReader(edgeInfoBytes)
info, err := deserializeChanEdgeInfo(edgeInfoReader)
if err != nil {
return err
}
// With both edges read, execute the call back. IF this
// function returns an error then the transaction will
// be aborted.
return cb(&edgeInfo, edge1, edge2)
policy1 := channelMap[channelMapKey{
nodeKey: info.NodeKey1Bytes,
chanID: chanID,
}]
policy2 := channelMap[channelMapKey{
nodeKey: info.NodeKey2Bytes,
chanID: chanID,
}]
return cb(&info, policy1, policy2)
})
}, func() {})
}
@ -628,7 +704,6 @@ func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx,
return ErrGraphNotFound
}
cacheableNode := newGraphCacheNode(route.Vertex{}, nil)
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
// If this is the source key, then we skip this
// iteration as the value for this key is a pubKey
@ -638,8 +713,8 @@ func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx,
}
nodeReader := bytes.NewReader(nodeBytes)
err := deserializeLightningNodeCacheable(
nodeReader, cacheableNode,
cacheableNode, err := deserializeLightningNodeCacheable(
nodeReader,
)
if err != nil {
return err
@ -2740,8 +2815,6 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (
type graphCacheNode struct {
pubKeyBytes route.Vertex
features *lnwire.FeatureVector
nodeScratch [8]byte
}
// newGraphCacheNode returns a new cache optimized node.
@ -4090,51 +4163,60 @@ func fetchLightningNode(nodeBucket kvdb.RBucket,
return deserializeLightningNode(nodeReader)
}
func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error {
func deserializeLightningNodeCacheable(r io.Reader) (*graphCacheNode, error) {
// Always populate a feature vector, even if we don't have a node
// announcement and short circuit below.
node.features = lnwire.EmptyFeatureVector()
node := newGraphCacheNode(
route.Vertex{},
lnwire.EmptyFeatureVector(),
)
var nodeScratch [8]byte
// Skip ahead:
// - LastUpdate (8 bytes)
if _, err := r.Read(node.nodeScratch[:]); err != nil {
return err
if _, err := r.Read(nodeScratch[:]); err != nil {
return nil, err
}
if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil {
return err
return nil, err
}
// Read the node announcement flag.
if _, err := r.Read(node.nodeScratch[:2]); err != nil {
return err
if _, err := r.Read(nodeScratch[:2]); err != nil {
return nil, err
}
hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2])
hasNodeAnn := byteOrder.Uint16(nodeScratch[:2])
// The rest of the data is optional, and will only be there if we got a
// node announcement for this node.
if hasNodeAnn == 0 {
return nil
return node, nil
}
// We did get a node announcement for this node, so we'll have the rest
// of the data available.
var rgb uint8
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
return nil, err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
return nil, err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
return nil, err
}
if _, err := wire.ReadVarString(r, 0); err != nil {
return err
return nil, err
}
return node.features.Decode(r)
if err := node.features.Decode(r); err != nil {
return nil, err
}
return node, nil
}
func deserializeLightningNode(r io.Reader) (LightningNode, error) {
@ -4652,6 +4734,27 @@ func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy,
func deserializeChanEdgePolicy(r io.Reader,
nodes kvdb.RBucket) (*ChannelEdgePolicy, error) {
// Deserialize the policy. Note that in case an optional field is not
// found, both an error and a populated policy object are returned.
edge, deserializeErr := deserializeChanEdgePolicyRaw(r)
if deserializeErr != nil &&
deserializeErr != ErrEdgePolicyOptionalFieldNotFound {
return nil, deserializeErr
}
// Populate full LightningNode struct.
pub := edge.Node.PubKeyBytes[:]
node, err := fetchLightningNode(nodes, pub)
if err != nil {
return nil, fmt.Errorf("unable to fetch node: %x, %v", pub, err)
}
edge.Node = &node
return edge, deserializeErr
}
func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy, error) {
edge := &ChannelEdgePolicy{}
var err error
@ -4701,13 +4804,9 @@ func deserializeChanEdgePolicy(r io.Reader,
if _, err := r.Read(pub[:]); err != nil {
return nil, err
}
node, err := fetchLightningNode(nodes, pub[:])
if err != nil {
return nil, fmt.Errorf("unable to fetch node: %x, %v",
pub[:], err)
edge.Node = &LightningNode{
PubKeyBytes: pub,
}
edge.Node = &node
// We'll try and see if there are any opaque bytes left, if not, then
// we'll ignore the EOF error and return the edge as is.

View File

@ -205,9 +205,8 @@ func (c *GraphCache) Stats() string {
numChannels)
}
// AddNode adds a graph node, including all the (directed) channels of that
// node.
func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error {
// AddNodeFeatures adds a graph node and its features to the cache.
func (c *GraphCache) AddNodeFeatures(node GraphCacheNode) {
nodePubKey := node.PubKey()
// Only hold the lock for a short time. The `ForEachChannel()` below is
@ -217,6 +216,12 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error {
c.mtx.Lock()
c.nodeFeatures[nodePubKey] = node.Features()
c.mtx.Unlock()
}
// AddNode adds a graph node, including all the (directed) channels of that
// node.
func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error {
c.AddNodeFeatures(node)
return node.ForEachChannel(
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,

View File

@ -1207,11 +1207,22 @@ func TestGraphTraversalCacheable(t *testing.T) {
// Iterate through all the known channels within the graph DB by
// iterating over each node, once again if the map is empty that
// indicates that all edges have properly been reached.
var nodes []GraphCacheNode
err = graph.ForEachNodeCacheable(
func(tx kvdb.RTx, node GraphCacheNode) error {
delete(nodeMap, node.PubKey())
return node.ForEachChannel(
nodes = append(nodes, node)
return nil
},
)
require.NoError(t, err)
require.Len(t, nodeMap, 0)
err = graph.db.View(func(tx kvdb.RTx) error {
for _, node := range nodes {
err := node.ForEachChannel(
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,
policy *ChannelEdgePolicy,
policy2 *ChannelEdgePolicy) error {
@ -1220,10 +1231,15 @@ func TestGraphTraversalCacheable(t *testing.T) {
return nil
},
)
},
)
if err != nil {
return err
}
}
return nil
}, func() {})
require.NoError(t, err)
require.Len(t, nodeMap, 0)
require.Len(t, chanIndex, 0)
}
@ -3695,9 +3711,20 @@ func BenchmarkForEachChannel(b *testing.B) {
totalCapacity btcutil.Amount
maxHTLCs lnwire.MilliSatoshi
)
err := graph.ForEachNodeCacheable(
func(tx kvdb.RTx, n GraphCacheNode) error {
return n.ForEachChannel(
var nodes []GraphCacheNode
err = graph.ForEachNodeCacheable(
func(tx kvdb.RTx, node GraphCacheNode) error {
nodes = append(nodes, node)
return nil
},
)
require.NoError(b, err)
err = graph.db.View(func(tx kvdb.RTx) error {
for _, n := range nodes {
err := n.ForEachChannel(
tx, func(tx kvdb.RTx,
info *ChannelEdgeInfo,
policy *ChannelEdgePolicy,
@ -3715,8 +3742,13 @@ func BenchmarkForEachChannel(b *testing.B) {
return nil
},
)
},
)
if err != nil {
return err
}
}
return nil
}, func() {})
require.NoError(b, err)
}
}
@ -3760,3 +3792,52 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) {
require.Equal(t, numChans, 1)
}
// TestGraphLoading asserts that the cache is properly reconstructed after a
// restart.
func TestGraphLoading(t *testing.T) {
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName, err := ioutil.TempDir("", "channelgraph")
require.NoError(t, err)
defer os.RemoveAll(tempDirName)
// Next, create the graph for the first time.
backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cgr")
require.NoError(t, err)
defer backend.Close()
defer backendCleanup()
opts := DefaultOptions()
graph, err := NewChannelGraph(
backend, opts.RejectCacheSize, opts.ChannelCacheSize,
opts.BatchCommitInterval, opts.PreAllocCacheNumNodes,
true, false,
)
require.NoError(t, err)
// Populate the graph with test data.
const numNodes = 100
const numChannels = 4
_, _ = fillTestGraph(t, graph, numNodes, numChannels)
// Recreate the graph. This should cause the graph cache to be
// populated.
graphReloaded, err := NewChannelGraph(
backend, opts.RejectCacheSize, opts.ChannelCacheSize,
opts.BatchCommitInterval, opts.PreAllocCacheNumNodes,
true, false,
)
require.NoError(t, err)
// Assert that the cache content is identical.
require.Equal(
t, graph.graphCache.nodeChannels,
graphReloaded.graphCache.nodeChannels,
)
require.Equal(
t, graph.graphCache.nodeFeatures,
graphReloaded.graphCache.nodeFeatures,
)
}

View File

@ -19,6 +19,11 @@ connection from the watch-only node.
In other words, freshly-installed LND can now be initialized with multiple
channels from an external (e.g. hardware) wallet *in a single transaction*.
## Database
* [Speed up graph cache loading on startup with
Postgres](https://github.com/lightningnetwork/lnd/pull/6111)
## Build System
* [Clean up Makefile by using go

2
go.mod
View File

@ -46,7 +46,7 @@ require (
github.com/lightningnetwork/lnd/cert v1.1.0
github.com/lightningnetwork/lnd/clock v1.1.0
github.com/lightningnetwork/lnd/healthcheck v1.2.0
github.com/lightningnetwork/lnd/kvdb v1.2.5
github.com/lightningnetwork/lnd/kvdb v1.3.0
github.com/lightningnetwork/lnd/queue v1.1.0
github.com/lightningnetwork/lnd/ticker v1.1.0
github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796

View File

@ -406,3 +406,9 @@ func (b *readWriteBucket) Prefetch(paths ...[]string) {
b.tx.stm.Prefetch(flattenMap(keys), flattenMap(ranges))
}
// ForAll is an optimized version of ForEach with the limitation that no
// additional queries can be executed within the callback.
func (b *readWriteBucket) ForAll(cb func(k, v []byte) error) error {
return b.ForEach(cb)
}

View File

@ -109,6 +109,12 @@ type ExtendedRBucket interface {
// Prefetch will attempt to prefetch all values under a path.
Prefetch(paths ...[]string)
// ForAll is an optimized version of ForEach.
//
// NOTE: ForAll differs from ForEach in that no additional queries can
// be executed within the callback.
ForAll(func(k, v []byte) error) error
}
// Prefetch will attempt to prefetch all values under a path from the passed
@ -119,6 +125,16 @@ func Prefetch(b RBucket, paths ...[]string) {
}
}
// ForAll is an optimized version of ForEach with the limitation that no
// additional queries can be executed within the callback.
func ForAll(b RBucket, cb func(k, v []byte) error) error {
if bucket, ok := b.(ExtendedRBucket); ok {
return bucket.ForAll(cb)
}
return b.ForEach(cb)
}
// RootBucket is a wrapper to ExtendedRTx.RootBucket which does nothing if
// the implementation doesn't have ExtendedRTx.
func RootBucket(t RTx) RBucket {

View File

@ -427,3 +427,36 @@ func (b *readWriteBucket) Sequence() uint64 {
return uint64(seq)
}
// Prefetch will attempt to prefetch all values under a path from the passed
// bucket.
func (b *readWriteBucket) Prefetch(paths ...[]string) {}
// ForAll is an optimized version of ForEach with the limitation that no
// additional queries can be executed within the callback.
func (b *readWriteBucket) ForAll(cb func(k, v []byte) error) error {
rows, cancel, err := b.tx.Query(
"SELECT key, value FROM " + b.table + " WHERE " +
parentSelector(b.id) + " ORDER BY key",
)
if err != nil {
return err
}
defer cancel()
for rows.Next() {
var key, value []byte
err := rows.Scan(&key, &value)
if err != nil {
return err
}
err = cb(key, value)
if err != nil {
return err
}
}
return nil
}

View File

@ -4,6 +4,7 @@
package postgres
import (
"context"
"database/sql"
"sync"
@ -39,7 +40,16 @@ func newReadWriteTx(db *db, readOnly bool) (*readWriteTx, error) {
}
locker.Lock()
tx, err := db.db.Begin()
// Start the transaction. Don't use the timeout context because it would
// be applied to the transaction as a whole. If possible, mark the
// transaction as read-only to make sure that potential programming
// errors cannot cause changes to the database.
tx, err := db.db.BeginTx(
context.Background(),
&sql.TxOptions{
ReadOnly: readOnly,
},
)
if err != nil {
locker.Unlock()
return nil, err
@ -165,6 +175,21 @@ func (tx *readWriteTx) QueryRow(query string, args ...interface{}) (*sql.Row,
return tx.tx.QueryRowContext(ctx, query, args...), cancel
}
// Query executes a multi-row query call with a timeout context.
func (tx *readWriteTx) Query(query string, args ...interface{}) (*sql.Rows,
func(), error) {
ctx, cancel := tx.db.getTimeoutCtx()
rows, err := tx.tx.QueryContext(ctx, query, args...)
if err != nil {
cancel()
return nil, func() {}, err
}
return rows, cancel, nil
}
// Exec executes a Exec call with a timeout context.
func (tx *readWriteTx) Exec(query string, args ...interface{}) (sql.Result,
error) {

View File

@ -84,7 +84,35 @@ func TestPostgres(t *testing.T) {
},
{
name: "bucket for each",
test: testBucketForEach,
test: func(t *testing.T, db walletdb.DB) {
testBucketIterator(t, db, func(bucket walletdb.ReadWriteBucket,
callback func(key, val []byte) error) error {
return bucket.ForEach(callback)
})
},
expectedDb: m{
"test_kv": []m{
{"id": int64(1), "key": "apple", "parent_id": nil, "sequence": nil, "value": nil},
{"id": int64(2), "key": "banana", "parent_id": int64(1), "sequence": nil, "value": nil},
{"id": int64(3), "key": "key1", "parent_id": int64(1), "sequence": nil, "value": "val1"},
{"id": int64(4), "key": "key1", "parent_id": int64(2), "sequence": nil, "value": "val1"},
{"id": int64(5), "key": "key2", "parent_id": int64(1), "sequence": nil, "value": "val2"},
{"id": int64(6), "key": "key2", "parent_id": int64(2), "sequence": nil, "value": "val2"},
{"id": int64(7), "key": "key3", "parent_id": int64(1), "sequence": nil, "value": "val3"},
{"id": int64(8), "key": "key3", "parent_id": int64(2), "sequence": nil, "value": "val3"},
},
},
},
{
name: "bucket for all",
test: func(t *testing.T, db walletdb.DB) {
testBucketIterator(t, db, func(bucket walletdb.ReadWriteBucket,
callback func(key, val []byte) error) error {
return ForAll(bucket, callback)
})
},
expectedDb: m{
"test_kv": []m{
{"id": int64(1), "key": "apple", "parent_id": nil, "sequence": nil, "value": nil},

View File

@ -159,7 +159,20 @@ func testBucketDeletion(t *testing.T, db walletdb.DB) {
require.Nil(t, err)
}
type bucketIterator = func(walletdb.ReadWriteBucket,
func(key, val []byte) error) error
func testBucketForEach(t *testing.T, db walletdb.DB) {
testBucketIterator(t, db, func(bucket walletdb.ReadWriteBucket,
callback func(key, val []byte) error) error {
return bucket.ForEach(callback)
})
}
func testBucketIterator(t *testing.T, db walletdb.DB,
iterator bucketIterator) {
err := Update(db, func(tx walletdb.ReadWriteTx) error {
// "apple"
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
@ -199,7 +212,7 @@ func testBucketForEach(t *testing.T, db walletdb.DB) {
require.Equal(t, expected, got)
got = make(map[string]string)
err = banana.ForEach(func(key, val []byte) error {
err = iterator(banana, func(key, val []byte) error {
got[string(key)] = string(val)
return nil
})

View File

@ -157,11 +157,6 @@ type ChannelGraphSource interface {
// ForEachNode is used to iterate over every node in the known graph.
ForEachNode(func(node *channeldb.LightningNode) error) error
// ForEachChannel is used to iterate over every channel in the known
// graph.
ForEachChannel(func(chanInfo *channeldb.ChannelEdgeInfo,
e1, e2 *channeldb.ChannelEdgePolicy) error) error
}
// PaymentAttemptDispatcher is used by the router to send payment attempts onto
@ -2541,16 +2536,6 @@ func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx,
})
}
// ForEachChannel is used to iterate over every known edge (channel) within our
// view of the channel graph.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (r *ChannelRouter) ForEachChannel(cb func(chanInfo *channeldb.ChannelEdgeInfo,
e1, e2 *channeldb.ChannelEdgePolicy) error) error {
return r.cfg.Graph.ForEachChannel(cb)
}
// AddProof updates the channel edge info with proof which is needed to
// properly announce the edge to the rest of the network.
//