diff --git a/routing/blinding.go b/routing/blinding.go index 70e1a22cb..a343c15da 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -235,21 +235,33 @@ func (s *BlindedPaymentPathSet) FinalCLTVDelta() uint16 { // LargestLastHopPayloadPath returns the BlindedPayment in the set that has the // largest last-hop payload. This is to be used for onion size estimation in // path finding. -func (s *BlindedPaymentPathSet) LargestLastHopPayloadPath() *BlindedPayment { +func (s *BlindedPaymentPathSet) LargestLastHopPayloadPath() (*BlindedPayment, + error) { + var ( largestPath *BlindedPayment currentMax int ) + + if len(s.paths) == 0 { + return nil, fmt.Errorf("no blinded paths in the set") + } + + // We set the largest path to make sure we always return a path even + // if the cipher text is empty. + largestPath = s.paths[0] + for _, path := range s.paths { numHops := len(path.BlindedPath.BlindedHops) lastHop := path.BlindedPath.BlindedHops[numHops-1] if len(lastHop.CipherText) > currentMax { largestPath = path + currentMax = len(lastHop.CipherText) } } - return largestPath + return largestPath, nil } // ToRouteHints converts the blinded path payment set into a RouteHints map so diff --git a/routing/pathfind.go b/routing/pathfind.go index 5b5bbe9f2..3f1d0ba09 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -700,7 +700,10 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // The payload size of the final hop differ from intermediate hops // and depends on whether the destination is blinded or not. - lastHopPayloadSize := lastHopPayloadSize(r, finalHtlcExpiry, amt) + lastHopPayloadSize, err := lastHopPayloadSize(r, finalHtlcExpiry, amt) + if err != nil { + return nil, 0, err + } // We can't always assume that the end destination is publicly // advertised to the network so we'll manually include the target node. @@ -1433,11 +1436,15 @@ func getProbabilityBasedDist(weight int64, probability float64, // It depends on the tlv types which are present and also whether the hop is // part of a blinded route or not. func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, - amount lnwire.MilliSatoshi) uint64 { + amount lnwire.MilliSatoshi) (uint64, error) { if r.BlindedPaymentPathSet != nil { - paymentPath := r.BlindedPaymentPathSet. + paymentPath, err := r.BlindedPaymentPathSet. LargestLastHopPayloadPath() + if err != nil { + return 0, err + } + blindedPath := paymentPath.BlindedPath.BlindedHops blindedPoint := paymentPath.BlindedPath.BlindingPoint @@ -1452,7 +1459,7 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, } // The final hop does not have a short chanID set. - return finalHop.PayloadSize(0) + return finalHop.PayloadSize(0), nil } var mpp *record.MPP @@ -1478,7 +1485,7 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, } // The final hop does not have a short chanID set. - return finalHop.PayloadSize(0) + return finalHop.PayloadSize(0), nil } // overflowSafeAdd adds two MilliSatoshi values and returns the result. If an diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 2adeca134..4579abc6a 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3413,32 +3413,48 @@ func TestLastHopPayloadSize(t *testing.T) { customRecords = map[uint64][]byte{ record.CustomTypeStart: {1, 2, 3}, } - sizeEncryptedData = 100 - encrypedData = bytes.Repeat( - []byte{1}, sizeEncryptedData, + + encrypedDataSmall = bytes.Repeat( + []byte{1}, 5, ) - _, blindedPoint = btcec.PrivKeyFromBytes([]byte{5}) - paymentAddr = &[32]byte{1} - ampOptions = &Options{} - amtToForward = lnwire.MilliSatoshi(10000) - finalHopExpiry int32 = 144 + encrypedDataLarge = bytes.Repeat( + []byte{1}, 100, + ) + _, blindedPoint = btcec.PrivKeyFromBytes([]byte{5}) + paymentAddr = &[32]byte{1} + ampOptions = &Options{} + amtToForward = lnwire.MilliSatoshi(10000) + emptyEncryptedData = []byte{} + finalHopExpiry int32 = 144 oneHopPath = &sphinx.BlindedPath{ BlindedHops: []*sphinx.BlindedHopInfo{ { - CipherText: encrypedData, + CipherText: emptyEncryptedData, }, }, BlindingPoint: blindedPoint, } - twoHopPath = &sphinx.BlindedPath{ + twoHopPathSmallHopSize = &sphinx.BlindedPath{ BlindedHops: []*sphinx.BlindedHopInfo{ { - CipherText: encrypedData, + CipherText: encrypedDataLarge, }, { - CipherText: encrypedData, + CipherText: encrypedDataLarge, + }, + }, + BlindingPoint: blindedPoint, + } + + twoHopPathLargeHopSize = &sphinx.BlindedPath{ + BlindedHops: []*sphinx.BlindedHopInfo{ + { + CipherText: encrypedDataSmall, + }, + { + CipherText: encrypedDataSmall, }, }, BlindingPoint: blindedPoint, @@ -3451,15 +3467,19 @@ func TestLastHopPayloadSize(t *testing.T) { require.NoError(t, err) twoHopBlindedPayment, err := NewBlindedPaymentPathSet( - []*BlindedPayment{{BlindedPath: twoHopPath}}, + []*BlindedPayment{ + {BlindedPath: twoHopPathLargeHopSize}, + {BlindedPath: twoHopPathSmallHopSize}, + }, ) require.NoError(t, err) testCases := []struct { - name string - restrictions *RestrictParams - finalHopExpiry int32 - amount lnwire.MilliSatoshi + name string + restrictions *RestrictParams + finalHopExpiry int32 + amount lnwire.MilliSatoshi + expectedEncryptedData []byte }{ { name: "Non blinded final hop", @@ -3477,16 +3497,18 @@ func TestLastHopPayloadSize(t *testing.T) { restrictions: &RestrictParams{ BlindedPaymentPathSet: oneHopBlindedPayment, }, - amount: amtToForward, - finalHopExpiry: finalHopExpiry, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + expectedEncryptedData: emptyEncryptedData, }, { name: "Blinded final hop of a two hop payment", restrictions: &RestrictParams{ BlindedPaymentPathSet: twoHopBlindedPayment, }, - amount: amtToForward, - finalHopExpiry: finalHopExpiry, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + expectedEncryptedData: encrypedDataLarge, }, } @@ -3510,16 +3532,23 @@ func TestLastHopPayloadSize(t *testing.T) { var finalHop route.Hop if tc.restrictions.BlindedPaymentPathSet != nil { - path := tc.restrictions.BlindedPaymentPathSet. - LargestLastHopPayloadPath() + bPSet := tc.restrictions.BlindedPaymentPathSet + path, err := bPSet.LargestLastHopPayloadPath() + require.NotNil(t, path) + + require.NoError(t, err) + blindedPath := path.BlindedPath.BlindedHops blindedPoint := path.BlindedPath.BlindingPoint + lastHop := blindedPath[len(blindedPath)-1] + require.Equal(t, lastHop.CipherText, + tc.expectedEncryptedData) //nolint:ll finalHop = route.Hop{ AmtToForward: tc.amount, OutgoingTimeLock: uint32(tc.finalHopExpiry), - EncryptedData: blindedPath[len(blindedPath)-1].CipherText, + EncryptedData: lastHop.CipherText, } if len(blindedPath) == 1 { finalHop.BlindingPoint = blindedPoint @@ -3539,11 +3568,11 @@ func TestLastHopPayloadSize(t *testing.T) { payLoad, err := createHopPayload(finalHop, 0, true) require.NoErrorf(t, err, "failed to create hop payload") - expectedPayloadSize := lastHopPayloadSize( + expectedPayloadSize, err := lastHopPayloadSize( tc.restrictions, tc.finalHopExpiry, tc.amount, ) - + require.NoError(t, err) require.Equal( t, expectedPayloadSize, uint64(payLoad.NumBytes()),