From 369c09be6152b76a24915050a8a5fe6bccf2b8f0 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Tue, 21 Sep 2021 19:18:20 +0200 Subject: [PATCH] channeldb+routing: add in-memory graph Adds an in-memory channel graph cache for faster pathfinding. Original PoC by: Joost Jager Co-Authored by: Oliver Gugger --- channeldb/graph.go | 149 +++++++++++--- channeldb/graph_cache.go | 328 ++++++++++++++++++++++++++++++ channeldb/graph_cache_test.go | 110 ++++++++++ routing/graph.go | 44 +--- routing/mock_graph_test.go | 43 ++-- routing/pathfind.go | 15 +- routing/pathfind_test.go | 49 ++++- routing/payment_session_source.go | 7 +- routing/router.go | 12 -- routing/router_test.go | 12 ++ routing/unified_policies.go | 16 +- 11 files changed, 652 insertions(+), 133 deletions(-) create mode 100644 channeldb/graph_cache.go create mode 100644 channeldb/graph_cache_test.go diff --git a/channeldb/graph.go b/channeldb/graph.go index 92b04bd93..e3ec83113 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -179,6 +179,7 @@ type ChannelGraph struct { cacheMu sync.RWMutex rejectCache *rejectCache chanCache *channelCache + graphCache *GraphCache chanScheduler batch.Scheduler nodeScheduler batch.Scheduler @@ -197,6 +198,7 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, db: db, rejectCache: newRejectCache(rejectCacheSize), chanCache: newChannelCache(chanCacheSize), + graphCache: NewGraphCache(), } g.chanScheduler = batch.NewTimeScheduler( db, &g.cacheMu, batchCommitInterval, @@ -205,6 +207,19 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, db, nil, batchCommitInterval, ) + startTime := time.Now() + log.Debugf("Populating in-memory channel graph, this might take a " + + "while...") + err := g.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { + return g.graphCache.AddNode(tx, &graphCacheNode{node}) + }) + if err != nil { + return nil, err + } + + log.Debugf("Finished populating in-memory channel graph (took %v, %s)", + time.Since(startTime), g.graphCache.Stats()) + return g, nil } @@ -286,11 +301,6 @@ func initChannelGraph(db kvdb.Backend) error { return nil } -// Database returns a pointer to the underlying database. -func (c *ChannelGraph) Database() kvdb.Backend { - return c.db -} - // ForEachChannel iterates through all the channel edges stored within the // graph and invokes the passed callback for each edge. The callback takes two // edges as since this is a directed graph, both the in/out edges are visited. @@ -354,23 +364,22 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, // ForEachNodeChannel iterates through all channels of a given node, executing the // passed callback with an edge info structure and the policies of each end // of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the +// connecting node, while the second is the incoming edge *from* the // connecting node. If the callback returns an error, then the iteration is // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error { +func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex, + cb func(channel *DirectedChannel) error) error { - db := c.db + return c.graphCache.ForEachChannel(node, cb) +} - return nodeTraversal(tx, nodePub, db, cb) +// FetchNodeFeatures returns the features of a given node. +func (c *ChannelGraph) FetchNodeFeatures( + node route.Vertex) (*lnwire.FeatureVector, error) { + + return c.graphCache.GetFeatures(node), nil } // DisabledChannelIDs returns the channel ids of disabled channels. @@ -549,6 +558,11 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode, r := &batch.Request{ Update: func(tx kvdb.RwTx) error { + wNode := &graphCacheNode{node} + if err := c.graphCache.AddNode(tx, wNode); err != nil { + return err + } + return addLightningNode(tx, node) }, } @@ -627,6 +641,8 @@ func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error { return ErrGraphNodeNotFound } + c.graphCache.RemoveNode(nodePub) + return c.deleteLightningNode(nodes, nodePub[:]) }, func() {}) } @@ -753,6 +769,8 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo) error return ErrEdgeAlreadyExist } + c.graphCache.AddChannel(edge, nil, nil) + // Before we insert the channel into the database, we'll ensure that // both nodes already exist in the channel graph. If either node // doesn't, then we'll insert a "shell" node that just includes its @@ -952,6 +970,8 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error { return ErrEdgeNotFound } + c.graphCache.UpdateChannel(edge) + return putChanEdgeInfo(edgeIndex, edge, chanKey) }, func() {}) } @@ -1037,7 +1057,7 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, // will be returned if that outpoint isn't known to be // a channel. If no error is returned, then a channel // was successfully pruned. - err = delChannelEdge( + err = c.delChannelEdge( edges, edgeIndex, chanIndex, zombieIndex, nodes, chanID, false, false, ) @@ -1088,6 +1108,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, c.chanCache.remove(channel.ChannelID) } + log.Debugf("Pruned graph, cache now has %s", c.graphCache.Stats()) + return chansClosed, nil } @@ -1188,6 +1210,8 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, continue } + c.graphCache.RemoveNode(nodePubKey) + // If we reach this point, then there are no longer any edges // that connect this node, so we can delete it. if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil { @@ -1286,7 +1310,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf } for _, k := range keys { - err = delChannelEdge( + err = c.delChannelEdge( edges, edgeIndex, chanIndex, zombieIndex, nodes, k, false, false, ) @@ -1394,7 +1418,9 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { // true, then when we mark these edges as zombies, we'll set up the keys such // that we require the node that failed to send the fresh update to be the one // that resurrects the channel from its zombie state. -func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...uint64) error { +func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, + chanIDs ...uint64) error { + // TODO(roasbeef): possibly delete from node bucket if node has no more // channels // TODO(roasbeef): don't delete both edges? @@ -1427,7 +1453,7 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...u var rawChanID [8]byte for _, chanID := range chanIDs { byteOrder.PutUint64(rawChanID[:], chanID) - err := delChannelEdge( + err := c.delChannelEdge( edges, edgeIndex, chanIndex, zombieIndex, nodes, rawChanID[:], true, strictZombiePruning, ) @@ -1556,7 +1582,9 @@ type ChannelEdge struct { // ChanUpdatesInHorizon returns all the known channel edges which have at least // one edge that has an update timestamp within the specified horizon. -func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { +func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, + endTime time.Time) ([]ChannelEdge, error) { + // To ensure we don't return duplicate ChannelEdges, we'll use an // additional map to keep track of the edges already seen to prevent // re-adding it. @@ -1689,7 +1717,9 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha // update timestamp within the passed range. This method can be used by two // nodes to quickly determine if they have the same set of up to date node // announcements. -func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { +func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, + endTime time.Time) ([]LightningNode, error) { + var nodesInHorizon []LightningNode err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -2017,7 +2047,7 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64, return nil } -func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, +func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error { edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) @@ -2025,6 +2055,11 @@ func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, return err } + c.graphCache.RemoveChannel( + edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes, + edgeInfo.ChannelID, + ) + // We'll also remove the entry in the edge update index bucket before // we delete the edges themselves so we can access their last update // times. @@ -2159,7 +2194,9 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy, }, Update: func(tx kvdb.RwTx) error { var err error - isUpdate1, err = updateEdgePolicy(tx, edge) + isUpdate1, err = updateEdgePolicy( + tx, edge, c.graphCache, + ) // Silence ErrEdgeNotFound so that the batch can // succeed, but propagate the error via local state. @@ -2222,7 +2259,9 @@ func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy, isUpdate1 bool) { // buckets using an existing database transaction. The returned boolean will be // true if the updated policy belongs to node1, and false if the policy belonged // to node2. -func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) { +func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy, + graphCache *GraphCache) (bool, error) { + edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { return false, ErrEdgeNotFound @@ -2270,6 +2309,14 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) { return false, err } + var ( + fromNodePubKey route.Vertex + toNodePubKey route.Vertex + ) + copy(fromNodePubKey[:], fromNode) + copy(toNodePubKey[:], toNode) + graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1) + return isUpdate1, nil } @@ -2481,6 +2528,39 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) ( return node, nil } +// graphCacheNode is a struct that wraps a LightningNode in a way that it can be +// cached in the graph cache. +type graphCacheNode struct { + lnNode *LightningNode +} + +// PubKey returns the node's public identity key. +func (w *graphCacheNode) PubKey() route.Vertex { + return w.lnNode.PubKeyBytes +} + +// Features returns the node's features. +func (w *graphCacheNode) Features() *lnwire.FeatureVector { + return w.lnNode.Features +} + +// ForEachChannel iterates through all channels of this node, executing the +// passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +func (w *graphCacheNode) ForEachChannel(tx kvdb.RTx, + cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { + + return w.lnNode.ForEachChannel(tx, cb) +} + +var _ GraphCacheNode = (*graphCacheNode)(nil) + // HasLightningNode determines if the graph has a vertex identified by the // target node identity public key. If the node exists in the database, a // timestamp of when the data for the node was lasted updated is returned along @@ -2621,7 +2701,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // ForEachChannel iterates through all channels of this node, executing the // passed callback with an edge info structure and the policies of each end // of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the +// connecting node, while the second is the incoming edge *from* the // connecting node. If the callback returns an error, then the iteration is // halted with the error propagated back up to the caller. // @@ -2632,7 +2712,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // be nil and a fresh transaction will be created to execute the graph // traversal. func (l *LightningNode) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { nodePub := l.PubKeyBytes[:] db := l.db @@ -3490,6 +3571,8 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64, "bucket: %w", err) } + c.graphCache.RemoveChannel(pubKey1, pubKey2, chanID) + return markEdgeZombie(zombieIndex, chanID, pubKey1, pubKey2) }) if err != nil { @@ -3544,6 +3627,18 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { c.rejectCache.remove(chanID) c.chanCache.remove(chanID) + // We need to add the channel back into our graph cache, otherwise we + // won't use it for path finding. + edgeInfos, err := c.FetchChanInfos([]uint64{chanID}) + if err != nil { + return err + } + for _, edgeInfo := range edgeInfos { + c.graphCache.AddChannel( + edgeInfo.Info, edgeInfo.Policy1, edgeInfo.Policy2, + ) + } + return nil } diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go new file mode 100644 index 000000000..d1ec6dd2a --- /dev/null +++ b/channeldb/graph_cache.go @@ -0,0 +1,328 @@ +package channeldb + +import ( + "fmt" + "sync" + + "github.com/btcsuite/btcutil" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// GraphCacheNode is an interface for all the information the cache needs to know +// about a lightning node. +type GraphCacheNode interface { + // PubKey is the node's public identity key. + PubKey() route.Vertex + + // Features returns the node's p2p features. + Features() *lnwire.FeatureVector + + // ForEachChannel iterates through all channels of a given node, + // executing the passed callback with an edge info structure and the + // policies of each end of the channel. The first edge policy is the + // outgoing edge *to* the connecting node, while the second is the + // incoming edge *from* the connecting node. If the callback returns an + // error, then the iteration is halted with the error propagated back up + // to the caller. + ForEachChannel(kvdb.RTx, + func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error +} + +// DirectedChannel is a type that stores the channel information as seen from +// one side of the channel. +type DirectedChannel struct { + // ChannelID is the unique identifier of this channel. + ChannelID uint64 + + // IsNode1 indicates if this is the node with the smaller public key. + IsNode1 bool + + // OtherNode is the public key of the node on the other end of this + // channel. + OtherNode route.Vertex + + // Capacity is the announced capacity of this channel in satoshis. + Capacity btcutil.Amount + + // OutPolicy is the outgoing policy from this node *to* the other node. + OutPolicy *ChannelEdgePolicy + + // InPolicy is the incoming policy *from* the other node to this node. + InPolicy *ChannelEdgePolicy +} + +// GraphCache is a type that holds a minimal set of information of the public +// channel graph that can be used for pathfinding. +type GraphCache struct { + nodeChannels map[route.Vertex]map[uint64]*DirectedChannel + nodeFeatures map[route.Vertex]*lnwire.FeatureVector + + mtx sync.RWMutex +} + +// NewGraphCache creates a new graphCache. +func NewGraphCache() *GraphCache { + return &GraphCache{ + nodeChannels: make(map[route.Vertex]map[uint64]*DirectedChannel), + nodeFeatures: make(map[route.Vertex]*lnwire.FeatureVector), + } +} + +// Stats returns statistics about the current cache size. +func (c *GraphCache) Stats() string { + c.mtx.RLock() + defer c.mtx.RUnlock() + + numChannels := 0 + for node := range c.nodeChannels { + numChannels += len(c.nodeChannels[node]) + } + return fmt.Sprintf("num_node_features=%d, num_nodes=%d, "+ + "num_channels=%d", len(c.nodeFeatures), len(c.nodeChannels), + numChannels) +} + +// AddNode adds a graph node, including all the (directed) channels of that +// node. +func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { + nodePubKey := node.PubKey() + + // Only hold the lock for a short time. The `ForEachChannel()` below is + // possibly slow as it has to go to the backend, so we can unlock + // between the calls. And the AddChannel() method will acquire its own + // lock anyway. + c.mtx.Lock() + c.nodeFeatures[nodePubKey] = node.Features() + c.mtx.Unlock() + + return node.ForEachChannel( + tx, func(tx kvdb.RTx, info *ChannelEdgeInfo, + outPolicy *ChannelEdgePolicy, + inPolicy *ChannelEdgePolicy) error { + + c.AddChannel(info, outPolicy, inPolicy) + + return nil + }, + ) +} + +// AddChannel adds a non-directed channel, meaning that the order of policy 1 +// and policy 2 does not matter, the directionality is extracted from the info +// and policy flags automatically. The policy will be set as the outgoing policy +// on one node and the incoming policy on the peer's side. +func (c *GraphCache) AddChannel(info *ChannelEdgeInfo, + policy1 *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) { + + if info == nil { + return + } + + if policy1 != nil && policy1.IsDisabled() && + policy2 != nil && policy2.IsDisabled() { + + return + } + + // Create the edge entry for both nodes. + c.mtx.Lock() + c.updateOrAddEdge(info.NodeKey1Bytes, &DirectedChannel{ + ChannelID: info.ChannelID, + IsNode1: true, + OtherNode: info.NodeKey2Bytes, + Capacity: info.Capacity, + }) + c.updateOrAddEdge(info.NodeKey2Bytes, &DirectedChannel{ + ChannelID: info.ChannelID, + IsNode1: false, + OtherNode: info.NodeKey1Bytes, + Capacity: info.Capacity, + }) + c.mtx.Unlock() + + // The policy's node is always the to_node. So if policy 1 has to_node + // of node 2 then we have the policy 1 as seen from node 1. + if policy1 != nil { + fromNode, toNode := info.NodeKey1Bytes, info.NodeKey2Bytes + if policy1.Node.PubKeyBytes != info.NodeKey2Bytes { + fromNode, toNode = toNode, fromNode + } + isEdge1 := policy1.ChannelFlags&lnwire.ChanUpdateDirection == 0 + c.UpdatePolicy(policy1, fromNode, toNode, isEdge1) + } + if policy2 != nil { + fromNode, toNode := info.NodeKey2Bytes, info.NodeKey1Bytes + if policy2.Node.PubKeyBytes != info.NodeKey1Bytes { + fromNode, toNode = toNode, fromNode + } + isEdge1 := policy2.ChannelFlags&lnwire.ChanUpdateDirection == 0 + c.UpdatePolicy(policy2, fromNode, toNode, isEdge1) + } +} + +// updateOrAddEdge makes sure the edge information for a node is either updated +// if it already exists or is added to that node's list of channels. +func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) { + if len(c.nodeChannels[node]) == 0 { + c.nodeChannels[node] = make(map[uint64]*DirectedChannel) + } + + c.nodeChannels[node][edge.ChannelID] = edge +} + +// UpdatePolicy updates a single policy on both the from and to node. The order +// of the from and to node is not strictly important. But we assume that a +// channel edge was added beforehand so that the directed channel struct already +// exists in the cache. +func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode, + toNode route.Vertex, edge1 bool) { + + // If a policy's node is nil, we can't cache it yet as that would lead + // to problems in pathfinding. + if policy.Node == nil { + // TODO(guggero): Fix this problem! + log.Warnf("Cannot cache policy because of missing node (from "+ + "%x to %x)", fromNode[:], toNode[:]) + return + } + + c.mtx.Lock() + defer c.mtx.Unlock() + + updatePolicy := func(nodeKey route.Vertex) { + if len(c.nodeChannels[nodeKey]) == 0 { + return + } + + channel, ok := c.nodeChannels[nodeKey][policy.ChannelID] + if !ok { + return + } + + // Edge 1 is defined as the policy for the direction of node1 to + // node2. + switch { + // This is node 1, and it is edge 1, so this is the outgoing + // policy for node 1. + case channel.IsNode1 && edge1: + channel.OutPolicy = policy + + // This is node 2, and it is edge 2, so this is the outgoing + // policy for node 2. + case !channel.IsNode1 && !edge1: + channel.OutPolicy = policy + + // The other two cases left mean it's the inbound policy for the + // node. + default: + channel.InPolicy = policy + } + } + + updatePolicy(fromNode) + updatePolicy(toNode) +} + +// RemoveNode completely removes a node and all its channels (including the +// peer's side). +func (c *GraphCache) RemoveNode(node route.Vertex) { + c.mtx.Lock() + defer c.mtx.Unlock() + + delete(c.nodeFeatures, node) + + // First remove all channels from the other nodes' lists. + for _, channel := range c.nodeChannels[node] { + c.removeChannelIfFound(channel.OtherNode, channel.ChannelID) + } + + // Then remove our whole node completely. + delete(c.nodeChannels, node) +} + +// RemoveChannel removes a single channel between two nodes. +func (c *GraphCache) RemoveChannel(node1, node2 route.Vertex, chanID uint64) { + c.mtx.Lock() + defer c.mtx.Unlock() + + // Remove that one channel from both sides. + c.removeChannelIfFound(node1, chanID) + c.removeChannelIfFound(node2, chanID) +} + +// removeChannelIfFound removes a single channel from one side. +func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) { + if len(c.nodeChannels[node]) == 0 { + return + } + + delete(c.nodeChannels[node], chanID) +} + +// UpdateChannel updates the channel edge information for a specific edge. We +// expect the edge to already exist and be known. If it does not yet exist, this +// call is a no-op. +func (c *GraphCache) UpdateChannel(info *ChannelEdgeInfo) { + c.mtx.Lock() + defer c.mtx.Unlock() + + if len(c.nodeChannels[info.NodeKey1Bytes]) == 0 || + len(c.nodeChannels[info.NodeKey2Bytes]) == 0 { + + return + } + + channel, ok := c.nodeChannels[info.NodeKey1Bytes][info.ChannelID] + if ok { + // We only expect to be called when the channel is already + // known. + channel.Capacity = info.Capacity + channel.OtherNode = info.NodeKey2Bytes + } + + channel, ok = c.nodeChannels[info.NodeKey2Bytes][info.ChannelID] + if ok { + channel.Capacity = info.Capacity + channel.OtherNode = info.NodeKey1Bytes + } +} + +// ForEachChannel invokes the given callback for each channel of the given node. +func (c *GraphCache) ForEachChannel(node route.Vertex, + cb func(channel *DirectedChannel) error) error { + + c.mtx.RLock() + defer c.mtx.RUnlock() + + channels, ok := c.nodeChannels[node] + if !ok { + return nil + } + + for _, channel := range channels { + if err := cb(channel); err != nil { + return err + } + } + + return nil +} + +// GetFeatures returns the features of the node with the given ID. +func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector { + c.mtx.RLock() + defer c.mtx.RUnlock() + + features, ok := c.nodeFeatures[node] + if !ok || features == nil { + // The router expects the features to never be nil, so we return + // an empty feature set instead. + return lnwire.EmptyFeatureVector() + } + + return features +} diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go new file mode 100644 index 000000000..71967c68c --- /dev/null +++ b/channeldb/graph_cache_test.go @@ -0,0 +1,110 @@ +package channeldb + +import ( + "encoding/hex" + "testing" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" +) + +var ( + pubKey1Bytes, _ = hex.DecodeString( + "0248f5cba4c6da2e4c9e01e81d1404dfac0cbaf3ee934a4fc117d2ea9a64" + + "22c91d", + ) + pubKey2Bytes, _ = hex.DecodeString( + "038155ba86a8d3b23c806c855097ca5c9fa0f87621f1e7a7d2835ad057f6" + + "f4484f", + ) + + pubKey1, _ = route.NewVertexFromBytes(pubKey1Bytes) + pubKey2, _ = route.NewVertexFromBytes(pubKey2Bytes) +) + +type node struct { + pubKey route.Vertex + features *lnwire.FeatureVector + + edgeInfos []*ChannelEdgeInfo + outPolicies []*ChannelEdgePolicy + inPolicies []*ChannelEdgePolicy +} + +func (n *node) PubKey() route.Vertex { + return n.pubKey +} +func (n *node) Features() *lnwire.FeatureVector { + return n.features +} + +func (n *node) ForEachChannel(tx kvdb.RTx, + cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { + + for idx := range n.edgeInfos { + err := cb( + tx, n.edgeInfos[idx], n.outPolicies[idx], + n.inPolicies[idx], + ) + if err != nil { + return err + } + } + + return nil +} + +// TestGraphCacheAddNode tests that a channel going from node A to node B can be +// cached correctly, independent of the direction we add the channel as. +func TestGraphCacheAddNode(t *testing.T) { + runTest := func(nodeA, nodeB route.Vertex) { + t.Helper() + + outPolicy1 := &ChannelEdgePolicy{ + ChannelID: 1000, + ChannelFlags: 0, + Node: &LightningNode{ + PubKeyBytes: nodeB, + }, + } + inPolicy1 := &ChannelEdgePolicy{ + ChannelID: 1000, + ChannelFlags: 1, + Node: &LightningNode{ + PubKeyBytes: nodeA, + }, + } + node := &node{ + pubKey: nodeA, + features: lnwire.EmptyFeatureVector(), + edgeInfos: []*ChannelEdgeInfo{{ + ChannelID: 1000, + // Those are direction independent! + NodeKey1Bytes: pubKey1, + NodeKey2Bytes: pubKey2, + Capacity: 500, + }}, + outPolicies: []*ChannelEdgePolicy{outPolicy1}, + inPolicies: []*ChannelEdgePolicy{inPolicy1}, + } + cache := NewGraphCache() + require.NoError(t, cache.AddNode(nil, node)) + + fromChannels := cache.nodeChannels[nodeA] + toChannels := cache.nodeChannels[nodeB] + + require.Len(t, fromChannels, 1) + require.Len(t, toChannels, 1) + + require.Equal(t, outPolicy1, fromChannels[0].OutPolicy) + require.Equal(t, inPolicy1, fromChannels[0].InPolicy) + + require.Equal(t, inPolicy1, toChannels[0].OutPolicy) + require.Equal(t, outPolicy1, toChannels[0].InPolicy) + } + runTest(pubKey1, pubKey2) + runTest(pubKey2, pubKey1) +} diff --git a/routing/graph.go b/routing/graph.go index be58698f4..578f480ab 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -2,7 +2,6 @@ package routing import ( "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -12,8 +11,7 @@ import ( type routingGraph interface { // forEachNodeChannel calls the callback for every channel of the given node. forEachNodeChannel(nodePub route.Vertex, - cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) error) error + cb func(channel *channeldb.DirectedChannel) error) error // sourceNode returns the source node of the graph. sourceNode() route.Vertex @@ -26,7 +24,6 @@ type routingGraph interface { // database. type dbRoutingTx struct { graph *channeldb.ChannelGraph - tx kvdb.RTx source route.Vertex } @@ -38,37 +35,19 @@ func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { return nil, err } - tx, err := graph.Database().BeginReadTx() - if err != nil { - return nil, err - } - return &dbRoutingTx{ graph: graph, - tx: tx, source: sourceNode.PubKeyBytes, }, nil } -// close closes the underlying db transaction. -func (g *dbRoutingTx) close() error { - return g.tx.Rollback() -} - // forEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the routingGraph interface. func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, - cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) error) error { + cb func(channel *channeldb.DirectedChannel) error) error { - txCb := func(_ kvdb.RTx, info *channeldb.ChannelEdgeInfo, - p1, p2 *channeldb.ChannelEdgePolicy) error { - - return cb(info, p1, p2) - } - - return g.graph.ForEachNodeChannel(g.tx, nodePub[:], txCb) + return g.graph.ForEachNodeChannel(nodePub, cb) } // sourceNode returns the source node of the graph. @@ -85,20 +64,5 @@ func (g *dbRoutingTx) sourceNode() route.Vertex { func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { - targetNode, err := g.graph.FetchLightningNode(nodePub) - switch err { - - // If the node exists and has features, return them directly. - case nil: - return targetNode.Features, nil - - // If we couldn't find a node announcement, populate a blank feature - // vector. - case channeldb.ErrGraphNodeNotFound: - return lnwire.EmptyFeatureVector(), nil - - // Otherwise bubble the error up. - default: - return nil, err - } + return g.graph.FetchNodeFeatures(nodePub) } diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index 3834d9e51..badeeebb9 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -159,8 +159,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte, // // NOTE: Part of the routingGraph interface. func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, - cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) error) error { + cb func(channel *channeldb.DirectedChannel) error) error { // Look up the mock node. node, ok := m.nodes[nodePub] @@ -171,36 +170,38 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, // Iterate over all of its channels. for peer, channel := range node.channels { // Lexicographically sort the pubkeys. - var node1, node2 route.Vertex + var node1 route.Vertex if bytes.Compare(nodePub[:], peer[:]) == -1 { - node1, node2 = peer, nodePub + node1 = peer } else { - node1, node2 = nodePub, peer + node1 = nodePub } peerNode := m.nodes[peer] // Call the per channel callback. err := cb( - &channeldb.ChannelEdgeInfo{ - NodeKey1Bytes: node1, - NodeKey2Bytes: node2, - }, - &channeldb.ChannelEdgePolicy{ + &channeldb.DirectedChannel{ ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: peer, - Features: lnwire.EmptyFeatureVector(), + IsNode1: nodePub == node1, + OtherNode: peer, + Capacity: channel.capacity, + OutPolicy: &channeldb.ChannelEdgePolicy{ + ChannelID: channel.id, + Node: &channeldb.LightningNode{ + PubKeyBytes: peer, + Features: lnwire.EmptyFeatureVector(), + }, + FeeBaseMSat: node.baseFee, }, - FeeBaseMSat: node.baseFee, - }, - &channeldb.ChannelEdgePolicy{ - ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: nodePub, - Features: lnwire.EmptyFeatureVector(), + InPolicy: &channeldb.ChannelEdgePolicy{ + ChannelID: channel.id, + Node: &channeldb.LightningNode{ + PubKeyBytes: nodePub, + Features: lnwire.EmptyFeatureVector(), + }, + FeeBaseMSat: peerNode.baseFee, }, - FeeBaseMSat: peerNode.baseFee, }, ) if err != nil { diff --git a/routing/pathfind.go b/routing/pathfind.go index 3d722c822..fc3be7942 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -359,14 +359,12 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi - cb := func(edgeInfo *channeldb.ChannelEdgeInfo, outEdge, - _ *channeldb.ChannelEdgePolicy) error { - - if outEdge == nil { + cb := func(channel *channeldb.DirectedChannel) error { + if channel.OutPolicy == nil { return nil } - chanID := outEdge.ChannelID + chanID := channel.ChannelID // Enforce outgoing channel restriction. if outgoingChans != nil { @@ -381,9 +379,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // This can happen when a channel is added to the graph after // we've already queried the bandwidth hints. if !ok { - bandwidth = lnwire.NewMSatFromSatoshis( - edgeInfo.Capacity, - ) + bandwidth = lnwire.NewMSatFromSatoshis(channel.Capacity) } if bandwidth > max { @@ -889,7 +885,8 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Determine the next hop forward using the next map. currentNodeWithDist, ok := distance[currentNode] if !ok { - // If the node doesnt have a next hop it means we didn't find a path. + // If the node doesn't have a next hop it means we + // didn't find a path. return nil, errNoPathFound } diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index d098429c1..b353c24ea 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -304,6 +304,16 @@ func parseTestGraph(path string) (*testGraphInstance, error) { } } + aliasForNode := func(node route.Vertex) string { + for alias, pubKey := range aliasMap { + if pubKey == node { + return alias + } + } + + return "" + } + // With all the vertexes inserted, we can now insert the edges into the // test graph. for _, edge := range g.Edges { @@ -353,10 +363,17 @@ func parseTestGraph(path string) (*testGraphInstance, error) { return nil, err } + channelFlags := lnwire.ChanUpdateChanFlags(edge.ChannelFlags) + isUpdate1 := channelFlags&lnwire.ChanUpdateDirection == 0 + targetNode := edgeInfo.NodeKey1Bytes + if isUpdate1 { + targetNode = edgeInfo.NodeKey2Bytes + } + edgePolicy := &channeldb.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), - ChannelFlags: lnwire.ChanUpdateChanFlags(edge.ChannelFlags), + ChannelFlags: channelFlags, ChannelID: edge.ChannelID, LastUpdate: testTime, TimeLockDelta: edge.Expiry, @@ -364,6 +381,10 @@ func parseTestGraph(path string) (*testGraphInstance, error) { MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC), FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat), FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), + Node: &channeldb.LightningNode{ + Alias: aliasForNode(targetNode), + PubKeyBytes: targetNode, + }, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -635,6 +656,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( channelFlags |= lnwire.ChanUpdateDisabled } + node2Features := lnwire.EmptyFeatureVector() + if node2.testChannelPolicy != nil { + node2Features = node2.Features + } + edgePolicy := &channeldb.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, @@ -646,6 +672,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( MaxHTLC: node1.MaxHTLC, FeeBaseMSat: node1.FeeBaseMsat, FeeProportionalMillionths: node1.FeeRate, + Node: &channeldb.LightningNode{ + Alias: node2.Alias, + PubKeyBytes: node2Vertex, + Features: node2Features, + }, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -663,6 +694,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( } channelFlags |= lnwire.ChanUpdateDirection + node1Features := lnwire.EmptyFeatureVector() + if node1.testChannelPolicy != nil { + node1Features = node1.Features + } + edgePolicy := &channeldb.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, @@ -674,6 +710,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( MaxHTLC: node2.MaxHTLC, FeeBaseMSat: node2.FeeBaseMsat, FeeProportionalMillionths: node2.FeeRate, + Node: &channeldb.LightningNode{ + Alias: node1.Alias, + PubKeyBytes: node1Vertex, + Features: node1Features, + }, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -2980,12 +3021,6 @@ func dbFindPath(graph *channeldb.ChannelGraph, if err != nil { return nil, err } - defer func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }() return findPath( &graphParams{ diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 8122ff711..f08005909 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -47,12 +47,7 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { if err != nil { return nil, nil, err } - return routingTx, func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }, nil + return routingTx, func() {}, nil } // NewPaymentSession creates a new payment session backed by the latest prune diff --git a/routing/router.go b/routing/router.go index 00fa4d316..9864a991d 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1756,12 +1756,6 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, if err != nil { return nil, err } - defer func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }() path, err := findPath( &graphParams{ @@ -2763,12 +2757,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, if err != nil { return nil, err } - defer func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }() // Traverse hops backwards to accumulate fees in the running amounts. source := r.selfNode.PubKeyBytes diff --git a/routing/router_test.go b/routing/router_test.go index 510d18bf5..1633d3810 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1393,6 +1393,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey2Bytes, + }, } edgePolicy.ChannelFlags = 0 @@ -1409,6 +1412,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey1Bytes, + }, } edgePolicy.ChannelFlags = 1 @@ -1490,6 +1496,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey2Bytes, + }, } edgePolicy.ChannelFlags = 0 @@ -1505,6 +1514,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey1Bytes, + }, } edgePolicy.ChannelFlags = 1 diff --git a/routing/unified_policies.go b/routing/unified_policies.go index 0ff509382..4a6e5e00b 100644 --- a/routing/unified_policies.go +++ b/routing/unified_policies.go @@ -69,24 +69,18 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { - cb := func(edgeInfo *channeldb.ChannelEdgeInfo, _, - inEdge *channeldb.ChannelEdgePolicy) error { - + cb := func(channel *channeldb.DirectedChannel) error { // If there is no edge policy for this candidate node, skip. // Note that we are searching backwards so this node would have // come prior to the pivot node in the route. - if inEdge == nil { + if channel.InPolicy == nil { return nil } - // The node on the other end of this channel is the from node. - fromNode, err := edgeInfo.OtherNodeKeyBytes(u.toNode[:]) - if err != nil { - return err - } - // Add this policy to the unified policies map. - u.addPolicy(fromNode, inEdge, edgeInfo.Capacity) + u.addPolicy( + channel.OtherNode, channel.InPolicy, channel.Capacity, + ) return nil }