From 0a2ccfc52b68725b10e97502f6ea480e3a5f5d14 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 21 Oct 2021 13:55:22 +0200 Subject: [PATCH] multi: use single read transaction for path finding This commit partially reverts bf27d05a. To avoid creating multiple database transactions during a single path finding operation, we create an explicit transaction when the cached graph is instantiated. We cache the source node to avoid needing to look that up for every path finding session. The database transaction will be nil in case of the in-memory graph. --- channeldb/graph.go | 14 ++++++++-- routing/graph.go | 23 ++++++++++++++--- routing/integrated_routing_context_test.go | 8 ++++-- routing/pathfind_test.go | 13 +++++++++- routing/payment_session.go | 25 ++++++++++++------ routing/payment_session_source.go | 30 +++++++++++++++++----- routing/payment_session_test.go | 12 ++++++--- routing/router_test.go | 6 ++--- server.go | 7 ++--- 9 files changed, 105 insertions(+), 33 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index daf6ae6d2..d1bb85b17 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -308,6 +308,16 @@ func initChannelGraph(db kvdb.Backend) error { return nil } +// NewPathFindTx returns a new read transaction that can be used for a single +// path finding session. Will return nil if the graph cache is enabled. +func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { + if c.graphCache != nil { + return nil, nil + } + + return c.db.BeginReadTx() +} + // ForEachChannel iterates through all the channel edges stored within the // graph and invokes the passed callback for each edge. The callback takes two // edges as since this is a directed graph, both the in/out edges are visited. @@ -376,7 +386,7 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. -func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex, +func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, node route.Vertex, cb func(channel *DirectedChannel) error) error { if c.graphCache != nil { @@ -414,7 +424,7 @@ func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex, return cb(directedChannel) } - return nodeTraversal(nil, node[:], c.db, dbCallback) + return nodeTraversal(tx, node[:], c.db, dbCallback) } // FetchNodeFeatures returns the features of a given node. If no features are diff --git a/routing/graph.go b/routing/graph.go index 7e0ba65b2..54ddd46b1 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -2,6 +2,7 @@ package routing import ( "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -25,6 +26,7 @@ type routingGraph interface { // database. type CachedGraph struct { graph *channeldb.ChannelGraph + tx kvdb.RTx source route.Vertex } @@ -32,27 +34,40 @@ type CachedGraph struct { // interface. var _ routingGraph = (*CachedGraph)(nil) -// NewCachedGraph instantiates a new db-connected routing graph. It implictly +// NewCachedGraph instantiates a new db-connected routing graph. It implicitly // instantiates a new read transaction. -func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) { - sourceNode, err := graph.SourceNode() +func NewCachedGraph(sourceNode *channeldb.LightningNode, + graph *channeldb.ChannelGraph) (*CachedGraph, error) { + + tx, err := graph.NewPathFindTx() if err != nil { return nil, err } return &CachedGraph{ graph: graph, + tx: tx, source: sourceNode.PubKeyBytes, }, nil } +// close attempts to close the underlying db transaction. This is a no-op in +// case the underlying graph uses an in-memory cache. +func (g *CachedGraph) close() error { + if g.tx == nil { + return nil + } + + return g.tx.Rollback() +} + // forEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the routingGraph interface. func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error { - return g.graph.ForEachNodeChannel(nodePub, cb) + return g.graph.ForEachNodeChannel(g.tx, nodePub, cb) } // sourceNode returns the source node of the graph. diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 79fa37080..bbca1975a 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -145,7 +145,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, c.t.Fatal(err) } - getBandwidthHints := func() (bandwidthHints, error) { + getBandwidthHints := func(_ routingGraph) (bandwidthHints, error) { // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { @@ -179,7 +179,11 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, } session, err := newPaymentSession( - &payment, getBandwidthHints, c.graph, mc, c.pathFindingCfg, + &payment, getBandwidthHints, + func() (routingGraph, func(), error) { + return c.graph, func() {}, nil + }, + mc, c.pathFindingCfg, ) if err != nil { c.t.Fatal(err) diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 768729f33..923bdd5d6 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3060,11 +3060,22 @@ func dbFindPath(graph *channeldb.ChannelGraph, source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { - routingGraph, err := NewCachedGraph(graph) + sourceNode, err := graph.SourceNode() if err != nil { return nil, err } + routingGraph, err := NewCachedGraph(sourceNode, graph) + if err != nil { + return nil, err + } + + defer func() { + if err := routingGraph.close(); err != nil { + log.Errorf("Error closing db tx: %v", err) + } + }() + return findPath( &graphParams{ additionalEdges: additionalEdges, diff --git a/routing/payment_session.go b/routing/payment_session.go index 8895d28fe..4d593113f 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -164,7 +164,7 @@ type PaymentSession interface { type paymentSession struct { additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy - getBandwidthHints func() (bandwidthHints, error) + getBandwidthHints func(routingGraph) (bandwidthHints, error) payment *LightningPayment @@ -172,7 +172,7 @@ type paymentSession struct { pathFinder pathFinder - routingGraph routingGraph + getRoutingGraph func() (routingGraph, func(), error) // pathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probabiity. @@ -192,8 +192,8 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, - getBandwidthHints func() (bandwidthHints, error), - routingGraph routingGraph, + getBandwidthHints func(routingGraph) (bandwidthHints, error), + getRoutingGraph func() (routingGraph, func(), error), missionControl MissionController, pathFindingConfig PathFindingConfig) ( *paymentSession, error) { @@ -209,7 +209,7 @@ func newPaymentSession(p *LightningPayment, getBandwidthHints: getBandwidthHints, payment: p, pathFinder: findPath, - routingGraph: routingGraph, + getRoutingGraph: getRoutingGraph, pathFindingConfig: pathFindingConfig, missionControl: missionControl, minShardAmt: DefaultShardMinAmt, @@ -274,33 +274,42 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, } for { + // Get a routing graph. + routingGraph, cleanup, err := p.getRoutingGraph() + if err != nil { + return nil, err + } + // We'll also obtain a set of bandwidthHints from the lower // layer for each of our outbound channels. This will allow the // path finding to skip any links that aren't active or just // don't have enough bandwidth to carry the payment. New // bandwidth hints are queried for every new path finding // attempt, because concurrent payments may change balances. - bandwidthHints, err := p.getBandwidthHints() + bandwidthHints, err := p.getBandwidthHints(routingGraph) if err != nil { return nil, err } p.log.Debugf("pathfinding for amt=%v", maxAmt) - sourceVertex := p.routingGraph.sourceNode() + sourceVertex := routingGraph.sourceNode() // Find a route for the current amount. path, err := p.pathFinder( &graphParams{ additionalEdges: p.additionalEdges, bandwidthHints: bandwidthHints, - graph: p.routingGraph, + graph: routingGraph, }, restrictions, &p.pathFindingConfig, sourceVertex, p.payment.Target, maxAmt, finalHtlcExpiry, ) + // Close routing graph. + cleanup() + switch { case err == errNoPathFound: // Don't split if this is a legacy payment without mpp diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 6889d0a17..930d69e00 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -17,7 +17,10 @@ var _ PaymentSessionSource = (*SessionSource)(nil) type SessionSource struct { // Graph is the channel graph that will be used to gather metrics from // and also to carry out path finding queries. - Graph routingGraph + Graph *channeldb.ChannelGraph + + // SourceNode is the graph's source node. + SourceNode *channeldb.LightningNode // GetLink is a method that allows querying the lower link layer // to determine the up to date available bandwidth at a prospective link @@ -40,6 +43,21 @@ type SessionSource struct { PathFindingConfig PathFindingConfig } +// getRoutingGraph returns a routing graph and a clean-up function for +// pathfinding. +func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { + routingTx, err := NewCachedGraph(m.SourceNode, m.Graph) + if err != nil { + return nil, nil, err + } + return routingTx, func() { + err := routingTx.close() + if err != nil { + log.Errorf("Error closing db tx: %v", err) + } + }, nil +} + // NewPaymentSession creates a new payment session backed by the latest prune // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the @@ -47,14 +65,14 @@ type SessionSource struct { func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( PaymentSession, error) { - sourceNode := m.Graph.sourceNode() - - getBandwidthHints := func() (bandwidthHints, error) { - return newBandwidthManager(m.Graph, sourceNode, m.GetLink) + getBandwidthHints := func(graph routingGraph) (bandwidthHints, error) { + return newBandwidthManager( + graph, m.SourceNode.PubKeyBytes, m.GetLink, + ) } session, err := newPaymentSession( - p, getBandwidthHints, m.Graph, + p, getBandwidthHints, m.getRoutingGraph, m.MissionControl, m.PathFindingConfig, ) if err != nil { diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index f177da730..11823d4c7 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -116,10 +116,12 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create the paymentsession. session, err := newPaymentSession( payment, - func() (bandwidthHints, error) { + func(routingGraph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - &sessionGraph{}, + func() (routingGraph, func(), error) { + return &sessionGraph{}, func() {}, nil + }, &MissionControl{}, PathFindingConfig{}, ) @@ -194,10 +196,12 @@ func TestRequestRoute(t *testing.T) { session, err := newPaymentSession( payment, - func() (bandwidthHints, error) { + func(routingGraph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - &sessionGraph{}, + func() (routingGraph, func(), error) { + return &sessionGraph{}, func() {}, nil + }, &MissionControl{}, PathFindingConfig{}, ) diff --git a/routing/router_test.go b/routing/router_test.go index 77681b2a6..f77c3059c 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -129,11 +129,11 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, ) require.NoError(t, err, "failed to create missioncontrol") - cachedGraph, err := NewCachedGraph(graphInstance.graph) + sourceNode, err := graphInstance.graph.SourceNode() require.NoError(t, err) - sessionSource := &SessionSource{ - Graph: cachedGraph, + Graph: graphInstance.graph, + SourceNode: sourceNode, GetLink: graphInstance.getLink, PathFindingConfig: pathFindingConfig, MissionControl: mc, diff --git a/server.go b/server.go index e2eb3039f..354b90dba 100644 --- a/server.go +++ b/server.go @@ -860,12 +860,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MinProbability: routingConfig.MinRouteProbability, } - cachedGraph, err := routing.NewCachedGraph(chanGraph) + sourceNode, err := chanGraph.SourceNode() if err != nil { - return nil, err + return nil, fmt.Errorf("error getting source node: %v", err) } paymentSessionSource := &routing.SessionSource{ - Graph: cachedGraph, + Graph: chanGraph, + SourceNode: sourceNode, MissionControl: s.missionControl, GetLink: s.htlcSwitch.GetLinkByShortID, PathFindingConfig: pathFindingConfig,