diff --git a/record/blinded_data.go b/record/blinded_data.go index e133a4763..16ff6bd5a 100644 --- a/record/blinded_data.go +++ b/record/blinded_data.go @@ -17,6 +17,11 @@ type BlindedRouteData struct { // ShortChannelID is the channel ID of the next hop. ShortChannelID tlv.OptionalRecordT[tlv.TlvType2, lnwire.ShortChannelID] + // PathID is a secret set of bytes that the blinded path creator will + // set so that they can check the value on decryption to ensure that the + // path they created was used for the intended purpose. + PathID tlv.OptionalRecordT[tlv.TlvType6, []byte] + // NextBlindingOverride is a blinding point that should be switched // in for the next hop. This is used to combine two blinded paths into // one (which primarily is used in onion messaging, but in theory @@ -68,12 +73,33 @@ func NewNonFinalBlindedRouteData(chanID lnwire.ShortChannelID, return info } +// NewFinalHopBlindedRouteData creates the data that's provided for the final +// hop in a blinded route. +func NewFinalHopBlindedRouteData(constraints *PaymentConstraints, + pathID []byte) *BlindedRouteData { + + var data BlindedRouteData + if pathID != nil { + data.PathID = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6](pathID), + ) + } + + if constraints != nil { + data.Constraints = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType12](*constraints)) + } + + return &data +} + // DecodeBlindedRouteData decodes the data provided within a blinded route. func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) { var ( d BlindedRouteData scid = d.ShortChannelID.Zero() + pathID = d.PathID.Zero() blindingOverride = d.NextBlindingOverride.Zero() relayInfo = d.RelayInfo.Zero() constraints = d.Constraints.Zero() @@ -86,7 +112,8 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) { } typeMap, err := tlvRecords.ExtractRecords( - &scid, &blindingOverride, &relayInfo, &constraints, &features, + &scid, &pathID, &blindingOverride, &relayInfo, &constraints, + &features, ) if err != nil { return nil, err @@ -96,6 +123,10 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) { d.ShortChannelID = tlv.SomeRecordT(scid) } + if val, ok := typeMap[d.PathID.TlvType()]; ok && val == nil { + d.PathID = tlv.SomeRecordT(pathID) + } + val, ok := typeMap[d.NextBlindingOverride.TlvType()] if ok && val == nil { d.NextBlindingOverride = tlv.SomeRecordT(blindingOverride) @@ -129,6 +160,10 @@ func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) { recordProducers = append(recordProducers, &scid) }) + data.PathID.WhenSome(func(pathID tlv.RecordT[tlv.TlvType6, []byte]) { + recordProducers = append(recordProducers, &pathID) + }) + data.NextBlindingOverride.WhenSome(func(pk tlv.RecordT[tlv.TlvType8, *btcec.PublicKey]) { diff --git a/record/blinded_data_test.go b/record/blinded_data_test.go index 2394c8ac9..bc1e706ef 100644 --- a/record/blinded_data_test.go +++ b/record/blinded_data_test.go @@ -118,6 +118,59 @@ func TestBlindedDataEncoding(t *testing.T) { } } +// TestBlindedDataFinalHopEncoding tests the encoding and decoding of a blinded +// data blob intended for the final hop of a blinded path where only the pathID +// will potentially be set. +func TestBlindedDataFinalHopEncoding(t *testing.T) { + tests := []struct { + name string + pathID []byte + constraints bool + }{ + { + name: "with path ID", + pathID: []byte{1, 2, 3, 4, 5, 6}, + }, + { + name: "with no path ID", + pathID: nil, + }, + { + name: "with path ID and constraints", + pathID: []byte{1, 2, 3, 4, 5, 6}, + constraints: true, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var constraints *PaymentConstraints + if test.constraints { + constraints = &PaymentConstraints{ + MaxCltvExpiry: 4, + HtlcMinimumMsat: 5, + } + } + + encodedData := NewFinalHopBlindedRouteData( + constraints, test.pathID, + ) + + encoded, err := EncodeBlindedRouteData(encodedData) + require.NoError(t, err) + + b := bytes.NewBuffer(encoded) + decodedData, err := DecodeBlindedRouteData(b) + require.NoError(t, err) + + require.Equal(t, encodedData, decodedData) + }) + } +} + // TestBlindedRouteVectors tests encoding/decoding of the test vectors for // blinded route data provided in the specification. //