diff --git a/routing/notifications_test.go b/routing/notifications_test.go index 0cb475367..f49d68be2 100644 --- a/routing/notifications_test.go +++ b/routing/notifications_test.go @@ -107,72 +107,6 @@ func randChannelEdge(ctx *testCtx, chanValue btcutil.Amount, return fundingTx, chanUtxo, chanID } -type testCtx struct { - router *ChannelRouter - - graph *channeldb.ChannelGraph - - chain *mockChain - - notifier *mockNotifier -} - -func createTestCtx(startingHeight uint32) (*testCtx, func(), error) { - // First we'll set up a test graph for usage within the test. - graph, cleanup, err := makeTestGraph() - if err != nil { - return nil, nil, fmt.Errorf("unable to create test graph: %v", err) - } - - sourceNode, err := createGraphNode() - if err != nil { - return nil, nil, fmt.Errorf("unable to create source node: %v", err) - } - if err := graph.SetSourceNode(sourceNode); err != nil { - return nil, nil, fmt.Errorf("unable to set source node: %v", err) - } - - // Next we'll initialize an instance of the channel router with mock - // versions of the chain and channel notifier. As we don't need to test - // any p2p functionality, the peer send and switch send messages won't - // be populated. - chain := newMockChain(startingHeight) - notifier := newMockNotifier() - router, err := New(Config{ - Graph: graph, - Chain: chain, - Notifier: notifier, - Broadcast: func(_ *btcec.PublicKey, msg ...lnwire.Message) error { - return nil - }, - SendMessages: func(_ *btcec.PublicKey, msg ...lnwire.Message) error { - return nil - }, - SendToSwitch: func(_ *btcec.PublicKey, - _ *lnwire.UpdateAddHTLC) ([32]byte, error) { - return [32]byte{}, nil - }, - }) - if err != nil { - return nil, nil, fmt.Errorf("unable to create router %v", err) - } - if err := router.Start(); err != nil { - return nil, nil, fmt.Errorf("unable to start router: %v", err) - } - - cleanUp := func() { - router.Stop() - cleanup() - } - - return &testCtx{ - router: router, - graph: graph, - chain: chain, - notifier: notifier, - }, cleanUp, nil -} - type mockChain struct { blocks map[chainhash.Hash]*wire.MsgBlock blockIndex map[uint32]chainhash.Hash diff --git a/routing/router.go b/routing/router.go index ac6fec891..d819040b0 100644 --- a/routing/router.go +++ b/routing/router.go @@ -3,7 +3,9 @@ package routing import ( "bytes" "encoding/hex" + "errors" "fmt" + "sort" "sync" "sync/atomic" "time" @@ -1050,13 +1052,16 @@ func (r *ChannelRouter) ProcessRoutingMessage(msg lnwire.Message, src *btcec.Pub } } -// FindRoute attempts to query the ChannelRouter for the "best" path to a +// FindRoutes attempts to query the ChannelRouter for the all available paths to a // particular target destination which is able to send `amt` after factoring in -// channel capacities and cumulative fees along the route. Once we have a set -// of candidate routes, we calculate the required fee and time lock values -// running backwards along the route. The route that will be ranked the highest -// is the one with the lowest cumulative fee along the route. -func (r *ChannelRouter) FindRoute(target *btcec.PublicKey, amt btcutil.Amount) (*Route, error) { +// channel capacities and cumulative fees along each route route. To find all +// elgible paths, we use a modified version of Yen's algorithm which itself +// uses a modidifed version of Dijkstra's algorithm within its inner loop. +// Once we have a set of candidate routes, we calculate the required fee and +// time lock values running backwards along the route. The route that will be +// ranked the highest is the one with the lowest cumulative fee along the +// route. +func (r *ChannelRouter) FindRoutes(target *btcec.PublicKey, amt btcutil.Amount) ([]*Route, error) { dest := target.SerializeCompressed() log.Debugf("Searching for path to %x, sending %v", dest, amt) @@ -1070,32 +1075,52 @@ func (r *ChannelRouter) FindRoute(target *btcec.PublicKey, amt btcutil.Amount) ( return nil, ErrTargetNotInNetwork } - // First we'll find a single shortest path from the source (our - // selfNode) to the target destination that's capable of carrying amt - // satoshis along the path before fees are calculated. - // - // TODO(roasbeef): add k-shortest paths - routeHops, err := findRoute(r.cfg.Graph, r.selfNode, target, amt) - if err != nil { - log.Errorf("Unable to find path: %v", err) - return nil, err - } - - // If we were able to find a path we construct a new route which - // calculate the relevant total fees and proper time lock values for - // each hop. - route, err := newRoute(amt, routeHops) + // Now that we know the destination is reachable within the graph, + // we'll execute our KSP algorithm to find the k-shortest paths from + // our source to the destination. + shortestPaths, err := findPaths(r.cfg.Graph, r.selfNode, target, amt) if err != nil { return nil, err } - log.Debugf("Obtained path sending %v to %x: %v", amt, dest, - newLogClosure(func() string { - return spew.Sdump(route) + // Now that we have a set of paths, we'll need to turn them into + // *routes* by computing the required time-lock and fee information for + // each path. During this process, some paths may be discarded if they + // aren't able to support the total satoshis flow once fees have been + // factored in. + validRoutes := make([]*Route, 0, len(shortestPaths)) + for _, path := range shortestPaths { + // Attempt to make the path into a route. We snip off the first + // hop inthe path as it contains a "self-hop" that is inserted + // by our KSP algorithm. + route, err := newRoute(amt, path[1:]) + if err != nil { + continue + } + + // If the path as enough total flow to support the computed + // route, then we'll add it to our set of valid routes. + validRoutes = append(validRoutes, route) + } + + // Finally, we'll sort the set of validate routes to optimize for + // loweest total fees, using the reuired time-lcok within the route as + // a tie-breaker. + sort.Slice(validRoutes, func(i, j int) bool { + if validRoutes[i].TotalFees == validRoutes[j].TotalFees { + return validRoutes[i].TotalTimeLock < validRoutes[j].TotalTimeLock + } + + return validRoutes[i].TotalFees < validRoutes[j].TotalFees + }) + + log.Debugf("Obtained %v paths sending %v to %x: %v", len(validRoutes), + amt, dest, newLogClosure(func() string { + return spew.Sdump(validRoutes) }), ) - return route, nil + return validRoutes, nil } // generateSphinxPacket generates then encodes a sphinx packet which encodes diff --git a/routing/router_test.go b/routing/router_test.go new file mode 100644 index 000000000..5ee592fa5 --- /dev/null +++ b/routing/router_test.go @@ -0,0 +1,148 @@ +package routing + +import ( + "bytes" + "errors" + "fmt" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/roasbeef/btcd/btcec" + "github.com/roasbeef/btcutil" +) + +type testCtx struct { + router *ChannelRouter + + graph *channeldb.ChannelGraph + + aliases map[string]*btcec.PublicKey + + chain *mockChain + + notifier *mockNotifier +} + +func createTestCtx(startingHeight uint32, testGraph ...string) (*testCtx, func(), error) { + var ( + graph *channeldb.ChannelGraph + sourceNode *channeldb.LightningNode + cleanup func() + err error + ) + + aliasMap := make(map[string]*btcec.PublicKey) + + // If the testGraph isn't set, then we'll create an empty graph to + // start out with. Our usage of a variadic parameter allows caller to + // omit the testGraph argument all together if they wish to start with + // a blank graph. + if testGraph == nil { + // First we'll set up a test graph for usage within the test. + graph, cleanup, err = makeTestGraph() + if err != nil { + return nil, nil, fmt.Errorf("unable to create test graph: %v", err) + } + + sourceNode, err = createGraphNode() + if err != nil { + return nil, nil, fmt.Errorf("unable to create source node: %v", err) + } + if err = graph.SetSourceNode(sourceNode); err != nil { + return nil, nil, fmt.Errorf("unable to set source node: %v", err) + } + } else { + // Otherwise, we'll attempt to locate and parse out the file + // that encodes the graph that our tests should be run against. + graph, cleanup, aliasMap, err = parseTestGraph(testGraph[0]) + if err != nil { + return nil, nil, fmt.Errorf("unable to create test graph: %v", err) + } + + sourceNode, err = graph.SourceNode() + if err != nil { + return nil, nil, fmt.Errorf("unable to fetch source node: %v", err) + } + } + + // Next we'll initialize an instance of the channel router with mock + // versions of the chain and channel notifier. As we don't need to test + // any p2p functionality, the peer send and switch send messages won't + // be populated. + chain := newMockChain(startingHeight) + notifier := newMockNotifier() + router, err := New(Config{ + Graph: graph, + Chain: chain, + Notifier: notifier, + Broadcast: func(_ *btcec.PublicKey, msg ...lnwire.Message) error { + return nil + }, + SendMessages: func(_ *btcec.PublicKey, msg ...lnwire.Message) error { + return nil + }, + SendToSwitch: func(_ *btcec.PublicKey, + _ *lnwire.UpdateAddHTLC) ([32]byte, error) { + return [32]byte{}, nil + }, + }) + if err != nil { + return nil, nil, fmt.Errorf("unable to create router %v", err) + } + if err := router.Start(); err != nil { + return nil, nil, fmt.Errorf("unable to start router: %v", err) + } + + cleanUp := func() { + router.Stop() + cleanup() + } + + return &testCtx{ + router: router, + graph: graph, + aliases: aliasMap, + chain: chain, + notifier: notifier, + }, cleanUp, nil +} + +// TestFindRoutesFeeSorting asserts that routes found by the FindRoutes method +// within the channel router are properly returned in a sorted order, with the +// lowest fee route coming first. +func TestFindRoutesFeeSorting(t *testing.T) { + const startingBlockHeight = 101 + ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) + defer cleanUp() + if err != nil { + t.Fatalf("unable to create router: %v", err) + } + + // In this test we'd like to ensure proper integration of the various + // functions that are involved in path finding, and also route + // selection. + + // Execute a query for all possible routes between roasbeef and luo ji. + const paymentAmt = btcutil.Amount(100) + target := ctx.aliases["luoji"] + routes, err := ctx.router.FindRoutes(target, paymentAmt) + if err != nil { + t.Fatalf("unable to find any routes: %v", err) + } + + // Exactly, two such paths should be found. + if len(routes) != 2 { + t.Fatalf("2 routes shouldn't been selected, instead %v were: ", + len(routes)) + } + + // The paths should properly be ranked according to their total fee + // rate. + if routes[0].TotalFees > routes[1].TotalFees { + t.Fatalf("routes not ranked by total fee: %v", + spew.Sdump(routes)) + } +} +