channeldb+routing: add in-memory graph

Adds an in-memory channel graph cache for faster pathfinding.

Original PoC by: Joost Jager
Co-Authored by: Oliver Gugger
This commit is contained in:
Joost Jager 2021-09-21 19:18:20 +02:00 committed by Oliver Gugger
parent d6fa912188
commit 369c09be61
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
11 changed files with 652 additions and 133 deletions

View File

@ -179,6 +179,7 @@ type ChannelGraph struct {
cacheMu sync.RWMutex
rejectCache *rejectCache
chanCache *channelCache
graphCache *GraphCache
chanScheduler batch.Scheduler
nodeScheduler batch.Scheduler
@ -197,6 +198,7 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
db: db,
rejectCache: newRejectCache(rejectCacheSize),
chanCache: newChannelCache(chanCacheSize),
graphCache: NewGraphCache(),
}
g.chanScheduler = batch.NewTimeScheduler(
db, &g.cacheMu, batchCommitInterval,
@ -205,6 +207,19 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
db, nil, batchCommitInterval,
)
startTime := time.Now()
log.Debugf("Populating in-memory channel graph, this might take a " +
"while...")
err := g.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error {
return g.graphCache.AddNode(tx, &graphCacheNode{node})
})
if err != nil {
return nil, err
}
log.Debugf("Finished populating in-memory channel graph (took %v, %s)",
time.Since(startTime), g.graphCache.Stats())
return g, nil
}
@ -286,11 +301,6 @@ func initChannelGraph(db kvdb.Backend) error {
return nil
}
// Database returns a pointer to the underlying database.
func (c *ChannelGraph) Database() kvdb.Backend {
return c.db
}
// ForEachChannel iterates through all the channel edges stored within the
// graph and invokes the passed callback for each edge. The callback takes two
// edges as since this is a directed graph, both the in/out edges are visited.
@ -354,23 +364,22 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo,
// ForEachNodeChannel iterates through all channels of a given node, executing the
// passed callback with an edge info structure and the policies of each end
// of the channel. The first edge policy is the outgoing edge *to* the
// the connecting node, while the second is the incoming edge *from* the
// connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
//
// If the caller wishes to re-use an existing boltdb transaction, then it
// should be passed as the first argument. Otherwise the first argument should
// be nil and a fresh transaction will be created to execute the graph
// traversal.
func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex,
cb func(channel *DirectedChannel) error) error {
db := c.db
return c.graphCache.ForEachChannel(node, cb)
}
return nodeTraversal(tx, nodePub, db, cb)
// FetchNodeFeatures returns the features of a given node.
func (c *ChannelGraph) FetchNodeFeatures(
node route.Vertex) (*lnwire.FeatureVector, error) {
return c.graphCache.GetFeatures(node), nil
}
// DisabledChannelIDs returns the channel ids of disabled channels.
@ -549,6 +558,11 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode,
r := &batch.Request{
Update: func(tx kvdb.RwTx) error {
wNode := &graphCacheNode{node}
if err := c.graphCache.AddNode(tx, wNode); err != nil {
return err
}
return addLightningNode(tx, node)
},
}
@ -627,6 +641,8 @@ func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error {
return ErrGraphNodeNotFound
}
c.graphCache.RemoveNode(nodePub)
return c.deleteLightningNode(nodes, nodePub[:])
}, func() {})
}
@ -753,6 +769,8 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo) error
return ErrEdgeAlreadyExist
}
c.graphCache.AddChannel(edge, nil, nil)
// Before we insert the channel into the database, we'll ensure that
// both nodes already exist in the channel graph. If either node
// doesn't, then we'll insert a "shell" node that just includes its
@ -952,6 +970,8 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error {
return ErrEdgeNotFound
}
c.graphCache.UpdateChannel(edge)
return putChanEdgeInfo(edgeIndex, edge, chanKey)
}, func() {})
}
@ -1037,7 +1057,7 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
// will be returned if that outpoint isn't known to be
// a channel. If no error is returned, then a channel
// was successfully pruned.
err = delChannelEdge(
err = c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes,
chanID, false, false,
)
@ -1088,6 +1108,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
c.chanCache.remove(channel.ChannelID)
}
log.Debugf("Pruned graph, cache now has %s", c.graphCache.Stats())
return chansClosed, nil
}
@ -1188,6 +1210,8 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket,
continue
}
c.graphCache.RemoveNode(nodePubKey)
// If we reach this point, then there are no longer any edges
// that connect this node, so we can delete it.
if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil {
@ -1286,7 +1310,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf
}
for _, k := range keys {
err = delChannelEdge(
err = c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes,
k, false, false,
)
@ -1394,7 +1418,9 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) {
// true, then when we mark these edges as zombies, we'll set up the keys such
// that we require the node that failed to send the fresh update to be the one
// that resurrects the channel from its zombie state.
func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...uint64) error {
func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool,
chanIDs ...uint64) error {
// TODO(roasbeef): possibly delete from node bucket if node has no more
// channels
// TODO(roasbeef): don't delete both edges?
@ -1427,7 +1453,7 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...u
var rawChanID [8]byte
for _, chanID := range chanIDs {
byteOrder.PutUint64(rawChanID[:], chanID)
err := delChannelEdge(
err := c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes,
rawChanID[:], true, strictZombiePruning,
)
@ -1556,7 +1582,9 @@ type ChannelEdge struct {
// ChanUpdatesInHorizon returns all the known channel edges which have at least
// one edge that has an update timestamp within the specified horizon.
func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) {
func (c *ChannelGraph) ChanUpdatesInHorizon(startTime,
endTime time.Time) ([]ChannelEdge, error) {
// To ensure we don't return duplicate ChannelEdges, we'll use an
// additional map to keep track of the edges already seen to prevent
// re-adding it.
@ -1689,7 +1717,9 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
// update timestamp within the passed range. This method can be used by two
// nodes to quickly determine if they have the same set of up to date node
// announcements.
func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) {
func (c *ChannelGraph) NodeUpdatesInHorizon(startTime,
endTime time.Time) ([]LightningNode, error) {
var nodesInHorizon []LightningNode
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
@ -2017,7 +2047,7 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64,
return nil
}
func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error {
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
@ -2025,6 +2055,11 @@ func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
return err
}
c.graphCache.RemoveChannel(
edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes,
edgeInfo.ChannelID,
)
// We'll also remove the entry in the edge update index bucket before
// we delete the edges themselves so we can access their last update
// times.
@ -2159,7 +2194,9 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy,
},
Update: func(tx kvdb.RwTx) error {
var err error
isUpdate1, err = updateEdgePolicy(tx, edge)
isUpdate1, err = updateEdgePolicy(
tx, edge, c.graphCache,
)
// Silence ErrEdgeNotFound so that the batch can
// succeed, but propagate the error via local state.
@ -2222,7 +2259,9 @@ func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy, isUpdate1 bool) {
// buckets using an existing database transaction. The returned boolean will be
// true if the updated policy belongs to node1, and false if the policy belonged
// to node2.
func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) {
func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy,
graphCache *GraphCache) (bool, error) {
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return false, ErrEdgeNotFound
@ -2270,6 +2309,14 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) {
return false, err
}
var (
fromNodePubKey route.Vertex
toNodePubKey route.Vertex
)
copy(fromNodePubKey[:], fromNode)
copy(toNodePubKey[:], toNode)
graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1)
return isUpdate1, nil
}
@ -2481,6 +2528,39 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (
return node, nil
}
// graphCacheNode is a struct that wraps a LightningNode in a way that it can be
// cached in the graph cache.
type graphCacheNode struct {
lnNode *LightningNode
}
// PubKey returns the node's public identity key.
func (w *graphCacheNode) PubKey() route.Vertex {
return w.lnNode.PubKeyBytes
}
// Features returns the node's features.
func (w *graphCacheNode) Features() *lnwire.FeatureVector {
return w.lnNode.Features
}
// ForEachChannel iterates through all channels of this node, executing the
// passed callback with an edge info structure and the policies of each end
// of the channel. The first edge policy is the outgoing edge *to* the
// connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
func (w *graphCacheNode) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
return w.lnNode.ForEachChannel(tx, cb)
}
var _ GraphCacheNode = (*graphCacheNode)(nil)
// HasLightningNode determines if the graph has a vertex identified by the
// target node identity public key. If the node exists in the database, a
// timestamp of when the data for the node was lasted updated is returned along
@ -2621,7 +2701,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend,
// ForEachChannel iterates through all channels of this node, executing the
// passed callback with an edge info structure and the policies of each end
// of the channel. The first edge policy is the outgoing edge *to* the
// the connecting node, while the second is the incoming edge *from* the
// connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller.
//
@ -2632,7 +2712,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend,
// be nil and a fresh transaction will be created to execute the graph
// traversal.
func (l *LightningNode) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
nodePub := l.PubKeyBytes[:]
db := l.db
@ -3490,6 +3571,8 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64,
"bucket: %w", err)
}
c.graphCache.RemoveChannel(pubKey1, pubKey2, chanID)
return markEdgeZombie(zombieIndex, chanID, pubKey1, pubKey2)
})
if err != nil {
@ -3544,6 +3627,18 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error {
c.rejectCache.remove(chanID)
c.chanCache.remove(chanID)
// We need to add the channel back into our graph cache, otherwise we
// won't use it for path finding.
edgeInfos, err := c.FetchChanInfos([]uint64{chanID})
if err != nil {
return err
}
for _, edgeInfo := range edgeInfos {
c.graphCache.AddChannel(
edgeInfo.Info, edgeInfo.Policy1, edgeInfo.Policy2,
)
}
return nil
}

328
channeldb/graph_cache.go Normal file
View File

@ -0,0 +1,328 @@
package channeldb
import (
"fmt"
"sync"
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// GraphCacheNode is an interface for all the information the cache needs to know
// about a lightning node.
type GraphCacheNode interface {
// PubKey is the node's public identity key.
PubKey() route.Vertex
// Features returns the node's p2p features.
Features() *lnwire.FeatureVector
// ForEachChannel iterates through all channels of a given node,
// executing the passed callback with an edge info structure and the
// policies of each end of the channel. The first edge policy is the
// outgoing edge *to* the connecting node, while the second is the
// incoming edge *from* the connecting node. If the callback returns an
// error, then the iteration is halted with the error propagated back up
// to the caller.
ForEachChannel(kvdb.RTx,
func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error
}
// DirectedChannel is a type that stores the channel information as seen from
// one side of the channel.
type DirectedChannel struct {
// ChannelID is the unique identifier of this channel.
ChannelID uint64
// IsNode1 indicates if this is the node with the smaller public key.
IsNode1 bool
// OtherNode is the public key of the node on the other end of this
// channel.
OtherNode route.Vertex
// Capacity is the announced capacity of this channel in satoshis.
Capacity btcutil.Amount
// OutPolicy is the outgoing policy from this node *to* the other node.
OutPolicy *ChannelEdgePolicy
// InPolicy is the incoming policy *from* the other node to this node.
InPolicy *ChannelEdgePolicy
}
// GraphCache is a type that holds a minimal set of information of the public
// channel graph that can be used for pathfinding.
type GraphCache struct {
nodeChannels map[route.Vertex]map[uint64]*DirectedChannel
nodeFeatures map[route.Vertex]*lnwire.FeatureVector
mtx sync.RWMutex
}
// NewGraphCache creates a new graphCache.
func NewGraphCache() *GraphCache {
return &GraphCache{
nodeChannels: make(map[route.Vertex]map[uint64]*DirectedChannel),
nodeFeatures: make(map[route.Vertex]*lnwire.FeatureVector),
}
}
// Stats returns statistics about the current cache size.
func (c *GraphCache) Stats() string {
c.mtx.RLock()
defer c.mtx.RUnlock()
numChannels := 0
for node := range c.nodeChannels {
numChannels += len(c.nodeChannels[node])
}
return fmt.Sprintf("num_node_features=%d, num_nodes=%d, "+
"num_channels=%d", len(c.nodeFeatures), len(c.nodeChannels),
numChannels)
}
// AddNode adds a graph node, including all the (directed) channels of that
// node.
func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error {
nodePubKey := node.PubKey()
// Only hold the lock for a short time. The `ForEachChannel()` below is
// possibly slow as it has to go to the backend, so we can unlock
// between the calls. And the AddChannel() method will acquire its own
// lock anyway.
c.mtx.Lock()
c.nodeFeatures[nodePubKey] = node.Features()
c.mtx.Unlock()
return node.ForEachChannel(
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,
outPolicy *ChannelEdgePolicy,
inPolicy *ChannelEdgePolicy) error {
c.AddChannel(info, outPolicy, inPolicy)
return nil
},
)
}
// AddChannel adds a non-directed channel, meaning that the order of policy 1
// and policy 2 does not matter, the directionality is extracted from the info
// and policy flags automatically. The policy will be set as the outgoing policy
// on one node and the incoming policy on the peer's side.
func (c *GraphCache) AddChannel(info *ChannelEdgeInfo,
policy1 *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) {
if info == nil {
return
}
if policy1 != nil && policy1.IsDisabled() &&
policy2 != nil && policy2.IsDisabled() {
return
}
// Create the edge entry for both nodes.
c.mtx.Lock()
c.updateOrAddEdge(info.NodeKey1Bytes, &DirectedChannel{
ChannelID: info.ChannelID,
IsNode1: true,
OtherNode: info.NodeKey2Bytes,
Capacity: info.Capacity,
})
c.updateOrAddEdge(info.NodeKey2Bytes, &DirectedChannel{
ChannelID: info.ChannelID,
IsNode1: false,
OtherNode: info.NodeKey1Bytes,
Capacity: info.Capacity,
})
c.mtx.Unlock()
// The policy's node is always the to_node. So if policy 1 has to_node
// of node 2 then we have the policy 1 as seen from node 1.
if policy1 != nil {
fromNode, toNode := info.NodeKey1Bytes, info.NodeKey2Bytes
if policy1.Node.PubKeyBytes != info.NodeKey2Bytes {
fromNode, toNode = toNode, fromNode
}
isEdge1 := policy1.ChannelFlags&lnwire.ChanUpdateDirection == 0
c.UpdatePolicy(policy1, fromNode, toNode, isEdge1)
}
if policy2 != nil {
fromNode, toNode := info.NodeKey2Bytes, info.NodeKey1Bytes
if policy2.Node.PubKeyBytes != info.NodeKey1Bytes {
fromNode, toNode = toNode, fromNode
}
isEdge1 := policy2.ChannelFlags&lnwire.ChanUpdateDirection == 0
c.UpdatePolicy(policy2, fromNode, toNode, isEdge1)
}
}
// updateOrAddEdge makes sure the edge information for a node is either updated
// if it already exists or is added to that node's list of channels.
func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) {
if len(c.nodeChannels[node]) == 0 {
c.nodeChannels[node] = make(map[uint64]*DirectedChannel)
}
c.nodeChannels[node][edge.ChannelID] = edge
}
// UpdatePolicy updates a single policy on both the from and to node. The order
// of the from and to node is not strictly important. But we assume that a
// channel edge was added beforehand so that the directed channel struct already
// exists in the cache.
func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode,
toNode route.Vertex, edge1 bool) {
// If a policy's node is nil, we can't cache it yet as that would lead
// to problems in pathfinding.
if policy.Node == nil {
// TODO(guggero): Fix this problem!
log.Warnf("Cannot cache policy because of missing node (from "+
"%x to %x)", fromNode[:], toNode[:])
return
}
c.mtx.Lock()
defer c.mtx.Unlock()
updatePolicy := func(nodeKey route.Vertex) {
if len(c.nodeChannels[nodeKey]) == 0 {
return
}
channel, ok := c.nodeChannels[nodeKey][policy.ChannelID]
if !ok {
return
}
// Edge 1 is defined as the policy for the direction of node1 to
// node2.
switch {
// This is node 1, and it is edge 1, so this is the outgoing
// policy for node 1.
case channel.IsNode1 && edge1:
channel.OutPolicy = policy
// This is node 2, and it is edge 2, so this is the outgoing
// policy for node 2.
case !channel.IsNode1 && !edge1:
channel.OutPolicy = policy
// The other two cases left mean it's the inbound policy for the
// node.
default:
channel.InPolicy = policy
}
}
updatePolicy(fromNode)
updatePolicy(toNode)
}
// RemoveNode completely removes a node and all its channels (including the
// peer's side).
func (c *GraphCache) RemoveNode(node route.Vertex) {
c.mtx.Lock()
defer c.mtx.Unlock()
delete(c.nodeFeatures, node)
// First remove all channels from the other nodes' lists.
for _, channel := range c.nodeChannels[node] {
c.removeChannelIfFound(channel.OtherNode, channel.ChannelID)
}
// Then remove our whole node completely.
delete(c.nodeChannels, node)
}
// RemoveChannel removes a single channel between two nodes.
func (c *GraphCache) RemoveChannel(node1, node2 route.Vertex, chanID uint64) {
c.mtx.Lock()
defer c.mtx.Unlock()
// Remove that one channel from both sides.
c.removeChannelIfFound(node1, chanID)
c.removeChannelIfFound(node2, chanID)
}
// removeChannelIfFound removes a single channel from one side.
func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) {
if len(c.nodeChannels[node]) == 0 {
return
}
delete(c.nodeChannels[node], chanID)
}
// UpdateChannel updates the channel edge information for a specific edge. We
// expect the edge to already exist and be known. If it does not yet exist, this
// call is a no-op.
func (c *GraphCache) UpdateChannel(info *ChannelEdgeInfo) {
c.mtx.Lock()
defer c.mtx.Unlock()
if len(c.nodeChannels[info.NodeKey1Bytes]) == 0 ||
len(c.nodeChannels[info.NodeKey2Bytes]) == 0 {
return
}
channel, ok := c.nodeChannels[info.NodeKey1Bytes][info.ChannelID]
if ok {
// We only expect to be called when the channel is already
// known.
channel.Capacity = info.Capacity
channel.OtherNode = info.NodeKey2Bytes
}
channel, ok = c.nodeChannels[info.NodeKey2Bytes][info.ChannelID]
if ok {
channel.Capacity = info.Capacity
channel.OtherNode = info.NodeKey1Bytes
}
}
// ForEachChannel invokes the given callback for each channel of the given node.
func (c *GraphCache) ForEachChannel(node route.Vertex,
cb func(channel *DirectedChannel) error) error {
c.mtx.RLock()
defer c.mtx.RUnlock()
channels, ok := c.nodeChannels[node]
if !ok {
return nil
}
for _, channel := range channels {
if err := cb(channel); err != nil {
return err
}
}
return nil
}
// GetFeatures returns the features of the node with the given ID.
func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector {
c.mtx.RLock()
defer c.mtx.RUnlock()
features, ok := c.nodeFeatures[node]
if !ok || features == nil {
// The router expects the features to never be nil, so we return
// an empty feature set instead.
return lnwire.EmptyFeatureVector()
}
return features
}

