From c5cc6f1392279d9a5ee3900ee4b63fd7ec7876e1 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 11 Nov 2024 11:15:46 +0200 Subject: [PATCH] routing: remove context.TODOs --- graph/session/graph_session.go | 18 ++-- lnrpc/invoicesrpc/addinvoice.go | 3 +- lnrpc/routerrpc/router_backend.go | 8 +- lnrpc/routerrpc/router_backend_test.go | 7 +- lnrpc/routerrpc/router_server.go | 12 +-- lnrpc/routerrpc/router_server_deprecated.go | 4 +- routing/bandwidth.go | 8 +- routing/bandwidth_test.go | 4 +- routing/blindedpath/blinded_path.go | 5 +- routing/blindedpath/blinded_path_test.go | 16 +-- routing/graph.go | 15 +-- routing/integrated_routing_context_test.go | 26 +++-- routing/mock_graph_test.go | 5 +- routing/mock_test.go | 14 +-- routing/pathfind.go | 39 +++---- routing/pathfind_test.go | 15 +-- routing/payment_lifecycle.go | 6 +- routing/payment_lifecycle_test.go | 34 ++++--- routing/payment_session.go | 19 ++-- routing/payment_session_source.go | 8 +- routing/payment_session_test.go | 12 ++- routing/router.go | 30 +++--- routing/router_test.go | 107 +++++++++++--------- routing/unified_edges.go | 5 +- rpcserver.go | 11 +- 25 files changed, 242 insertions(+), 189 deletions(-) diff --git a/graph/session/graph_session.go b/graph/session/graph_session.go index ede5542c9..d1617bb59 100644 --- a/graph/session/graph_session.go +++ b/graph/session/graph_session.go @@ -30,8 +30,10 @@ func NewGraphSessionFactory(graph ReadOnlyGraph) routing.GraphSessionFactory { // was created at Graph construction time. // // NOTE: This is part of the routing.GraphSessionFactory interface. -func (g *Factory) NewGraphSession() (routing.Graph, func() error, error) { - tx, err := g.graph.NewPathFindTx(context.TODO()) +func (g *Factory) NewGraphSession(ctx context.Context) (routing.Graph, + func() error, error) { + + tx, err := g.graph.NewPathFindTx(ctx) if err != nil { return nil, nil, err } @@ -83,22 +85,20 @@ func (g *session) close() error { // ForEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the routing.Graph interface. -func (g *session) ForEachNodeChannel(nodePub route.Vertex, +func (g *session) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, cb func(channel *graphdb.DirectedChannel) error) error { - return g.graph.ForEachNodeDirectedChannel( - context.TODO(), g.tx, nodePub, cb, - ) + return g.graph.ForEachNodeDirectedChannel(ctx, g.tx, nodePub, cb) } // FetchNodeFeatures returns the features of the given node. If the node is // unknown, assume no additional features are supported. // // NOTE: Part of the routing.Graph interface. -func (g *session) FetchNodeFeatures(nodePub route.Vertex) ( - *lnwire.FeatureVector, error) { +func (g *session) FetchNodeFeatures(ctx context.Context, + nodePub route.Vertex) (*lnwire.FeatureVector, error) { - return g.graph.FetchNodeFeatures(context.TODO(), g.tx, nodePub) + return g.graph.FetchNodeFeatures(ctx, g.tx, nodePub) } // A compile-time check to ensure that *session implements the diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index d9a4d7afa..c69db4f5c 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -96,7 +96,8 @@ type AddInvoiceConfig struct { // QueryBlindedRoutes can be used to generate a few routes to this node // that can then be used in the construction of a blinded payment path. - QueryBlindedRoutes func(lnwire.MilliSatoshi) ([]*route.Route, error) + QueryBlindedRoutes func(context.Context, lnwire.MilliSatoshi) ( + []*route.Route, error) } // AddInvoiceData contains the required data to create a new invoice. diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 9421e991b..5aabbf618 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -50,7 +50,8 @@ type RouterBackend struct { // FetchAmountPairCapacity determines the maximal channel capacity // between two nodes given a certain amount. - FetchAmountPairCapacity func(nodeFrom, nodeTo route.Vertex, + FetchAmountPairCapacity func(ctx context.Context, nodeFrom, + nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) // FetchChannelEndpoints returns the pubkeys of both endpoints of the @@ -60,7 +61,8 @@ type RouterBackend struct { // FindRoute is a closure that abstracts away how we locate/query for // routes. - FindRoute func(*routing.RouteRequest) (*route.Route, float64, error) + FindRoute func(context.Context, *routing.RouteRequest) (*route.Route, + float64, error) MissionControl MissionControl @@ -169,7 +171,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, // Query the channel router for a possible path to the destination that // can carry `in.Amt` satoshis _including_ the total fee required on // the route - route, successProb, err := r.FindRoute(routeReq) + route, successProb, err := r.FindRoute(ctx, routeReq) if err != nil { return nil, err } diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index 877a3cc17..78e25f7e5 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -120,8 +120,8 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, } } - findRoute := func(req *routing.RouteRequest) (*route.Route, float64, - error) { + findRoute := func(_ context.Context, req *routing.RouteRequest) ( + *route.Route, float64, error) { if int64(req.Amount) != amtSat*1000 { t.Fatal("unexpected amount") @@ -200,7 +200,8 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, return 1, nil }, - FetchAmountPairCapacity: func(nodeFrom, nodeTo route.Vertex, + FetchAmountPairCapacity: func(_ context.Context, nodeFrom, + nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { return 1, nil diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index a4112ba64..b4460d87d 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -426,7 +426,7 @@ func (s *Server) EstimateRouteFee(ctx context.Context, return nil, errors.New("amount must be greater than 0") default: - return s.probeDestination(req.Dest, req.AmtSat) + return s.probeDestination(ctx, req.Dest, req.AmtSat) } case isProbeInvoice: @@ -440,8 +440,8 @@ func (s *Server) EstimateRouteFee(ctx context.Context, // probeDestination estimates fees along a route to a destination based on the // contents of the local graph. -func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse, - error) { +func (s *Server) probeDestination(ctx context.Context, dest []byte, + amtSat int64) (*RouteFeeResponse, error) { destNode, err := route.NewVertexFromBytes(dest) if err != nil { @@ -469,7 +469,7 @@ func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse, return nil, err } - route, _, err := s.cfg.Router.FindRoute(routeReq) + route, _, err := s.cfg.Router.FindRoute(ctx, routeReq) if err != nil { return nil, err } @@ -1429,7 +1429,7 @@ func (s *Server) trackPaymentStream(context context.Context, } // BuildRoute builds a route from a list of hop addresses. -func (s *Server) BuildRoute(_ context.Context, +func (s *Server) BuildRoute(ctx context.Context, req *BuildRouteRequest) (*BuildRouteResponse, error) { if len(req.HopPubkeys) == 0 { @@ -1490,7 +1490,7 @@ func (s *Server) BuildRoute(_ context.Context, // Build the route and return it to the caller. route, err := s.cfg.Router.BuildRoute( - amt, hops, outgoingChan, req.FinalCltvDelta, payAddr, + ctx, amt, hops, outgoingChan, req.FinalCltvDelta, payAddr, firstHopBlob, ) if err != nil { diff --git a/lnrpc/routerrpc/router_server_deprecated.go b/lnrpc/routerrpc/router_server_deprecated.go index 7be1e3d91..fee5bcab5 100644 --- a/lnrpc/routerrpc/router_server_deprecated.go +++ b/lnrpc/routerrpc/router_server_deprecated.go @@ -123,7 +123,7 @@ func (s *Server) SendToRoute(ctx context.Context, // QueryProbability returns the current success probability estimate for a // given node pair and amount. -func (s *Server) QueryProbability(_ context.Context, +func (s *Server) QueryProbability(ctx context.Context, req *QueryProbabilityRequest) (*QueryProbabilityResponse, error) { fromNode, err := route.NewVertexFromBytes(req.FromNode) @@ -142,7 +142,7 @@ func (s *Server) QueryProbability(_ context.Context, var prob float64 mc := s.cfg.RouterBackend.MissionControl capacity, err := s.cfg.RouterBackend.FetchAmountPairCapacity( - fromNode, toNode, amt, + ctx, fromNode, toNode, amt, ) // If we cannot query the capacity this means that either we don't have diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 12e82131d..186417327 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "github.com/lightningnetwork/lnd/fn" @@ -82,8 +83,9 @@ type bandwidthManager struct { // hints for the edges we directly have open ourselves. Obtaining these hints // allows us to reduce the number of extraneous attempts as we can skip channels // that are inactive, or just don't have enough bandwidth to carry the payment. -func newBandwidthManager(graph Graph, sourceNode route.Vertex, - linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob], +func newBandwidthManager(ctx context.Context, graph Graph, + sourceNode route.Vertex, linkQuery getLinkQuery, + firstHopBlob fn.Option[tlv.Blob], trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) { manager := &bandwidthManager{ @@ -95,7 +97,7 @@ func newBandwidthManager(graph Graph, sourceNode route.Vertex, // First, we'll collect the set of outbound edges from the target // source node and add them to our bandwidth manager's map of channels. - err := graph.ForEachNodeChannel(sourceNode, + err := graph.ForEachNodeChannel(ctx, sourceNode, func(channel *graphdb.DirectedChannel) error { shortID := lnwire.NewShortChanIDFromInt( channel.ChannelID, diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index 4872b5a7e..083559f79 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "testing" "github.com/btcsuite/btcd/btcutil" @@ -116,7 +117,8 @@ func TestBandwidthManager(t *testing.T) { ) m, err := newBandwidthManager( - g, sourceNode.pubkey, testCase.linkQuery, + context.Background(), g, sourceNode.pubkey, + testCase.linkQuery, fn.None[[]byte](), fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), ) diff --git a/routing/blindedpath/blinded_path.go b/routing/blindedpath/blinded_path.go index e452f4cfc..e3d8378b9 100644 --- a/routing/blindedpath/blinded_path.go +++ b/routing/blindedpath/blinded_path.go @@ -39,7 +39,8 @@ type BuildBlindedPathCfg struct { // various lengths and may even contain only a single hop. Any route // shorter than MinNumHops will be padded with dummy hops during route // construction. - FindRoutes func(value lnwire.MilliSatoshi) ([]*route.Route, error) + FindRoutes func(ctx context.Context, value lnwire.MilliSatoshi) ( + []*route.Route, error) // FetchChannelEdgesByID attempts to look up the two directed edges for // the channel identified by the channel ID. @@ -118,7 +119,7 @@ func BuildBlindedPaymentPaths(ctx context.Context, cfg *BuildBlindedPathCfg) ( // Find some appropriate routes for the value to be routed. This will // return a set of routes made up of real nodes. - routes, err := cfg.FindRoutes(cfg.ValueMsat) + routes, err := cfg.FindRoutes(ctx, cfg.ValueMsat) if err != nil { return nil, err } diff --git a/routing/blindedpath/blinded_path_test.go b/routing/blindedpath/blinded_path_test.go index a21bb5fa3..f2fa7b4a0 100644 --- a/routing/blindedpath/blinded_path_test.go +++ b/routing/blindedpath/blinded_path_test.go @@ -595,8 +595,8 @@ func TestBuildBlindedPath(t *testing.T) { } paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{ - FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route, - error) { + FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) ( + []*route.Route, error) { return []*route.Route{realRoute}, nil }, @@ -765,8 +765,8 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { } paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{ - FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route, - error) { + FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) ( + []*route.Route, error) { return []*route.Route{realRoute}, nil }, @@ -935,8 +935,8 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { // still get 1 valid path. var errCount int paths, err = BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{ - FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route, - error) { + FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) ( + []*route.Route, error) { return []*route.Route{realRoute, realRoute, realRoute}, nil @@ -1016,8 +1016,8 @@ func TestSingleHopBlindedPath(t *testing.T) { } paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{ - FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route, - error) { + FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) ( + []*route.Route, error) { return []*route.Route{realRoute}, nil }, diff --git a/routing/graph.go b/routing/graph.go index 7608ee92b..dc17fea9f 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "github.com/btcsuite/btcd/btcutil" @@ -14,11 +15,12 @@ import ( type Graph interface { // ForEachNodeChannel calls the callback for every channel of the given // node. - ForEachNodeChannel(nodePub route.Vertex, + ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, cb func(channel *graphdb.DirectedChannel) error) error // FetchNodeFeatures returns the features of the given node. - FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) + FetchNodeFeatures(ctx context.Context, nodePub route.Vertex) ( + *lnwire.FeatureVector, error) } // GraphSessionFactory can be used to produce a new Graph instance which can @@ -30,13 +32,14 @@ type GraphSessionFactory interface { // session. It returns the Graph along with a call-back that must be // called once Graph access is complete. This call-back will close any // read-only transaction that was created at Graph construction time. - NewGraphSession() (Graph, func() error, error) + NewGraphSession(ctx context.Context) (Graph, func() error, error) } // FetchAmountPairCapacity determines the maximal public capacity between two // nodes depending on the amount we try to send. -func FetchAmountPairCapacity(graph Graph, source, nodeFrom, nodeTo route.Vertex, - amount lnwire.MilliSatoshi) (btcutil.Amount, error) { +func FetchAmountPairCapacity(ctx context.Context, graph Graph, source, nodeFrom, + nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, + error) { // Create unified edges for all incoming connections. // @@ -44,7 +47,7 @@ func FetchAmountPairCapacity(graph Graph, source, nodeFrom, nodeTo route.Vertex, // by a deprecated router rpc. u := newNodeEdgeUnifier(source, nodeTo, false, nil) - err := u.addGraphPolicies(graph) + err := u.addGraphPolicies(ctx, graph) if err != nil { return 0, err } diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 4a2e679ba..7bb1677b4 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "math" "os" @@ -122,6 +123,8 @@ func (h htlcAttempt) String() string { func (c *integratedRoutingContext) testPayment(maxParts uint32, destFeatureBits ...lnwire.FeatureBit) ([]htlcAttempt, error) { + ctx := context.Background() + // We start out with the base set of MPP feature bits. If the caller // overrides this set of bits, then we'll use their feature bits // entirely. @@ -173,7 +176,9 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, ) require.NoError(c.t, err) - getBandwidthHints := func(_ Graph) (bandwidthHints, error) { + getBandwidthHints := func(_ context.Context, _ Graph) (bandwidthHints, + error) { + // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { @@ -235,8 +240,8 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, // Find a route. route, err := session.RequestRoute( - amtRemaining, lnwire.MaxMilliSatoshi, inFlightHtlcs, 0, - lnwire.CustomRecords{ + ctx, amtRemaining, lnwire.MaxMilliSatoshi, + inFlightHtlcs, 0, lnwire.CustomRecords{ lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3}, }, ) @@ -326,8 +331,8 @@ func newMockGraphSessionFactory(graph Graph) GraphSessionFactory { return &mockGraphSessionFactory{Graph: graph} } -func (m *mockGraphSessionFactory) NewGraphSession() (Graph, func() error, - error) { +func (m *mockGraphSessionFactory) NewGraphSession(_ context.Context) (Graph, + func() error, error) { return m, func() error { return nil @@ -349,8 +354,8 @@ func newMockGraphSessionFactoryFromChanDB( } } -func (g *mockGraphSessionFactoryChanDB) NewGraphSession() (Graph, func() error, - error) { +func (g *mockGraphSessionFactoryChanDB) NewGraphSession(_ context.Context) ( + Graph, func() error, error) { tx, err := g.graph.NewPathFindTx() if err != nil { @@ -391,14 +396,15 @@ func (g *mockGraphSessionChanDB) close() error { return nil } -func (g *mockGraphSessionChanDB) ForEachNodeChannel(nodePub route.Vertex, +func (g *mockGraphSessionChanDB) ForEachNodeChannel(_ context.Context, + nodePub route.Vertex, cb func(channel *graphdb.DirectedChannel) error) error { return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) } -func (g *mockGraphSessionChanDB) FetchNodeFeatures(nodePub route.Vertex) ( - *lnwire.FeatureVector, error) { +func (g *mockGraphSessionChanDB) FetchNodeFeatures(_ context.Context, + nodePub route.Vertex) (*lnwire.FeatureVector, error) { return g.graph.FetchNodeFeatures(g.tx, nodePub) } diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index cab7c9726..74fc48333 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "fmt" "testing" @@ -165,7 +166,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte, // forEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the Graph interface. -func (m *mockGraph) ForEachNodeChannel(nodePub route.Vertex, +func (m *mockGraph) ForEachNodeChannel(_ context.Context, nodePub route.Vertex, cb func(channel *graphdb.DirectedChannel) error) error { // Look up the mock node. @@ -221,7 +222,7 @@ func (m *mockGraph) sourceNode() route.Vertex { // fetchNodeFeatures returns the features of the given node. // // NOTE: Part of the Graph interface. -func (m *mockGraph) FetchNodeFeatures(nodePub route.Vertex) ( +func (m *mockGraph) FetchNodeFeatures(_ context.Context, _ route.Vertex) ( *lnwire.FeatureVector, error) { return lnwire.EmptyFeatureVector(), nil diff --git a/routing/mock_test.go b/routing/mock_test.go index 3cdb5ebaf..b6d24d83a 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "sync" @@ -168,9 +169,9 @@ type mockPaymentSessionOld struct { var _ PaymentSession = (*mockPaymentSessionOld)(nil) -func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, - _, height uint32, _ lnwire.CustomRecords) (*route.Route, - error) { +func (m *mockPaymentSessionOld) RequestRoute(_ context.Context, _, + _ lnwire.MilliSatoshi, _, height uint32, + _ lnwire.CustomRecords) (*route.Route, error) { if m.release != nil { m.release <- struct{}{} @@ -694,12 +695,13 @@ type mockPaymentSession struct { var _ PaymentSession = (*mockPaymentSession)(nil) -func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32, +func (m *mockPaymentSession) RequestRoute(ctx context.Context, maxAmt, + feeLimit lnwire.MilliSatoshi, activeShards, height uint32, firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) { args := m.Called( - maxAmt, feeLimit, activeShards, height, firstHopCustomRecords, + ctx, maxAmt, feeLimit, activeShards, height, + firstHopCustomRecords, ) // Type assertion on nil will fail, so we check and return here. diff --git a/routing/pathfind.go b/routing/pathfind.go index db474e1e8..160752e22 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -3,6 +3,7 @@ package routing import ( "bytes" "container/heap" + "context" "errors" "fmt" "math" @@ -50,7 +51,7 @@ const ( ) // pathFinder defines the interface of a path finding algorithm. -type pathFinder = func(g *graphParams, r *RestrictParams, +type pathFinder = func(ctx context.Context, g *graphParams, r *RestrictParams, cfg *PathFindingConfig, self, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( []*unifiedEdge, float64, error) @@ -491,8 +492,8 @@ type PathFindingConfig struct { // getOutgoingBalance returns the maximum available balance in any of the // channels of the given node. The second return parameters is the total // available balance. -func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, - bandwidthHints bandwidthHints, +func getOutgoingBalance(ctx context.Context, node route.Vertex, + outgoingChans map[uint64]struct{}, bandwidthHints bandwidthHints, g Graph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi @@ -540,7 +541,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, } // Iterate over all channels of the to node. - err := g.ForEachNodeChannel(node, cb) + err := g.ForEachNodeChannel(ctx, node, cb) if err != nil { return 0, 0, err } @@ -558,10 +559,10 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // source. This is to properly accumulate fees that need to be paid along the // path and accurately check the amount to forward at every node against the // available bandwidth. -func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, - self, source, target route.Vertex, amt lnwire.MilliSatoshi, - timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64, - error) { +func findPath(ctx context.Context, g *graphParams, r *RestrictParams, + cfg *PathFindingConfig, self, source, target route.Vertex, + amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( + []*unifiedEdge, float64, error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to @@ -580,7 +581,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, features := r.DestFeatures if features == nil { var err error - features, err = g.graph.FetchNodeFeatures(target) + features, err = g.graph.FetchNodeFeatures(ctx, target) if err != nil { return nil, 0, err } @@ -624,7 +625,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // balance available. if source == self { max, total, err := getOutgoingBalance( - self, outgoingChanMap, g.bandwidthHints, g.graph, + ctx, self, outgoingChanMap, g.bandwidthHints, g.graph, ) if err != nil { return nil, 0, err @@ -968,7 +969,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } // Fetch node features fresh from the graph. - fromFeatures, err := g.graph.FetchNodeFeatures(node) + fromFeatures, err := g.graph.FetchNodeFeatures(ctx, node) if err != nil { return nil, err } @@ -1008,7 +1009,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, self, pivot, !isExitHop, outgoingChanMap, ) - err := u.addGraphPolicies(g.graph) + err := u.addGraphPolicies(ctx, g.graph) if err != nil { return nil, 0, err } @@ -1183,7 +1184,7 @@ type blindedHop struct { // _and_ the introduction node for the path has more than one public channel. // Any filtering of paths based on payment value or success probabilities is // left to the caller. -func findBlindedPaths(g Graph, target route.Vertex, +func findBlindedPaths(ctx context.Context, g Graph, target route.Vertex, restrictions *blindedPathRestrictions) ([][]blindedHop, error) { // Sanity check the restrictions. @@ -1202,7 +1203,7 @@ func findBlindedPaths(g Graph, target route.Vertex, return true, nil } - features, err := g.FetchNodeFeatures(node) + features, err := g.FetchNodeFeatures(ctx, node) if err != nil { return false, err } @@ -1216,7 +1217,7 @@ func findBlindedPaths(g Graph, target route.Vertex, // a node that doesn't have any other edges - in that final case, the // whole path should be ignored. paths, _, err := processNodeForBlindedPath( - g, target, supportsRouteBlinding, nil, restrictions, + ctx, g, target, supportsRouteBlinding, nil, restrictions, ) if err != nil { return nil, err @@ -1251,7 +1252,7 @@ func findBlindedPaths(g Graph, target route.Vertex, // processNodeForBlindedPath is a recursive function that traverses the graph // in a depth first manner searching for a set of blinded paths to the given // node. -func processNodeForBlindedPath(g Graph, node route.Vertex, +func processNodeForBlindedPath(ctx context.Context, g Graph, node route.Vertex, supportsRouteBlinding func(vertex route.Vertex) (bool, error), alreadyVisited map[route.Vertex]bool, restrictions *blindedPathRestrictions) ([][]blindedHop, bool, error) { @@ -1298,7 +1299,7 @@ func processNodeForBlindedPath(g Graph, node route.Vertex, // Now, iterate over the node's channels in search for paths to this // node that can be used for blinded paths - err = g.ForEachNodeChannel(node, + err = g.ForEachNodeChannel(ctx, node, func(channel *graphdb.DirectedChannel) error { // Keep track of how many incoming channels this node // has. We only use a node as an introduction node if it @@ -1308,8 +1309,8 @@ func processNodeForBlindedPath(g Graph, node route.Vertex, // Process each channel peer to gather any paths that // lead to the peer. nextPaths, hasMoreChans, err := processNodeForBlindedPath( //nolint:lll - g, channel.OtherNode, supportsRouteBlinding, - visited, restrictions, + ctx, g, channel.OtherNode, + supportsRouteBlinding, visited, restrictions, ) if err != nil { return err diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 81708d393..92b36e63e 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "crypto/sha256" "encoding/hex" "encoding/json" @@ -2221,7 +2222,7 @@ func TestPathFindSpecExample(t *testing.T) { ) require.NoError(t, err, "invalid route request") - route, _, err := ctx.router.FindRoute(req) + route, _, err := ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find route") // Now we'll examine the route returned for correctness. @@ -2248,7 +2249,7 @@ func TestPathFindSpecExample(t *testing.T) { ) require.NoError(t, err, "invalid route request") - route, _, err = ctx.router.FindRoute(req) + route, _, err = ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find routes") // The route should be two hops. @@ -3112,6 +3113,8 @@ func dbFindPath(graph *graphdb.ChannelGraph, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, error) { + ctx := context.Background() + sourceNode, err := graph.SourceNode() if err != nil { return nil, err @@ -3119,7 +3122,7 @@ func dbFindPath(graph *graphdb.ChannelGraph, graphSessFactory := newMockGraphSessionFactoryFromChanDB(graph) - graphSess, closeGraphSess, err := graphSessFactory.NewGraphSession() + graphSess, closeGraphSess, err := graphSessFactory.NewGraphSession(ctx) if err != nil { return nil, err } @@ -3131,7 +3134,7 @@ func dbFindPath(graph *graphdb.ChannelGraph, }() route, _, err := findPath( - &graphParams{ + ctx, &graphParams{ additionalEdges: additionalEdges, bandwidthHints: bandwidthHints, graph: graphSess, @@ -3154,8 +3157,8 @@ func dbFindBlindedPaths(graph *graphdb.ChannelGraph, } return findBlindedPaths( - newMockGraphSessionChanDB(graph), sourceNode.PubKeyBytes, - restrictions, + context.Background(), newMockGraphSessionChanDB(graph), + sourceNode.PubKeyBytes, restrictions, ) } diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 267ce3965..aa5de1044 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -259,7 +259,7 @@ lifecycle: } // Now request a route to be used to create our HTLC attempt. - rt, err := p.requestRoute(ps) + rt, err := p.requestRoute(ctx, ps) if err != nil { return exitWithErr(err) } @@ -366,14 +366,14 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error { // requestRoute is responsible for finding a route to be used to create an HTLC // attempt. -func (p *paymentLifecycle) requestRoute( +func (p *paymentLifecycle) requestRoute(ctx context.Context, ps *channeldb.MPPaymentState) (*route.Route, error) { remainingFees := p.calcFeeBudget(ps.FeesPaid) // Query our payment session to construct a route. rt, err := p.paySession.RequestRoute( - ps.RemainingAmt, remainingFees, + ctx, ps.RemainingAmt, remainingFees, uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), p.firstHopCustomRecords, ) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 315c1bad5..4b7721344 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -380,10 +380,10 @@ func TestRequestRouteSucceed(t *testing.T) { // Mock the paySession's `RequestRoute` method to return no error. paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, + mock.Anything, mock.Anything, ).Return(dummyRoute, nil) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(context.Background(), ps) require.NoError(t, err, "expect no error") require.Equal(t, dummyRoute, result, "returned route not matched") @@ -417,10 +417,10 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) { // Mock the paySession's `RequestRoute` method to return an error. paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, + mock.Anything, mock.Anything, ).Return(nil, errDummy) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(context.Background(), ps) // Expect an error is returned since it's critical. require.ErrorIs(t, err, errDummy, "error not matched") @@ -452,7 +452,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { // type. m.paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, + mock.Anything, mock.Anything, ).Return(nil, errNoTlvPayload) // The payment should be failed with reason no route. @@ -460,7 +460,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { p.identifier, channeldb.FailureReasonNoRoute, ).Return(nil).Once() - result, err := p.requestRoute(ps) + result, err := p.requestRoute(context.Background(), ps) // Expect no error is returned since it's not critical. require.NoError(t, err, "expected no error") @@ -500,10 +500,10 @@ func TestRequestRouteFailPaymentError(t *testing.T) { // Mock the paySession's `RequestRoute` method to return an error. paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, + mock.Anything, mock.Anything, ).Return(nil, errNoTlvPayload) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(context.Background(), ps) // Expect an error is returned. require.ErrorIs(t, err, errDummy, "error not matched") @@ -876,7 +876,8 @@ func TestResumePaymentFailOnRequestRouteErr(t *testing.T) { // 4. mock requestRoute to return an error. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(nil, errDummy).Once() @@ -922,7 +923,8 @@ func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) { // 4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() @@ -982,7 +984,8 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { // 4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() @@ -1074,7 +1077,8 @@ func TestResumePaymentSuccess(t *testing.T) { // 1.4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() @@ -1175,7 +1179,8 @@ func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { // 1.4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() @@ -1237,7 +1242,8 @@ func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { // 2.4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt/2, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt/2, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() diff --git a/routing/payment_session.go b/routing/payment_session.go index 0afdf822f..9c962a88d 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "github.com/btcsuite/btcd/btcec/v2" @@ -137,7 +138,7 @@ type PaymentSession interface { // // A noRouteError is returned if a non-critical error is encountered // during path finding. - RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, + RequestRoute(ctx context.Context, maxAmt, feeLimit lnwire.MilliSatoshi, activeShards, height uint32, firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) @@ -169,7 +170,7 @@ type paymentSession struct { additionalEdges map[route.Vertex][]AdditionalEdge - getBandwidthHints func(Graph) (bandwidthHints, error) + getBandwidthHints func(context.Context, Graph) (bandwidthHints, error) payment *LightningPayment @@ -197,7 +198,7 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, selfNode route.Vertex, - getBandwidthHints func(Graph) (bandwidthHints, error), + getBandwidthHints func(context.Context, Graph) (bandwidthHints, error), graphSessFactory GraphSessionFactory, missionControl MissionControlQuerier, pathFindingConfig PathFindingConfig) (*paymentSession, error) { @@ -244,8 +245,8 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex, // // NOTE: This function is safe for concurrent access. // NOTE: Part of the PaymentSession interface. -func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32, +func (p *paymentSession) RequestRoute(ctx context.Context, maxAmt, + feeLimit lnwire.MilliSatoshi, activeShards, height uint32, firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) { if p.empty { @@ -297,7 +298,9 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, for { // Get a routing graph session. - graph, closeGraph, err := p.graphSessFactory.NewGraphSession() + graph, closeGraph, err := p.graphSessFactory.NewGraphSession( + ctx, + ) if err != nil { return nil, err } @@ -308,7 +311,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // don't have enough bandwidth to carry the payment. New // bandwidth hints are queried for every new path finding // attempt, because concurrent payments may change balances. - bandwidthHints, err := p.getBandwidthHints(graph) + bandwidthHints, err := p.getBandwidthHints(ctx, graph) if err != nil { // Close routing graph session. if graphErr := closeGraph(); graphErr != nil { @@ -323,7 +326,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // Find a route for the current amount. path, _, err := p.pathFinder( - &graphParams{ + ctx, &graphParams{ additionalEdges: p.additionalEdges, bandwidthHints: bandwidthHints, graph: graph, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index d5f1a6af4..69f4d0bea 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -1,6 +1,8 @@ package routing import ( + "context" + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph/db/models" @@ -54,9 +56,11 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { - getBandwidthHints := func(graph Graph) (bandwidthHints, error) { + getBandwidthHints := func(ctx context.Context, + graph Graph) (bandwidthHints, error) { + return newBandwidthManager( - graph, m.SourceNode.PubKeyBytes, m.GetLink, + ctx, graph, m.SourceNode.PubKeyBytes, m.GetLink, firstHopBlob, trafficShaper, ) } diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 278e09044..51cfadb0d 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "testing" "time" @@ -115,7 +116,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create the paymentsession. session, err := newPaymentSession( payment, route.Vertex{}, - func(Graph) (bandwidthHints, error) { + func(context.Context, Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, newMockGraphSessionFactory(&sessionGraph{}), @@ -193,7 +194,7 @@ func TestRequestRoute(t *testing.T) { session, err := newPaymentSession( payment, route.Vertex{}, - func(Graph) (bandwidthHints, error) { + func(context.Context, Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, newMockGraphSessionFactory(&sessionGraph{}), @@ -205,7 +206,8 @@ func TestRequestRoute(t *testing.T) { } // Override pathfinder with a mock. - session.pathFinder = func(_ *graphParams, r *RestrictParams, + session.pathFinder = func(_ context.Context, _ *graphParams, + r *RestrictParams, _ *PathFindingConfig, _, _, _ route.Vertex, _ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge, float64, error) { @@ -233,8 +235,8 @@ func TestRequestRoute(t *testing.T) { } route, err := session.RequestRoute( - payment.Amount, payment.FeeLimit, 0, height, - lnwire.CustomRecords{ + context.Background(), payment.Amount, payment.FeeLimit, 0, + height, lnwire.CustomRecords{ lnwire.MinCustomRecordsTlvType + 123: []byte{1, 2, 3}, }, ) diff --git a/routing/router.go b/routing/router.go index b92aa1502..8eedfaf16 100644 --- a/routing/router.go +++ b/routing/router.go @@ -515,8 +515,8 @@ func getTargetNode(target *route.Vertex, // FindRoute attempts to query the ChannelRouter for the optimum path to a // particular target destination to which it is able to send `amt` after // factoring in channel capacities and cumulative fees along the route. -func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, - error) { +func (r *ChannelRouter) FindRoute(ctx context.Context, req *RouteRequest) ( + *route.Route, float64, error) { log.Debugf("Searching for path to %v, sending %v", req.Target, req.Amount) @@ -524,7 +524,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := newBandwidthManager( - r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, + ctx, r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, fn.None[tlv.Blob](), r.cfg.TrafficShaper, ) if err != nil { @@ -549,7 +549,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, } path, probability, err := findPath( - &graphParams{ + ctx, &graphParams{ additionalEdges: req.RouteHints, bandwidthHints: bandwidthHints, graph: r.cfg.RoutingGraph, @@ -616,14 +616,15 @@ type BlindedPathRestrictions struct { // FindBlindedPaths finds a selection of paths to the destination node that can // be used in blinded payment paths. -func (r *ChannelRouter) FindBlindedPaths(destination route.Vertex, - amt lnwire.MilliSatoshi, probabilitySrc probabilitySource, +func (r *ChannelRouter) FindBlindedPaths(ctx context.Context, + destination route.Vertex, amt lnwire.MilliSatoshi, + probabilitySrc probabilitySource, restrictions *BlindedPathRestrictions) ([]*route.Route, error) { // First, find a set of candidate paths given the destination node and // path length restrictions. paths, err := findBlindedPaths( - r.cfg.RoutingGraph, destination, &blindedPathRestrictions{ + ctx, r.cfg.RoutingGraph, destination, &blindedPathRestrictions{ minNumHops: restrictions.MinDistanceFromIntroNode, maxNumHops: restrictions.NumHops, nodeOmissionSet: restrictions.NodeOmissionSet, @@ -1366,7 +1367,8 @@ func (e ErrNoChannel) Error() string { // BuildRoute returns a fully specified route based on a list of pubkeys. If // amount is nil, the minimum routable amount is used. To force a specific // outgoing channel, use the outgoingChan parameter. -func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], +func (r *ChannelRouter) BuildRoute(ctx context.Context, + amt fn.Option[lnwire.MilliSatoshi], hops []route.Vertex, outgoingChan *uint64, finalCltvDelta int32, payAddr fn.Option[[32]byte], firstHopBlob fn.Option[[]byte]) ( *route.Route, error) { @@ -1383,8 +1385,8 @@ func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := newBandwidthManager( - r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, firstHopBlob, - r.cfg.TrafficShaper, + ctx, r.cfg.RoutingGraph, r.cfg.SelfNode, + r.cfg.GetLink, firstHopBlob, r.cfg.TrafficShaper, ) if err != nil { return nil, err @@ -1395,7 +1397,7 @@ func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], // We check that each node in the route has a connection to others that // can forward in principle. unifiers, err := getEdgeUnifiers( - r.cfg.SelfNode, hops, outgoingChans, r.cfg.RoutingGraph, + ctx, r.cfg.SelfNode, hops, outgoingChans, r.cfg.RoutingGraph, ) if err != nil { return nil, err @@ -1652,8 +1654,8 @@ func (r *ChannelRouter) failStaleAttempt(a channeldb.HTLCAttempt, } // getEdgeUnifiers returns a list of edge unifiers for the given route. -func getEdgeUnifiers(source route.Vertex, hops []route.Vertex, - outgoingChans map[uint64]struct{}, +func getEdgeUnifiers(ctx context.Context, source route.Vertex, + hops []route.Vertex, outgoingChans map[uint64]struct{}, graph Graph) ([]*edgeUnifier, error) { // Allocate a list that will contain the edge unifiers for this route. @@ -1678,7 +1680,7 @@ func getEdgeUnifiers(source route.Vertex, hops []route.Vertex, source, toNode, !isExitHop, outgoingChans, ) - err := u.addGraphPolicies(graph) + err := u.addGraphPolicies(ctx, graph) if err != nil { return nil, err } diff --git a/routing/router_test.go b/routing/router_test.go index db72bf266..4de54af1b 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "fmt" "image/color" "math" @@ -271,7 +272,7 @@ func TestFindRoutesWithFeeLimit(t *testing.T) { ) require.NoError(t, err, "invalid route request") - route, _, err := ctx.router.FindRoute(req) + route, _, err := ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find any routes") require.Falsef(t, @@ -1530,6 +1531,8 @@ func TestSendToRouteMaxHops(t *testing.T) { // TestBuildRoute tests whether correct routes are built. func TestBuildRoute(t *testing.T) { + ctx := context.Background() + // Setup a three node network. chanCapSat := btcutil.Amount(100000) paymentAddrFeatures := lnwire.NewFeatureVector( @@ -1638,7 +1641,9 @@ func TestBuildRoute(t *testing.T) { const startingBlockHeight = 101 - ctx := createTestCtxFromGraphInstance(t, startingBlockHeight, testGraph) + tctx := createTestCtxFromGraphInstance( + t, startingBlockHeight, testGraph, + ) checkHops := func(rt *route.Route, expected []uint64, payAddr [32]byte) { @@ -1664,27 +1669,28 @@ func TestBuildRoute(t *testing.T) { // Test that we can't build a route when no hops are given. hops = []route.Vertex{} - _, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.None[[32]byte](), fn.None[[]byte](), + _, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.None[[32]byte](), + fn.None[[]byte](), ) require.Error(t, err) // Create hop list for an unknown destination. - hops := []route.Vertex{ctx.aliases["b"], ctx.aliases["y"]} - _, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + hops := []route.Vertex{tctx.aliases["b"], tctx.aliases["y"]} + _, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) noChanErr := ErrNoChannel{} require.ErrorAs(t, err, &noChanErr) require.Equal(t, 1, noChanErr.position) // Create hop list from the route node pubkeys. - hops = []route.Vertex{ctx.aliases["b"], ctx.aliases["c"]} + hops = []route.Vertex{tctx.aliases["b"], tctx.aliases["c"]} amt := lnwire.NewMSatFromSatoshis(100) // Build the route for the given amount. - rt, err := ctx.router.BuildRoute( - fn.Some(amt), hops, nil, 40, fn.Some(payAddr), + rt, err := tctx.router.BuildRoute( + ctx, fn.Some(amt), hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) @@ -1696,8 +1702,8 @@ func TestBuildRoute(t *testing.T) { require.Equal(t, lnwire.MilliSatoshi(106000), rt.TotalAmount) // Build the route for the minimum amount. - rt, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + rt, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) @@ -1713,9 +1719,10 @@ func TestBuildRoute(t *testing.T) { // Test a route that contains incompatible channel htlc constraints. // There is no amount that can pass through both channel 5 and 4. - hops = []route.Vertex{ctx.aliases["e"], ctx.aliases["c"]} - _, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.None[[32]byte](), fn.None[[]byte](), + hops = []route.Vertex{tctx.aliases["e"], tctx.aliases["c"]} + _, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.None[[32]byte](), + fn.None[[]byte](), ) require.Error(t, err) noChanErr = ErrNoChannel{} @@ -1733,9 +1740,9 @@ func TestBuildRoute(t *testing.T) { // could me more applicable, which is why we don't get back the highest // amount that could be delivered to the receiver of 21819 msat, using // policy of channel 3. - hops = []route.Vertex{ctx.aliases["b"], ctx.aliases["z"]} - rt, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + hops = []route.Vertex{tctx.aliases["b"], tctx.aliases["z"]} + rt, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) checkHops(rt, []uint64{1, 8}, payAddr) @@ -1746,10 +1753,10 @@ func TestBuildRoute(t *testing.T) { // inbound fees. We expect a similar amount as for the above case of // b->c, but reduced by the inbound discount on the channel a->d. // We get 106000 - 1000 (base in) - 0.001 * 106000 (rate in) = 104894. - hops = []route.Vertex{ctx.aliases["d"], ctx.aliases["f"]} + hops = []route.Vertex{tctx.aliases["d"], tctx.aliases["f"]} amt = lnwire.NewMSatFromSatoshis(100) - rt, err = ctx.router.BuildRoute( - fn.Some(amt), hops, nil, 40, fn.Some(payAddr), + rt, err = tctx.router.BuildRoute( + ctx, fn.Some(amt), hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) @@ -1764,9 +1771,9 @@ func TestBuildRoute(t *testing.T) { // due to rounding. This would not be compatible with the sender amount // of 20179 msat, which results in underpayment of 1 msat in fee. There // is a third pass through newRoute in which this gets corrected to end - hops = []route.Vertex{ctx.aliases["d"], ctx.aliases["f"]} - rt, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + hops = []route.Vertex{tctx.aliases["d"], tctx.aliases["f"]} + rt, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) checkHops(rt, []uint64{9, 10}, payAddr) @@ -2894,7 +2901,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { paymentAmt, 0, noRestrictions, nil, nil, nil, MinCLTVDelta, ) require.NoError(t, err, "invalid route request") - _, _, err = ctx.router.FindRoute(req) + _, _, err = ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find any routes") // Now check that we can update the node info for the partial node @@ -2933,7 +2940,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { ) require.NoError(t, err, "invalid route request") - _, _, err = ctx.router.FindRoute(req) + _, _, err = ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find any routes") copy1, err := ctx.graph.FetchLightningNode(pub1) @@ -3072,6 +3079,8 @@ func createChannelEdge(bitcoinKey1, bitcoinKey2 []byte, func TestFindBlindedPathsWithMC(t *testing.T) { t.Parallel() + ctx := context.Background() + rbFeatureBits := []lnwire.FeatureBit{ lnwire.RouteBlindingOptional, } @@ -3128,15 +3137,15 @@ func TestFindBlindedPathsWithMC(t *testing.T) { ) require.NoError(t, err) - ctx := createTestCtxFromGraphInstance(t, 101, testGraph) + tctx := createTestCtxFromGraphInstance(t, 101, testGraph) var ( - alice = ctx.aliases["alice"] - bob = ctx.aliases["bob"] - charlie = ctx.aliases["charlie"] - dave = ctx.aliases["dave"] - eve = ctx.aliases["eve"] - frank = ctx.aliases["frank"] + alice = tctx.aliases["alice"] + bob = tctx.aliases["bob"] + charlie = tctx.aliases["charlie"] + dave = tctx.aliases["dave"] + eve = tctx.aliases["eve"] + frank = tctx.aliases["frank"] ) // Create a mission control store which initially sets the success @@ -3163,8 +3172,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // All the probabilities are set to 1. So if we restrict the path length // to 2 and allow a max of 3 routes, then we expect three paths here. - routes, err := ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err := tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3181,12 +3190,12 @@ func TestFindBlindedPathsWithMC(t *testing.T) { var actualPaths []string for _, path := range paths { label := getAliasFromPubKey( - path.SourcePubKey, ctx.aliases, + path.SourcePubKey, tctx.aliases, ) + "," for _, hop := range path.Hops { label += getAliasFromPubKey( - hop.PubKeyBytes, ctx.aliases, + hop.PubKeyBytes, tctx.aliases, ) + "," } @@ -3208,8 +3217,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // 3) A -> F -> D missionControl[bob][dave] = 0.5 missionControl[frank][dave] = 0.25 - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3225,8 +3234,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Just to show that the above result was not a fluke, let's change // the C->D link to be the weak one. missionControl[charlie][dave] = 0.125 - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3241,8 +3250,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Change the MaxNumPaths to 1 to assert that only the best route is // returned. - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 1, @@ -3255,8 +3264,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Test the edge case where Dave, the recipient, is also the // introduction node. - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 0, NumHops: 0, MaxNumPaths: 1, @@ -3270,8 +3279,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Finally, we make one of the routes have a probability less than the // minimum. This means we expect that route not to be chosen. missionControl[charlie][dave] = DefaultMinRouteProbability - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3285,8 +3294,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Test that if the user explicitly indicates that we should ignore // the Frank node during path selection, then this is done. - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, diff --git a/routing/unified_edges.go b/routing/unified_edges.go index c2e008e47..5ad777072 100644 --- a/routing/unified_edges.go +++ b/routing/unified_edges.go @@ -1,6 +1,7 @@ package routing import ( + "context" "math" "github.com/btcsuite/btcd/btcutil" @@ -94,7 +95,7 @@ func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. -func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error { +func (u *nodeEdgeUnifier) addGraphPolicies(ctx context.Context, g Graph) error { cb := func(channel *graphdb.DirectedChannel) error { // If there is no edge policy for this candidate node, skip. // Note that we are searching backwards so this node would have @@ -120,7 +121,7 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error { } // Iterate over all channels of the to node. - return g.ForEachNodeChannel(u.toNode, cb) + return g.ForEachNodeChannel(ctx, u.toNode, cb) } // unifiedEdge is the individual channel data that is kept inside an edgeUnifier diff --git a/rpcserver.go b/rpcserver.go index ab7d78d17..d448e23c8 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -707,11 +707,12 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, } return info.Capacity, nil }, - FetchAmountPairCapacity: func(nodeFrom, nodeTo route.Vertex, + FetchAmountPairCapacity: func(ctx context.Context, nodeFrom, + nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { return routing.FetchAmountPairCapacity( - graphsession.NewRoutingGraph(graphSource), + ctx, graphsession.NewRoutingGraph(graphSource), selfNode.PubKeyBytes, nodeFrom, nodeTo, amount, ) }, @@ -6125,11 +6126,11 @@ func (r *rpcServer) AddInvoice(ctx context.Context, }, GetAlias: r.server.aliasMgr.GetPeerAlias, BestHeight: r.server.cc.BestBlockTracker.BestHeight, - QueryBlindedRoutes: func(amt lnwire.MilliSatoshi) ( - []*route.Route, error) { + QueryBlindedRoutes: func(ctx context.Context, + amt lnwire.MilliSatoshi) ([]*route.Route, error) { return r.server.chanRouter.FindBlindedPaths( - r.selfNode, amt, + ctx, r.selfNode, amt, r.server.defaultMC.GetProbability, blindingRestrictions, )