routing: improve lasthoppaylaod size calculation

Fixes a bug and makes the function more robust. Before
we would always return the encrypted data size of last hop
of the last path. Now we return the greatest last hop payload
not always the one of the last path.
This commit is contained in:
ziggie
2024-12-09 22:15:21 +01:00
parent e47024b790
commit 3cec72ae9c
3 changed files with 81 additions and 33 deletions

View File

@@ -3413,32 +3413,48 @@ func TestLastHopPayloadSize(t *testing.T) {
customRecords = map[uint64][]byte{
record.CustomTypeStart: {1, 2, 3},
}
sizeEncryptedData = 100
encrypedData = bytes.Repeat(
[]byte{1}, sizeEncryptedData,
encrypedDataSmall = bytes.Repeat(
[]byte{1}, 5,
)
_, blindedPoint = btcec.PrivKeyFromBytes([]byte{5})
paymentAddr = &[32]byte{1}
ampOptions = &AMPOptions{}
amtToForward = lnwire.MilliSatoshi(10000)
finalHopExpiry int32 = 144
encrypedDataLarge = bytes.Repeat(
[]byte{1}, 100,
)
_, blindedPoint = btcec.PrivKeyFromBytes([]byte{5})
paymentAddr = &[32]byte{1}
ampOptions = &AMPOptions{}
amtToForward = lnwire.MilliSatoshi(10000)
emptyEncryptedData = []byte{}
finalHopExpiry int32 = 144
oneHopPath = &sphinx.BlindedPath{
BlindedHops: []*sphinx.BlindedHopInfo{
{
CipherText: encrypedData,
CipherText: emptyEncryptedData,
},
},
BlindingPoint: blindedPoint,
}
twoHopPath = &sphinx.BlindedPath{
twoHopPathSmallHopSize = &sphinx.BlindedPath{
BlindedHops: []*sphinx.BlindedHopInfo{
{
CipherText: encrypedData,
CipherText: encrypedDataLarge,
},
{
CipherText: encrypedData,
CipherText: encrypedDataLarge,
},
},
BlindingPoint: blindedPoint,
}
twoHopPathLargeHopSize = &sphinx.BlindedPath{
BlindedHops: []*sphinx.BlindedHopInfo{
{
CipherText: encrypedDataSmall,
},
{
CipherText: encrypedDataSmall,
},
},
BlindingPoint: blindedPoint,
@@ -3451,15 +3467,19 @@ func TestLastHopPayloadSize(t *testing.T) {
require.NoError(t, err)
twoHopBlindedPayment, err := NewBlindedPaymentPathSet(
[]*BlindedPayment{{BlindedPath: twoHopPath}},
[]*BlindedPayment{
{BlindedPath: twoHopPathLargeHopSize},
{BlindedPath: twoHopPathSmallHopSize},
},
)
require.NoError(t, err)
testCases := []struct {
name string
restrictions *RestrictParams
finalHopExpiry int32
amount lnwire.MilliSatoshi
name string
restrictions *RestrictParams
finalHopExpiry int32
amount lnwire.MilliSatoshi
expectedEncryptedData []byte
}{
{
name: "Non blinded final hop",
@@ -3477,16 +3497,18 @@ func TestLastHopPayloadSize(t *testing.T) {
restrictions: &RestrictParams{
BlindedPaymentPathSet: oneHopBlindedPayment,
},
amount: amtToForward,
finalHopExpiry: finalHopExpiry,
amount: amtToForward,
finalHopExpiry: finalHopExpiry,
expectedEncryptedData: emptyEncryptedData,
},
{
name: "Blinded final hop of a two hop payment",
restrictions: &RestrictParams{
BlindedPaymentPathSet: twoHopBlindedPayment,
},
amount: amtToForward,
finalHopExpiry: finalHopExpiry,
amount: amtToForward,
finalHopExpiry: finalHopExpiry,
expectedEncryptedData: encrypedDataLarge,
},
}
@@ -3510,16 +3532,23 @@ func TestLastHopPayloadSize(t *testing.T) {
var finalHop route.Hop
if tc.restrictions.BlindedPaymentPathSet != nil {
path := tc.restrictions.BlindedPaymentPathSet.
LargestLastHopPayloadPath()
bPSet := tc.restrictions.BlindedPaymentPathSet
path, err := bPSet.LargestLastHopPayloadPath()
require.NotNil(t, path)
require.NoError(t, err)
blindedPath := path.BlindedPath.BlindedHops
blindedPoint := path.BlindedPath.BlindingPoint
lastHop := blindedPath[len(blindedPath)-1]
require.Equal(t, lastHop.CipherText,
tc.expectedEncryptedData)
//nolint:ll
finalHop = route.Hop{
AmtToForward: tc.amount,
OutgoingTimeLock: uint32(tc.finalHopExpiry),
EncryptedData: blindedPath[len(blindedPath)-1].CipherText,
EncryptedData: lastHop.CipherText,
}
if len(blindedPath) == 1 {
finalHop.BlindingPoint = blindedPoint
@@ -3539,11 +3568,11 @@ func TestLastHopPayloadSize(t *testing.T) {
payLoad, err := createHopPayload(finalHop, 0, true)
require.NoErrorf(t, err, "failed to create hop payload")
expectedPayloadSize := lastHopPayloadSize(
expectedPayloadSize, err := lastHopPayloadSize(
tc.restrictions, tc.finalHopExpiry,
tc.amount,
)
require.NoError(t, err)
require.Equal(
t, expectedPayloadSize,
uint64(payLoad.NumBytes()),