diff --git a/routing/router.go b/routing/router.go index 13ed21468..a38980bba 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2782,7 +2782,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, } pathEdges, receiverAmt, err := getPathEdges( - senderAmt, unifiers, bandwidthHints, hops, + sourceNode, senderAmt, unifiers, bandwidthHints, hops, ) if err != nil { return nil, err @@ -2879,8 +2879,8 @@ func getRouteUnifiers(source route.Vertex, hops []route.Vertex, // getPathEdges returns the edges that make up the path and the total amount, // including fees, to send the payment. -func getPathEdges(receiverAmt lnwire.MilliSatoshi, unifiers []*edgeUnifier, - bandwidthHints *bandwidthManager, +func getPathEdges(source route.Vertex, receiverAmt lnwire.MilliSatoshi, + unifiers []*edgeUnifier, bandwidthHints *bandwidthManager, hops []route.Vertex) ([]*channeldb.CachedEdgePolicy, lnwire.MilliSatoshi, error) { @@ -2892,8 +2892,13 @@ func getPathEdges(receiverAmt lnwire.MilliSatoshi, unifiers []*edgeUnifier, for i, unifier := range unifiers { edge := unifier.getEdge(receiverAmt, bandwidthHints) if edge == nil { + fromNode := source + if i > 0 { + fromNode = hops[i-1] + } + return nil, 0, ErrNoChannel{ - fromNode: hops[i-1], + fromNode: fromNode, position: i, } } diff --git a/routing/router_test.go b/routing/router_test.go index 767cf7347..27a43d474 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3187,6 +3187,55 @@ func TestBuildRoute(t *testing.T) { } } +// TestGetPathEdges tests that the getPathEdges function returns the expected +// edges and amount when given a set of unifiers and does not panic. +func TestGetPathEdges(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) + + testCases := []struct { + sourceNode route.Vertex + amt lnwire.MilliSatoshi + unifiers []*edgeUnifier + bandwidthHints *bandwidthManager + hops []route.Vertex + + expectedEdges []*channeldb.CachedEdgePolicy + expectedAmt lnwire.MilliSatoshi + expectedErr string + }{{ + sourceNode: ctx.aliases["roasbeef"], + unifiers: []*edgeUnifier{ + { + edges: []*unifiedEdge{}, + localChan: true, + }, + }, + expectedErr: fmt.Sprintf("no matching outgoing channel "+ + "available for node 0 (%v)", ctx.aliases["roasbeef"]), + }} + + for _, tc := range testCases { + pathEdges, amt, err := getPathEdges( + tc.sourceNode, tc.amt, tc.unifiers, tc.bandwidthHints, + tc.hops, + ) + + if tc.expectedErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.expectedErr) + + continue + } + + require.NoError(t, err) + require.Equal(t, pathEdges, tc.expectedEdges) + require.Equal(t, amt, tc.expectedAmt) + } +} + // edgeCreationModifier is an enum-like type used to modify steps that are // skipped when creating a channel in the test context. type edgeCreationModifier uint8