From 1d1c42f9bae49ac5d3dc24eba7a1451668e83464 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 21 Sep 2021 19:18:22 +0200 Subject: [PATCH] multi: use minimal policy in cache --- channeldb/graph.go | 1 - channeldb/graph_cache.go | 155 ++++++++- channeldb/graph_cache_test.go | 53 ++- channeldb/graph_test.go | 462 +++++++++++++++++++------ lnrpc/routerrpc/router_backend.go | 2 +- lnrpc/routerrpc/router_backend_test.go | 2 +- routing/heap.go | 2 +- routing/mock_graph_test.go | 17 +- routing/mock_test.go | 10 +- routing/pathfind.go | 28 +- routing/pathfind_test.go | 58 ++-- routing/payment_lifecycle.go | 2 +- routing/payment_session.go | 10 +- routing/payment_session_source.go | 13 +- routing/payment_session_test.go | 13 +- routing/router.go | 4 +- routing/router_test.go | 4 +- routing/unified_policies.go | 14 +- routing/unified_policies_test.go | 6 +- 19 files changed, 629 insertions(+), 227 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index 8806bcff1..e3ec83113 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2315,7 +2315,6 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy, ) copy(fromNodePubKey[:], fromNode) copy(toNodePubKey[:], toNode) - // TODO(guggero): Fetch lightning nodes before updating the cache! graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1) return isUpdate1, nil diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index d1ec6dd2a..f36d022fb 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -32,6 +32,92 @@ type GraphCacheNode interface { *ChannelEdgePolicy) error) error } +// CachedEdgePolicy is a struct that only caches the information of a +// ChannelEdgePolicy that we actually use for pathfinding and therefore need to +// store in the cache. +type CachedEdgePolicy struct { + // ChannelID is the unique channel ID for the channel. The first 3 + // bytes are the block height, the next 3 the index within the block, + // and the last 2 bytes are the output index for the channel. + ChannelID uint64 + + // MessageFlags is a bitfield which indicates the presence of optional + // fields (like max_htlc) in the policy. + MessageFlags lnwire.ChanUpdateMsgFlags + + // ChannelFlags is a bitfield which signals the capabilities of the + // channel as well as the directed edge this update applies to. + ChannelFlags lnwire.ChanUpdateChanFlags + + // TimeLockDelta is the number of blocks this node will subtract from + // the expiry of an incoming HTLC. This value expresses the time buffer + // the node would like to HTLC exchanges. + TimeLockDelta uint16 + + // MinHTLC is the smallest value HTLC this node will forward, expressed + // in millisatoshi. + MinHTLC lnwire.MilliSatoshi + + // MaxHTLC is the largest value HTLC this node will forward, expressed + // in millisatoshi. + MaxHTLC lnwire.MilliSatoshi + + // FeeBaseMSat is the base HTLC fee that will be charged for forwarding + // ANY HTLC, expressed in mSAT's. + FeeBaseMSat lnwire.MilliSatoshi + + // FeeProportionalMillionths is the rate that the node will charge for + // HTLCs for each millionth of a satoshi forwarded. + FeeProportionalMillionths lnwire.MilliSatoshi + + // ToNodePubKey is a function that returns the to node of a policy. + // Since we only ever store the inbound policy, this is always the node + // that we query the channels for in ForEachChannel(). Therefore, we can + // save a lot of space by not storing this information in the memory and + // instead just set this function when we copy the policy from cache in + // ForEachChannel(). + ToNodePubKey func() route.Vertex + + // ToNodeFeatures are the to node's features. They are never set while + // the edge is in the cache, only on the copy that is returned in + // ForEachChannel(). + ToNodeFeatures *lnwire.FeatureVector +} + +// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over +// the passed active payment channel. This value is currently computed as +// specified in BOLT07, but will likely change in the near future. +func (c *CachedEdgePolicy) ComputeFee( + amt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts +} + +// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming +// amount. +func (c *CachedEdgePolicy) ComputeFeeFromIncoming( + incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return incomingAmt - divideCeil( + feeRateParts*(incomingAmt-c.FeeBaseMSat), + feeRateParts+c.FeeProportionalMillionths, + ) +} + +// NewCachedPolicy turns a full policy into a minimal one that can be cached. +func NewCachedPolicy(policy *ChannelEdgePolicy) *CachedEdgePolicy { + return &CachedEdgePolicy{ + ChannelID: policy.ChannelID, + MessageFlags: policy.MessageFlags, + ChannelFlags: policy.ChannelFlags, + TimeLockDelta: policy.TimeLockDelta, + MinHTLC: policy.MinHTLC, + MaxHTLC: policy.MaxHTLC, + FeeBaseMSat: policy.FeeBaseMSat, + FeeProportionalMillionths: policy.FeeProportionalMillionths, + } +} + // DirectedChannel is a type that stores the channel information as seen from // one side of the channel. type DirectedChannel struct { @@ -48,11 +134,35 @@ type DirectedChannel struct { // Capacity is the announced capacity of this channel in satoshis. Capacity btcutil.Amount - // OutPolicy is the outgoing policy from this node *to* the other node. - OutPolicy *ChannelEdgePolicy + // OutPolicySet is a boolean that indicates whether the node has an + // outgoing policy set. For pathfinding only the existence of the policy + // is important to know, not the actual content. + OutPolicySet bool // InPolicy is the incoming policy *from* the other node to this node. - InPolicy *ChannelEdgePolicy + // In path finding, we're walking backward from the destination to the + // source, so we're always interested in the edge that arrives to us + // from the other node. + InPolicy *CachedEdgePolicy +} + +// DeepCopy creates a deep copy of the channel, including the incoming policy. +func (c *DirectedChannel) DeepCopy() *DirectedChannel { + channelCopy := *c + + if channelCopy.InPolicy != nil { + inPolicyCopy := *channelCopy.InPolicy + channelCopy.InPolicy = &inPolicyCopy + + // The fields for the ToNode can be overwritten by the path + // finding algorithm, which is why we need a deep copy in the + // first place. So we always start out with nil values, just to + // be sure they don't contain any old data. + channelCopy.InPolicy.ToNodePubKey = nil + channelCopy.InPolicy.ToNodeFeatures = nil + } + + return &channelCopy } // GraphCache is a type that holds a minimal set of information of the public @@ -181,15 +291,6 @@ func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) { func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode, toNode route.Vertex, edge1 bool) { - // If a policy's node is nil, we can't cache it yet as that would lead - // to problems in pathfinding. - if policy.Node == nil { - // TODO(guggero): Fix this problem! - log.Warnf("Cannot cache policy because of missing node (from "+ - "%x to %x)", fromNode[:], toNode[:]) - return - } - c.mtx.Lock() defer c.mtx.Unlock() @@ -209,17 +310,17 @@ func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode, // This is node 1, and it is edge 1, so this is the outgoing // policy for node 1. case channel.IsNode1 && edge1: - channel.OutPolicy = policy + channel.OutPolicySet = true // This is node 2, and it is edge 2, so this is the outgoing // policy for node 2. case !channel.IsNode1 && !edge1: - channel.OutPolicy = policy + channel.OutPolicySet = true // The other two cases left mean it's the inbound policy for the // node. default: - channel.InPolicy = policy + channel.InPolicy = NewCachedPolicy(policy) } } @@ -303,8 +404,30 @@ func (c *GraphCache) ForEachChannel(node route.Vertex, return nil } + features, ok := c.nodeFeatures[node] + if !ok { + log.Warnf("Node %v has no features defined, falling back to "+ + "default feature vector for path finding", node) + + features = lnwire.EmptyFeatureVector() + } + + toNodeCallback := func() route.Vertex { + return node + } + for _, channel := range channels { - if err := cb(channel); err != nil { + // We need to copy the channel and policy to avoid it being + // updated in the cache if the path finding algorithm sets + // fields on it (currently only the ToNodeFeatures of the + // policy). + channelCopy := channel.DeepCopy() + if channelCopy.InPolicy != nil { + channelCopy.InPolicy.ToNodePubKey = toNodeCallback + channelCopy.InPolicy.ToNodeFeatures = features + } + + if err := cb(channelCopy); err != nil { return err } } diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index 71967c68c..57666e1eb 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -63,18 +63,25 @@ func TestGraphCacheAddNode(t *testing.T) { runTest := func(nodeA, nodeB route.Vertex) { t.Helper() + channelFlagA, channelFlagB := 0, 1 + if nodeA == pubKey2 { + channelFlagA, channelFlagB = 1, 0 + } + outPolicy1 := &ChannelEdgePolicy{ ChannelID: 1000, - ChannelFlags: 0, + ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA), Node: &LightningNode{ PubKeyBytes: nodeB, + Features: lnwire.EmptyFeatureVector(), }, } inPolicy1 := &ChannelEdgePolicy{ ChannelID: 1000, - ChannelFlags: 1, + ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB), Node: &LightningNode{ PubKeyBytes: nodeA, + Features: lnwire.EmptyFeatureVector(), }, } node := &node{ @@ -93,18 +100,48 @@ func TestGraphCacheAddNode(t *testing.T) { cache := NewGraphCache() require.NoError(t, cache.AddNode(nil, node)) - fromChannels := cache.nodeChannels[nodeA] - toChannels := cache.nodeChannels[nodeB] + var fromChannels, toChannels []*DirectedChannel + _ = cache.ForEachChannel(nodeA, func(c *DirectedChannel) error { + fromChannels = append(fromChannels, c) + return nil + }) + _ = cache.ForEachChannel(nodeB, func(c *DirectedChannel) error { + toChannels = append(toChannels, c) + return nil + }) require.Len(t, fromChannels, 1) require.Len(t, toChannels, 1) - require.Equal(t, outPolicy1, fromChannels[0].OutPolicy) - require.Equal(t, inPolicy1, fromChannels[0].InPolicy) + require.Equal(t, outPolicy1 != nil, fromChannels[0].OutPolicySet) + assertCachedPolicyEqual(t, inPolicy1, fromChannels[0].InPolicy) - require.Equal(t, inPolicy1, toChannels[0].OutPolicy) - require.Equal(t, outPolicy1, toChannels[0].InPolicy) + require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet) + assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy) } + runTest(pubKey1, pubKey2) runTest(pubKey2, pubKey1) } + +func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy, + cached *CachedEdgePolicy) { + + require.Equal(t, original.ChannelID, cached.ChannelID) + require.Equal(t, original.MessageFlags, cached.MessageFlags) + require.Equal(t, original.ChannelFlags, cached.ChannelFlags) + require.Equal(t, original.TimeLockDelta, cached.TimeLockDelta) + require.Equal(t, original.MinHTLC, cached.MinHTLC) + require.Equal(t, original.MaxHTLC, cached.MaxHTLC) + require.Equal(t, original.FeeBaseMSat, cached.FeeBaseMSat) + require.Equal( + t, original.FeeProportionalMillionths, + cached.FeeProportionalMillionths, + ) + require.Equal( + t, + route.Vertex(original.Node.PubKeyBytes), + cached.ToNodePubKey(), + ) + require.Equal(t, original.Node.Features, cached.ToNodeFeatures) +} diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index d2953a523..e624105a3 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -42,7 +42,10 @@ var ( _, _ = testSig.R.SetString("63724406601629180062774974542967536251589935445068131219452686511677818569431", 10) _, _ = testSig.S.SetString("18801056069249825825291287104931333862866033135609736119018462340006816851118", 10) - testFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) + testFeatures = lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), + lnwire.Features, + ) testPub = route.Vertex{2, 202, 4} ) @@ -146,6 +149,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { if err := graph.AddLightningNode(node); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node, testFeatures) // Next, fetch the node from the database to ensure everything was // serialized properly. @@ -170,6 +174,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { if err := graph.DeleteLightningNode(testPub); err != nil { t.Fatalf("unable to delete node; %v", err) } + assertNodeNotInCache(t, graph, testPub) // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. @@ -200,6 +205,7 @@ func TestPartialNode(t *testing.T) { if err := graph.AddLightningNode(node); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node, nil) // Next, fetch the node from the database to ensure everything was // serialized properly. @@ -232,6 +238,7 @@ func TestPartialNode(t *testing.T) { if err := graph.DeleteLightningNode(testPub); err != nil { t.Fatalf("unable to delete node: %v", err) } + assertNodeNotInCache(t, graph, testPub) // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. @@ -390,6 +397,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { if err := graph.AddChannelEdge(&edgeInfo); err != nil { t.Fatalf("unable to create channel edge: %v", err) } + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo) // Ensure that both policies are returned as unknown (nil). _, e1, e2, err := graph.FetchChannelEdgesByID(chanID) @@ -405,6 +413,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { if err := graph.DeleteChannelEdges(false, chanID); err != nil { t.Fatalf("unable to delete edge: %v", err) } + assertNoEdge(t, graph, chanID) // Ensure that any query attempts to lookup the delete channel edge are // properly deleted. @@ -544,6 +553,9 @@ func TestDisconnectBlockAtHeight(t *testing.T) { if err := graph.AddChannelEdge(&edgeInfo3); err != nil { t.Fatalf("unable to create channel edge: %v", err) } + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo) + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo2) + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo3) // Call DisconnectBlockAtHeight, which should prune every channel // that has a funding height of 'height' or greater. @@ -551,6 +563,9 @@ func TestDisconnectBlockAtHeight(t *testing.T) { if err != nil { t.Fatalf("unable to prune %v", err) } + assertNoEdge(t, graph, edgeInfo.ChannelID) + assertNoEdge(t, graph, edgeInfo2.ChannelID) + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo3) // The two edges should have been removed. if len(removed) != 2 { @@ -769,6 +784,7 @@ func TestEdgeInfoUpdates(t *testing.T) { if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node1, testFeatures) node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) @@ -776,6 +792,7 @@ func TestEdgeInfoUpdates(t *testing.T) { if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node2, testFeatures) // Create an edge and add it to the db. edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) @@ -785,11 +802,13 @@ func TestEdgeInfoUpdates(t *testing.T) { if err := graph.UpdateEdgePolicy(edge1); err != ErrEdgeNotFound { t.Fatalf("expected ErrEdgeNotFound, got: %v", err) } + require.Len(t, graph.graphCache.nodeChannels, 0) // Add the edge info. if err := graph.AddChannelEdge(edgeInfo); err != nil { t.Fatalf("unable to create channel edge: %v", err) } + assertEdgeWithNoPoliciesInCache(t, graph, edgeInfo) chanID := edgeInfo.ChannelID outpoint := edgeInfo.ChannelPoint @@ -799,9 +818,11 @@ func TestEdgeInfoUpdates(t *testing.T) { if err := graph.UpdateEdgePolicy(edge1); err != nil { t.Fatalf("unable to update edge: %v", err) } + assertEdgeWithPolicyInCache(t, graph, edgeInfo, edge1, true) if err := graph.UpdateEdgePolicy(edge2); err != nil { t.Fatalf("unable to update edge: %v", err) } + assertEdgeWithPolicyInCache(t, graph, edgeInfo, edge2, false) // Check for existence of the edge within the database, it should be // found. @@ -856,6 +877,191 @@ func TestEdgeInfoUpdates(t *testing.T) { assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) } +func assertNodeInCache(t *testing.T, g *ChannelGraph, n *LightningNode, + expectedFeatures *lnwire.FeatureVector) { + + // Let's check the internal view first. + require.Equal( + t, expectedFeatures, g.graphCache.nodeFeatures[n.PubKeyBytes], + ) + + // The external view should reflect this as well. Except when we expect + // the features to be nil internally, we return an empty feature vector + // on the public interface instead. + if expectedFeatures == nil { + expectedFeatures = lnwire.EmptyFeatureVector() + } + features := g.graphCache.GetFeatures(n.PubKeyBytes) + require.Equal(t, expectedFeatures, features) +} + +func assertNodeNotInCache(t *testing.T, g *ChannelGraph, n route.Vertex) { + _, ok := g.graphCache.nodeFeatures[n] + require.False(t, ok) + + _, ok = g.graphCache.nodeChannels[n] + require.False(t, ok) + + // We should get the default features for this node. + features := g.graphCache.GetFeatures(n) + require.Equal(t, lnwire.EmptyFeatureVector(), features) +} + +func assertEdgeWithNoPoliciesInCache(t *testing.T, g *ChannelGraph, + e *ChannelEdgeInfo) { + + // Let's check the internal view first. + require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey1Bytes]) + require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey2Bytes]) + + expectedNode1Channel := &DirectedChannel{ + ChannelID: e.ChannelID, + IsNode1: true, + OtherNode: e.NodeKey2Bytes, + Capacity: e.Capacity, + OutPolicySet: false, + InPolicy: nil, + } + require.Contains( + t, g.graphCache.nodeChannels[e.NodeKey1Bytes], e.ChannelID, + ) + require.Equal( + t, expectedNode1Channel, + g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID], + ) + + expectedNode2Channel := &DirectedChannel{ + ChannelID: e.ChannelID, + IsNode1: false, + OtherNode: e.NodeKey1Bytes, + Capacity: e.Capacity, + OutPolicySet: false, + InPolicy: nil, + } + require.Contains( + t, g.graphCache.nodeChannels[e.NodeKey2Bytes], e.ChannelID, + ) + require.Equal( + t, expectedNode2Channel, + g.graphCache.nodeChannels[e.NodeKey2Bytes][e.ChannelID], + ) + + // The external view should reflect this as well. + var foundChannel *DirectedChannel + err := g.graphCache.ForEachChannel( + e.NodeKey1Bytes, func(c *DirectedChannel) error { + if c.ChannelID == e.ChannelID { + foundChannel = c + } + + return nil + }, + ) + require.NoError(t, err) + require.NotNil(t, foundChannel) + require.Equal(t, expectedNode1Channel, foundChannel) + + err = g.graphCache.ForEachChannel( + e.NodeKey2Bytes, func(c *DirectedChannel) error { + if c.ChannelID == e.ChannelID { + foundChannel = c + } + + return nil + }, + ) + require.NoError(t, err) + require.NotNil(t, foundChannel) + require.Equal(t, expectedNode2Channel, foundChannel) +} + +func assertNoEdge(t *testing.T, g *ChannelGraph, chanID uint64) { + // Make sure no channel in the cache has the given channel ID. If there + // are no channels at all, that is fine as well. + for _, channels := range g.graphCache.nodeChannels { + for _, channel := range channels { + require.NotEqual(t, channel.ChannelID, chanID) + } + } +} + +func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, + e *ChannelEdgeInfo, p *ChannelEdgePolicy, policy1 bool) { + + // Check the internal state first. + c1, ok := g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID] + require.True(t, ok) + + if policy1 { + require.True(t, c1.OutPolicySet) + } else { + require.NotNil(t, c1.InPolicy) + require.Equal( + t, p.FeeProportionalMillionths, + c1.InPolicy.FeeProportionalMillionths, + ) + } + + c2, ok := g.graphCache.nodeChannels[e.NodeKey2Bytes][e.ChannelID] + require.True(t, ok) + + if policy1 { + require.NotNil(t, c2.InPolicy) + require.Equal( + t, p.FeeProportionalMillionths, + c2.InPolicy.FeeProportionalMillionths, + ) + } else { + require.True(t, c2.OutPolicySet) + } + + // Now for both nodes make sure that the external view is also correct. + var ( + c1Ext *DirectedChannel + c2Ext *DirectedChannel + ) + require.NoError(t, g.graphCache.ForEachChannel( + e.NodeKey1Bytes, func(c *DirectedChannel) error { + c1Ext = c + + return nil + }, + )) + require.NoError(t, g.graphCache.ForEachChannel( + e.NodeKey2Bytes, func(c *DirectedChannel) error { + c2Ext = c + + return nil + }, + )) + + // Only compare the fields that are actually copied, then compare the + // values of the functions separately. + require.Equal(t, c1, c1Ext.DeepCopy()) + require.Equal(t, c2, c2Ext.DeepCopy()) + if policy1 { + require.Equal( + t, p.FeeProportionalMillionths, + c2Ext.InPolicy.FeeProportionalMillionths, + ) + require.Equal( + t, route.Vertex(e.NodeKey2Bytes), + c2Ext.InPolicy.ToNodePubKey(), + ) + require.Equal(t, testFeatures, c2Ext.InPolicy.ToNodeFeatures) + } else { + require.Equal( + t, p.FeeProportionalMillionths, + c1Ext.InPolicy.FeeProportionalMillionths, + ) + require.Equal( + t, route.Vertex(e.NodeKey1Bytes), + c1Ext.InPolicy.ToNodePubKey(), + ) + require.Equal(t, testFeatures, c1Ext.InPolicy.ToNodeFeatures) + } +} + func randEdgePolicy(chanID uint64, db kvdb.Backend) *ChannelEdgePolicy { update := prand.Int63() @@ -890,106 +1096,10 @@ func TestGraphTraversal(t *testing.T) { // We'd like to test some of the graph traversal capabilities within // the DB, so we'll create a series of fake nodes to insert into the - // graph. + // graph. And we'll create 5 channels between each node pair. const numNodes = 20 - nodes := make([]*LightningNode, numNodes) - nodeIndex := map[string]struct{}{} - for i := 0; i < numNodes; i++ { - node, err := createTestVertex(graph.db) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } - - nodes[i] = node - nodeIndex[node.Alias] = struct{}{} - } - - // Add each of the nodes into the graph, they should be inserted - // without error. - for _, node := range nodes { - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - } - - // Iterate over each node as returned by the graph, if all nodes are - // reached, then the map created above should be empty. - err = graph.ForEachNode(func(_ kvdb.RTx, node *LightningNode) error { - delete(nodeIndex, node.Alias) - return nil - }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(nodeIndex) != 0 { - t.Fatalf("all nodes not reached within ForEach") - } - - // Determine which node is "smaller", we'll need this in order to - // properly create the edges for the graph. - var firstNode, secondNode *LightningNode - if bytes.Compare(nodes[0].PubKeyBytes[:], nodes[1].PubKeyBytes[:]) == -1 { - firstNode = nodes[0] - secondNode = nodes[1] - } else { - firstNode = nodes[0] - secondNode = nodes[1] - } - - // Create 5 channels between the first two nodes we generated above. const numChannels = 5 - chanIndex := map[uint64]struct{}{} - for i := 0; i < numChannels; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - chanID := uint64(i + 1) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - edgeInfo := ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: op, - Capacity: 1000, - } - copy(edgeInfo.NodeKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], nodes[1].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], nodes[1].PubKeyBytes[:]) - err := graph.AddChannelEdge(&edgeInfo) - if err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create and add an edge with random data that points from - // node1 -> node2. - edge := randEdgePolicy(chanID, graph.db) - edge.ChannelFlags = 0 - edge.Node = secondNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Create another random edge that points from node2 -> node1 - // this time. - edge = randEdgePolicy(chanID, graph.db) - edge.ChannelFlags = 1 - edge.Node = firstNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - chanIndex[chanID] = struct{}{} - } + chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels) // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have @@ -1000,16 +1110,13 @@ func TestGraphTraversal(t *testing.T) { delete(chanIndex, ei.ChannelID) return nil }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(chanIndex) != 0 { - t.Fatalf("all edges not reached within ForEach") - } + require.NoError(t, err) + require.Len(t, chanIndex, 0) // Finally, we want to test the ability to iterate over all the // outgoing channels for a particular node. numNodeChans := 0 + firstNode, secondNode := nodeList[0], nodeList[1] err = firstNode.ForEachChannel(nil, func(_ kvdb.RTx, _ *ChannelEdgeInfo, outEdge, inEdge *ChannelEdgePolicy) error { @@ -1034,13 +1141,148 @@ func TestGraphTraversal(t *testing.T) { numNodeChans++ return nil }) - if err != nil { - t.Fatalf("for each failure: %v", err) + require.NoError(t, err) + require.Equal(t, numChannels, numNodeChans) +} + +func TestGraphCacheTraversal(t *testing.T) { + t.Parallel() + + graph, cleanUp, err := MakeTestGraph() + defer cleanUp() + require.NoError(t, err) + + // We'd like to test some of the graph traversal capabilities within + // the DB, so we'll create a series of fake nodes to insert into the + // graph. And we'll create 5 channels between each node pair. + const numNodes = 20 + const numChannels = 5 + chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels) + + // Iterate through all the known channels within the graph DB, once + // again if the map is empty that indicates that all edges have + // properly been reached. + numNodeChans := 0 + for _, node := range nodeList { + err = graph.graphCache.ForEachChannel( + node.PubKeyBytes, func(d *DirectedChannel) error { + delete(chanIndex, d.ChannelID) + + if !d.OutPolicySet || d.InPolicy == nil { + return fmt.Errorf("channel policy not " + + "present") + } + + // The incoming edge should also indicate that + // it's pointing to the origin node. + inPolicyNodeKey := d.InPolicy.ToNodePubKey() + if !bytes.Equal( + inPolicyNodeKey[:], node.PubKeyBytes[:], + ) { + return fmt.Errorf("wrong outgoing edge") + } + + numNodeChans++ + + return nil + }, + ) + require.NoError(t, err) } - if numNodeChans != numChannels { - t.Fatalf("all edges for node not reached within ForEach: "+ - "expected %v, got %v", numChannels, numNodeChans) + require.Len(t, chanIndex, 0) + + // We count the channels for both nodes, so there should be double the + // amount now. Except for the very last node, that doesn't have any + // channels to make the loop easier in fillTestGraph(). + require.Equal(t, numChannels*2*(numNodes-1), numNodeChans) +} + +func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes, + numChannels int) (map[uint64]struct{}, []*LightningNode) { + + nodes := make([]*LightningNode, numNodes) + nodeIndex := map[string]struct{}{} + for i := 0; i < numNodes; i++ { + node, err := createTestVertex(graph.db) + require.NoError(t, err) + + nodes[i] = node + nodeIndex[node.Alias] = struct{}{} } + + // Add each of the nodes into the graph, they should be inserted + // without error. + for _, node := range nodes { + require.NoError(t, graph.AddLightningNode(node)) + } + + // Iterate over each node as returned by the graph, if all nodes are + // reached, then the map created above should be empty. + err := graph.ForEachNode(func(_ kvdb.RTx, node *LightningNode) error { + delete(nodeIndex, node.Alias) + return nil + }) + require.NoError(t, err) + require.Len(t, nodeIndex, 0) + + // Create a number of channels between each of the node pairs generated + // above. This will result in numChannels*(numNodes-1) channels. + chanIndex := map[uint64]struct{}{} + for n := 0; n < numNodes-1; n++ { + node1 := nodes[n] + node2 := nodes[n+1] + if bytes.Compare(node1.PubKeyBytes[:], node2.PubKeyBytes[:]) == -1 { + node1, node2 = node2, node1 + } + + for i := 0; i < numChannels; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + chanID := uint64((n << 4) + i + 1) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + edgeInfo := ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: op, + Capacity: 1000, + } + copy(edgeInfo.NodeKey1Bytes[:], node1.PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], node2.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey1Bytes[:], node1.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey2Bytes[:], node2.PubKeyBytes[:]) + err := graph.AddChannelEdge(&edgeInfo) + require.NoError(t, err) + + // Create and add an edge with random data that points + // from node1 -> node2. + edge := randEdgePolicy(chanID, graph.db) + edge.ChannelFlags = 0 + edge.Node = node2 + edge.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(edge)) + + // Create another random edge that points from + // node2 -> node1 this time. + edge = randEdgePolicy(chanID, graph.db) + edge.ChannelFlags = 1 + edge.Node = node1 + edge.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(edge)) + + chanIndex[chanID] = struct{}{} + } + } + + return chanIndex, nodes } func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash, diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 1814c6358..d39ff7a7d 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -55,7 +55,7 @@ type RouterBackend struct { FindRoute func(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, destCustomRecords record.CustomSet, - routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, + routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, finalExpiry uint16) (*route.Route, error) MissionControl MissionControl diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index 26a44cbb9..1b05d5f81 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -126,7 +126,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, findRoute := func(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, _ record.CustomSet, - routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, + routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, finalExpiry uint16) (*route.Route, error) { if int64(amt) != amtSat*1000 { diff --git a/routing/heap.go b/routing/heap.go index f6869663c..36563bb66 100644 --- a/routing/heap.go +++ b/routing/heap.go @@ -39,7 +39,7 @@ type nodeWithDist struct { weight int64 // nextHop is the edge this route comes from. - nextHop *channeldb.ChannelEdgePolicy + nextHop *channeldb.CachedEdgePolicy // routingInfoSize is the total size requirement for the payloads field // in the onion packet from this hop towards the final destination. diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index badeeebb9..d29c096fd 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -186,20 +186,13 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, IsNode1: nodePub == node1, OtherNode: peer, Capacity: channel.capacity, - OutPolicy: &channeldb.ChannelEdgePolicy{ + OutPolicySet: true, + InPolicy: &channeldb.CachedEdgePolicy{ ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: peer, - Features: lnwire.EmptyFeatureVector(), - }, - FeeBaseMSat: node.baseFee, - }, - InPolicy: &channeldb.ChannelEdgePolicy{ - ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: nodePub, - Features: lnwire.EmptyFeatureVector(), + ToNodePubKey: func() route.Vertex { + return nodePub }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), FeeBaseMSat: peerNode.baseFee, }, }, diff --git a/routing/mock_test.go b/routing/mock_test.go index 383f89185..a59ae2aa4 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -173,13 +173,13 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, } func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate, - _ *btcec.PublicKey, _ *channeldb.ChannelEdgePolicy) bool { + _ *btcec.PublicKey, _ *channeldb.CachedEdgePolicy) bool { return false } func (m *mockPaymentSessionOld) GetAdditionalEdgePolicy(_ *btcec.PublicKey, - _ uint64) *channeldb.ChannelEdgePolicy { + _ uint64) *channeldb.CachedEdgePolicy { return nil } @@ -637,17 +637,17 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, } func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, - pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { + pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { args := m.Called(msg, pubKey, policy) return args.Bool(0) } func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, - channelID uint64) *channeldb.ChannelEdgePolicy { + channelID uint64) *channeldb.CachedEdgePolicy { args := m.Called(pubKey, channelID) - return args.Get(0).(*channeldb.ChannelEdgePolicy) + return args.Get(0).(*channeldb.CachedEdgePolicy) } type mockControlTower struct { diff --git a/routing/pathfind.go b/routing/pathfind.go index fc3be7942..27a67ea7a 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -42,7 +42,7 @@ const ( type pathFinder = func(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( - []*channeldb.ChannelEdgePolicy, error) + []*channeldb.CachedEdgePolicy, error) var ( // DefaultAttemptCost is the default fixed virtual cost in path finding @@ -76,7 +76,7 @@ var ( // of the edge. type edgePolicyWithSource struct { sourceNode route.Vertex - edge *channeldb.ChannelEdgePolicy + edge *channeldb.CachedEdgePolicy } // finalHopParams encapsulates various parameters for route construction that @@ -102,7 +102,7 @@ type finalHopParams struct { // any feature vectors on all hops have been validated for transitive // dependencies. func newRoute(sourceVertex route.Vertex, - pathEdges []*channeldb.ChannelEdgePolicy, currentHeight uint32, + pathEdges []*channeldb.CachedEdgePolicy, currentHeight uint32, finalHop finalHopParams) (*route.Route, error) { var ( @@ -147,10 +147,10 @@ func newRoute(sourceVertex route.Vertex, supports := func(feature lnwire.FeatureBit) bool { // If this edge comes from router hints, the features // could be nil. - if edge.Node.Features == nil { + if edge.ToNodeFeatures == nil { return false } - return edge.Node.Features.HasFeature(feature) + return edge.ToNodeFeatures.HasFeature(feature) } // We start by assuming the node doesn't support TLV. We'll now @@ -225,7 +225,7 @@ func newRoute(sourceVertex route.Vertex, // each new hop such that, the final slice of hops will be in // the forwards order. currentHop := &route.Hop{ - PubKeyBytes: edge.Node.PubKeyBytes, + PubKeyBytes: edge.ToNodePubKey(), ChannelID: edge.ChannelID, AmtToForward: amtToForward, OutgoingTimeLock: outgoingTimeLock, @@ -280,7 +280,7 @@ type graphParams struct { // additionalEdges is an optional set of edges that should be // considered during path finding, that is not already found in the // channel graph. - additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy + additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy // bandwidthHints is an optional map from channels to bandwidths that // can be populated if the caller has a better estimate of the current @@ -360,7 +360,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, var max, total lnwire.MilliSatoshi cb := func(channel *channeldb.DirectedChannel) error { - if channel.OutPolicy == nil { + if !channel.OutPolicySet { return nil } @@ -412,7 +412,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // available bandwidth. func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, - finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { + finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to @@ -519,7 +519,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Build reverse lookup to find incoming edges. Needed because // search is taken place from target to source. for _, outgoingEdgePolicy := range outgoingEdgePolicies { - toVertex := outgoingEdgePolicy.Node.PubKeyBytes + toVertex := outgoingEdgePolicy.ToNodePubKey() incomingEdgePolicy := &edgePolicyWithSource{ sourceNode: vertex, edge: outgoingEdgePolicy, @@ -583,7 +583,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // satisfy our specific requirements. processEdge := func(fromVertex route.Vertex, fromFeatures *lnwire.FeatureVector, - edge *channeldb.ChannelEdgePolicy, toNodeDist *nodeWithDist) { + edge *channeldb.CachedEdgePolicy, toNodeDist *nodeWithDist) { edgesExpanded++ @@ -879,7 +879,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Use the distance map to unravel the forward path from source to // target. - var pathEdges []*channeldb.ChannelEdgePolicy + var pathEdges []*channeldb.CachedEdgePolicy currentNode := source for { // Determine the next hop forward using the next map. @@ -894,7 +894,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, pathEdges = append(pathEdges, currentNodeWithDist.nextHop) // Advance current node. - currentNode = currentNodeWithDist.nextHop.Node.PubKeyBytes + currentNode = currentNodeWithDist.nextHop.ToNodePubKey() // Check stop condition at the end of this loop. This prevents // breaking out too soon for self-payments that have target set @@ -915,7 +915,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // route construction does not care where the features are actually // taken from. In the future we may wish to do route construction within // findPath, and avoid using ChannelEdgePolicy altogether. - pathEdges[len(pathEdges)-1].Node.Features = features + pathEdges[len(pathEdges)-1].ToNodeFeatures = features log.Debugf("Found route: probability=%v, hops=%v, fee=%v", distance[source].probability, len(pathEdges), diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index b353c24ea..7c7c7586b 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -1099,20 +1099,23 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) { // Create the channel edge going from songoku to doge and include it in // our map of additional edges. - songokuToDoge := &channeldb.ChannelEdgePolicy{ - Node: doge, + songokuToDoge := &channeldb.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return doge.PubKeyBytes + }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), ChannelID: 1337, FeeBaseMSat: 1, FeeProportionalMillionths: 1000, TimeLockDelta: 9, } - additionalEdges := map[route.Vertex][]*channeldb.ChannelEdgePolicy{ + additionalEdges := map[route.Vertex][]*channeldb.CachedEdgePolicy{ graph.aliasMap["songoku"]: {songokuToDoge}, } find := func(r *RestrictParams) ( - []*channeldb.ChannelEdgePolicy, error) { + []*channeldb.CachedEdgePolicy, error) { return dbFindPath( graph.graph, additionalEdges, nil, @@ -1179,14 +1182,13 @@ func TestNewRoute(t *testing.T) { createHop := func(baseFee lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, bandwidth lnwire.MilliSatoshi, - timeLockDelta uint16) *channeldb.ChannelEdgePolicy { + timeLockDelta uint16) *channeldb.CachedEdgePolicy { - return &channeldb.ChannelEdgePolicy{ - Node: &channeldb.LightningNode{ - Features: lnwire.NewFeatureVector( - nil, nil, - ), + return &channeldb.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return route.Vertex{} }, + ToNodeFeatures: lnwire.NewFeatureVector(nil, nil), FeeProportionalMillionths: feeRate, FeeBaseMSat: baseFee, TimeLockDelta: timeLockDelta, @@ -1199,7 +1201,7 @@ func TestNewRoute(t *testing.T) { // hops is the list of hops (the route) that gets passed into // the call to newRoute. - hops []*channeldb.ChannelEdgePolicy + hops []*channeldb.CachedEdgePolicy // paymentAmount is the amount that is send into the route // indicated by hops. @@ -1248,7 +1250,7 @@ func TestNewRoute(t *testing.T) { // For a single hop payment, no fees are expected to be paid. name: "single hop", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(100, 1000, 1000000, 10), }, expectedFees: []lnwire.MilliSatoshi{0}, @@ -1261,7 +1263,7 @@ func TestNewRoute(t *testing.T) { // a fee to receive the payment. name: "two hop", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 1000, 1000000, 10), createHop(30, 1000, 1000000, 5), }, @@ -1276,7 +1278,7 @@ func TestNewRoute(t *testing.T) { name: "two hop tlv onion feature", destFeatures: tlvFeatures, paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 1000, 1000000, 10), createHop(30, 1000, 1000000, 5), }, @@ -1293,7 +1295,7 @@ func TestNewRoute(t *testing.T) { destFeatures: tlvPayAddrFeatures, paymentAddr: &testPaymentAddr, paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 1000, 1000000, 10), createHop(30, 1000, 1000000, 5), }, @@ -1313,7 +1315,7 @@ func TestNewRoute(t *testing.T) { // gets rounded down to 1. name: "three hop", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 10, 1000000, 10), createHop(0, 10, 1000000, 5), createHop(0, 10, 1000000, 3), @@ -1328,7 +1330,7 @@ func TestNewRoute(t *testing.T) { // because of the increase amount to forward. name: "three hop with fee carry over", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 10000, 1000000, 10), createHop(0, 10000, 1000000, 5), createHop(0, 10000, 1000000, 3), @@ -1343,7 +1345,7 @@ func TestNewRoute(t *testing.T) { // effect. name: "three hop with minimal fees for carry over", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 10000, 1000000, 10), // First hop charges 0.1% so the second hop fee @@ -1367,7 +1369,7 @@ func TestNewRoute(t *testing.T) { // custom feature vector. if testCase.destFeatures != nil { finalHop := testCase.hops[len(testCase.hops)-1] - finalHop.Node.Features = testCase.destFeatures + finalHop.ToNodeFeatures = testCase.destFeatures } assertRoute := func(t *testing.T, route *route.Route) { @@ -1594,7 +1596,7 @@ func TestDestTLVGraphFallback(t *testing.T) { } find := func(r *RestrictParams, - target route.Vertex) ([]*channeldb.ChannelEdgePolicy, error) { + target route.Vertex) ([]*channeldb.CachedEdgePolicy, error) { return dbFindPath( ctx.graph, nil, nil, @@ -2325,16 +2327,16 @@ func TestPathFindSpecExample(t *testing.T) { } func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, - path []*channeldb.ChannelEdgePolicy, nodeAliases ...string) { + path []*channeldb.CachedEdgePolicy, nodeAliases ...string) { if len(path) != len(nodeAliases) { t.Fatal("number of hops and number of aliases do not match") } for i, hop := range path { - if hop.Node.PubKeyBytes != aliasMap[nodeAliases[i]] { + if hop.ToNodePubKey() != aliasMap[nodeAliases[i]] { t.Fatalf("expected %v to be pos #%v in hop, instead "+ - "%v was", nodeAliases[i], i, hop.Node.Alias) + "%v was", nodeAliases[i], i, hop.ToNodePubKey()) } } } @@ -2985,7 +2987,7 @@ func (c *pathFindingTestContext) cleanup() { } func (c *pathFindingTestContext) findPath(target route.Vertex, - amt lnwire.MilliSatoshi) ([]*channeldb.ChannelEdgePolicy, + amt lnwire.MilliSatoshi) ([]*channeldb.CachedEdgePolicy, error) { return dbFindPath( @@ -2994,7 +2996,9 @@ func (c *pathFindingTestContext) findPath(target route.Vertex, ) } -func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, expected []uint64) { +func (c *pathFindingTestContext) assertPath(path []*channeldb.CachedEdgePolicy, + expected []uint64) { + if len(path) != len(expected) { c.t.Fatalf("expected path of length %v, but got %v", len(expected), len(path)) @@ -3011,11 +3015,11 @@ func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, // dbFindPath calls findPath after getting a db transaction from the database // graph. func dbFindPath(graph *channeldb.ChannelGraph, - additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy, + additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy, bandwidthHints map[uint64]lnwire.MilliSatoshi, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, - finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { + finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { routingTx, err := newDbRoutingTx(graph) if err != nil { diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index aa856e7b2..945a53466 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -898,7 +898,7 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route, var ( isAdditionalEdge bool - policy *channeldb.ChannelEdgePolicy + policy *channeldb.CachedEdgePolicy ) // Before we apply the channel update, we need to decide whether the diff --git a/routing/payment_session.go b/routing/payment_session.go index 22e88090b..d3024d3ff 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -144,13 +144,13 @@ type PaymentSession interface { // a boolean to indicate whether the update has been applied without // error. UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey, - policy *channeldb.ChannelEdgePolicy) bool + policy *channeldb.CachedEdgePolicy) bool // GetAdditionalEdgePolicy uses the public key and channel ID to query // the ephemeral channel edge policy for additional edges. Returns a nil // if nothing found. GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, - channelID uint64) *channeldb.ChannelEdgePolicy + channelID uint64) *channeldb.CachedEdgePolicy } // paymentSession is used during an HTLC routings session to prune the local @@ -162,7 +162,7 @@ type PaymentSession interface { // loop if payment attempts take long enough. An additional set of edges can // also be provided to assist in reaching the payment's destination. type paymentSession struct { - additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy + additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error) @@ -403,7 +403,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // updates to the supplied policy. It returns a boolean to indicate whether // there's an error when applying the updates. func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, - pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { + pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { // Validate the message signature. if err := VerifyChannelUpdateSignature(msg, pubKey); err != nil { @@ -428,7 +428,7 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, // ephemeral channel edge policy for additional edges. Returns a nil if nothing // found. func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, - channelID uint64) *channeldb.ChannelEdgePolicy { + channelID uint64) *channeldb.CachedEdgePolicy { target := route.NewVertex(pubKey) diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 661d5861d..fdfccd5f1 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -93,9 +93,9 @@ func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession { // RouteHintsToEdges converts a list of invoice route hints to an edge map that // can be passed into pathfinding. func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( - map[route.Vertex][]*channeldb.ChannelEdgePolicy, error) { + map[route.Vertex][]*channeldb.CachedEdgePolicy, error) { - edges := make(map[route.Vertex][]*channeldb.ChannelEdgePolicy) + edges := make(map[route.Vertex][]*channeldb.CachedEdgePolicy) // Traverse through all of the available hop hints and include them in // our edges map, indexed by the public key of the channel's starting @@ -125,9 +125,12 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( // Finally, create the channel edge from the hop hint // and add it to list of edges corresponding to the node // at the start of the channel. - edge := &channeldb.ChannelEdgePolicy{ - Node: endNode, - ChannelID: hopHint.ChannelID, + edge := &channeldb.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return endNode.PubKeyBytes + }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), + ChannelID: hopHint.ChannelID, FeeBaseMSat: lnwire.MilliSatoshi( hopHint.FeeBaseMSat, ), diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index edc4515b5..bcfc3b0e9 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -217,7 +217,7 @@ func TestRequestRoute(t *testing.T) { session.pathFinder = func( g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, - finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { + finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { // We expect find path to receive a cltv limit excluding the // final cltv delta (including the block padding). @@ -225,13 +225,14 @@ func TestRequestRoute(t *testing.T) { t.Fatal("wrong cltv limit") } - path := []*channeldb.ChannelEdgePolicy{ + path := []*channeldb.CachedEdgePolicy{ { - Node: &channeldb.LightningNode{ - Features: lnwire.NewFeatureVector( - nil, nil, - ), + ToNodePubKey: func() route.Vertex { + return route.Vertex{} }, + ToNodeFeatures: lnwire.NewFeatureVector( + nil, nil, + ), }, } diff --git a/routing/router.go b/routing/router.go index aa034eea0..1de113056 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1727,7 +1727,7 @@ type routingMsg struct { func (r *ChannelRouter) FindRoute(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *RestrictParams, destCustomRecords record.CustomSet, - routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, + routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, finalExpiry uint16) (*route.Route, error) { log.Debugf("Searching for path to %v, sending %v", target, amt) @@ -2822,7 +2822,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // total amount, we make a forward pass. Because the amount may have // been increased in the backward pass, fees need to be recalculated and // amount ranges re-checked. - var pathEdges []*channeldb.ChannelEdgePolicy + var pathEdges []*channeldb.CachedEdgePolicy receiverAmt := runningAmt for i, edge := range edges { policy := edge.getPolicy(receiverAmt, bandwidthHints) diff --git a/routing/router_test.go b/routing/router_test.go index d263ce738..ed6bfdc6a 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -2478,8 +2478,8 @@ func TestFindPathFeeWeighting(t *testing.T) { if len(path) != 1 { t.Fatalf("expected path length of 1, instead was: %v", len(path)) } - if path[0].Node.Alias != "luoji" { - t.Fatalf("wrong node: %v", path[0].Node.Alias) + if path[0].ToNodePubKey() != ctx.aliases["luoji"] { + t.Fatalf("wrong node: %v", path[0].ToNodePubKey()) } } diff --git a/routing/unified_policies.go b/routing/unified_policies.go index 4a6e5e00b..fe7cc1ec4 100644 --- a/routing/unified_policies.go +++ b/routing/unified_policies.go @@ -40,7 +40,7 @@ func newUnifiedPolicies(sourceNode, toNode route.Vertex, // addPolicy adds a single channel policy. Capacity may be zero if unknown // (light clients). func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, - edge *channeldb.ChannelEdgePolicy, capacity btcutil.Amount) { + edge *channeldb.CachedEdgePolicy, capacity btcutil.Amount) { localChan := fromNode == u.sourceNode @@ -92,7 +92,7 @@ func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { // unifiedPolicyEdge is the individual channel data that is kept inside an // unifiedPolicy object. type unifiedPolicyEdge struct { - policy *channeldb.ChannelEdgePolicy + policy *channeldb.CachedEdgePolicy capacity btcutil.Amount } @@ -133,7 +133,7 @@ type unifiedPolicy struct { // specific amount to send. It differentiates between local and network // channels. func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, - bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { if u.localChan { return u.getPolicyLocal(amt, bandwidthHints) @@ -145,10 +145,10 @@ func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, // getPolicyLocal returns the optimal policy to use for this local connection // given a specific amount to send. func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, - bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { var ( - bestPolicy *channeldb.ChannelEdgePolicy + bestPolicy *channeldb.CachedEdgePolicy maxBandwidth lnwire.MilliSatoshi ) @@ -200,10 +200,10 @@ func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, // a specific amount to send. The goal is to return a policy that maximizes the // probability of a successful forward in a non-strict forwarding context. func (u *unifiedPolicy) getPolicyNetwork( - amt lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + amt lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { var ( - bestPolicy *channeldb.ChannelEdgePolicy + bestPolicy *channeldb.CachedEdgePolicy maxFee lnwire.MilliSatoshi maxTimelock uint16 ) diff --git a/routing/unified_policies_test.go b/routing/unified_policies_test.go index e89a3cb12..ac915f99a 100644 --- a/routing/unified_policies_test.go +++ b/routing/unified_policies_test.go @@ -20,7 +20,7 @@ func TestUnifiedPolicies(t *testing.T) { u := newUnifiedPolicies(source, toNode, nil) // Add two channels between the pair of nodes. - p1 := channeldb.ChannelEdgePolicy{ + p1 := channeldb.CachedEdgePolicy{ FeeProportionalMillionths: 100000, FeeBaseMSat: 30, TimeLockDelta: 60, @@ -28,7 +28,7 @@ func TestUnifiedPolicies(t *testing.T) { MaxHTLC: 500, MinHTLC: 100, } - p2 := channeldb.ChannelEdgePolicy{ + p2 := channeldb.CachedEdgePolicy{ FeeProportionalMillionths: 190000, FeeBaseMSat: 10, TimeLockDelta: 40, @@ -39,7 +39,7 @@ func TestUnifiedPolicies(t *testing.T) { u.addPolicy(fromNode, &p1, 7) u.addPolicy(fromNode, &p2, 7) - checkPolicy := func(policy *channeldb.ChannelEdgePolicy, + checkPolicy := func(policy *channeldb.CachedEdgePolicy, feeBase lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, timeLockDelta uint16) {