From 90d6b863a8007153f3a296fb0b9f5ff4771a04d8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 25 Jun 2024 19:27:13 -0700 Subject: [PATCH] routing+refactor: remove the need to give CachedGraph source node access In preparation for the next commit. --- routing/graph.go | 20 ++++++++------------ routing/mock_graph_test.go | 25 ------------------------- routing/pathfind_test.go | 2 +- routing/payment_session_source.go | 2 +- routing/router.go | 3 +-- rpcserver.go | 9 ++++----- 6 files changed, 15 insertions(+), 46 deletions(-) diff --git a/routing/graph.go b/routing/graph.go index 1f4b24bb5..0c4d2e1d4 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -25,9 +25,8 @@ type Graph interface { // CachedGraph is a Graph implementation that retrieves from the // database. type CachedGraph struct { - graph *channeldb.ChannelGraph - tx kvdb.RTx - source route.Vertex + graph *channeldb.ChannelGraph + tx kvdb.RTx } // A compile time assertion to make sure CachedGraph implements the Graph @@ -36,18 +35,15 @@ var _ Graph = (*CachedGraph)(nil) // NewCachedGraph instantiates a new db-connected routing graph. It implicitly // instantiates a new read transaction. -func NewCachedGraph(sourceNode *channeldb.LightningNode, - graph *channeldb.ChannelGraph) (*CachedGraph, error) { - +func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) { tx, err := graph.NewPathFindTx() if err != nil { return nil, err } return &CachedGraph{ - graph: graph, - tx: tx, - source: sourceNode.PubKeyBytes, + graph: graph, + tx: tx, }, nil } @@ -82,16 +78,16 @@ func (g *CachedGraph) FetchNodeFeatures(nodePub route.Vertex) ( // FetchAmountPairCapacity determines the maximal public capacity between two // nodes depending on the amount we try to send. -func (g *CachedGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, +func FetchAmountPairCapacity(graph Graph, source, nodeFrom, nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { // Create unified edges for all incoming connections. // // Note: Inbound fees are not used here because this method is only used // by a deprecated router rpc. - u := newNodeEdgeUnifier(g.source, nodeTo, false, nil) + u := newNodeEdgeUnifier(source, nodeTo, false, nil) - err := u.addGraphPolicies(g) + err := u.addGraphPolicies(graph) if err != nil { return 0, err } diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index 348eb3746..de0341234 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -227,31 +227,6 @@ func (m *mockGraph) FetchNodeFeatures(nodePub route.Vertex) ( return lnwire.EmptyFeatureVector(), nil } -// FetchAmountPairCapacity returns the maximal capacity between nodes in the -// graph. -// -// NOTE: Part of the Graph interface. -func (m *mockGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, - amount lnwire.MilliSatoshi) (btcutil.Amount, error) { - - var capacity btcutil.Amount - - cb := func(channel *channeldb.DirectedChannel) error { - if channel.OtherNode == nodeTo { - capacity = channel.Capacity - } - - return nil - } - - err := m.ForEachNodeChannel(nodeFrom, cb) - if err != nil { - return 0, err - } - - return capacity, nil -} - // htlcResult describes the resolution of an htlc. If failure is nil, the htlc // was settled. type htlcResult struct { diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index a35c9e2f7..0eea8edc0 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3201,7 +3201,7 @@ func dbFindPath(graph *channeldb.ChannelGraph, return nil, err } - routingGraph, err := NewCachedGraph(sourceNode, graph) + routingGraph, err := NewCachedGraph(graph) if err != nil { return nil, err } diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 51bfc9781..cc90c465b 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -47,7 +47,7 @@ type SessionSource struct { // getRoutingGraph returns a routing graph and a clean-up function for // pathfinding. func (m *SessionSource) getRoutingGraph() (Graph, func(), error) { - routingTx, err := NewCachedGraph(m.SourceNode, m.Graph) + routingTx, err := NewCachedGraph(m.Graph) if err != nil { return nil, nil, err } diff --git a/routing/router.go b/routing/router.go index 9af047f4e..9af37a543 100644 --- a/routing/router.go +++ b/routing/router.go @@ -517,8 +517,7 @@ func New(cfg Config) (*ChannelRouter, error) { r := &ChannelRouter{ cfg: &cfg, cachedGraph: &CachedGraph{ - graph: cfg.Graph, - source: selfNode.PubKeyBytes, + graph: cfg.Graph, }, networkUpdates: make(chan *routingMsg), topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{}, diff --git a/rpcserver.go b/rpcserver.go index a306d4fd2..59e3196fe 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -691,9 +691,7 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, FetchAmountPairCapacity: func(nodeFrom, nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { - routingGraph, err := routing.NewCachedGraph( - selfNode, graph, - ) + routingGraph, err := routing.NewCachedGraph(graph) if err != nil { return 0, err } @@ -706,8 +704,9 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, } }() - return routingGraph.FetchAmountPairCapacity( - nodeFrom, nodeTo, amount, + return routing.FetchAmountPairCapacity( + routingGraph, selfNode.PubKeyBytes, nodeFrom, + nodeTo, amount, ) }, FetchChannelEndpoints: func(chanID uint64) (route.Vertex,