View File

@ -0,0 +1,110 @@
package channeldb
import (
"encoding/hex"
"testing"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/require"
)
var (
pubKey1Bytes, _ = hex.DecodeString(
"0248f5cba4c6da2e4c9e01e81d1404dfac0cbaf3ee934a4fc117d2ea9a64" +
"22c91d",
)
pubKey2Bytes, _ = hex.DecodeString(
"038155ba86a8d3b23c806c855097ca5c9fa0f87621f1e7a7d2835ad057f6" +
"f4484f",
)
pubKey1, _ = route.NewVertexFromBytes(pubKey1Bytes)
pubKey2, _ = route.NewVertexFromBytes(pubKey2Bytes)
)
type node struct {
pubKey route.Vertex
features *lnwire.FeatureVector
edgeInfos []*ChannelEdgeInfo
outPolicies []*ChannelEdgePolicy
inPolicies []*ChannelEdgePolicy
}
func (n *node) PubKey() route.Vertex {
return n.pubKey
}
func (n *node) Features() *lnwire.FeatureVector {
return n.features
}
func (n *node) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
for idx := range n.edgeInfos {
err := cb(
tx, n.edgeInfos[idx], n.outPolicies[idx],
n.inPolicies[idx],
)
if err != nil {
return err
}
}
return nil
}
// TestGraphCacheAddNode tests that a channel going from node A to node B can be
// cached correctly, independent of the direction we add the channel as.
func TestGraphCacheAddNode(t *testing.T) {
runTest := func(nodeA, nodeB route.Vertex) {
t.Helper()
outPolicy1 := &ChannelEdgePolicy{
ChannelID: 1000,
ChannelFlags: 0,
Node: &LightningNode{
PubKeyBytes: nodeB,
},
}
inPolicy1 := &ChannelEdgePolicy{
ChannelID: 1000,
ChannelFlags: 1,
Node: &LightningNode{
PubKeyBytes: nodeA,
},
}
node := &node{
pubKey: nodeA,
features: lnwire.EmptyFeatureVector(),
edgeInfos: []*ChannelEdgeInfo{{
ChannelID: 1000,
// Those are direction independent!
NodeKey1Bytes: pubKey1,
NodeKey2Bytes: pubKey2,
Capacity: 500,
}},
outPolicies: []*ChannelEdgePolicy{outPolicy1},
inPolicies: []*ChannelEdgePolicy{inPolicy1},
}
cache := NewGraphCache()
require.NoError(t, cache.AddNode(nil, node))
fromChannels := cache.nodeChannels[nodeA]
toChannels := cache.nodeChannels[nodeB]
require.Len(t, fromChannels, 1)
require.Len(t, toChannels, 1)
require.Equal(t, outPolicy1, fromChannels[0].OutPolicy)
require.Equal(t, inPolicy1, fromChannels[0].InPolicy)
require.Equal(t, inPolicy1, toChannels[0].OutPolicy)
require.Equal(t, outPolicy1, toChannels[0].InPolicy)
}
runTest(pubKey1, pubKey2)
runTest(pubKey2, pubKey1)
}

