From 5a903c270f3c1e5e4a049a0e407c744ea796dd39 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 14 Jun 2024 18:47:15 -0400 Subject: [PATCH] routing: remove sourceNode from routingGraph interface In this commit, we further reduce the routingGraph interface and this time we make it more node-agnostic so that it can be backed by any graph and not one with a concept of "sourceNode". --- routing/graph.go | 12 +----------- routing/integrated_routing_context_test.go | 2 +- routing/pathfind.go | 9 ++++----- routing/pathfind_test.go | 3 ++- routing/payment_session.go | 11 ++++++----- routing/payment_session_source.go | 4 ++-- routing/payment_session_test.go | 10 +++++----- routing/router.go | 5 +++-- 8 files changed, 24 insertions(+), 32 deletions(-) diff --git a/routing/graph.go b/routing/graph.go index dafadb892..1f0abf9c0 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -18,9 +18,6 @@ type routingGraph interface { forEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error - // sourceNode returns the source node of the graph. - sourceNode() route.Vertex - // fetchNodeFeatures returns the features of the given node. fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) } @@ -73,13 +70,6 @@ func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex, return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) } -// sourceNode returns the source node of the graph. -// -// NOTE: Part of the routingGraph interface. -func (g *CachedGraph) sourceNode() route.Vertex { - return g.source -} - // fetchNodeFeatures returns the features of the given node. If the node is // unknown, assume no additional features are supported. // @@ -99,7 +89,7 @@ func (g *CachedGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, // // Note: Inbound fees are not used here because this method is only used // by a deprecated router rpc. - u := newNodeEdgeUnifier(g.sourceNode(), nodeTo, false, nil) + u := newNodeEdgeUnifier(g.source, nodeTo, false, nil) err := u.addGraphPolicies(g) if err != nil { diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 4215d3b25..95a5eaf65 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -200,7 +200,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, } session, err := newPaymentSession( - &payment, getBandwidthHints, + &payment, c.graph.source.pubkey, getBandwidthHints, func() (routingGraph, func(), error) { return c.graph, func() {}, nil }, diff --git a/routing/pathfind.go b/routing/pathfind.go index 208a55085..d7d2893b0 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -48,7 +48,7 @@ const ( // pathFinder defines the interface of a path finding algorithm. type pathFinder = func(g *graphParams, r *RestrictParams, - cfg *PathFindingConfig, source, target route.Vertex, + cfg *PathFindingConfig, self, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( []*unifiedEdge, float64, error) @@ -521,8 +521,9 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // path and accurately check the amount to forward at every node against the // available bandwidth. func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, - source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, - finalHtlcExpiry int32) ([]*unifiedEdge, float64, error) { + self, source, target route.Vertex, amt lnwire.MilliSatoshi, + timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64, + error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to @@ -583,8 +584,6 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // If we are routing from ourselves, check that we have enough local // balance available. - self := g.graph.sourceNode() - if source == self { max, total, err := getOutgoingBalance( self, outgoingChanMap, g.bandwidthHints, g.graph, diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 0f2a2659b..a35c9e2f7 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3218,7 +3218,8 @@ func dbFindPath(graph *channeldb.ChannelGraph, bandwidthHints: bandwidthHints, graph: routingGraph, }, - r, cfg, source, target, amt, timePref, finalHtlcExpiry, + r, cfg, sourceNode.PubKeyBytes, source, target, amt, timePref, + finalHtlcExpiry, ) return route, err diff --git a/routing/payment_session.go b/routing/payment_session.go index 2d174244c..bdd194812 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -163,6 +163,8 @@ type PaymentSession interface { // loop if payment attempts take long enough. An additional set of edges can // also be provided to assist in reaching the payment's destination. type paymentSession struct { + selfNode route.Vertex + additionalEdges map[route.Vertex][]AdditionalEdge getBandwidthHints func(routingGraph) (bandwidthHints, error) @@ -192,7 +194,7 @@ type paymentSession struct { } // newPaymentSession instantiates a new payment session. -func newPaymentSession(p *LightningPayment, +func newPaymentSession(p *LightningPayment, selfNode route.Vertex, getBandwidthHints func(routingGraph) (bandwidthHints, error), getRoutingGraph func() (routingGraph, func(), error), missionControl MissionController, pathFindingConfig PathFindingConfig) ( @@ -206,6 +208,7 @@ func newPaymentSession(p *LightningPayment, logPrefix := fmt.Sprintf("PaymentSession(%x):", p.Identifier()) return &paymentSession{ + selfNode: selfNode, additionalEdges: edges, getBandwidthHints: getBandwidthHints, payment: p, @@ -296,8 +299,6 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, p.log.Debugf("pathfinding for amt=%v", maxAmt) - sourceVertex := routingGraph.sourceNode() - // Find a route for the current amount. path, _, err := p.pathFinder( &graphParams{ @@ -306,7 +307,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, graph: routingGraph, }, restrictions, &p.pathFindingConfig, - sourceVertex, p.payment.Target, + p.selfNode, p.selfNode, p.payment.Target, maxAmt, p.payment.TimePref, finalHtlcExpiry, ) @@ -384,7 +385,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // this into a route by applying the time-lock and fee // requirements. route, err := newRoute( - sourceVertex, path, height, + p.selfNode, path, height, finalHopParams{ amt: maxAmt, totalAmt: p.payment.Amount, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index b96a2294b..ba010391b 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -73,8 +73,8 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( } session, err := newPaymentSession( - p, getBandwidthHints, m.getRoutingGraph, - m.MissionControl, m.PathFindingConfig, + p, m.SourceNode.PubKeyBytes, getBandwidthHints, + m.getRoutingGraph, m.MissionControl, m.PathFindingConfig, ) if err != nil { return nil, err diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 75b84a51a..b7efed5b7 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -115,7 +115,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create the paymentsession. session, err := newPaymentSession( - payment, + payment, route.Vertex{}, func(routingGraph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, @@ -195,7 +195,7 @@ func TestRequestRoute(t *testing.T) { } session, err := newPaymentSession( - payment, + payment, route.Vertex{}, func(routingGraph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, @@ -211,9 +211,9 @@ func TestRequestRoute(t *testing.T) { // Override pathfinder with a mock. session.pathFinder = func(_ *graphParams, r *RestrictParams, - _ *PathFindingConfig, _, _ route.Vertex, _ lnwire.MilliSatoshi, - _ float64, _ int32) ([]*unifiedEdge, float64, - error) { + _ *PathFindingConfig, _, _, _ route.Vertex, + _ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge, + float64, error) { // We expect find path to receive a cltv limit excluding the // final cltv delta (including the block padding). diff --git a/routing/router.go b/routing/router.go index 149cd3415..597705754 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2148,8 +2148,9 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, bandwidthHints: bandwidthHints, graph: r.cachedGraph, }, - req.Restrictions, &r.cfg.PathFindingConfig, req.Source, - req.Target, req.Amount, req.TimePreference, finalHtlcExpiry, + req.Restrictions, &r.cfg.PathFindingConfig, + r.selfNode.PubKeyBytes, req.Source, req.Target, req.Amount, + req.TimePreference, finalHtlcExpiry, ) if err != nil { return nil, 0, err