diff --git a/routing/graph.go b/routing/graph.go index 578f480ab..7e0ba65b2 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -9,7 +9,8 @@ import ( // routingGraph is an abstract interface that provides information about nodes // and edges to pathfinding. type routingGraph interface { - // forEachNodeChannel calls the callback for every channel of the given node. + // forEachNodeChannel calls the callback for every channel of the given + // node. forEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error @@ -20,22 +21,26 @@ type routingGraph interface { fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) } -// dbRoutingTx is a routingGraph implementation that retrieves from the +// CachedGraph is a routingGraph implementation that retrieves from the // database. -type dbRoutingTx struct { +type CachedGraph struct { graph *channeldb.ChannelGraph source route.Vertex } -// newDbRoutingTx instantiates a new db-connected routing graph. It implictly +// A compile time assertion to make sure CachedGraph implements the routingGraph +// interface. +var _ routingGraph = (*CachedGraph)(nil) + +// NewCachedGraph instantiates a new db-connected routing graph. It implictly // instantiates a new read transaction. -func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { +func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) { sourceNode, err := graph.SourceNode() if err != nil { return nil, err } - return &dbRoutingTx{ + return &CachedGraph{ graph: graph, source: sourceNode.PubKeyBytes, }, nil @@ -44,7 +49,7 @@ func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { // forEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the routingGraph interface. -func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, +func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error { return g.graph.ForEachNodeChannel(nodePub, cb) @@ -53,7 +58,7 @@ func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, // sourceNode returns the source node of the graph. // // NOTE: Part of the routingGraph interface. -func (g *dbRoutingTx) sourceNode() route.Vertex { +func (g *CachedGraph) sourceNode() route.Vertex { return g.source } @@ -61,7 +66,7 @@ func (g *dbRoutingTx) sourceNode() route.Vertex { // unknown, assume no additional features are supported. // // NOTE: Part of the routingGraph interface. -func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( +func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { return g.graph.FetchNodeFeatures(nodePub) diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 114b2272e..d13b1c432 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -162,11 +162,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, } session, err := newPaymentSession( - &payment, getBandwidthHints, - func() (routingGraph, func(), error) { - return c.graph, func() {}, nil - }, - mc, c.pathFindingCfg, + &payment, getBandwidthHints, c.graph, mc, c.pathFindingCfg, ) if err != nil { c.t.Fatal(err) diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index d29c096fd..6d0156666 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -182,10 +182,10 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, // Call the per channel callback. err := cb( &channeldb.DirectedChannel{ - ChannelID: channel.id, - IsNode1: nodePub == node1, - OtherNode: peer, - Capacity: channel.capacity, + ChannelID: channel.id, + IsNode1: nodePub == node1, + OtherNode: peer, + Capacity: channel.capacity, OutPolicySet: true, InPolicy: &channeldb.CachedEdgePolicy{ ChannelID: channel.id, @@ -193,7 +193,7 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, return nodePub }, ToNodeFeatures: lnwire.EmptyFeatureVector(), - FeeBaseMSat: peerNode.baseFee, + FeeBaseMSat: peerNode.baseFee, }, }, ) diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 7c7c7586b..d438de824 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3021,7 +3021,7 @@ func dbFindPath(graph *channeldb.ChannelGraph, source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { - routingTx, err := newDbRoutingTx(graph) + routingGraph, err := NewCachedGraph(graph) if err != nil { return nil, err } @@ -3030,7 +3030,7 @@ func dbFindPath(graph *channeldb.ChannelGraph, &graphParams{ additionalEdges: additionalEdges, bandwidthHints: bandwidthHints, - graph: routingTx, + graph: routingGraph, }, r, cfg, source, target, amt, finalHtlcExpiry, ) diff --git a/routing/payment_session.go b/routing/payment_session.go index d3024d3ff..bbf9b6f96 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -172,7 +172,7 @@ type paymentSession struct { pathFinder pathFinder - getRoutingGraph func() (routingGraph, func(), error) + routingGraph routingGraph // pathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probabiity. @@ -193,7 +193,7 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error), - getRoutingGraph func() (routingGraph, func(), error), + routingGraph routingGraph, missionControl MissionController, pathFindingConfig PathFindingConfig) ( *paymentSession, error) { @@ -209,7 +209,7 @@ func newPaymentSession(p *LightningPayment, getBandwidthHints: getBandwidthHints, payment: p, pathFinder: findPath, - getRoutingGraph: getRoutingGraph, + routingGraph: routingGraph, pathFindingConfig: pathFindingConfig, missionControl: missionControl, minShardAmt: DefaultShardMinAmt, @@ -287,29 +287,20 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, p.log.Debugf("pathfinding for amt=%v", maxAmt) - // Get a routing graph. - routingGraph, cleanup, err := p.getRoutingGraph() - if err != nil { - return nil, err - } - - sourceVertex := routingGraph.sourceNode() + sourceVertex := p.routingGraph.sourceNode() // Find a route for the current amount. path, err := p.pathFinder( &graphParams{ additionalEdges: p.additionalEdges, bandwidthHints: bandwidthHints, - graph: routingGraph, + graph: p.routingGraph, }, restrictions, &p.pathFindingConfig, sourceVertex, p.payment.Target, maxAmt, finalHtlcExpiry, ) - // Close routing graph. - cleanup() - switch { case err == errNoPathFound: // Don't split if this is a legacy payment without mpp diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index fdfccd5f1..d688f9814 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -17,7 +17,7 @@ var _ PaymentSessionSource = (*SessionSource)(nil) type SessionSource struct { // Graph is the channel graph that will be used to gather metrics from // and also to carry out path finding queries. - Graph *channeldb.ChannelGraph + Graph routingGraph // QueryBandwidth is a method that allows querying the lower link layer // to determine the up to date available bandwidth at a prospective link @@ -40,16 +40,6 @@ type SessionSource struct { PathFindingConfig PathFindingConfig } -// getRoutingGraph returns a routing graph and a clean-up function for -// pathfinding. -func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { - routingTx, err := newDbRoutingTx(m.Graph) - if err != nil { - return nil, nil, err - } - return routingTx, func() {}, nil -} - // NewPaymentSession creates a new payment session backed by the latest prune // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the @@ -57,21 +47,16 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( PaymentSession, error) { - sourceNode, err := m.Graph.SourceNode() - if err != nil { - return nil, err - } - getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, error) { return generateBandwidthHints( - sourceNode.PubKeyBytes, m.Graph, m.QueryBandwidth, + m.Graph.sourceNode(), m.Graph, m.QueryBandwidth, ) } session, err := newPaymentSession( - p, getBandwidthHints, m.getRoutingGraph, + p, getBandwidthHints, m.Graph, m.MissionControl, m.PathFindingConfig, ) if err != nil { diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index bcfc3b0e9..dae331f84 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -121,9 +121,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { return nil, nil }, - func() (routingGraph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + &sessionGraph{}, &MissionControl{}, PathFindingConfig{}, ) @@ -203,9 +201,7 @@ func TestRequestRoute(t *testing.T) { return nil, nil }, - func() (routingGraph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + &sessionGraph{}, &MissionControl{}, PathFindingConfig{}, ) diff --git a/routing/router.go b/routing/router.go index 1de113056..dd8a375a2 100644 --- a/routing/router.go +++ b/routing/router.go @@ -406,6 +406,10 @@ type ChannelRouter struct { // when doing any path finding. selfNode *channeldb.LightningNode + // cachedGraph is an instance of routingGraph that caches the source node as + // well as the channel graph itself in memory. + cachedGraph routingGraph + // newBlocks is a channel in which new blocks connected to the end of // the main chain are sent over, and blocks updated after a call to // UpdateFilter. @@ -460,14 +464,17 @@ var _ ChannelGraphSource = (*ChannelRouter)(nil) // channel graph is a subset of the UTXO set) set, then the router will proceed // to fully sync to the latest state of the UTXO set. func New(cfg Config) (*ChannelRouter, error) { - selfNode, err := cfg.Graph.SourceNode() if err != nil { return nil, err } r := &ChannelRouter{ - cfg: &cfg, + cfg: &cfg, + cachedGraph: &CachedGraph{ + graph: cfg.Graph, + source: selfNode.PubKeyBytes, + }, networkUpdates: make(chan *routingMsg), topologyClients: make(map[uint64]*topologyClient), ntfnClientUpdates: make(chan *topologyClientUpdate), @@ -1735,7 +1742,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := generateBandwidthHints( - r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err @@ -1752,16 +1759,11 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // execute our path finding algorithm. finalHtlcExpiry := currentHeight + int32(finalExpiry) - routingTx, err := newDbRoutingTx(r.cfg.Graph) - if err != nil { - return nil, err - } - path, err := findPath( &graphParams{ additionalEdges: routeHints, bandwidthHints: bandwidthHints, - graph: routingTx, + graph: r.cachedGraph, }, restrictions, &r.cfg.PathFindingConfig, @@ -2657,14 +2659,14 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error { // these hints allows us to reduce the number of extraneous attempts as we can // skip channels that are inactive, or just don't have enough bandwidth to // carry the payment. -func generateBandwidthHints(sourceNode route.Vertex, graph *channeldb.ChannelGraph, +func generateBandwidthHints(sourceNode route.Vertex, graph routingGraph, queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) ( map[uint64]lnwire.MilliSatoshi, error) { // First, we'll collect the set of outbound edges from the target // source node. var localChans []*channeldb.DirectedChannel - err := graph.ForEachNodeChannel( + err := graph.forEachNodeChannel( sourceNode, func(channel *channeldb.DirectedChannel) error { localChans = append(localChans, channel) return nil @@ -2722,7 +2724,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := generateBandwidthHints( - r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err @@ -2752,12 +2754,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, runningAmt = *amt } - // Open a transaction to execute the graph queries in. - routingTx, err := newDbRoutingTx(r.cfg.Graph) - if err != nil { - return nil, err - } - // Traverse hops backwards to accumulate fees in the running amounts. source := r.selfNode.PubKeyBytes for i := len(hops) - 1; i >= 0; i-- { @@ -2776,7 +2772,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // known in the graph. u := newUnifiedPolicies(source, toNode, outgoingChans) - err := u.addGraphPolicies(routingTx) + err := u.addGraphPolicies(r.cachedGraph) if err != nil { return nil, err } diff --git a/routing/router_test.go b/routing/router_test.go index ed6bfdc6a..4b5dd505f 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -129,8 +129,11 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, ) require.NoError(t, err, "failed to create missioncontrol") + cachedGraph, err := NewCachedGraph(graphInstance.graph) + require.NoError(t, err) + sessionSource := &SessionSource{ - Graph: graphInstance.graph, + Graph: cachedGraph, QueryBandwidth: func( c *channeldb.DirectedChannel) lnwire.MilliSatoshi { diff --git a/server.go b/server.go index 55e49d9f2..5531a6b33 100644 --- a/server.go +++ b/server.go @@ -776,8 +776,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MinProbability: routingConfig.MinRouteProbability, } + cachedGraph, err := routing.NewCachedGraph(chanGraph) + if err != nil { + return nil, err + } paymentSessionSource := &routing.SessionSource{ - Graph: chanGraph, + Graph: cachedGraph, MissionControl: s.missionControl, QueryBandwidth: queryBandwidth, PathFindingConfig: pathFindingConfig,