View File

@ -2,7 +2,6 @@ package routing
import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
@ -12,8 +11,7 @@ import (
type routingGraph interface {
// forEachNodeChannel calls the callback for every channel of the given node.
forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy) error) error
cb func(channel *channeldb.DirectedChannel) error) error
// sourceNode returns the source node of the graph.
sourceNode() route.Vertex
@ -26,7 +24,6 @@ type routingGraph interface {
// database.
type dbRoutingTx struct {
graph *channeldb.ChannelGraph
tx kvdb.RTx
source route.Vertex
}
@ -38,37 +35,19 @@ func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) {
return nil, err
}
tx, err := graph.Database().BeginReadTx()
if err != nil {
return nil, err
}
return &dbRoutingTx{
graph: graph,
tx: tx,
source: sourceNode.PubKeyBytes,
}, nil
}
// close closes the underlying db transaction.
func (g *dbRoutingTx) close() error {
return g.tx.Rollback()
}
// forEachNodeChannel calls the callback for every channel of the given node.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy) error) error {
cb func(channel *channeldb.DirectedChannel) error) error {
txCb := func(_ kvdb.RTx, info *channeldb.ChannelEdgeInfo,
p1, p2 *channeldb.ChannelEdgePolicy) error {
return cb(info, p1, p2)
}
return g.graph.ForEachNodeChannel(g.tx, nodePub[:], txCb)
return g.graph.ForEachNodeChannel(nodePub, cb)
}
// sourceNode returns the source node of the graph.
@ -85,20 +64,5 @@ func (g *dbRoutingTx) sourceNode() route.Vertex {
func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
targetNode, err := g.graph.FetchLightningNode(nodePub)
switch err {
// If the node exists and has features, return them directly.
case nil:
return targetNode.Features, nil
// If we couldn't find a node announcement, populate a blank feature
// vector.
case channeldb.ErrGraphNodeNotFound:
return lnwire.EmptyFeatureVector(), nil
// Otherwise bubble the error up.
default:
return nil, err
}
return g.graph.FetchNodeFeatures(nodePub)
}

