record: add Padding field to BlindedRouteData

When we start creating blinded paths to ourselves, we will want to be
able to pad the data for each hop so that the `encrypted_recipient_data`
for each hop is the same. We add a `PadBy` method that allows a caller
to add a certain number of bytes to the padding field. Note that adding
n bytes won't always mean that the encoded payload will increase by size
n since there will be overhead for the type and lenght fields for the new
TLV field. This will also be the case when the number of bytes added
results in a BigSize bucket jump for TLV length field. The
responsibility of ensuring that the final payloads are the same size is
left to the caller who may need to call PadBy iteratively to achieve the
goal. I decided to leave this to the caller since doing this at the
actual TLV level will be quite intrusive & I think it is uneccessary to
touch that code for this unique use case.
This commit is contained in:
Elle Mouton
2024-05-02 14:35:34 +02:00
parent 15f3cce27d
commit 9ada4a9068
2 changed files with 111 additions and 6 deletions

View File

@@ -14,6 +14,11 @@ import (
// route encrypted data blob that is created by the recipient to provide // route encrypted data blob that is created by the recipient to provide
// forwarding information. // forwarding information.
type BlindedRouteData struct { type BlindedRouteData struct {
// Padding is an optional set of bytes that a recipient can use to pad
// the data so that the encrypted recipient data blobs are all the same
// length.
Padding tlv.OptionalRecordT[tlv.TlvType1, []byte]
// ShortChannelID is the channel ID of the next hop. // ShortChannelID is the channel ID of the next hop.
ShortChannelID tlv.OptionalRecordT[tlv.TlvType2, lnwire.ShortChannelID] ShortChannelID tlv.OptionalRecordT[tlv.TlvType2, lnwire.ShortChannelID]
@@ -98,6 +103,7 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) {
var ( var (
d BlindedRouteData d BlindedRouteData
padding = d.Padding.Zero()
scid = d.ShortChannelID.Zero() scid = d.ShortChannelID.Zero()
pathID = d.PathID.Zero() pathID = d.PathID.Zero()
blindingOverride = d.NextBlindingOverride.Zero() blindingOverride = d.NextBlindingOverride.Zero()
@@ -112,13 +118,18 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) {
} }
typeMap, err := tlvRecords.ExtractRecords( typeMap, err := tlvRecords.ExtractRecords(
&scid, &pathID, &blindingOverride, &relayInfo, &constraints, &padding, &scid, &pathID, &blindingOverride, &relayInfo,
&features, &constraints, &features,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
val, ok := typeMap[d.Padding.TlvType()]
if ok && val == nil {
d.Padding = tlv.SomeRecordT(padding)
}
if val, ok := typeMap[d.ShortChannelID.TlvType()]; ok && val == nil { if val, ok := typeMap[d.ShortChannelID.TlvType()]; ok && val == nil {
d.ShortChannelID = tlv.SomeRecordT(scid) d.ShortChannelID = tlv.SomeRecordT(scid)
} }
@@ -127,7 +138,7 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) {
d.PathID = tlv.SomeRecordT(pathID) d.PathID = tlv.SomeRecordT(pathID)
} }
val, ok := typeMap[d.NextBlindingOverride.TlvType()] val, ok = typeMap[d.NextBlindingOverride.TlvType()]
if ok && val == nil { if ok && val == nil {
d.NextBlindingOverride = tlv.SomeRecordT(blindingOverride) d.NextBlindingOverride = tlv.SomeRecordT(blindingOverride)
} }
@@ -154,6 +165,10 @@ func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) {
recordProducers = make([]tlv.RecordProducer, 0, 5) recordProducers = make([]tlv.RecordProducer, 0, 5)
) )
data.Padding.WhenSome(func(p tlv.RecordT[tlv.TlvType1, []byte]) {
recordProducers = append(recordProducers, &p)
})
data.ShortChannelID.WhenSome(func(scid tlv.RecordT[tlv.TlvType2, data.ShortChannelID.WhenSome(func(scid tlv.RecordT[tlv.TlvType2,
lnwire.ShortChannelID]) { lnwire.ShortChannelID]) {
@@ -195,6 +210,19 @@ func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) {
return e[:], nil return e[:], nil
} }
// PadBy adds "n" padding bytes to the BlindedRouteData using the Padding field.
// Callers should be aware that the total payload size will change by more than
// "n" since the "n" bytes will be prefixed by BigSize type and length fields.
// Callers may need to call PadBy iteratively until each encrypted data packet
// is the same size and so each call will overwrite the Padding record.
// Note that calling PadBy with an n value of 0 will still result in a zero
// length TLV entry being added.
func (b *BlindedRouteData) PadBy(n int) {
b.Padding = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType1](make([]byte, n)),
)
}
// PaymentRelayInfo describes the relay policy for a blinded path. // PaymentRelayInfo describes the relay policy for a blinded path.
type PaymentRelayInfo struct { type PaymentRelayInfo struct {
// CltvExpiryDelta is the expiry delta for the payment. // CltvExpiryDelta is the expiry delta for the payment.
@@ -208,8 +236,8 @@ type PaymentRelayInfo struct {
BaseFee uint32 BaseFee uint32
} }
// newPaymentRelayRecord creates a tlv.Record that encodes the payment relay // Record creates a tlv.Record that encodes the payment relay (type 10) type for
// (type 10) type for an encrypted blob payload. // an encrypted blob payload.
func (i *PaymentRelayInfo) Record() tlv.Record { func (i *PaymentRelayInfo) Record() tlv.Record {
return tlv.MakeDynamicRecord( return tlv.MakeDynamicRecord(
10, &i, func() uint64 { 10, &i, func() uint64 {

View File

@@ -171,6 +171,74 @@ func TestBlindedDataFinalHopEncoding(t *testing.T) {
} }
} }
// TestBlindedRouteDataPadding tests the PadBy method of BlindedRouteData.
func TestBlindedRouteDataPadding(t *testing.T) {
newBlindedRouteData := func() *BlindedRouteData {
channelID := lnwire.NewShortChanIDFromInt(1)
info := PaymentRelayInfo{
FeeRate: 2,
CltvExpiryDelta: 3,
BaseFee: 30,
}
constraints := &PaymentConstraints{
MaxCltvExpiry: 4,
HtlcMinimumMsat: 100,
}
return NewNonFinalBlindedRouteData(
channelID, pubkey(t), info, constraints, nil,
)
}
tests := []struct {
name string
paddingSize int
expectedSizeIncrease uint64
}{
{
// Calling PadBy with an n value of 0 in the case where
// there is not yet a padding field will result in a
// zero length TLV entry being added. This will add 2
// bytes for the type and length fields.
name: "no extra padding",
expectedSizeIncrease: 2,
},
{
name: "small padding (length " +
"field of 1 byte)",
paddingSize: 200,
expectedSizeIncrease: 202,
},
{
name: "medium padding (length field " +
"of 3 bytes)",
paddingSize: 256,
expectedSizeIncrease: 260,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data := newBlindedRouteData()
prePaddingEncoding, err := EncodeBlindedRouteData(data)
require.NoError(t, err)
data.PadBy(test.paddingSize)
postPaddingEncoding, err := EncodeBlindedRouteData(data)
require.NoError(t, err)
require.EqualValues(
t, test.expectedSizeIncrease,
len(postPaddingEncoding)-
len(prePaddingEncoding),
)
})
}
}
// TestBlindedRouteVectors tests encoding/decoding of the test vectors for // TestBlindedRouteVectors tests encoding/decoding of the test vectors for
// blinded route data provided in the specification. // blinded route data provided in the specification.
// //
@@ -184,6 +252,7 @@ func TestBlindingSpecTestVectors(t *testing.T) {
tests := []struct { tests := []struct {
encoded string encoded string
expectedPaymentData *BlindedRouteData expectedPaymentData *BlindedRouteData
expectedPadding int
}{ }{
{ {
encoded: "011a0000000000000000000000000000000000000000000000000000020800000000000006c10a0800240000009627100c06000b69e505dc0e00fd023103123456", encoded: "011a0000000000000000000000000000000000000000000000000000020800000000000006c10a0800240000009627100c06000b69e505dc0e00fd023103123456",
@@ -208,6 +277,7 @@ func TestBlindingSpecTestVectors(t *testing.T) {
lnwire.Features, lnwire.Features,
), ),
), ),
expectedPadding: 26,
}, },
{ {
encoded: "020800000000000004510821031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f0a0800300000006401f40c06000b69c105dc0e00", encoded: "020800000000000004510821031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f0a0800300000006401f40c06000b69c105dc0e00",
@@ -228,7 +298,8 @@ func TestBlindingSpecTestVectors(t *testing.T) {
lnwire.NewFeatureVector( lnwire.NewFeatureVector(
lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
lnwire.Features, lnwire.Features,
)), ),
),
}, },
} }
@@ -242,6 +313,12 @@ func TestBlindingSpecTestVectors(t *testing.T) {
decodedRoute, err := DecodeBlindedRouteData(buff) decodedRoute, err := DecodeBlindedRouteData(buff)
require.NoError(t, err) require.NoError(t, err)
if test.expectedPadding != 0 {
test.expectedPaymentData.PadBy(
test.expectedPadding,
)
}
require.Equal( require.Equal(
t, test.expectedPaymentData, decodedRoute, t, test.expectedPaymentData, decodedRoute,
) )