routing: remove sourceNode from routingGraph interface

In this commit, we further reduce the routingGraph interface and this
time we make it more node-agnostic so that it can be backed by any graph
and not one with a concept of "sourceNode".
This commit is contained in:
Elle Mouton
2024-06-14 18:47:15 -04:00
parent 5c18b5a042
commit 5a903c270f
8 changed files with 24 additions and 32 deletions

View File

@@ -18,9 +18,6 @@ type routingGraph interface {
forEachNodeChannel(nodePub route.Vertex, forEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error cb func(channel *channeldb.DirectedChannel) error) error
// sourceNode returns the source node of the graph.
sourceNode() route.Vertex
// fetchNodeFeatures returns the features of the given node. // fetchNodeFeatures returns the features of the given node.
fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
} }
@@ -73,13 +70,6 @@ func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex,
return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb)
} }
// sourceNode returns the source node of the graph.
//
// NOTE: Part of the routingGraph interface.
func (g *CachedGraph) sourceNode() route.Vertex {
return g.source
}
// fetchNodeFeatures returns the features of the given node. If the node is // fetchNodeFeatures returns the features of the given node. If the node is
// unknown, assume no additional features are supported. // unknown, assume no additional features are supported.
// //
@@ -99,7 +89,7 @@ func (g *CachedGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex,
// //
// Note: Inbound fees are not used here because this method is only used // Note: Inbound fees are not used here because this method is only used
// by a deprecated router rpc. // by a deprecated router rpc.
u := newNodeEdgeUnifier(g.sourceNode(), nodeTo, false, nil) u := newNodeEdgeUnifier(g.source, nodeTo, false, nil)
err := u.addGraphPolicies(g) err := u.addGraphPolicies(g)
if err != nil { if err != nil {

View File

@@ -200,7 +200,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
} }
session, err := newPaymentSession( session, err := newPaymentSession(
&payment, getBandwidthHints, &payment, c.graph.source.pubkey, getBandwidthHints,
func() (routingGraph, func(), error) { func() (routingGraph, func(), error) {
return c.graph, func() {}, nil return c.graph, func() {}, nil
}, },

View File

@@ -48,7 +48,7 @@ const (
// pathFinder defines the interface of a path finding algorithm. // pathFinder defines the interface of a path finding algorithm.
type pathFinder = func(g *graphParams, r *RestrictParams, type pathFinder = func(g *graphParams, r *RestrictParams,
cfg *PathFindingConfig, source, target route.Vertex, cfg *PathFindingConfig, self, source, target route.Vertex,
amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) (
[]*unifiedEdge, float64, error) []*unifiedEdge, float64, error)
@@ -521,8 +521,9 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
// path and accurately check the amount to forward at every node against the // path and accurately check the amount to forward at every node against the
// available bandwidth. // available bandwidth.
func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, self, source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*unifiedEdge, float64, error) { timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64,
error) {
// Pathfinding can be a significant portion of the total payment // Pathfinding can be a significant portion of the total payment
// latency, especially on low-powered devices. Log several metrics to // latency, especially on low-powered devices. Log several metrics to
@@ -583,8 +584,6 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// If we are routing from ourselves, check that we have enough local // If we are routing from ourselves, check that we have enough local
// balance available. // balance available.
self := g.graph.sourceNode()
if source == self { if source == self {
max, total, err := getOutgoingBalance( max, total, err := getOutgoingBalance(
self, outgoingChanMap, g.bandwidthHints, g.graph, self, outgoingChanMap, g.bandwidthHints, g.graph,

View File

@@ -3218,7 +3218,8 @@ func dbFindPath(graph *channeldb.ChannelGraph,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
graph: routingGraph, graph: routingGraph,
}, },
r, cfg, source, target, amt, timePref, finalHtlcExpiry, r, cfg, sourceNode.PubKeyBytes, source, target, amt, timePref,
finalHtlcExpiry,
) )
return route, err return route, err

View File

@@ -163,6 +163,8 @@ type PaymentSession interface {
// loop if payment attempts take long enough. An additional set of edges can // loop if payment attempts take long enough. An additional set of edges can
// also be provided to assist in reaching the payment's destination. // also be provided to assist in reaching the payment's destination.
type paymentSession struct { type paymentSession struct {
selfNode route.Vertex
additionalEdges map[route.Vertex][]AdditionalEdge additionalEdges map[route.Vertex][]AdditionalEdge
getBandwidthHints func(routingGraph) (bandwidthHints, error) getBandwidthHints func(routingGraph) (bandwidthHints, error)
@@ -192,7 +194,7 @@ type paymentSession struct {
} }
// newPaymentSession instantiates a new payment session. // newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment, func newPaymentSession(p *LightningPayment, selfNode route.Vertex,
getBandwidthHints func(routingGraph) (bandwidthHints, error), getBandwidthHints func(routingGraph) (bandwidthHints, error),
getRoutingGraph func() (routingGraph, func(), error), getRoutingGraph func() (routingGraph, func(), error),
missionControl MissionController, pathFindingConfig PathFindingConfig) ( missionControl MissionController, pathFindingConfig PathFindingConfig) (
@@ -206,6 +208,7 @@ func newPaymentSession(p *LightningPayment,
logPrefix := fmt.Sprintf("PaymentSession(%x):", p.Identifier()) logPrefix := fmt.Sprintf("PaymentSession(%x):", p.Identifier())
return &paymentSession{ return &paymentSession{
selfNode: selfNode,
additionalEdges: edges, additionalEdges: edges,
getBandwidthHints: getBandwidthHints, getBandwidthHints: getBandwidthHints,
payment: p, payment: p,
@@ -296,8 +299,6 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
p.log.Debugf("pathfinding for amt=%v", maxAmt) p.log.Debugf("pathfinding for amt=%v", maxAmt)
sourceVertex := routingGraph.sourceNode()
// Find a route for the current amount. // Find a route for the current amount.
path, _, err := p.pathFinder( path, _, err := p.pathFinder(
&graphParams{ &graphParams{
@@ -306,7 +307,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
graph: routingGraph, graph: routingGraph,
}, },
restrictions, &p.pathFindingConfig, restrictions, &p.pathFindingConfig,
sourceVertex, p.payment.Target, p.selfNode, p.selfNode, p.payment.Target,
maxAmt, p.payment.TimePref, finalHtlcExpiry, maxAmt, p.payment.TimePref, finalHtlcExpiry,
) )
@@ -384,7 +385,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// this into a route by applying the time-lock and fee // this into a route by applying the time-lock and fee
// requirements. // requirements.
route, err := newRoute( route, err := newRoute(
sourceVertex, path, height, p.selfNode, path, height,
finalHopParams{ finalHopParams{
amt: maxAmt, amt: maxAmt,
totalAmt: p.payment.Amount, totalAmt: p.payment.Amount,

View File

@@ -73,8 +73,8 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
} }
session, err := newPaymentSession( session, err := newPaymentSession(
p, getBandwidthHints, m.getRoutingGraph, p, m.SourceNode.PubKeyBytes, getBandwidthHints,
m.MissionControl, m.PathFindingConfig, m.getRoutingGraph, m.MissionControl, m.PathFindingConfig,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -115,7 +115,7 @@ func TestUpdateAdditionalEdge(t *testing.T) {
// Create the paymentsession. // Create the paymentsession.
session, err := newPaymentSession( session, err := newPaymentSession(
payment, payment, route.Vertex{},
func(routingGraph) (bandwidthHints, error) { func(routingGraph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil return &mockBandwidthHints{}, nil
}, },
@@ -195,7 +195,7 @@ func TestRequestRoute(t *testing.T) {
} }
session, err := newPaymentSession( session, err := newPaymentSession(
payment, payment, route.Vertex{},
func(routingGraph) (bandwidthHints, error) { func(routingGraph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil return &mockBandwidthHints{}, nil
}, },
@@ -211,9 +211,9 @@ func TestRequestRoute(t *testing.T) {
// Override pathfinder with a mock. // Override pathfinder with a mock.
session.pathFinder = func(_ *graphParams, r *RestrictParams, session.pathFinder = func(_ *graphParams, r *RestrictParams,
_ *PathFindingConfig, _, _ route.Vertex, _ lnwire.MilliSatoshi, _ *PathFindingConfig, _, _, _ route.Vertex,
_ float64, _ int32) ([]*unifiedEdge, float64, _ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge,
error) { float64, error) {
// We expect find path to receive a cltv limit excluding the // We expect find path to receive a cltv limit excluding the
// final cltv delta (including the block padding). // final cltv delta (including the block padding).

View File

@@ -2148,8 +2148,9 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
graph: r.cachedGraph, graph: r.cachedGraph,
}, },
req.Restrictions, &r.cfg.PathFindingConfig, req.Source, req.Restrictions, &r.cfg.PathFindingConfig,
req.Target, req.Amount, req.TimePreference, finalHtlcExpiry, r.selfNode.PubKeyBytes, req.Source, req.Target, req.Amount,
req.TimePreference, finalHtlcExpiry,
) )
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err