routing: remove sourceNode from routingGraph interface

In this commit, we further reduce the routingGraph interface and this
time we make it more node-agnostic so that it can be backed by any graph
and not one with a concept of "sourceNode".
This commit is contained in:
Elle Mouton
2024-06-14 18:47:15 -04:00
parent 5c18b5a042
commit 5a903c270f
8 changed files with 24 additions and 32 deletions

View File

@@ -18,9 +18,6 @@ type routingGraph interface {
forEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error
// sourceNode returns the source node of the graph.
sourceNode() route.Vertex
// fetchNodeFeatures returns the features of the given node.
fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
}
@@ -73,13 +70,6 @@ func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex,
return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb)
}
// sourceNode returns the source node of the graph.
//
// NOTE: Part of the routingGraph interface.
func (g *CachedGraph) sourceNode() route.Vertex {
return g.source
}
// fetchNodeFeatures returns the features of the given node. If the node is
// unknown, assume no additional features are supported.
//
@@ -99,7 +89,7 @@ func (g *CachedGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex,
//
// Note: Inbound fees are not used here because this method is only used
// by a deprecated router rpc.
u := newNodeEdgeUnifier(g.sourceNode(), nodeTo, false, nil)
u := newNodeEdgeUnifier(g.source, nodeTo, false, nil)
err := u.addGraphPolicies(g)
if err != nil {

View File

@@ -200,7 +200,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
}
session, err := newPaymentSession(
&payment, getBandwidthHints,
&payment, c.graph.source.pubkey, getBandwidthHints,
func() (routingGraph, func(), error) {
return c.graph, func() {}, nil
},

View File

@@ -48,7 +48,7 @@ const (
// pathFinder defines the interface of a path finding algorithm.
type pathFinder = func(g *graphParams, r *RestrictParams,
cfg *PathFindingConfig, source, target route.Vertex,
cfg *PathFindingConfig, self, source, target route.Vertex,
amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) (
[]*unifiedEdge, float64, error)
@@ -521,8 +521,9 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
// path and accurately check the amount to forward at every node against the
// available bandwidth.
func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64,
finalHtlcExpiry int32) ([]*unifiedEdge, float64, error) {
self, source, target route.Vertex, amt lnwire.MilliSatoshi,
timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64,
error) {
// Pathfinding can be a significant portion of the total payment
// latency, especially on low-powered devices. Log several metrics to
@@ -583,8 +584,6 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// If we are routing from ourselves, check that we have enough local
// balance available.
self := g.graph.sourceNode()
if source == self {
max, total, err := getOutgoingBalance(
self, outgoingChanMap, g.bandwidthHints, g.graph,

View File

@@ -3218,7 +3218,8 @@ func dbFindPath(graph *channeldb.ChannelGraph,
bandwidthHints: bandwidthHints,
graph: routingGraph,
},
r, cfg, source, target, amt, timePref, finalHtlcExpiry,
r, cfg, sourceNode.PubKeyBytes, source, target, amt, timePref,
finalHtlcExpiry,
)
return route, err

View File

@@ -163,6 +163,8 @@ type PaymentSession interface {
// loop if payment attempts take long enough. An additional set of edges can
// also be provided to assist in reaching the payment's destination.
type paymentSession struct {
selfNode route.Vertex
additionalEdges map[route.Vertex][]AdditionalEdge
getBandwidthHints func(routingGraph) (bandwidthHints, error)
@@ -192,7 +194,7 @@ type paymentSession struct {
}
// newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment,
func newPaymentSession(p *LightningPayment, selfNode route.Vertex,
getBandwidthHints func(routingGraph) (bandwidthHints, error),
getRoutingGraph func() (routingGraph, func(), error),
missionControl MissionController, pathFindingConfig PathFindingConfig) (
@@ -206,6 +208,7 @@ func newPaymentSession(p *LightningPayment,
logPrefix := fmt.Sprintf("PaymentSession(%x):", p.Identifier())
return &paymentSession{
selfNode: selfNode,
additionalEdges: edges,
getBandwidthHints: getBandwidthHints,
payment: p,
@@ -296,8 +299,6 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
p.log.Debugf("pathfinding for amt=%v", maxAmt)
sourceVertex := routingGraph.sourceNode()
// Find a route for the current amount.
path, _, err := p.pathFinder(
&graphParams{
@@ -306,7 +307,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
graph: routingGraph,
},
restrictions, &p.pathFindingConfig,
sourceVertex, p.payment.Target,
p.selfNode, p.selfNode, p.payment.Target,
maxAmt, p.payment.TimePref, finalHtlcExpiry,
)
@@ -384,7 +385,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// this into a route by applying the time-lock and fee
// requirements.
route, err := newRoute(
sourceVertex, path, height,
p.selfNode, path, height,
finalHopParams{
amt: maxAmt,
totalAmt: p.payment.Amount,

View File

@@ -73,8 +73,8 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
}
session, err := newPaymentSession(
p, getBandwidthHints, m.getRoutingGraph,
m.MissionControl, m.PathFindingConfig,
p, m.SourceNode.PubKeyBytes, getBandwidthHints,
m.getRoutingGraph, m.MissionControl, m.PathFindingConfig,
)
if err != nil {
return nil, err

View File

@@ -115,7 +115,7 @@ func TestUpdateAdditionalEdge(t *testing.T) {
// Create the paymentsession.
session, err := newPaymentSession(
payment,
payment, route.Vertex{},
func(routingGraph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
@@ -195,7 +195,7 @@ func TestRequestRoute(t *testing.T) {
}
session, err := newPaymentSession(
payment,
payment, route.Vertex{},
func(routingGraph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
@@ -211,9 +211,9 @@ func TestRequestRoute(t *testing.T) {
// Override pathfinder with a mock.
session.pathFinder = func(_ *graphParams, r *RestrictParams,
_ *PathFindingConfig, _, _ route.Vertex, _ lnwire.MilliSatoshi,
_ float64, _ int32) ([]*unifiedEdge, float64,
error) {
_ *PathFindingConfig, _, _, _ route.Vertex,
_ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge,
float64, error) {
// We expect find path to receive a cltv limit excluding the
// final cltv delta (including the block padding).

View File

@@ -2148,8 +2148,9 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64,
bandwidthHints: bandwidthHints,
graph: r.cachedGraph,
},
req.Restrictions, &r.cfg.PathFindingConfig, req.Source,
req.Target, req.Amount, req.TimePreference, finalHtlcExpiry,
req.Restrictions, &r.cfg.PathFindingConfig,
r.selfNode.PubKeyBytes, req.Source, req.Target, req.Amount,
req.TimePreference, finalHtlcExpiry,
)
if err != nil {
return nil, 0, err