multi: use single read transaction for path finding

This commit partially reverts bf27d05a.
To avoid creating multiple database transactions during a single path
finding operation, we create an explicit transaction when the cached
graph is instantiated.
We cache the source node to avoid needing to look that up for every path
finding session.
The database transaction will be nil in case of the in-memory graph.
This commit is contained in:
Oliver Gugger
2021-10-21 13:55:22 +02:00
parent 1fef2970cf
commit 0a2ccfc52b
9 changed files with 105 additions and 33 deletions

View File

@@ -308,6 +308,16 @@ func initChannelGraph(db kvdb.Backend) error {
return nil
}
// NewPathFindTx returns a new read transaction that can be used for a single
// path finding session. Will return nil if the graph cache is enabled.
func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) {
if c.graphCache != nil {
return nil, nil
}
return c.db.BeginReadTx()
}
// ForEachChannel iterates through all the channel edges stored within the
// graph and invokes the passed callback for each edge. The callback takes two
// edges as since this is a directed graph, both the in/out edges are visited.
@@ -376,7 +386,7 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo,
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex,
func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, node route.Vertex,
cb func(channel *DirectedChannel) error) error {
if c.graphCache != nil {
@@ -414,7 +424,7 @@ func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex,
return cb(directedChannel)
}
return nodeTraversal(nil, node[:], c.db, dbCallback)
return nodeTraversal(tx, node[:], c.db, dbCallback)
}
// FetchNodeFeatures returns the features of a given node. If no features are

View File