View File

@ -159,8 +159,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte,
//
// NOTE: Part of the routingGraph interface.
func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy) error) error {
cb func(channel *channeldb.DirectedChannel) error) error {
// Look up the mock node.
node, ok := m.nodes[nodePub]
@ -171,36 +170,38 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,
// Iterate over all of its channels.
for peer, channel := range node.channels {
// Lexicographically sort the pubkeys.
var node1, node2 route.Vertex
var node1 route.Vertex
if bytes.Compare(nodePub[:], peer[:]) == -1 {
node1, node2 = peer, nodePub
node1 = peer
} else {
node1, node2 = nodePub, peer
node1 = nodePub
}
peerNode := m.nodes[peer]
// Call the per channel callback.
err := cb(
&channeldb.ChannelEdgeInfo{
NodeKey1Bytes: node1,
NodeKey2Bytes: node2,
},
&channeldb.ChannelEdgePolicy{
&channeldb.DirectedChannel{
ChannelID: channel.id,
Node: &channeldb.LightningNode{
PubKeyBytes: peer,
Features: lnwire.EmptyFeatureVector(),
IsNode1: nodePub == node1,
OtherNode: peer,
Capacity: channel.capacity,
OutPolicy: &channeldb.ChannelEdgePolicy{
ChannelID: channel.id,
Node: &channeldb.LightningNode{
PubKeyBytes: peer,
Features: lnwire.EmptyFeatureVector(),
},
FeeBaseMSat: node.baseFee,
},
FeeBaseMSat: node.baseFee,
},
&channeldb.ChannelEdgePolicy{
ChannelID: channel.id,
Node: &channeldb.LightningNode{
PubKeyBytes: nodePub,
Features: lnwire.EmptyFeatureVector(),
InPolicy: &channeldb.ChannelEdgePolicy{
ChannelID: channel.id,
Node: &channeldb.LightningNode{
PubKeyBytes: nodePub,
Features: lnwire.EmptyFeatureVector(),
},
FeeBaseMSat: peerNode.baseFee,
},
FeeBaseMSat: peerNode.baseFee,
},
)
if err != nil {

View File

@ -359,14 +359,12 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
var max, total lnwire.MilliSatoshi
cb := func(edgeInfo *channeldb.ChannelEdgeInfo, outEdge,
_ *channeldb.ChannelEdgePolicy) error {
if outEdge == nil {
cb := func(channel *channeldb.DirectedChannel) error {
if channel.OutPolicy == nil {
return nil
}
chanID := outEdge.ChannelID
chanID := channel.ChannelID
// Enforce outgoing channel restriction.
if outgoingChans != nil {
@ -381,9 +379,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
// This can happen when a channel is added to the graph after
// we've already queried the bandwidth hints.
if !ok {
bandwidth = lnwire.NewMSatFromSatoshis(
edgeInfo.Capacity,
)
bandwidth = lnwire.NewMSatFromSatoshis(channel.Capacity)
}
if bandwidth > max {
@ -889,7 +885,8 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Determine the next hop forward using the next map.
currentNodeWithDist, ok := distance[currentNode]
if !ok {
// If the node doesnt have a next hop it means we didn't find a path.
// If the node doesn't have a next hop it means we
// didn't find a path.
return nil, errNoPathFound
}

View File

@ -304,6 +304,16 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
}
}
aliasForNode := func(node route.Vertex) string {
for alias, pubKey := range aliasMap {
if pubKey == node {
return alias
}
}
return ""
}
// With all the vertexes inserted, we can now insert the edges into the
// test graph.
for _, edge := range g.Edges {
@ -353,10 +363,17 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
return nil, err
}
channelFlags := lnwire.ChanUpdateChanFlags(edge.ChannelFlags)
isUpdate1 := channelFlags&lnwire.ChanUpdateDirection == 0
targetNode := edgeInfo.NodeKey1Bytes
if isUpdate1 {
targetNode = edgeInfo.NodeKey2Bytes
}
edgePolicy := &channeldb.ChannelEdgePolicy{
SigBytes: testSig.Serialize(),
MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags),
ChannelFlags: lnwire.ChanUpdateChanFlags(edge.ChannelFlags),
ChannelFlags: channelFlags,
ChannelID: edge.ChannelID,
LastUpdate: testTime,
TimeLockDelta: edge.Expiry,
@ -364,6 +381,10 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC),
FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat),
FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate),
Node: &channeldb.LightningNode{
Alias: aliasForNode(targetNode),
PubKeyBytes: targetNode,
},
}
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, err
@ -635,6 +656,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
channelFlags |= lnwire.ChanUpdateDisabled
}
node2Features := lnwire.EmptyFeatureVector()
if node2.testChannelPolicy != nil {
node2Features = node2.Features
}
edgePolicy := &channeldb.ChannelEdgePolicy{
SigBytes: testSig.Serialize(),
MessageFlags: msgFlags,
@ -646,6 +672,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
MaxHTLC: node1.MaxHTLC,
FeeBaseMSat: node1.FeeBaseMsat,
FeeProportionalMillionths: node1.FeeRate,
Node: &channeldb.LightningNode{
Alias: node2.Alias,
PubKeyBytes: node2Vertex,
Features: node2Features,
},
}
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, err
@ -663,6 +694,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
}
channelFlags |= lnwire.ChanUpdateDirection
node1Features := lnwire.EmptyFeatureVector()
if node1.testChannelPolicy != nil {
node1Features = node1.Features
}
edgePolicy := &channeldb.ChannelEdgePolicy{
SigBytes: testSig.Serialize(),
MessageFlags: msgFlags,
@ -674,6 +710,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
MaxHTLC: node2.MaxHTLC,
FeeBaseMSat: node2.FeeBaseMsat,
FeeProportionalMillionths: node2.FeeRate,
Node: &channeldb.LightningNode{
Alias: node1.Alias,
PubKeyBytes: node1Vertex,
Features: node1Features,
},
}
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, err
@ -2980,12 +3021,6 @@ func dbFindPath(graph *channeldb.ChannelGraph,
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
return findPath(
&graphParams{

View File

@ -47,12 +47,7 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
if err != nil {
return nil, nil, err
}
return routingTx, func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}, nil
return routingTx, func() {}, nil
}
// NewPaymentSession creates a new payment session backed by the latest prune

View File

@ -1756,12 +1756,6 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
path, err := findPath(
&graphParams{
@ -2763,12 +2757,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
// Traverse hops backwards to accumulate fees in the running amounts.
source := r.selfNode.PubKeyBytes

View File

@ -1393,6 +1393,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1,
FeeBaseMSat: 10,
FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey2Bytes,
},
}
edgePolicy.ChannelFlags = 0
@ -1409,6 +1412,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1,
FeeBaseMSat: 10,
FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey1Bytes,
},
}
edgePolicy.ChannelFlags = 1
@ -1490,6 +1496,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1,
FeeBaseMSat: 10,
FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey2Bytes,
},
}
edgePolicy.ChannelFlags = 0
@ -1505,6 +1514,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1,
FeeBaseMSat: 10,
FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey1Bytes,
},
}
edgePolicy.ChannelFlags = 1

View File

@ -69,24 +69,18 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex,
// addGraphPolicies adds all policies that are known for the toNode in the
// graph.
func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error {
cb := func(edgeInfo *channeldb.ChannelEdgeInfo, _,
inEdge *channeldb.ChannelEdgePolicy) error {
cb := func(channel *channeldb.DirectedChannel) error {
// If there is no edge policy for this candidate node, skip.
// Note that we are searching backwards so this node would have
// come prior to the pivot node in the route.
if inEdge == nil {
if channel.InPolicy == nil {
return nil
}
// The node on the other end of this channel is the from node.
fromNode, err := edgeInfo.OtherNodeKeyBytes(u.toNode[:])
if err != nil {
return err
}
// Add this policy to the unified policies map.
u.addPolicy(fromNode, inEdge, edgeInfo.Capacity)
u.addPolicy(
channel.OtherNode, channel.InPolicy, channel.Capacity,
)
return nil
}