diff --git a/routing/unified_edges_test.go b/routing/unified_edges_test.go index f1c86abe9..6157c9dac 100644 --- a/routing/unified_edges_test.go +++ b/routing/unified_edges_test.go @@ -3,22 +3,23 @@ package routing import ( "testing" + "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" ) // TestNodeEdgeUnifier tests the composition of unified edges for nodes that // have multiple channels between them. func TestNodeEdgeUnifier(t *testing.T) { + t.Parallel() + source := route.Vertex{1} toNode := route.Vertex{2} fromNode := route.Vertex{3} - bandwidthHints := &mockBandwidthHints{} - u := newNodeEdgeUnifier(source, toNode, nil) - // Add two channels between the pair of nodes. p1 := channeldb.CachedEdgePolicy{ FeeProportionalMillionths: 100000, @@ -36,57 +37,82 @@ func TestNodeEdgeUnifier(t *testing.T) { MaxHTLC: 400, MinHTLC: 100, } - u.addPolicy(fromNode, &p1, 7) - u.addPolicy(fromNode, &p2, 7) + c1 := btcutil.Amount(7) + c2 := btcutil.Amount(8) - checkPolicy := func(edge *unifiedEdge, feeBase lnwire.MilliSatoshi, - feeRate lnwire.MilliSatoshi, timeLockDelta uint16) { + unifierFilled := newNodeEdgeUnifier(source, toNode, nil) + unifierFilled.addPolicy(fromNode, &p1, c1) + unifierFilled.addPolicy(fromNode, &p2, c2) - t.Helper() - - policy := edge.policy - - if policy.FeeBaseMSat != feeBase { - t.Fatalf("expected fee base %v, got %v", - feeBase, policy.FeeBaseMSat) - } - - if policy.TimeLockDelta != timeLockDelta { - t.Fatalf("expected fee base %v, got %v", - timeLockDelta, policy.TimeLockDelta) - } - - if policy.FeeProportionalMillionths != feeRate { - t.Fatalf("expected fee rate %v, got %v", - feeRate, policy.FeeProportionalMillionths) - } + tests := []struct { + name string + unifier *nodeEdgeUnifier + amount lnwire.MilliSatoshi + expectedFeeBase lnwire.MilliSatoshi + expectedFeeRate lnwire.MilliSatoshi + expectedTimeLock uint16 + expectNoPolicy bool + }{ + { + name: "amount below min htlc", + unifier: unifierFilled, + amount: 50, + expectNoPolicy: true, + }, + { + name: "amount above max htlc", + unifier: unifierFilled, + amount: 550, + expectNoPolicy: true, + }, + // For 200 msat, p1 yields the highest fee. Use that policy to + // forward, because it will also match p2 in case p1 does not + // have enough balance. + { + name: "use p1 with highest fee", + unifier: unifierFilled, + amount: 200, + expectedFeeBase: p1.FeeBaseMSat, + expectedFeeRate: p1.FeeProportionalMillionths, + expectedTimeLock: p1.TimeLockDelta, + }, + // For 400 sat, p2 yields the highest fee. Use that policy to + // forward, because it will also match p1 in case p2 does not + // have enough balance. In order to match p1, it needs to have + // p1's time lock delta. + { + name: "use p2 with highest fee", + unifier: unifierFilled, + amount: 400, + expectedFeeBase: p2.FeeBaseMSat, + expectedFeeRate: p2.FeeProportionalMillionths, + expectedTimeLock: p1.TimeLockDelta, + }, } - edge := u.edgeUnifiers[fromNode].getEdge(50, bandwidthHints) - if edge != nil { - t.Fatal("expected no policy for amt below min htlc") + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + edge := test.unifier.edgeUnifiers[fromNode].getEdge( + test.amount, bandwidthHints, + ) + + if test.expectNoPolicy { + require.Nil(t, edge, "expected no policy") + + return + } + + policy := edge.policy + require.Equal(t, test.expectedFeeBase, + policy.FeeBaseMSat, "base fee") + require.Equal(t, test.expectedFeeRate, + policy.FeeProportionalMillionths, "fee rate") + require.Equal(t, test.expectedTimeLock, + policy.TimeLockDelta, "timelock") + }) } - - edge = u.edgeUnifiers[fromNode].getEdge(550, bandwidthHints) - if edge != nil { - t.Fatal("expected no policy for amt above max htlc") - } - - // For 200 sat, p1 yields the highest fee. Use that policy to forward, - // because it will also match p2 in case p1 does not have enough - // balance. - edge = u.edgeUnifiers[fromNode].getEdge(200, bandwidthHints) - checkPolicy( - edge, p1.FeeBaseMSat, p1.FeeProportionalMillionths, - p1.TimeLockDelta, - ) - - // For 400 sat, p2 yields the highest fee. Use that policy to forward, - // because it will also match p1 in case p2 does not have enough - // balance. In order to match p1, it needs to have p1's time lock delta. - edge = u.edgeUnifiers[fromNode].getEdge(400, bandwidthHints) - checkPolicy( - edge, p2.FeeBaseMSat, p2.FeeProportionalMillionths, - p1.TimeLockDelta, - ) }