diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index d5e35ee4f..a47ffa813 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -622,19 +622,8 @@ func TestFindLowestFeePath(t *testing.T) { }), } - testGraphInstance, err := createTestGraphFromChannels( - testChannels, "roasbeef", - ) - if err != nil { - t.Fatalf("unable to create graph: %v", err) - } - defer testGraphInstance.cleanUp() - - sourceNode, err := testGraphInstance.graph.SourceNode() - if err != nil { - t.Fatalf("unable to fetch source node: %v", err) - } - sourceVertex := route.Vertex(sourceNode.PubKeyBytes) + ctx := newPathFindingTestContext(t, testChannels, "roasbeef") + defer ctx.cleanup() const ( startingHeight = 100 @@ -642,20 +631,13 @@ func TestFindLowestFeePath(t *testing.T) { ) paymentAmt := lnwire.NewMSatFromSatoshis(100) - target := testGraphInstance.aliasMap["target"] - path, err := findPath( - &graphParams{ - graph: testGraphInstance.graph, - }, - noRestrictions, - testPathFindingConfig, - sourceNode.PubKeyBytes, target, paymentAmt, - ) + target := ctx.keyFromAlias("target") + path, err := ctx.findPath(target, paymentAmt) if err != nil { t.Fatalf("unable to find path: %v", err) } route, err := newRoute( - paymentAmt, sourceVertex, path, startingHeight, + paymentAmt, ctx.source, path, startingHeight, finalHopCLTV, nil, ) if err != nil { @@ -663,11 +645,10 @@ func TestFindLowestFeePath(t *testing.T) { } // Assert that the lowest fee route is returned. - if route.Hops[1].PubKeyBytes != testGraphInstance.aliasMap["b"] { + if route.Hops[1].PubKeyBytes != ctx.keyFromAlias("b") { t.Fatalf("expected route to pass through b, "+ "but got a route through %v", - getAliasFromPubKey(route.Hops[1].PubKeyBytes, - testGraphInstance.aliasMap)) + ctx.aliasFromKey(route.Hops[1].PubKeyBytes)) } } @@ -1394,53 +1375,34 @@ func TestRouteFailMaxHTLC(t *testing.T) { }), } - graph, err := createTestGraphFromChannels(testChannels, "roasbeef") - if err != nil { - t.Fatalf("unable to create graph: %v", err) - } - defer graph.cleanUp() - - sourceNode, err := graph.graph.SourceNode() - if err != nil { - t.Fatalf("unable to fetch source node: %v", err) - } + ctx := newPathFindingTestContext(t, testChannels, "roasbeef") + defer ctx.cleanup() // First, attempt to send a payment greater than the max HTLC we are // about to set, which should succeed. - target := graph.aliasMap["target"] + target := ctx.keyFromAlias("target") payAmt := lnwire.MilliSatoshi(100001) - _, err = findPath( - &graphParams{ - graph: graph.graph, - }, - noRestrictions, testPathFindingConfig, - sourceNode.PubKeyBytes, target, payAmt, - ) + _, err := ctx.findPath(target, payAmt) if err != nil { t.Fatalf("graph should've been able to support payment: %v", err) } // Next, update the middle edge policy to only allow payments up to 100k // msat. - _, midEdge, _, err := graph.graph.FetchChannelEdgesByID(firstToSecondID) + graph := ctx.testGraphInstance.graph + _, midEdge, _, err := graph.FetchChannelEdgesByID(firstToSecondID) if err != nil { t.Fatalf("unable to fetch channel edges by ID: %v", err) } midEdge.MessageFlags = 1 midEdge.MaxHTLC = payAmt - 1 - if err := graph.graph.UpdateEdgePolicy(midEdge); err != nil { + if err := graph.UpdateEdgePolicy(midEdge); err != nil { t.Fatalf("unable to update edge: %v", err) } // We'll now attempt to route through that edge with a payment above // 100k msat, which should fail. - _, err = findPath( - &graphParams{ - graph: graph.graph, - }, - noRestrictions, testPathFindingConfig, - sourceNode.PubKeyBytes, target, payAmt, - ) + _, err = ctx.findPath(target, payAmt) if !IsError(err, ErrNoPathFound) { t.Fatalf("graph shouldn't be able to support payment: %v", err) } @@ -1885,19 +1847,8 @@ func TestRestrictOutgoingChannel(t *testing.T) { }), } - testGraphInstance, err := createTestGraphFromChannels( - testChannels, "roasbeef", - ) - if err != nil { - t.Fatalf("unable to create graph: %v", err) - } - defer testGraphInstance.cleanUp() - - sourceNode, err := testGraphInstance.graph.SourceNode() - if err != nil { - t.Fatalf("unable to fetch source node: %v", err) - } - sourceVertex := route.Vertex(sourceNode.PubKeyBytes) + ctx := newPathFindingTestContext(t, testChannels, "roasbeef") + defer ctx.cleanup() const ( startingHeight = 100 @@ -1905,29 +1856,18 @@ func TestRestrictOutgoingChannel(t *testing.T) { ) paymentAmt := lnwire.NewMSatFromSatoshis(100) - target := testGraphInstance.aliasMap["target"] + target := ctx.keyFromAlias("target") outgoingChannelID := uint64(2) // Find the best path given the restriction to only use channel 2 as the // outgoing channel. - path, err := findPath( - &graphParams{ - graph: testGraphInstance.graph, - }, - &RestrictParams{ - FeeLimit: noFeeLimit, - OutgoingChannelID: &outgoingChannelID, - ProbabilitySource: noProbabilitySource, - CltvLimit: math.MaxUint32, - }, - testPathFindingConfig, - sourceVertex, target, paymentAmt, - ) + ctx.restrictParams.OutgoingChannelID = &outgoingChannelID + path, err := ctx.findPath(target, paymentAmt) if err != nil { t.Fatalf("unable to find path: %v", err) } route, err := newRoute( - paymentAmt, sourceVertex, path, startingHeight, + paymentAmt, ctx.source, path, startingHeight, finalHopCLTV, nil, ) if err != nil { @@ -1984,35 +1924,14 @@ func testCltvLimit(t *testing.T, limit uint32, expectedChannel uint64) { }), } - testGraphInstance, err := createTestGraphFromChannels( - testChannels, "roasbeef", - ) - if err != nil { - t.Fatalf("unable to create graph: %v", err) - } - defer testGraphInstance.cleanUp() - - sourceNode, err := testGraphInstance.graph.SourceNode() - if err != nil { - t.Fatalf("unable to fetch source node: %v", err) - } - sourceVertex := route.Vertex(sourceNode.PubKeyBytes) + ctx := newPathFindingTestContext(t, testChannels, "roasbeef") + defer ctx.cleanup() paymentAmt := lnwire.NewMSatFromSatoshis(100) - target := testGraphInstance.aliasMap["target"] + target := ctx.keyFromAlias("target") - path, err := findPath( - &graphParams{ - graph: testGraphInstance.graph, - }, - &RestrictParams{ - FeeLimit: noFeeLimit, - CltvLimit: limit, - ProbabilitySource: noProbabilitySource, - }, - testPathFindingConfig, - sourceVertex, target, paymentAmt, - ) + ctx.restrictParams.CltvLimit = limit + path, err := ctx.findPath(target, paymentAmt) if expectedChannel == 0 { // Finish test if we expect no route. if IsError(err, ErrNoPathFound) { @@ -2029,7 +1948,7 @@ func testCltvLimit(t *testing.T, limit uint32, expectedChannel uint64) { finalHopCLTV = 1 ) route, err := newRoute( - paymentAmt, sourceVertex, path, startingHeight, finalHopCLTV, + paymentAmt, ctx.source, path, startingHeight, finalHopCLTV, nil, ) if err != nil { @@ -2146,27 +2065,16 @@ func testProbabilityRouting(t *testing.T, p10, p11, p20, minProbability float64, }, 20), } - testGraphInstance, err := createTestGraphFromChannels( - testChannels, "roasbeef", - ) - if err != nil { - t.Fatalf("unable to create graph: %v", err) - } - defer testGraphInstance.cleanUp() + ctx := newPathFindingTestContext(t, testChannels, "roasbeef") + defer ctx.cleanup() - alias := testGraphInstance.aliasMap - - sourceNode, err := testGraphInstance.graph.SourceNode() - if err != nil { - t.Fatalf("unable to fetch source node: %v", err) - } - sourceVertex := route.Vertex(sourceNode.PubKeyBytes) + alias := ctx.testGraphInstance.aliasMap paymentAmt := lnwire.NewMSatFromSatoshis(100) - target := testGraphInstance.aliasMap["target"] + target := ctx.testGraphInstance.aliasMap["target"] // Configure a probability source with the test parameters. - probabilitySource := func(fromNode, toNode route.Vertex, + ctx.restrictParams.ProbabilitySource = func(fromNode, toNode route.Vertex, amt lnwire.MilliSatoshi) float64 { if amt == 0 { @@ -2185,21 +2093,12 @@ func testProbabilityRouting(t *testing.T, p10, p11, p20, minProbability float64, } } - path, err := findPath( - &graphParams{ - graph: testGraphInstance.graph, - }, - &RestrictParams{ - FeeLimit: noFeeLimit, - ProbabilitySource: probabilitySource, - CltvLimit: math.MaxUint32, - }, - &PathFindingConfig{ - PaymentAttemptPenalty: lnwire.NewMSatFromSatoshis(10), - MinProbability: minProbability, - }, - sourceVertex, target, paymentAmt, - ) + ctx.pathFindingConfig = PathFindingConfig{ + PaymentAttemptPenalty: lnwire.NewMSatFromSatoshis(10), + MinProbability: minProbability, + } + + path, err := ctx.findPath(target, paymentAmt) if expectedChan == 0 { if err == nil || !IsError(err, ErrNoPathFound) { t.Fatalf("expected no path found, but got %v", err) @@ -2217,3 +2116,71 @@ func testProbabilityRouting(t *testing.T, p10, p11, p20, minProbability float64, path[1].ChannelID) } } + +type pathFindingTestContext struct { + t *testing.T + graphParams graphParams + restrictParams RestrictParams + pathFindingConfig PathFindingConfig + testGraphInstance *testGraphInstance + source route.Vertex +} + +func newPathFindingTestContext(t *testing.T, testChannels []*testChannel, + source string) *pathFindingTestContext { + + testGraphInstance, err := createTestGraphFromChannels( + testChannels, source, + ) + if err != nil { + t.Fatalf("unable to create graph: %v", err) + } + + sourceNode, err := testGraphInstance.graph.SourceNode() + if err != nil { + t.Fatalf("unable to fetch source node: %v", err) + } + + ctx := &pathFindingTestContext{ + t: t, + testGraphInstance: testGraphInstance, + source: route.Vertex(sourceNode.PubKeyBytes), + } + + ctx.pathFindingConfig = *testPathFindingConfig + + ctx.graphParams.graph = testGraphInstance.graph + + ctx.restrictParams.FeeLimit = noFeeLimit + ctx.restrictParams.ProbabilitySource = noProbabilitySource + ctx.restrictParams.CltvLimit = math.MaxUint32 + + return ctx +} + +func (c *pathFindingTestContext) keyFromAlias(alias string) route.Vertex { + return c.testGraphInstance.aliasMap[alias] +} + +func (c *pathFindingTestContext) aliasFromKey(pubKey route.Vertex) string { + for alias, key := range c.testGraphInstance.aliasMap { + if key == pubKey { + return alias + } + } + return "" +} + +func (c *pathFindingTestContext) cleanup() { + c.testGraphInstance.cleanUp() +} + +func (c *pathFindingTestContext) findPath(target route.Vertex, + amt lnwire.MilliSatoshi) ([]*channeldb.ChannelEdgePolicy, + error) { + + return findPath( + &c.graphParams, &c.restrictParams, &c.pathFindingConfig, + c.source, target, amt, + ) +}