diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 8b267960d..8594e1384 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -55,14 +55,9 @@ type RouterBackend struct { FetchChannelEndpoints func(chanID uint64) (route.Vertex, route.Vertex, error) - // FindRoutes is a closure that abstracts away how we locate/query for + // FindRoute is a closure that abstracts away how we locate/query for // routes. - FindRoute func(source, target route.Vertex, - amt lnwire.MilliSatoshi, timePref float64, - restrictions *routing.RestrictParams, - destCustomRecords record.CustomSet, - routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, - finalExpiry uint16) (*route.Route, float64, error) + FindRoute func(*routing.RouteRequest) (*route.Route, float64, error) MissionControl MissionControl @@ -330,14 +325,19 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, // Query the channel router for a possible path to the destination that // can carry `in.Amt` satoshis _including_ the total fee required on // the route. - route, successProb, err := r.FindRoute( - sourcePubKey, targetPubKey, amt, in.TimePref, restrictions, + routeReq, err := routing.NewRouteRequest( + sourcePubKey, &targetPubKey, amt, in.TimePref, restrictions, customRecords, routeHintEdges, finalCLTVDelta, ) if err != nil { return nil, err } + route, successProb, err := r.FindRoute(routeReq) + if err != nil { + return nil, err + } + // For each valid route, we'll convert the result into the format // required by the RPC system. rpcRoute, err := r.MarshallRoute(route) diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index 4a5d73135..877a3cc17 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -7,10 +7,8 @@ import ( "testing" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" @@ -122,24 +120,23 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, } } - findRoute := func(source, target route.Vertex, - amt lnwire.MilliSatoshi, _ float64, - restrictions *routing.RestrictParams, _ record.CustomSet, - routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, - finalExpiry uint16) (*route.Route, float64, error) { + findRoute := func(req *routing.RouteRequest) (*route.Route, float64, + error) { - if int64(amt) != amtSat*1000 { + if int64(req.Amount) != amtSat*1000 { t.Fatal("unexpected amount") } - if source != sourceKey { + if req.Source != sourceKey { t.Fatal("unexpected source key") } + target := req.Target if !bytes.Equal(target[:], destNodeBytes) { t.Fatal("unexpected target key") } + restrictions := req.Restrictions if restrictions.FeeLimit != 250*1000 { t.Fatal("unexpected fee limit") } @@ -172,6 +169,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, t.Fatal("unexpected dest features") } + routeHints := req.RouteHints if _, ok := routeHints[hintNode]; !ok { t.Fatal("expected route hint") } @@ -187,7 +185,9 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, } hops := []*route.Hop{{}} - route, err := route.NewRouteFromHops(amt, 144, source, hops) + route, err := route.NewRouteFromHops( + req.Amount, 144, req.Source, hops, + ) return route, expectedProb, err } diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 3dff8f50b..fa0913dab 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -398,8 +398,8 @@ func (s *Server) EstimateRouteFee(ctx context.Context, // restriction for the default CLTV limit, otherwise we can find a route // that exceeds it and is useless to us. mc := s.cfg.RouterBackend.MissionControl - route, _, err := s.cfg.Router.FindRoute( - s.cfg.RouterBackend.SelfNode, destNode, amtMsat, 0, + routeReq, err := routing.NewRouteRequest( + s.cfg.RouterBackend.SelfNode, &destNode, amtMsat, 0, &routing.RestrictParams{ FeeLimit: feeLimit, CltvLimit: s.cfg.RouterBackend.MaxTotalTimelock, @@ -410,6 +410,11 @@ func (s *Server) EstimateRouteFee(ctx context.Context, return nil, err } + route, _, err := s.cfg.Router.FindRoute(routeReq) + if err != nil { + return nil, err + } + return &RouteFeeResponse{ RoutingFeeMsat: int64(route.TotalFees()), TimeLockDelay: int64(route.TotalTimeLock), diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 4923fd7ab..5b66e5c73 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2291,10 +2291,13 @@ func TestPathFindSpecExample(t *testing.T) { // Query for a route of 4,999,999 mSAT to carol. carol := ctx.aliases["C"] const amt lnwire.MilliSatoshi = 4999999 - route, _, err := ctx.router.FindRoute( - bobNode.PubKeyBytes, carol, amt, 0, noRestrictions, nil, nil, + req, err := NewRouteRequest( + bobNode.PubKeyBytes, &carol, amt, 0, noRestrictions, nil, nil, MinCLTVDelta, ) + require.NoError(t, err, "invalid route request") + + route, _, err := ctx.router.FindRoute(req) require.NoError(t, err, "unable to find route") // Now we'll examine the route returned for correctness. @@ -2341,10 +2344,13 @@ func TestPathFindSpecExample(t *testing.T) { } // We'll now request a route from A -> B -> C. - route, _, err = ctx.router.FindRoute( - source.PubKeyBytes, carol, amt, 0, noRestrictions, nil, nil, + req, err = NewRouteRequest( + source.PubKeyBytes, &carol, amt, 0, noRestrictions, nil, nil, MinCLTVDelta, ) + require.NoError(t, err, "invalid route request") + + route, _, err = ctx.router.FindRoute(req) require.NoError(t, err, "unable to find routes") // The route should be two hops. diff --git a/routing/router.go b/routing/router.go index 77dd9ceb0..f0858bf5a 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1809,16 +1809,72 @@ type routingMsg struct { err chan error } +// RouteRequest contains the parameters for a pathfinding request. +type RouteRequest struct { + // Source is the node that the path originates from. + Source route.Vertex + + // Target is the node that the path terminates at. + Target route.Vertex + + // Amount is the Amount in millisatoshis to be delivered to the target + // node. + Amount lnwire.MilliSatoshi + + // TimePreference expresses the caller's time preference for + // pathfinding. + TimePreference float64 + + // Restrictions provides a set of additional Restrictions that the + // route must adhere to. + Restrictions *RestrictParams + + // CustomRecords is a set of custom tlv records to include for the + // final hop. + CustomRecords record.CustomSet + + // RouteHints contains an additional set of edges to include in our + // view of the graph. + RouteHints RouteHints + + // FinalExpiry is the cltv delta for the final hop. + FinalExpiry uint16 +} + +// RouteHints is an alias type for a set of route hints, with the source node +// as the map's key and the details of the hint(s) in the edge policy. +type RouteHints map[route.Vertex][]*channeldb.CachedEdgePolicy + +// NewRouteRequest produces a new route request. +func NewRouteRequest(source route.Vertex, target *route.Vertex, + amount lnwire.MilliSatoshi, timePref float64, + restrictions *RestrictParams, customRecords record.CustomSet, + routeHints RouteHints, finalExpiry uint16) (*RouteRequest, error) { + + if target == nil { + return nil, errors.New("target node required") + } + + return &RouteRequest{ + Source: source, + Target: *target, + Amount: amount, + TimePreference: timePref, + Restrictions: restrictions, + CustomRecords: customRecords, + RouteHints: routeHints, + FinalExpiry: finalExpiry, + }, nil +} + // FindRoute attempts to query the ChannelRouter for the optimum path to a // particular target destination to which it is able to send `amt` after // factoring in channel capacities and cumulative fees along the route. -func (r *ChannelRouter) FindRoute(source, target route.Vertex, - amt lnwire.MilliSatoshi, timePref float64, restrictions *RestrictParams, - destCustomRecords record.CustomSet, - routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, - finalExpiry uint16) (*route.Route, float64, error) { +func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, + error) { - log.Debugf("Searching for path to %v, sending %v", target, amt) + log.Debugf("Searching for path to %v, sending %v", req.Target, + req.Amount) // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. @@ -1838,22 +1894,22 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // Now that we know the destination is reachable within the graph, we'll // execute our path finding algorithm. - finalHtlcExpiry := currentHeight + int32(finalExpiry) + finalHtlcExpiry := currentHeight + int32(req.FinalExpiry) // Validate time preference. + timePref := req.TimePreference if timePref < -1 || timePref > 1 { return nil, 0, errors.New("time preference out of range") } path, probability, err := findPath( &graphParams{ - additionalEdges: routeHints, + additionalEdges: req.RouteHints, bandwidthHints: bandwidthHints, graph: r.cachedGraph, }, - restrictions, - &r.cfg.PathFindingConfig, - source, target, amt, timePref, finalHtlcExpiry, + req.Restrictions, &r.cfg.PathFindingConfig, req.Source, + req.Target, req.Amount, req.TimePreference, finalHtlcExpiry, ) if err != nil { return nil, 0, err @@ -1861,12 +1917,12 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // Create the route with absolute time lock values. route, err := newRoute( - source, path, uint32(currentHeight), + req.Source, path, uint32(currentHeight), finalHopParams{ - amt: amt, - totalAmt: amt, - cltvDelta: finalExpiry, - records: destCustomRecords, + amt: req.Amount, + totalAmt: req.Amount, + cltvDelta: req.FinalExpiry, + records: req.CustomRecords, }, ) if err != nil { @@ -1874,7 +1930,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, } go log.Tracef("Obtained path to send %v to %x: %v", - amt, target, newLogClosure(func() string { + req.Amount, req.Target, newLogClosure(func() string { return spew.Sdump(route) }), ) diff --git a/routing/router_test.go b/routing/router_test.go index ce5925ef8..b6bcef350 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -271,11 +271,13 @@ func TestFindRoutesWithFeeLimit(t *testing.T) { CltvLimit: math.MaxUint32, } - route, _, err := ctx.router.FindRoute( - ctx.router.selfNode.PubKeyBytes, - target, paymentAmt, 0, restrictions, nil, nil, - MinCLTVDelta, + req, err := NewRouteRequest( + ctx.router.selfNode.PubKeyBytes, &target, paymentAmt, 0, + restrictions, nil, nil, MinCLTVDelta, ) + require.NoError(t, err, "invalid route request") + + route, _, err := ctx.router.FindRoute(req) require.NoError(t, err, "unable to find any routes") require.Falsef(t, @@ -1558,11 +1560,13 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { targetNode := priv2.PubKey() var targetPubKeyBytes route.Vertex copy(targetPubKeyBytes[:], targetNode.SerializeCompressed()) - _, _, err = ctx.router.FindRoute( - ctx.router.selfNode.PubKeyBytes, - targetPubKeyBytes, paymentAmt, 0, noRestrictions, nil, nil, - MinCLTVDelta, + + req, err := NewRouteRequest( + ctx.router.selfNode.PubKeyBytes, &targetPubKeyBytes, + paymentAmt, 0, noRestrictions, nil, nil, MinCLTVDelta, ) + require.NoError(t, err, "invalid route request") + _, _, err = ctx.router.FindRoute(req) require.NoError(t, err, "unable to find any routes") // Now check that we can update the node info for the partial node @@ -1599,11 +1603,13 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { // Should still be able to find the route, and the info should be // updated. - _, _, err = ctx.router.FindRoute( - ctx.router.selfNode.PubKeyBytes, - targetPubKeyBytes, paymentAmt, 0, noRestrictions, nil, nil, - MinCLTVDelta, + req, err = NewRouteRequest( + ctx.router.selfNode.PubKeyBytes, &targetPubKeyBytes, + paymentAmt, 0, noRestrictions, nil, nil, MinCLTVDelta, ) + require.NoError(t, err, "invalid route request") + + _, _, err = ctx.router.FindRoute(req) require.NoError(t, err, "unable to find any routes") copy1, err := ctx.graph.FetchLightningNode(pub1)