@@ -2,6 +2,7 @@ package routing
import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
@@ -25,6 +26,7 @@ type routingGraph interface {
// database.
type CachedGraph struct {
graph *channeldb.ChannelGraph
tx kvdb.RTx
source route.Vertex
}
@@ -32,27 +34,40 @@ type CachedGraph struct {
// interface.
var _ routingGraph = (*CachedGraph)(nil)
// NewCachedGraph instantiates a new db-connected routing graph. It implictly
// NewCachedGraph instantiates a new db-connected routing graph. It implicitly
// instantiates a new read transaction.
func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) {
sourceNode, err := graph.SourceNode()
func NewCachedGraph(sourceNode *channeldb.LightningNode,
graph *channeldb.ChannelGraph) (*CachedGraph, error) {
tx, err := graph.NewPathFindTx()
if err != nil {
return nil, err
}
return &CachedGraph{
graph: graph,
tx: tx,
source: sourceNode.PubKeyBytes,
}, nil
}
// close attempts to close the underlying db transaction. This is a no-op in
// case the underlying graph uses an in-memory cache.
func (g *CachedGraph) close() error {
if g.tx == nil {
return nil
}
return g.tx.Rollback()
}
// forEachNodeChannel calls the callback for every channel of the given node.
//
// NOTE: Part of the routingGraph interface.
func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error {
return g.graph.ForEachNodeChannel(nodePub, cb)
return g.graph.ForEachNodeChannel(g.tx, nodePub, cb)
}
// sourceNode returns the source node of the graph.

View File

@@ -145,7 +145,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
c.t.Fatal(err)
}
getBandwidthHints := func() (bandwidthHints, error) {
getBandwidthHints := func(_ routingGraph) (bandwidthHints, error) {
// Create bandwidth hints based on local channel balances.
bandwidthHints := map[uint64]lnwire.MilliSatoshi{}
for _, ch := range c.graph.nodes[c.source.pubkey].channels {
@@ -179,7 +179,11 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
}
session, err := newPaymentSession(
&payment, getBandwidthHints, c.graph, mc, c.pathFindingCfg,
&payment, getBandwidthHints,
func() (routingGraph, func(), error) {
return c.graph, func() {}, nil
},
mc, c.pathFindingCfg,
)
if err != nil {
c.t.Fatal(err)

View File

@@ -3060,11 +3060,22 @@ func dbFindPath(graph *channeldb.ChannelGraph,
source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {
routingGraph, err := NewCachedGraph(graph)
sourceNode, err := graph.SourceNode()
if err != nil {
return nil, err
}
routingGraph, err := NewCachedGraph(sourceNode, graph)
if err != nil {
return nil, err
}
defer func() {
if err := routingGraph.close(); err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
return findPath(
&graphParams{
additionalEdges: additionalEdges,

View File

@@ -164,7 +164,7 @@ type PaymentSession interface {
type paymentSession struct {
additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy
getBandwidthHints func() (bandwidthHints, error)
getBandwidthHints func(routingGraph) (bandwidthHints, error)
payment *LightningPayment
@@ -172,7 +172,7 @@ type paymentSession struct {
pathFinder pathFinder
routingGraph routingGraph
getRoutingGraph func() (routingGraph, func(), error)
// pathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probabiity.
@@ -192,8 +192,8 @@ type paymentSession struct {
// newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment,
getBandwidthHints func() (bandwidthHints, error),
routingGraph routingGraph,
getBandwidthHints func(routingGraph) (bandwidthHints, error),
getRoutingGraph func() (routingGraph, func(), error),
missionControl MissionController, pathFindingConfig PathFindingConfig) (
*paymentSession, error) {
@@ -209,7 +209,7 @@ func newPaymentSession(p *LightningPayment,
getBandwidthHints: getBandwidthHints,
payment: p,
pathFinder: findPath,
routingGraph: routingGraph,
getRoutingGraph: getRoutingGraph,
pathFindingConfig: pathFindingConfig,
missionControl: missionControl,
minShardAmt: DefaultShardMinAmt,
@@ -274,33 +274,42 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
}
for {
// Get a routing graph.
routingGraph, cleanup, err := p.getRoutingGraph()
if err != nil {
return nil, err
}
// We'll also obtain a set of bandwidthHints from the lower
// layer for each of our outbound channels. This will allow the
// path finding to skip any links that aren't active or just
// don't have enough bandwidth to carry the payment. New
// bandwidth hints are queried for every new path finding
// attempt, because concurrent payments may change balances.
bandwidthHints, err := p.getBandwidthHints()
bandwidthHints, err := p.getBandwidthHints(routingGraph)
if err != nil {
return nil, err
}
p.log.Debugf("pathfinding for amt=%v", maxAmt)
sourceVertex := p.routingGraph.sourceNode()
sourceVertex := routingGraph.sourceNode()
// Find a route for the current amount.
path, err := p.pathFinder(
&graphParams{
additionalEdges: p.additionalEdges,
bandwidthHints: bandwidthHints,
graph: p.routingGraph,
graph: routingGraph,
},
restrictions, &p.pathFindingConfig,
sourceVertex, p.payment.Target,
maxAmt, finalHtlcExpiry,
)
// Close routing graph.
cleanup()
switch {
case err == errNoPathFound:
// Don't split if this is a legacy payment without mpp

View File

@@ -17,7 +17,10 @@ var _ PaymentSessionSource = (*SessionSource)(nil)
type SessionSource struct {
// Graph is the channel graph that will be used to gather metrics from
// and also to carry out path finding queries.
Graph routingGraph
Graph *channeldb.ChannelGraph
// SourceNode is the graph's source node.
SourceNode *channeldb.LightningNode
// GetLink is a method that allows querying the lower link layer
// to determine the up to date available bandwidth at a prospective link
@@ -40,6 +43,21 @@ type SessionSource struct {
PathFindingConfig PathFindingConfig
}
// getRoutingGraph returns a routing graph and a clean-up function for
// pathfinding.
func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
routingTx, err := NewCachedGraph(m.SourceNode, m.Graph)
if err != nil {
return nil, nil, err
}
return routingTx, func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}, nil
}
// NewPaymentSession creates a new payment session backed by the latest prune
// view from Mission Control. An optional set of routing hints can be provided
// in order to populate additional edges to explore when finding a path to the
@@ -47,14 +65,14 @@ type SessionSource struct {
func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
PaymentSession, error) {
sourceNode := m.Graph.sourceNode()
getBandwidthHints := func() (bandwidthHints, error) {
return newBandwidthManager(m.Graph, sourceNode, m.GetLink)
getBandwidthHints := func(graph routingGraph) (bandwidthHints, error) {
return newBandwidthManager(
graph, m.SourceNode.PubKeyBytes, m.GetLink,
)
}
session, err := newPaymentSession(
p, getBandwidthHints, m.Graph,
p, getBandwidthHints, m.getRoutingGraph,
m.MissionControl, m.PathFindingConfig,
)
if err != nil {

View File

@@ -116,10 +116,12 @@ func TestUpdateAdditionalEdge(t *testing.T) {
// Create the paymentsession.
session, err := newPaymentSession(
payment,
func() (bandwidthHints, error) {
func(routingGraph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
&sessionGraph{},
func() (routingGraph, func(), error) {
return &sessionGraph{}, func() {}, nil
},
&MissionControl{},
PathFindingConfig{},
)
@@ -194,10 +196,12 @@ func TestRequestRoute(t *testing.T) {
session, err := newPaymentSession(
payment,
func() (bandwidthHints, error) {
func(routingGraph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
&sessionGraph{},
func() (routingGraph, func(), error) {
return &sessionGraph{}, func() {}, nil
},
&MissionControl{},
PathFindingConfig{},
)

View File

@@ -129,11 +129,11 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
)
require.NoError(t, err, "failed to create missioncontrol")
cachedGraph, err := NewCachedGraph(graphInstance.graph)
sourceNode, err := graphInstance.graph.SourceNode()
require.NoError(t, err)
sessionSource := &SessionSource{
Graph: cachedGraph,
Graph: graphInstance.graph,
SourceNode: sourceNode,
GetLink: graphInstance.getLink,
PathFindingConfig: pathFindingConfig,
MissionControl: mc,

View File

@@ -860,12 +860,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
MinProbability: routingConfig.MinRouteProbability,
}
cachedGraph, err := routing.NewCachedGraph(chanGraph)
sourceNode, err := chanGraph.SourceNode()
if err != nil {
return nil, err
return nil, fmt.Errorf("error getting source node: %v", err)
}
paymentSessionSource := &routing.SessionSource{
Graph: cachedGraph,
Graph: chanGraph,
SourceNode: sourceNode,
MissionControl: s.missionControl,
GetLink: s.htlcSwitch.GetLinkByShortID,
PathFindingConfig: pathFindingConfig,