mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-31 16:09:02 +02:00
Merge pull request #6111 from bottlepay/cache-loading
kvdb+channeldb: speed up graph cache
This commit is contained in:
commit
d67e6d5414
@ -216,15 +216,29 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
|
||||
startTime := time.Now()
|
||||
log.Debugf("Populating in-memory channel graph, this might " +
|
||||
"take a while...")
|
||||
|
||||
err := g.ForEachNodeCacheable(
|
||||
func(tx kvdb.RTx, node GraphCacheNode) error {
|
||||
return g.graphCache.AddNode(tx, node)
|
||||
g.graphCache.AddNodeFeatures(node)
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = g.ForEachChannel(func(info *ChannelEdgeInfo,
|
||||
policy1, policy2 *ChannelEdgePolicy) error {
|
||||
|
||||
g.graphCache.AddChannel(info, policy1, policy2)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("Finished populating in-memory channel graph (took "+
|
||||
"%v, %s)", time.Since(startTime), g.graphCache.Stats())
|
||||
}
|
||||
@ -232,6 +246,71 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// channelMapKey is the key structure used for storing channel edge policies.
|
||||
type channelMapKey struct {
|
||||
nodeKey route.Vertex
|
||||
chanID [8]byte
|
||||
}
|
||||
|
||||
// getChannelMap loads all channel edge policies from the database and stores
|
||||
// them in a map.
|
||||
func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) (
|
||||
map[channelMapKey]*ChannelEdgePolicy, error) {
|
||||
|
||||
// Create a map to store all channel edge policies.
|
||||
channelMap := make(map[channelMapKey]*ChannelEdgePolicy)
|
||||
|
||||
err := kvdb.ForAll(edges, func(k, edgeBytes []byte) error {
|
||||
// Skip embedded buckets.
|
||||
if bytes.Equal(k, edgeIndexBucket) ||
|
||||
bytes.Equal(k, edgeUpdateIndexBucket) ||
|
||||
bytes.Equal(k, zombieBucket) ||
|
||||
bytes.Equal(k, disabledEdgePolicyBucket) ||
|
||||
bytes.Equal(k, channelPointBucket) {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate key length.
|
||||
if len(k) != 33+8 {
|
||||
return fmt.Errorf("invalid edge key %x encountered", k)
|
||||
}
|
||||
|
||||
var key channelMapKey
|
||||
copy(key.nodeKey[:], k[:33])
|
||||
copy(key.chanID[:], k[33:])
|
||||
|
||||
// No need to deserialize unknown policy.
|
||||
if bytes.Equal(edgeBytes, unknownPolicy) {
|
||||
return nil
|
||||
}
|
||||
|
||||
edgeReader := bytes.NewReader(edgeBytes)
|
||||
edge, err := deserializeChanEdgePolicyRaw(
|
||||
edgeReader,
|
||||
)
|
||||
|
||||
switch {
|
||||
// If the db policy was missing an expected optional field, we
|
||||
// return nil as if the policy was unknown.
|
||||
case err == ErrEdgePolicyOptionalFieldNotFound:
|
||||
return nil
|
||||
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
channelMap[key] = edge
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return channelMap, nil
|
||||
}
|
||||
|
||||
var graphTopLevelBuckets = [][]byte{
|
||||
nodeBucket,
|
||||
edgeBucket,
|
||||
@ -332,50 +411,47 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) {
|
||||
func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo,
|
||||
*ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
|
||||
|
||||
// TODO(roasbeef): ptr map to reduce # of allocs? no duplicates
|
||||
|
||||
return kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
// First, grab the node bucket. This will be used to populate
|
||||
// the Node pointers in each edge read from disk.
|
||||
nodes := tx.ReadBucket(nodeBucket)
|
||||
if nodes == nil {
|
||||
return ErrGraphNotFound
|
||||
}
|
||||
|
||||
// Next, grab the edge bucket which stores the edges, and also
|
||||
// the index itself so we can group the directed edges together
|
||||
// logically.
|
||||
return c.db.View(func(tx kvdb.RTx) error {
|
||||
edges := tx.ReadBucket(edgeBucket)
|
||||
if edges == nil {
|
||||
return ErrGraphNoEdgesFound
|
||||
}
|
||||
|
||||
// First, load all edges in memory indexed by node and channel
|
||||
// id.
|
||||
channelMap, err := c.getChannelMap(edges)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
|
||||
if edgeIndex == nil {
|
||||
return ErrGraphNoEdgesFound
|
||||
}
|
||||
|
||||
// For each edge pair within the edge index, we fetch each edge
|
||||
// itself and also the node information in order to fully
|
||||
// populated the object.
|
||||
return edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error {
|
||||
infoReader := bytes.NewReader(edgeInfoBytes)
|
||||
edgeInfo, err := deserializeChanEdgeInfo(infoReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
edgeInfo.db = c.db
|
||||
// Load edge index, recombine each channel with the policies
|
||||
// loaded above and invoke the callback.
|
||||
return kvdb.ForAll(edgeIndex, func(k, edgeInfoBytes []byte) error {
|
||||
var chanID [8]byte
|
||||
copy(chanID[:], k)
|
||||
|
||||
edge1, edge2, err := fetchChanEdgePolicies(
|
||||
edgeIndex, edges, nodes, chanID, c.db,
|
||||
)
|
||||
edgeInfoReader := bytes.NewReader(edgeInfoBytes)
|
||||
info, err := deserializeChanEdgeInfo(edgeInfoReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// With both edges read, execute the call back. IF this
|
||||
// function returns an error then the transaction will
|
||||
// be aborted.
|
||||
return cb(&edgeInfo, edge1, edge2)
|
||||
policy1 := channelMap[channelMapKey{
|
||||
nodeKey: info.NodeKey1Bytes,
|
||||
chanID: chanID,
|
||||
}]
|
||||
|
||||
policy2 := channelMap[channelMapKey{
|
||||
nodeKey: info.NodeKey2Bytes,
|
||||
chanID: chanID,
|
||||
}]
|
||||
|
||||
return cb(&info, policy1, policy2)
|
||||
})
|
||||
}, func() {})
|
||||
}
|
||||
@ -628,7 +704,6 @@ func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx,
|
||||
return ErrGraphNotFound
|
||||
}
|
||||
|
||||
cacheableNode := newGraphCacheNode(route.Vertex{}, nil)
|
||||
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
|
||||
// If this is the source key, then we skip this
|
||||
// iteration as the value for this key is a pubKey
|
||||
@ -638,8 +713,8 @@ func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx,
|
||||
}
|
||||
|
||||
nodeReader := bytes.NewReader(nodeBytes)
|
||||
err := deserializeLightningNodeCacheable(
|
||||
nodeReader, cacheableNode,
|
||||
cacheableNode, err := deserializeLightningNodeCacheable(
|
||||
nodeReader,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -2740,8 +2815,6 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (
|
||||
type graphCacheNode struct {
|
||||
pubKeyBytes route.Vertex
|
||||
features *lnwire.FeatureVector
|
||||
|
||||
nodeScratch [8]byte
|
||||
}
|
||||
|
||||
// newGraphCacheNode returns a new cache optimized node.
|
||||
@ -4090,51 +4163,60 @@ func fetchLightningNode(nodeBucket kvdb.RBucket,
|
||||
return deserializeLightningNode(nodeReader)
|
||||
}
|
||||
|
||||
func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error {
|
||||
func deserializeLightningNodeCacheable(r io.Reader) (*graphCacheNode, error) {
|
||||
// Always populate a feature vector, even if we don't have a node
|
||||
// announcement and short circuit below.
|
||||
node.features = lnwire.EmptyFeatureVector()
|
||||
node := newGraphCacheNode(
|
||||
route.Vertex{},
|
||||
lnwire.EmptyFeatureVector(),
|
||||
)
|
||||
|
||||
var nodeScratch [8]byte
|
||||
|
||||
// Skip ahead:
|
||||
// - LastUpdate (8 bytes)
|
||||
if _, err := r.Read(node.nodeScratch[:]); err != nil {
|
||||
return err
|
||||
if _, err := r.Read(nodeScratch[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the node announcement flag.
|
||||
if _, err := r.Read(node.nodeScratch[:2]); err != nil {
|
||||
return err
|
||||
if _, err := r.Read(nodeScratch[:2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2])
|
||||
hasNodeAnn := byteOrder.Uint16(nodeScratch[:2])
|
||||
|
||||
// The rest of the data is optional, and will only be there if we got a
|
||||
// node announcement for this node.
|
||||
if hasNodeAnn == 0 {
|
||||
return nil
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// We did get a node announcement for this node, so we'll have the rest
|
||||
// of the data available.
|
||||
var rgb uint8
|
||||
if err := binary.Read(r, byteOrder, &rgb); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(r, byteOrder, &rgb); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(r, byteOrder, &rgb); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := wire.ReadVarString(r, 0); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.features.Decode(r)
|
||||
if err := node.features.Decode(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func deserializeLightningNode(r io.Reader) (LightningNode, error) {
|
||||
@ -4652,6 +4734,27 @@ func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy,
|
||||
func deserializeChanEdgePolicy(r io.Reader,
|
||||
nodes kvdb.RBucket) (*ChannelEdgePolicy, error) {
|
||||
|
||||
// Deserialize the policy. Note that in case an optional field is not
|
||||
// found, both an error and a populated policy object are returned.
|
||||
edge, deserializeErr := deserializeChanEdgePolicyRaw(r)
|
||||
if deserializeErr != nil &&
|
||||
deserializeErr != ErrEdgePolicyOptionalFieldNotFound {
|
||||
|
||||
return nil, deserializeErr
|
||||
}
|
||||
|
||||
// Populate full LightningNode struct.
|
||||
pub := edge.Node.PubKeyBytes[:]
|
||||
node, err := fetchLightningNode(nodes, pub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to fetch node: %x, %v", pub, err)
|
||||
}
|
||||
edge.Node = &node
|
||||
|
||||
return edge, deserializeErr
|
||||
}
|
||||
|
||||
func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy, error) {
|
||||
edge := &ChannelEdgePolicy{}
|
||||
|
||||
var err error
|
||||
@ -4701,13 +4804,9 @@ func deserializeChanEdgePolicy(r io.Reader,
|
||||
if _, err := r.Read(pub[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node, err := fetchLightningNode(nodes, pub[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to fetch node: %x, %v",
|
||||
pub[:], err)
|
||||
edge.Node = &LightningNode{
|
||||
PubKeyBytes: pub,
|
||||
}
|
||||
edge.Node = &node
|
||||
|
||||
// We'll try and see if there are any opaque bytes left, if not, then
|
||||
// we'll ignore the EOF error and return the edge as is.
|
||||
|
@ -205,9 +205,8 @@ func (c *GraphCache) Stats() string {
|
||||
numChannels)
|
||||
}
|
||||
|
||||
// AddNode adds a graph node, including all the (directed) channels of that
|
||||
// node.
|
||||
func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error {
|
||||
// AddNodeFeatures adds a graph node and its features to the cache.
|
||||
func (c *GraphCache) AddNodeFeatures(node GraphCacheNode) {
|
||||
nodePubKey := node.PubKey()
|
||||
|
||||
// Only hold the lock for a short time. The `ForEachChannel()` below is
|
||||
@ -217,6 +216,12 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error {
|
||||
c.mtx.Lock()
|
||||
c.nodeFeatures[nodePubKey] = node.Features()
|
||||
c.mtx.Unlock()
|
||||
}
|
||||
|
||||
// AddNode adds a graph node, including all the (directed) channels of that
|
||||
// node.
|
||||
func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error {
|
||||
c.AddNodeFeatures(node)
|
||||
|
||||
return node.ForEachChannel(
|
||||
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,
|
||||
|
@ -1207,11 +1207,22 @@ func TestGraphTraversalCacheable(t *testing.T) {
|
||||
// Iterate through all the known channels within the graph DB by
|
||||
// iterating over each node, once again if the map is empty that
|
||||
// indicates that all edges have properly been reached.
|
||||
var nodes []GraphCacheNode
|
||||
err = graph.ForEachNodeCacheable(
|
||||
func(tx kvdb.RTx, node GraphCacheNode) error {
|
||||
delete(nodeMap, node.PubKey())
|
||||
|
||||
return node.ForEachChannel(
|
||||
nodes = append(nodes, node)
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodeMap, 0)
|
||||
|
||||
err = graph.db.View(func(tx kvdb.RTx) error {
|
||||
for _, node := range nodes {
|
||||
err := node.ForEachChannel(
|
||||
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,
|
||||
policy *ChannelEdgePolicy,
|
||||
policy2 *ChannelEdgePolicy) error {
|
||||
@ -1220,10 +1231,15 @@ func TestGraphTraversalCacheable(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodeMap, 0)
|
||||
require.Len(t, chanIndex, 0)
|
||||
}
|
||||
|
||||
@ -3695,9 +3711,20 @@ func BenchmarkForEachChannel(b *testing.B) {
|
||||
totalCapacity btcutil.Amount
|
||||
maxHTLCs lnwire.MilliSatoshi
|
||||
)
|
||||
err := graph.ForEachNodeCacheable(
|
||||
func(tx kvdb.RTx, n GraphCacheNode) error {
|
||||
return n.ForEachChannel(
|
||||
|
||||
var nodes []GraphCacheNode
|
||||
err = graph.ForEachNodeCacheable(
|
||||
func(tx kvdb.RTx, node GraphCacheNode) error {
|
||||
nodes = append(nodes, node)
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
require.NoError(b, err)
|
||||
|
||||
err = graph.db.View(func(tx kvdb.RTx) error {
|
||||
for _, n := range nodes {
|
||||
err := n.ForEachChannel(
|
||||
tx, func(tx kvdb.RTx,
|
||||
info *ChannelEdgeInfo,
|
||||
policy *ChannelEdgePolicy,
|
||||
@ -3715,8 +3742,13 @@ func BenchmarkForEachChannel(b *testing.B) {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
require.NoError(b, err)
|
||||
}
|
||||
}
|
||||
@ -3760,3 +3792,52 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) {
|
||||
|
||||
require.Equal(t, numChans, 1)
|
||||
}
|
||||
|
||||
// TestGraphLoading asserts that the cache is properly reconstructed after a
|
||||
// restart.
|
||||
func TestGraphLoading(t *testing.T) {
|
||||
// First, create a temporary directory to be used for the duration of
|
||||
// this test.
|
||||
tempDirName, err := ioutil.TempDir("", "channelgraph")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDirName)
|
||||
|
||||
// Next, create the graph for the first time.
|
||||
backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cgr")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
defer backendCleanup()
|
||||
|
||||
opts := DefaultOptions()
|
||||
graph, err := NewChannelGraph(
|
||||
backend, opts.RejectCacheSize, opts.ChannelCacheSize,
|
||||
opts.BatchCommitInterval, opts.PreAllocCacheNumNodes,
|
||||
true, false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Populate the graph with test data.
|
||||
const numNodes = 100
|
||||
const numChannels = 4
|
||||
_, _ = fillTestGraph(t, graph, numNodes, numChannels)
|
||||
|
||||
// Recreate the graph. This should cause the graph cache to be
|
||||
// populated.
|
||||
graphReloaded, err := NewChannelGraph(
|
||||
backend, opts.RejectCacheSize, opts.ChannelCacheSize,
|
||||
opts.BatchCommitInterval, opts.PreAllocCacheNumNodes,
|
||||
true, false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert that the cache content is identical.
|
||||
require.Equal(
|
||||
t, graph.graphCache.nodeChannels,
|
||||
graphReloaded.graphCache.nodeChannels,
|
||||
)
|
||||
|
||||
require.Equal(
|
||||
t, graph.graphCache.nodeFeatures,
|
||||
graphReloaded.graphCache.nodeFeatures,
|
||||
)
|
||||
}
|
||||
|
@ -19,6 +19,11 @@ connection from the watch-only node.
|
||||
In other words, freshly-installed LND can now be initialized with multiple
|
||||
channels from an external (e.g. hardware) wallet *in a single transaction*.
|
||||
|
||||
## Database
|
||||
|
||||
* [Speed up graph cache loading on startup with
|
||||
Postgres](https://github.com/lightningnetwork/lnd/pull/6111)
|
||||
|
||||
## Build System
|
||||
|
||||
* [Clean up Makefile by using go
|
||||
|
2
go.mod
2
go.mod
@ -46,7 +46,7 @@ require (
|
||||
github.com/lightningnetwork/lnd/cert v1.1.0
|
||||
github.com/lightningnetwork/lnd/clock v1.1.0
|
||||
github.com/lightningnetwork/lnd/healthcheck v1.2.0
|
||||
github.com/lightningnetwork/lnd/kvdb v1.2.5
|
||||
github.com/lightningnetwork/lnd/kvdb v1.3.0
|
||||
github.com/lightningnetwork/lnd/queue v1.1.0
|
||||
github.com/lightningnetwork/lnd/ticker v1.1.0
|
||||
github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796
|
||||
|
@ -406,3 +406,9 @@ func (b *readWriteBucket) Prefetch(paths ...[]string) {
|
||||
|
||||
b.tx.stm.Prefetch(flattenMap(keys), flattenMap(ranges))
|
||||
}
|
||||
|
||||
// ForAll is an optimized version of ForEach with the limitation that no
|
||||
// additional queries can be executed within the callback.
|
||||
func (b *readWriteBucket) ForAll(cb func(k, v []byte) error) error {
|
||||
return b.ForEach(cb)
|
||||
}
|
||||
|
@ -109,6 +109,12 @@ type ExtendedRBucket interface {
|
||||
|
||||
// Prefetch will attempt to prefetch all values under a path.
|
||||
Prefetch(paths ...[]string)
|
||||
|
||||
// ForAll is an optimized version of ForEach.
|
||||
//
|
||||
// NOTE: ForAll differs from ForEach in that no additional queries can
|
||||
// be executed within the callback.
|
||||
ForAll(func(k, v []byte) error) error
|
||||
}
|
||||
|
||||
// Prefetch will attempt to prefetch all values under a path from the passed
|
||||
@ -119,6 +125,16 @@ func Prefetch(b RBucket, paths ...[]string) {
|
||||
}
|
||||
}
|
||||
|
||||
// ForAll is an optimized version of ForEach with the limitation that no
|
||||
// additional queries can be executed within the callback.
|
||||
func ForAll(b RBucket, cb func(k, v []byte) error) error {
|
||||
if bucket, ok := b.(ExtendedRBucket); ok {
|
||||
return bucket.ForAll(cb)
|
||||
}
|
||||
|
||||
return b.ForEach(cb)
|
||||
}
|
||||
|
||||
// RootBucket is a wrapper to ExtendedRTx.RootBucket which does nothing if
|
||||
// the implementation doesn't have ExtendedRTx.
|
||||
func RootBucket(t RTx) RBucket {
|
||||
|
@ -427,3 +427,36 @@ func (b *readWriteBucket) Sequence() uint64 {
|
||||
|
||||
return uint64(seq)
|
||||
}
|
||||
|
||||
// Prefetch will attempt to prefetch all values under a path from the passed
|
||||
// bucket.
|
||||
func (b *readWriteBucket) Prefetch(paths ...[]string) {}
|
||||
|
||||
// ForAll is an optimized version of ForEach with the limitation that no
|
||||
// additional queries can be executed within the callback.
|
||||
func (b *readWriteBucket) ForAll(cb func(k, v []byte) error) error {
|
||||
rows, cancel, err := b.tx.Query(
|
||||
"SELECT key, value FROM " + b.table + " WHERE " +
|
||||
parentSelector(b.id) + " ORDER BY key",
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
for rows.Next() {
|
||||
var key, value []byte
|
||||
|
||||
err := rows.Scan(&key, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = cb(key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
|
||||
@ -39,7 +40,16 @@ func newReadWriteTx(db *db, readOnly bool) (*readWriteTx, error) {
|
||||
}
|
||||
locker.Lock()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
// Start the transaction. Don't use the timeout context because it would
|
||||
// be applied to the transaction as a whole. If possible, mark the
|
||||
// transaction as read-only to make sure that potential programming
|
||||
// errors cannot cause changes to the database.
|
||||
tx, err := db.db.BeginTx(
|
||||
context.Background(),
|
||||
&sql.TxOptions{
|
||||
ReadOnly: readOnly,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
locker.Unlock()
|
||||
return nil, err
|
||||
@ -165,6 +175,21 @@ func (tx *readWriteTx) QueryRow(query string, args ...interface{}) (*sql.Row,
|
||||
return tx.tx.QueryRowContext(ctx, query, args...), cancel
|
||||
}
|
||||
|
||||
// Query executes a multi-row query call with a timeout context.
|
||||
func (tx *readWriteTx) Query(query string, args ...interface{}) (*sql.Rows,
|
||||
func(), error) {
|
||||
|
||||
ctx, cancel := tx.db.getTimeoutCtx()
|
||||
rows, err := tx.tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
cancel()
|
||||
|
||||
return nil, func() {}, err
|
||||
}
|
||||
|
||||
return rows, cancel, nil
|
||||
}
|
||||
|
||||
// Exec executes a Exec call with a timeout context.
|
||||
func (tx *readWriteTx) Exec(query string, args ...interface{}) (sql.Result,
|
||||
error) {
|
||||
|
@ -84,7 +84,35 @@ func TestPostgres(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "bucket for each",
|
||||
test: testBucketForEach,
|
||||
test: func(t *testing.T, db walletdb.DB) {
|
||||
testBucketIterator(t, db, func(bucket walletdb.ReadWriteBucket,
|
||||
callback func(key, val []byte) error) error {
|
||||
|
||||
return bucket.ForEach(callback)
|
||||
})
|
||||
},
|
||||
expectedDb: m{
|
||||
"test_kv": []m{
|
||||
{"id": int64(1), "key": "apple", "parent_id": nil, "sequence": nil, "value": nil},
|
||||
{"id": int64(2), "key": "banana", "parent_id": int64(1), "sequence": nil, "value": nil},
|
||||
{"id": int64(3), "key": "key1", "parent_id": int64(1), "sequence": nil, "value": "val1"},
|
||||
{"id": int64(4), "key": "key1", "parent_id": int64(2), "sequence": nil, "value": "val1"},
|
||||
{"id": int64(5), "key": "key2", "parent_id": int64(1), "sequence": nil, "value": "val2"},
|
||||
{"id": int64(6), "key": "key2", "parent_id": int64(2), "sequence": nil, "value": "val2"},
|
||||
{"id": int64(7), "key": "key3", "parent_id": int64(1), "sequence": nil, "value": "val3"},
|
||||
{"id": int64(8), "key": "key3", "parent_id": int64(2), "sequence": nil, "value": "val3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bucket for all",
|
||||
test: func(t *testing.T, db walletdb.DB) {
|
||||
testBucketIterator(t, db, func(bucket walletdb.ReadWriteBucket,
|
||||
callback func(key, val []byte) error) error {
|
||||
|
||||
return ForAll(bucket, callback)
|
||||
})
|
||||
},
|
||||
expectedDb: m{
|
||||
"test_kv": []m{
|
||||
{"id": int64(1), "key": "apple", "parent_id": nil, "sequence": nil, "value": nil},
|
||||
|
@ -159,7 +159,20 @@ func testBucketDeletion(t *testing.T, db walletdb.DB) {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
type bucketIterator = func(walletdb.ReadWriteBucket,
|
||||
func(key, val []byte) error) error
|
||||
|
||||
func testBucketForEach(t *testing.T, db walletdb.DB) {
|
||||
testBucketIterator(t, db, func(bucket walletdb.ReadWriteBucket,
|
||||
callback func(key, val []byte) error) error {
|
||||
|
||||
return bucket.ForEach(callback)
|
||||
})
|
||||
}
|
||||
|
||||
func testBucketIterator(t *testing.T, db walletdb.DB,
|
||||
iterator bucketIterator) {
|
||||
|
||||
err := Update(db, func(tx walletdb.ReadWriteTx) error {
|
||||
// "apple"
|
||||
apple, err := tx.CreateTopLevelBucket([]byte("apple"))
|
||||
@ -199,7 +212,7 @@ func testBucketForEach(t *testing.T, db walletdb.DB) {
|
||||
require.Equal(t, expected, got)
|
||||
|
||||
got = make(map[string]string)
|
||||
err = banana.ForEach(func(key, val []byte) error {
|
||||
err = iterator(banana, func(key, val []byte) error {
|
||||
got[string(key)] = string(val)
|
||||
return nil
|
||||
})
|
||||
|
@ -157,11 +157,6 @@ type ChannelGraphSource interface {
|
||||
|
||||
// ForEachNode is used to iterate over every node in the known graph.
|
||||
ForEachNode(func(node *channeldb.LightningNode) error) error
|
||||
|
||||
// ForEachChannel is used to iterate over every channel in the known
|
||||
// graph.
|
||||
ForEachChannel(func(chanInfo *channeldb.ChannelEdgeInfo,
|
||||
e1, e2 *channeldb.ChannelEdgePolicy) error) error
|
||||
}
|
||||
|
||||
// PaymentAttemptDispatcher is used by the router to send payment attempts onto
|
||||
@ -2541,16 +2536,6 @@ func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx,
|
||||
})
|
||||
}
|
||||
|
||||
// ForEachChannel is used to iterate over every known edge (channel) within our
|
||||
// view of the channel graph.
|
||||
//
|
||||
// NOTE: This method is part of the ChannelGraphSource interface.
|
||||
func (r *ChannelRouter) ForEachChannel(cb func(chanInfo *channeldb.ChannelEdgeInfo,
|
||||
e1, e2 *channeldb.ChannelEdgePolicy) error) error {
|
||||
|
||||
return r.cfg.Graph.ForEachChannel(cb)
|
||||
}
|
||||
|
||||
// AddProof updates the channel edge info with proof which is needed to
|
||||
// properly announce the edge to the rest of the network.
|
||||
//
|
||||
|
Loading…
x
Reference in New Issue
Block a user