diff --git a/record/blinded_data.go b/record/blinded_data.go index 16ff6bd5a..7b7081b3b 100644 --- a/record/blinded_data.go +++ b/record/blinded_data.go @@ -14,6 +14,11 @@ import ( // route encrypted data blob that is created by the recipient to provide // forwarding information. type BlindedRouteData struct { + // Padding is an optional set of bytes that a recipient can use to pad + // the data so that the encrypted recipient data blobs are all the same + // length. + Padding tlv.OptionalRecordT[tlv.TlvType1, []byte] + // ShortChannelID is the channel ID of the next hop. ShortChannelID tlv.OptionalRecordT[tlv.TlvType2, lnwire.ShortChannelID] @@ -98,6 +103,7 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) { var ( d BlindedRouteData + padding = d.Padding.Zero() scid = d.ShortChannelID.Zero() pathID = d.PathID.Zero() blindingOverride = d.NextBlindingOverride.Zero() @@ -112,13 +118,18 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) { } typeMap, err := tlvRecords.ExtractRecords( - &scid, &pathID, &blindingOverride, &relayInfo, &constraints, - &features, + &padding, &scid, &pathID, &blindingOverride, &relayInfo, + &constraints, &features, ) if err != nil { return nil, err } + val, ok := typeMap[d.Padding.TlvType()] + if ok && val == nil { + d.Padding = tlv.SomeRecordT(padding) + } + if val, ok := typeMap[d.ShortChannelID.TlvType()]; ok && val == nil { d.ShortChannelID = tlv.SomeRecordT(scid) } @@ -127,7 +138,7 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) { d.PathID = tlv.SomeRecordT(pathID) } - val, ok := typeMap[d.NextBlindingOverride.TlvType()] + val, ok = typeMap[d.NextBlindingOverride.TlvType()] if ok && val == nil { d.NextBlindingOverride = tlv.SomeRecordT(blindingOverride) } @@ -154,6 +165,10 @@ func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) { recordProducers = make([]tlv.RecordProducer, 0, 5) ) + data.Padding.WhenSome(func(p tlv.RecordT[tlv.TlvType1, []byte]) { + recordProducers = append(recordProducers, &p) + }) + data.ShortChannelID.WhenSome(func(scid tlv.RecordT[tlv.TlvType2, lnwire.ShortChannelID]) { @@ -195,6 +210,19 @@ func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) { return e[:], nil } +// PadBy adds "n" padding bytes to the BlindedRouteData using the Padding field. +// Callers should be aware that the total payload size will change by more than +// "n" since the "n" bytes will be prefixed by BigSize type and length fields. +// Callers may need to call PadBy iteratively until each encrypted data packet +// is the same size and so each call will overwrite the Padding record. +// Note that calling PadBy with an n value of 0 will still result in a zero +// length TLV entry being added. +func (b *BlindedRouteData) PadBy(n int) { + b.Padding = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](make([]byte, n)), + ) +} + // PaymentRelayInfo describes the relay policy for a blinded path. type PaymentRelayInfo struct { // CltvExpiryDelta is the expiry delta for the payment. @@ -208,8 +236,8 @@ type PaymentRelayInfo struct { BaseFee uint32 } -// newPaymentRelayRecord creates a tlv.Record that encodes the payment relay -// (type 10) type for an encrypted blob payload. +// Record creates a tlv.Record that encodes the payment relay (type 10) type for +// an encrypted blob payload. func (i *PaymentRelayInfo) Record() tlv.Record { return tlv.MakeDynamicRecord( 10, &i, func() uint64 { diff --git a/record/blinded_data_test.go b/record/blinded_data_test.go index bc1e706ef..88ab9cddd 100644 --- a/record/blinded_data_test.go +++ b/record/blinded_data_test.go @@ -171,6 +171,74 @@ func TestBlindedDataFinalHopEncoding(t *testing.T) { } } +// TestBlindedRouteDataPadding tests the PadBy method of BlindedRouteData. +func TestBlindedRouteDataPadding(t *testing.T) { + newBlindedRouteData := func() *BlindedRouteData { + channelID := lnwire.NewShortChanIDFromInt(1) + info := PaymentRelayInfo{ + FeeRate: 2, + CltvExpiryDelta: 3, + BaseFee: 30, + } + + constraints := &PaymentConstraints{ + MaxCltvExpiry: 4, + HtlcMinimumMsat: 100, + } + + return NewNonFinalBlindedRouteData( + channelID, pubkey(t), info, constraints, nil, + ) + } + + tests := []struct { + name string + paddingSize int + expectedSizeIncrease uint64 + }{ + { + // Calling PadBy with an n value of 0 in the case where + // there is not yet a padding field will result in a + // zero length TLV entry being added. This will add 2 + // bytes for the type and length fields. + name: "no extra padding", + expectedSizeIncrease: 2, + }, + { + name: "small padding (length " + + "field of 1 byte)", + paddingSize: 200, + expectedSizeIncrease: 202, + }, + { + name: "medium padding (length field " + + "of 3 bytes)", + paddingSize: 256, + expectedSizeIncrease: 260, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + data := newBlindedRouteData() + + prePaddingEncoding, err := EncodeBlindedRouteData(data) + require.NoError(t, err) + + data.PadBy(test.paddingSize) + + postPaddingEncoding, err := EncodeBlindedRouteData(data) + require.NoError(t, err) + + require.EqualValues( + t, test.expectedSizeIncrease, + len(postPaddingEncoding)- + len(prePaddingEncoding), + ) + }) + } +} + // TestBlindedRouteVectors tests encoding/decoding of the test vectors for // blinded route data provided in the specification. // @@ -184,6 +252,7 @@ func TestBlindingSpecTestVectors(t *testing.T) { tests := []struct { encoded string expectedPaymentData *BlindedRouteData + expectedPadding int }{ { encoded: "011a0000000000000000000000000000000000000000000000000000020800000000000006c10a0800240000009627100c06000b69e505dc0e00fd023103123456", @@ -208,6 +277,7 @@ func TestBlindingSpecTestVectors(t *testing.T) { lnwire.Features, ), ), + expectedPadding: 26, }, { encoded: "020800000000000004510821031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f0a0800300000006401f40c06000b69c105dc0e00", @@ -228,7 +298,8 @@ func TestBlindingSpecTestVectors(t *testing.T) { lnwire.NewFeatureVector( lnwire.NewRawFeatureVector(), lnwire.Features, - )), + ), + ), }, } @@ -242,6 +313,12 @@ func TestBlindingSpecTestVectors(t *testing.T) { decodedRoute, err := DecodeBlindedRouteData(buff) require.NoError(t, err) + if test.expectedPadding != 0 { + test.expectedPaymentData.PadBy( + test.expectedPadding, + ) + } + require.Equal( t, test.expectedPaymentData, decodedRoute, )