routing: refactor unified policies to edges

This commit refactors the semantics of unified policies to unified
edges. The main changes are the following renamings:

* unifiedPolicies -> nodeEdgeUnifier
* unifiedPolicy -> edgeUnifier
* unifiedPolicyEdge -> unifiedEdge

Comments and shortened variable names are changed to reflect the new
semantics.
This commit is contained in:
bitromortac 2022-11-09 18:05:36 +01:00
parent 7d29ab905c
commit 76e711ead0
No known key found for this signature in database
GPG Key ID: 1965063FC13BEBE2
4 changed files with 87 additions and 91 deletions

View File

@ -628,7 +628,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// satisfy our specific requirements. // satisfy our specific requirements.
processEdge := func(fromVertex route.Vertex, processEdge := func(fromVertex route.Vertex,
fromFeatures *lnwire.FeatureVector, fromFeatures *lnwire.FeatureVector,
edge *unifiedPolicyEdge, toNodeDist *nodeWithDist) { edge *unifiedEdge, toNodeDist *nodeWithDist) {
edgesExpanded++ edgesExpanded++
@ -849,8 +849,8 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
pivot := partialPath.node pivot := partialPath.node
// Create unified policies for all incoming connections. // Create unified edges for all incoming connections.
u := newUnifiedPolicies(self, pivot, outgoingChanMap) u := newNodeEdgeUnifier(self, pivot, outgoingChanMap)
err := u.addGraphPolicies(g.graph) err := u.addGraphPolicies(g.graph)
if err != nil { if err != nil {
@ -865,7 +865,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Expand all connections using the optimal policy for each // Expand all connections using the optimal policy for each
// connection. // connection.
for fromNode, unifiedPolicy := range u.policies { for fromNode, edgeUnifier := range u.edgeUnifiers {
// The target node is not recorded in the distance map. // The target node is not recorded in the distance map.
// Therefore we need to have this check to prevent // Therefore we need to have this check to prevent
// creating a cycle. Only when we intend to route to // creating a cycle. Only when we intend to route to
@ -882,11 +882,11 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
continue continue
} }
policy := unifiedPolicy.getPolicy( edge := edgeUnifier.getEdge(
amtToSend, g.bandwidthHints, amtToSend, g.bandwidthHints,
) )
if policy == nil { if edge == nil {
continue continue
} }
@ -903,7 +903,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Check if this candidate node is better than what we // Check if this candidate node is better than what we
// already have. // already have.
processEdge(fromNode, fromFeatures, policy, partialPath) processEdge(fromNode, fromFeatures, edge, partialPath)
} }
if nodeHeap.Len() == 0 { if nodeHeap.Len() == 0 {

View File

@ -2765,9 +2765,8 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
return nil, err return nil, err
} }
// Allocate a list that will contain the unified policies for this // Allocate a list that will contain the edge unifiers for this route.
// route. unifiers := make([]*edgeUnifier, len(hops))
edges := make([]*unifiedPolicy, len(hops))
var runningAmt lnwire.MilliSatoshi var runningAmt lnwire.MilliSatoshi
if useMinAmt { if useMinAmt {
@ -2796,9 +2795,9 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
localChan := i == 0 localChan := i == 0
// Build unified policies for this hop based on the channels // Build unified edges for this hop based on the channels known
// known in the graph. // in the graph.
u := newUnifiedPolicies(source, toNode, outgoingChans) u := newNodeEdgeUnifier(source, toNode, outgoingChans)
err := u.addGraphPolicies(r.cachedGraph) err := u.addGraphPolicies(r.cachedGraph)
if err != nil { if err != nil {
@ -2806,7 +2805,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
} }
// Exit if there are no channels. // Exit if there are no channels.
unifiedPolicy, ok := u.policies[fromNode] edgeUnifier, ok := u.edgeUnifiers[fromNode]
if !ok { if !ok {
log.Errorf("Cannot find policy for node %v", fromNode) log.Errorf("Cannot find policy for node %v", fromNode)
return nil, ErrNoChannel{ return nil, ErrNoChannel{
@ -2817,18 +2816,18 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// If using min amt, increase amt if needed. // If using min amt, increase amt if needed.
if useMinAmt { if useMinAmt {
min := unifiedPolicy.minAmt() min := edgeUnifier.minAmt()
if min > runningAmt { if min > runningAmt {
runningAmt = min runningAmt = min
} }
} }
// Get a forwarding policy for the specific amount that we want // Get an edge for the specific amount that we want to forward.
// to forward. edge := edgeUnifier.getEdge(runningAmt, bandwidthHints)
policy := unifiedPolicy.getPolicy(runningAmt, bandwidthHints) if edge == nil {
if policy == nil {
log.Errorf("Cannot find policy with amt=%v for node %v", log.Errorf("Cannot find policy with amt=%v for node %v",
runningAmt, fromNode) runningAmt, fromNode)
return nil, ErrNoChannel{ return nil, ErrNoChannel{
fromNode: fromNode, fromNode: fromNode,
position: i, position: i,
@ -2837,13 +2836,13 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// Add fee for this hop. // Add fee for this hop.
if !localChan { if !localChan {
runningAmt += policy.policy.ComputeFee(runningAmt) runningAmt += edge.policy.ComputeFee(runningAmt)
} }
log.Tracef("Select channel %v at position %v", log.Tracef("Select channel %v at position %v",
policy.policy.ChannelID, i) edge.policy.ChannelID, i)
edges[i] = unifiedPolicy unifiers[i] = edgeUnifier
} }
// Now that we arrived at the start of the route and found out the route // Now that we arrived at the start of the route and found out the route
@ -2852,9 +2851,9 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// amount ranges re-checked. // amount ranges re-checked.
var pathEdges []*channeldb.CachedEdgePolicy var pathEdges []*channeldb.CachedEdgePolicy
receiverAmt := runningAmt receiverAmt := runningAmt
for i, edge := range edges { for i, unifier := range unifiers {
policy := edge.getPolicy(receiverAmt, bandwidthHints) edge := unifier.getEdge(receiverAmt, bandwidthHints)
if policy == nil { if edge == nil {
return nil, ErrNoChannel{ return nil, ErrNoChannel{
fromNode: hops[i-1], fromNode: hops[i-1],
position: i, position: i,
@ -2863,12 +2862,12 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
if i > 0 { if i > 0 {
// Decrease the amount to send while going forward. // Decrease the amount to send while going forward.
receiverAmt -= policy.policy.ComputeFeeFromIncoming( receiverAmt -= edge.policy.ComputeFeeFromIncoming(
receiverAmt, receiverAmt,
) )
} }
pathEdges = append(pathEdges, policy.policy) pathEdges = append(pathEdges, edge.policy)
} }
// Build and return the final route. // Build and return the final route.

View File

@ -7,16 +7,16 @@ import (
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
) )
// unifiedPolicies holds all unified policies for connections towards a node. // nodeEdgeUnifier holds all edge unifiers for connections towards a node.
type unifiedPolicies struct { type nodeEdgeUnifier struct {
// policies contains a unified policy for every from node. // edgeUnifiers contains an edge unifier for every from node.
policies map[route.Vertex]*unifiedPolicy edgeUnifiers map[route.Vertex]*edgeUnifier
// sourceNode is the sender of a payment. The rules to pick the final // sourceNode is the sender of a payment. The rules to pick the final
// policy are different for local channels. // policy are different for local channels.
sourceNode route.Vertex sourceNode route.Vertex
// toNode is the node for which the unified policies are instantiated. // toNode is the node for which the edge unifiers are instantiated.
toNode route.Vertex toNode route.Vertex
// outChanRestr is an optional outgoing channel restriction for the // outChanRestr is an optional outgoing channel restriction for the
@ -24,13 +24,13 @@ type unifiedPolicies struct {
outChanRestr map[uint64]struct{} outChanRestr map[uint64]struct{}
} }
// newUnifiedPolicies instantiates a new unifiedPolicies object. Channel // newNodeEdgeUnifier instantiates a new nodeEdgeUnifier object. Channel
// policies can be added to this object. // policies can be added to this object.
func newUnifiedPolicies(sourceNode, toNode route.Vertex, func newNodeEdgeUnifier(sourceNode, toNode route.Vertex,
outChanRestr map[uint64]struct{}) *unifiedPolicies { outChanRestr map[uint64]struct{}) *nodeEdgeUnifier {
return &unifiedPolicies{ return &nodeEdgeUnifier{
policies: make(map[route.Vertex]*unifiedPolicy), edgeUnifiers: make(map[route.Vertex]*edgeUnifier),
toNode: toNode, toNode: toNode,
sourceNode: sourceNode, sourceNode: sourceNode,
outChanRestr: outChanRestr, outChanRestr: outChanRestr,
@ -39,7 +39,7 @@ func newUnifiedPolicies(sourceNode, toNode route.Vertex,
// addPolicy adds a single channel policy. Capacity may be zero if unknown // addPolicy adds a single channel policy. Capacity may be zero if unknown
// (light clients). // (light clients).
func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex,
edge *channeldb.CachedEdgePolicy, capacity btcutil.Amount) { edge *channeldb.CachedEdgePolicy, capacity btcutil.Amount) {
localChan := fromNode == u.sourceNode localChan := fromNode == u.sourceNode
@ -51,16 +51,16 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex,
} }
} }
// Update the policies map. // Update the edgeUnifiers map.
policy, ok := u.policies[fromNode] unifier, ok := u.edgeUnifiers[fromNode]
if !ok { if !ok {
policy = &unifiedPolicy{ unifier = &edgeUnifier{
localChan: localChan, localChan: localChan,
} }
u.policies[fromNode] = policy u.edgeUnifiers[fromNode] = unifier
} }
policy.edges = append(policy.edges, &unifiedPolicyEdge{ unifier.edges = append(unifier.edges, &unifiedEdge{
policy: edge, policy: edge,
capacity: capacity, capacity: capacity,
}) })
@ -68,7 +68,7 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex,
// addGraphPolicies adds all policies that are known for the toNode in the // addGraphPolicies adds all policies that are known for the toNode in the
// graph. // graph.
func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error {
cb := func(channel *channeldb.DirectedChannel) error { cb := func(channel *channeldb.DirectedChannel) error {
// If there is no edge policy for this candidate node, skip. // If there is no edge policy for this candidate node, skip.
// Note that we are searching backwards so this node would have // Note that we are searching backwards so this node would have
@ -77,7 +77,7 @@ func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error {
return nil return nil
} }
// Add this policy to the unified policies map. // Add this policy to the corresponding edgeUnifier.
u.addPolicy( u.addPolicy(
channel.OtherNode, channel.InPolicy, channel.Capacity, channel.OtherNode, channel.InPolicy, channel.Capacity,
) )
@ -89,16 +89,16 @@ func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error {
return g.forEachNodeChannel(u.toNode, cb) return g.forEachNodeChannel(u.toNode, cb)
} }
// unifiedPolicyEdge is the individual channel data that is kept inside an // unifiedEdge is the individual channel data that is kept inside an edgeUnifier
// unifiedPolicy object. // object.
type unifiedPolicyEdge struct { type unifiedEdge struct {
policy *channeldb.CachedEdgePolicy policy *channeldb.CachedEdgePolicy
capacity btcutil.Amount capacity btcutil.Amount
} }
// amtInRange checks whether an amount falls within the valid range for a // amtInRange checks whether an amount falls within the valid range for a
// channel. // channel.
func (u *unifiedPolicyEdge) amtInRange(amt lnwire.MilliSatoshi) bool { func (u *unifiedEdge) amtInRange(amt lnwire.MilliSatoshi) bool {
// If the capacity is available (non-light clients), skip channels that // If the capacity is available (non-light clients), skip channels that
// are too small. // are too small.
if u.capacity > 0 && if u.capacity > 0 &&
@ -122,33 +122,32 @@ func (u *unifiedPolicyEdge) amtInRange(amt lnwire.MilliSatoshi) bool {
return true return true
} }
// unifiedPolicy is the unified policy that covers all channels between a pair // edgeUnifier is an object that covers all channels between a pair of nodes.
// of nodes. type edgeUnifier struct {
type unifiedPolicy struct { edges []*unifiedEdge
edges []*unifiedPolicyEdge
localChan bool localChan bool
} }
// getPolicy returns the optimal policy to use for this connection given a // getEdge returns the optimal unified edge to use for this connection given a
// specific amount to send. It differentiates between local and network // specific amount to send. It differentiates between local and network
// channels. // channels.
func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, func (u *edgeUnifier) getEdge(amt lnwire.MilliSatoshi,
bandwidthHints bandwidthHints) *unifiedPolicyEdge { bandwidthHints bandwidthHints) *unifiedEdge {
if u.localChan { if u.localChan {
return u.getPolicyLocal(amt, bandwidthHints) return u.getEdgeLocal(amt, bandwidthHints)
} }
return u.getPolicyNetwork(amt) return u.getEdgeNetwork(amt)
} }
// getPolicyLocal returns the optimal policy to use for this local connection // getEdgeLocal returns the optimal unified edge to use for this local
// given a specific amount to send. // connection given a specific amount to send.
func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, func (u *edgeUnifier) getEdgeLocal(amt lnwire.MilliSatoshi,
bandwidthHints bandwidthHints) *unifiedPolicyEdge { bandwidthHints bandwidthHints) *unifiedEdge {
var ( var (
bestPolicy *unifiedPolicyEdge bestEdge *unifiedEdge
maxBandwidth lnwire.MilliSatoshi maxBandwidth lnwire.MilliSatoshi
) )
@ -191,19 +190,18 @@ func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi,
} }
maxBandwidth = bandwidth maxBandwidth = bandwidth
// Update best policy. // Update best edge.
bestPolicy = &unifiedPolicyEdge{policy: edge.policy} bestEdge = &unifiedEdge{policy: edge.policy}
} }
return bestPolicy return bestEdge
} }
// getPolicyNetwork returns the optimal policy to use for this connection given // getEdgeNetwork returns the optimal unified edge to use for this connection
// a specific amount to send. The goal is to return a policy that maximizes the // given a specific amount to send. The goal is to return a unified edge with a
// probability of a successful forward in a non-strict forwarding context. // policy that maximizes the probability of a successful forward in a non-strict
func (u *unifiedPolicy) getPolicyNetwork( // forwarding context.
amt lnwire.MilliSatoshi) *unifiedPolicyEdge { func (u *edgeUnifier) getEdgeNetwork(amt lnwire.MilliSatoshi) *unifiedEdge {
var ( var (
bestPolicy *channeldb.CachedEdgePolicy bestPolicy *channeldb.CachedEdgePolicy
maxFee lnwire.MilliSatoshi maxFee lnwire.MilliSatoshi
@ -256,14 +254,14 @@ func (u *unifiedPolicy) getPolicyNetwork(
// chance for this node pair. But this is all only needed for nodes that // chance for this node pair. But this is all only needed for nodes that
// have distinct policies for channels to the same peer. // have distinct policies for channels to the same peer.
policyCopy := *bestPolicy policyCopy := *bestPolicy
modifiedPolicy := unifiedPolicyEdge{policy: &policyCopy} modifiedEdge := unifiedEdge{policy: &policyCopy}
modifiedPolicy.policy.TimeLockDelta = maxTimelock modifiedEdge.policy.TimeLockDelta = maxTimelock
return &modifiedPolicy return &modifiedEdge
} }
// minAmt returns the minimum amount that can be forwarded on this connection. // minAmt returns the minimum amount that can be forwarded on this connection.
func (u *unifiedPolicy) minAmt() lnwire.MilliSatoshi { func (u *edgeUnifier) minAmt() lnwire.MilliSatoshi {
min := lnwire.MaxMilliSatoshi min := lnwire.MaxMilliSatoshi
for _, edge := range u.edges { for _, edge := range u.edges {
if edge.policy.MinHTLC < min { if edge.policy.MinHTLC < min {

View File

@ -8,16 +8,16 @@ import (
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
) )
// TestUnifiedPolicies tests the composition of unified policies for nodes that // TestNodeEdgeUnifier tests the composition of unified edges for nodes that
// have multiple channels between them. // have multiple channels between them.
func TestUnifiedPolicies(t *testing.T) { func TestNodeEdgeUnifier(t *testing.T) {
source := route.Vertex{1} source := route.Vertex{1}
toNode := route.Vertex{2} toNode := route.Vertex{2}
fromNode := route.Vertex{3} fromNode := route.Vertex{3}
bandwidthHints := &mockBandwidthHints{} bandwidthHints := &mockBandwidthHints{}
u := newUnifiedPolicies(source, toNode, nil) u := newNodeEdgeUnifier(source, toNode, nil)
// Add two channels between the pair of nodes. // Add two channels between the pair of nodes.
p1 := channeldb.CachedEdgePolicy{ p1 := channeldb.CachedEdgePolicy{
@ -39,13 +39,12 @@ func TestUnifiedPolicies(t *testing.T) {
u.addPolicy(fromNode, &p1, 7) u.addPolicy(fromNode, &p1, 7)
u.addPolicy(fromNode, &p2, 7) u.addPolicy(fromNode, &p2, 7)
checkPolicy := func(unifiedPolicy *unifiedPolicyEdge, checkPolicy := func(edge *unifiedEdge, feeBase lnwire.MilliSatoshi,
feeBase lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, timeLockDelta uint16) {
timeLockDelta uint16) {
t.Helper() t.Helper()
policy := unifiedPolicy.policy policy := edge.policy
if policy.FeeBaseMSat != feeBase { if policy.FeeBaseMSat != feeBase {
t.Fatalf("expected fee base %v, got %v", t.Fatalf("expected fee base %v, got %v",
@ -63,31 +62,31 @@ func TestUnifiedPolicies(t *testing.T) {
} }
} }
policy := u.policies[fromNode].getPolicy(50, bandwidthHints) edge := u.edgeUnifiers[fromNode].getEdge(50, bandwidthHints)
if policy != nil { if edge != nil {
t.Fatal("expected no policy for amt below min htlc") t.Fatal("expected no policy for amt below min htlc")
} }
policy = u.policies[fromNode].getPolicy(550, bandwidthHints) edge = u.edgeUnifiers[fromNode].getEdge(550, bandwidthHints)
if policy != nil { if edge != nil {
t.Fatal("expected no policy for amt above max htlc") t.Fatal("expected no policy for amt above max htlc")
} }
// For 200 sat, p1 yields the highest fee. Use that policy to forward, // For 200 sat, p1 yields the highest fee. Use that policy to forward,
// because it will also match p2 in case p1 does not have enough // because it will also match p2 in case p1 does not have enough
// balance. // balance.
policy = u.policies[fromNode].getPolicy(200, bandwidthHints) edge = u.edgeUnifiers[fromNode].getEdge(200, bandwidthHints)
checkPolicy( checkPolicy(
policy, p1.FeeBaseMSat, p1.FeeProportionalMillionths, edge, p1.FeeBaseMSat, p1.FeeProportionalMillionths,
p1.TimeLockDelta, p1.TimeLockDelta,
) )
// For 400 sat, p2 yields the highest fee. Use that policy to forward, // For 400 sat, p2 yields the highest fee. Use that policy to forward,
// because it will also match p1 in case p2 does not have enough // because it will also match p1 in case p2 does not have enough
// balance. In order to match p1, it needs to have p1's time lock delta. // balance. In order to match p1, it needs to have p1's time lock delta.
policy = u.policies[fromNode].getPolicy(400, bandwidthHints) edge = u.edgeUnifiers[fromNode].getEdge(400, bandwidthHints)
checkPolicy( checkPolicy(
policy, p2.FeeBaseMSat, p2.FeeProportionalMillionths, edge, p2.FeeBaseMSat, p2.FeeProportionalMillionths,
p1.TimeLockDelta, p1.TimeLockDelta,
) )
} }