routing: rename and export routingGraph

In preparation for structs outside of the `routing` package implementing
this interface, export `routingGraph` and rename it to `Graph` so as to
avoid stuttering.
This commit is contained in:
Elle Mouton
2024-06-25 19:22:00 -07:00
parent 5a903c270f
commit 3f121cbe81
10 changed files with 50 additions and 50 deletions

View File

@ -39,7 +39,7 @@ type bandwidthManager struct {
// hints for the edges we directly have open ourselves. Obtaining these hints // hints for the edges we directly have open ourselves. Obtaining these hints
// allows us to reduce the number of extraneous attempts as we can skip channels // allows us to reduce the number of extraneous attempts as we can skip channels
// that are inactive, or just don't have enough bandwidth to carry the payment. // that are inactive, or just don't have enough bandwidth to carry the payment.
func newBandwidthManager(graph routingGraph, sourceNode route.Vertex, func newBandwidthManager(graph Graph, sourceNode route.Vertex,
linkQuery getLinkQuery) (*bandwidthManager, error) { linkQuery getLinkQuery) (*bandwidthManager, error) {
manager := &bandwidthManager{ manager := &bandwidthManager{
@ -49,7 +49,7 @@ func newBandwidthManager(graph routingGraph, sourceNode route.Vertex,
// First, we'll collect the set of outbound edges from the target // First, we'll collect the set of outbound edges from the target
// source node and add them to our bandwidth manager's map of channels. // source node and add them to our bandwidth manager's map of channels.
err := graph.forEachNodeChannel(sourceNode, err := graph.ForEachNodeChannel(sourceNode,
func(channel *channeldb.DirectedChannel) error { func(channel *channeldb.DirectedChannel) error {
shortID := lnwire.NewShortChanIDFromInt( shortID := lnwire.NewShortChanIDFromInt(
channel.ChannelID, channel.ChannelID,

View File

@ -10,19 +10,19 @@ import (
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
) )
// routingGraph is an abstract interface that provides information about nodes // Graph is an abstract interface that provides information about nodes and
// and edges to pathfinding. // edges to pathfinding.
type routingGraph interface { type Graph interface {
// forEachNodeChannel calls the callback for every channel of the given // ForEachNodeChannel calls the callback for every channel of the given
// node. // node.
forEachNodeChannel(nodePub route.Vertex, ForEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error cb func(channel *channeldb.DirectedChannel) error) error
// fetchNodeFeatures returns the features of the given node. // FetchNodeFeatures returns the features of the given node.
fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
} }
// CachedGraph is a routingGraph implementation that retrieves from the // CachedGraph is a Graph implementation that retrieves from the
// database. // database.
type CachedGraph struct { type CachedGraph struct {
graph *channeldb.ChannelGraph graph *channeldb.ChannelGraph
@ -30,9 +30,9 @@ type CachedGraph struct {
source route.Vertex source route.Vertex
} }
// A compile time assertion to make sure CachedGraph implements the routingGraph // A compile time assertion to make sure CachedGraph implements the Graph
// interface. // interface.
var _ routingGraph = (*CachedGraph)(nil) var _ Graph = (*CachedGraph)(nil)
// NewCachedGraph instantiates a new db-connected routing graph. It implicitly // NewCachedGraph instantiates a new db-connected routing graph. It implicitly
// instantiates a new read transaction. // instantiates a new read transaction.
@ -61,20 +61,20 @@ func (g *CachedGraph) Close() error {
return g.tx.Rollback() return g.tx.Rollback()
} }
// forEachNodeChannel calls the callback for every channel of the given node. // ForEachNodeChannel calls the callback for every channel of the given node.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the Graph interface.
func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex, func (g *CachedGraph) ForEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error { cb func(channel *channeldb.DirectedChannel) error) error {
return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb)
} }
// fetchNodeFeatures returns the features of the given node. If the node is // FetchNodeFeatures returns the features of the given node. If the node is
// unknown, assume no additional features are supported. // unknown, assume no additional features are supported.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the Graph interface.
func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) ( func (g *CachedGraph) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) { *lnwire.FeatureVector, error) {
return g.graph.FetchNodeFeatures(nodePub) return g.graph.FetchNodeFeatures(nodePub)

View File

@ -163,7 +163,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
c.t.Fatal(err) c.t.Fatal(err)
} }
getBandwidthHints := func(_ routingGraph) (bandwidthHints, error) { getBandwidthHints := func(_ Graph) (bandwidthHints, error) {
// Create bandwidth hints based on local channel balances. // Create bandwidth hints based on local channel balances.
bandwidthHints := map[uint64]lnwire.MilliSatoshi{} bandwidthHints := map[uint64]lnwire.MilliSatoshi{}
for _, ch := range c.graph.nodes[c.source.pubkey].channels { for _, ch := range c.graph.nodes[c.source.pubkey].channels {
@ -201,7 +201,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
session, err := newPaymentSession( session, err := newPaymentSession(
&payment, c.graph.source.pubkey, getBandwidthHints, &payment, c.graph.source.pubkey, getBandwidthHints,
func() (routingGraph, func(), error) { func() (Graph, func(), error) {
return c.graph, func() {}, nil return c.graph, func() {}, nil
}, },
mc, c.pathFindingCfg, mc, c.pathFindingCfg,

View File

@ -164,8 +164,8 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte,
// forEachNodeChannel calls the callback for every channel of the given node. // forEachNodeChannel calls the callback for every channel of the given node.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the Graph interface.
func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, func (m *mockGraph) ForEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error { cb func(channel *channeldb.DirectedChannel) error) error {
// Look up the mock node. // Look up the mock node.
@ -213,15 +213,15 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,
// sourceNode returns the source node of the graph. // sourceNode returns the source node of the graph.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the Graph interface.
func (m *mockGraph) sourceNode() route.Vertex { func (m *mockGraph) sourceNode() route.Vertex {
return m.source.pubkey return m.source.pubkey
} }
// fetchNodeFeatures returns the features of the given node. // fetchNodeFeatures returns the features of the given node.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the Graph interface.
func (m *mockGraph) fetchNodeFeatures(nodePub route.Vertex) ( func (m *mockGraph) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) { *lnwire.FeatureVector, error) {
return lnwire.EmptyFeatureVector(), nil return lnwire.EmptyFeatureVector(), nil
@ -230,7 +230,7 @@ func (m *mockGraph) fetchNodeFeatures(nodePub route.Vertex) (
// FetchAmountPairCapacity returns the maximal capacity between nodes in the // FetchAmountPairCapacity returns the maximal capacity between nodes in the
// graph. // graph.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the Graph interface.
func (m *mockGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex, func (m *mockGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex,
amount lnwire.MilliSatoshi) (btcutil.Amount, error) { amount lnwire.MilliSatoshi) (btcutil.Amount, error) {
@ -244,7 +244,7 @@ func (m *mockGraph) FetchAmountPairCapacity(nodeFrom, nodeTo route.Vertex,
return nil return nil
} }
err := m.forEachNodeChannel(nodeFrom, cb) err := m.ForEachNodeChannel(nodeFrom, cb)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -295,5 +295,5 @@ func (m *mockGraph) sendHtlc(route *route.Route) (htlcResult, error) {
return source.fwd(nil, next) return source.fwd(nil, next)
} }
// Compile-time check for the routingGraph interface. // Compile-time check for the Graph interface.
var _ routingGraph = &mockGraph{} var _ Graph = &mockGraph{}

View File

@ -369,7 +369,7 @@ func edgeWeight(lockedAmt lnwire.MilliSatoshi, fee lnwire.MilliSatoshi,
// graphParams wraps the set of graph parameters passed to findPath. // graphParams wraps the set of graph parameters passed to findPath.
type graphParams struct { type graphParams struct {
// graph is the ChannelGraph to be used during path finding. // graph is the ChannelGraph to be used during path finding.
graph routingGraph graph Graph
// additionalEdges is an optional set of edges that should be // additionalEdges is an optional set of edges that should be
// considered during path finding, that is not already found in the // considered during path finding, that is not already found in the
@ -464,7 +464,7 @@ type PathFindingConfig struct {
// available balance. // available balance.
func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
bandwidthHints bandwidthHints, bandwidthHints bandwidthHints,
g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { g Graph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
var max, total lnwire.MilliSatoshi var max, total lnwire.MilliSatoshi
cb := func(channel *channeldb.DirectedChannel) error { cb := func(channel *channeldb.DirectedChannel) error {
@ -502,7 +502,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
} }
// Iterate over all channels of the to node. // Iterate over all channels of the to node.
err := g.forEachNodeChannel(node, cb) err := g.ForEachNodeChannel(node, cb)
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
@ -542,7 +542,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
features := r.DestFeatures features := r.DestFeatures
if features == nil { if features == nil {
var err error var err error
features, err = g.graph.fetchNodeFeatures(target) features, err = g.graph.FetchNodeFeatures(target)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -920,7 +920,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
} }
// Fetch node features fresh from the graph. // Fetch node features fresh from the graph.
fromFeatures, err := g.graph.fetchNodeFeatures(node) fromFeatures, err := g.graph.FetchNodeFeatures(node)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -167,7 +167,7 @@ type paymentSession struct {
additionalEdges map[route.Vertex][]AdditionalEdge additionalEdges map[route.Vertex][]AdditionalEdge
getBandwidthHints func(routingGraph) (bandwidthHints, error) getBandwidthHints func(Graph) (bandwidthHints, error)
payment *LightningPayment payment *LightningPayment
@ -175,7 +175,7 @@ type paymentSession struct {
pathFinder pathFinder pathFinder pathFinder
getRoutingGraph func() (routingGraph, func(), error) getRoutingGraph func() (Graph, func(), error)
// pathFindingConfig defines global parameters that control the // pathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probability. // trade-off in path finding between fees and probability.
@ -195,8 +195,8 @@ type paymentSession struct {
// newPaymentSession instantiates a new payment session. // newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment, selfNode route.Vertex, func newPaymentSession(p *LightningPayment, selfNode route.Vertex,
getBandwidthHints func(routingGraph) (bandwidthHints, error), getBandwidthHints func(Graph) (bandwidthHints, error),
getRoutingGraph func() (routingGraph, func(), error), getRoutingGraph func() (Graph, func(), error),
missionControl MissionController, pathFindingConfig PathFindingConfig) ( missionControl MissionController, pathFindingConfig PathFindingConfig) (
*paymentSession, error) { *paymentSession, error) {

View File

@ -46,7 +46,7 @@ type SessionSource struct {
// getRoutingGraph returns a routing graph and a clean-up function for // getRoutingGraph returns a routing graph and a clean-up function for
// pathfinding. // pathfinding.
func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { func (m *SessionSource) getRoutingGraph() (Graph, func(), error) {
routingTx, err := NewCachedGraph(m.SourceNode, m.Graph) routingTx, err := NewCachedGraph(m.SourceNode, m.Graph)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -66,7 +66,7 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
PaymentSession, error) { PaymentSession, error) {
getBandwidthHints := func(graph routingGraph) (bandwidthHints, error) { getBandwidthHints := func(graph Graph) (bandwidthHints, error) {
return newBandwidthManager( return newBandwidthManager(
graph, m.SourceNode.PubKeyBytes, m.GetLink, graph, m.SourceNode.PubKeyBytes, m.GetLink,
) )

View File

@ -116,10 +116,10 @@ func TestUpdateAdditionalEdge(t *testing.T) {
// Create the paymentsession. // Create the paymentsession.
session, err := newPaymentSession( session, err := newPaymentSession(
payment, route.Vertex{}, payment, route.Vertex{},
func(routingGraph) (bandwidthHints, error) { func(Graph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil return &mockBandwidthHints{}, nil
}, },
func() (routingGraph, func(), error) { func() (Graph, func(), error) {
return &sessionGraph{}, func() {}, nil return &sessionGraph{}, func() {}, nil
}, },
&MissionControl{}, &MissionControl{},
@ -196,10 +196,10 @@ func TestRequestRoute(t *testing.T) {
session, err := newPaymentSession( session, err := newPaymentSession(
payment, route.Vertex{}, payment, route.Vertex{},
func(routingGraph) (bandwidthHints, error) { func(Graph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil return &mockBandwidthHints{}, nil
}, },
func() (routingGraph, func(), error) { func() (Graph, func(), error) {
return &sessionGraph{}, func() {}, nil return &sessionGraph{}, func() {}, nil
}, },
&MissionControl{}, &MissionControl{},
@ -253,7 +253,7 @@ func TestRequestRoute(t *testing.T) {
} }
type sessionGraph struct { type sessionGraph struct {
routingGraph Graph
} }
func (g *sessionGraph) sourceNode() route.Vertex { func (g *sessionGraph) sourceNode() route.Vertex {

View File

@ -453,9 +453,9 @@ type ChannelRouter struct {
// when doing any path finding. // when doing any path finding.
selfNode *channeldb.LightningNode selfNode *channeldb.LightningNode
// cachedGraph is an instance of routingGraph that caches the source // cachedGraph is an instance of Graph that caches the source
// node as well as the channel graph itself in memory. // node as well as the channel graph itself in memory.
cachedGraph routingGraph cachedGraph Graph
// newBlocks is a channel in which new blocks connected to the end of // newBlocks is a channel in which new blocks connected to the end of
// the main chain are sent over, and blocks updated after a call to // the main chain are sent over, and blocks updated after a call to
@ -3177,7 +3177,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// getRouteUnifiers returns a list of edge unifiers for the given route. // getRouteUnifiers returns a list of edge unifiers for the given route.
func getRouteUnifiers(source route.Vertex, hops []route.Vertex, func getRouteUnifiers(source route.Vertex, hops []route.Vertex,
useMinAmt bool, runningAmt lnwire.MilliSatoshi, useMinAmt bool, runningAmt lnwire.MilliSatoshi,
outgoingChans map[uint64]struct{}, graph routingGraph, outgoingChans map[uint64]struct{}, graph Graph,
bandwidthHints *bandwidthManager) ([]*edgeUnifier, lnwire.MilliSatoshi, bandwidthHints *bandwidthManager) ([]*edgeUnifier, lnwire.MilliSatoshi,
error) { error) {

View File

@ -94,7 +94,7 @@ func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex,
// addGraphPolicies adds all policies that are known for the toNode in the // addGraphPolicies adds all policies that are known for the toNode in the
// graph. // graph.
func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error { func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error {
cb := func(channel *channeldb.DirectedChannel) error { cb := func(channel *channeldb.DirectedChannel) error {
// If there is no edge policy for this candidate node, skip. // If there is no edge policy for this candidate node, skip.
// Note that we are searching backwards so this node would have // Note that we are searching backwards so this node would have
@ -120,7 +120,7 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error {
} }
// Iterate over all channels of the to node. // Iterate over all channels of the to node.
return g.forEachNodeChannel(u.toNode, cb) return g.ForEachNodeChannel(u.toNode, cb)
} }
// unifiedEdge is the individual channel data that is kept inside an edgeUnifier // unifiedEdge is the individual channel data that is kept inside an edgeUnifier