From c48841a38bf7002999f2fe1fb0e22e2f8d907aa7 Mon Sep 17 00:00:00 2001 From: Carla Kirk-Cohen Date: Tue, 14 Jun 2022 14:05:30 +0200 Subject: [PATCH] record: add TLV encoding/decoding for blinded route data blobs This commit adds encoding and decoding for blinded route data blobs. TLV fields such as path_id (which are only used for the final hop) are omitted to minimize the change size. --- lnwire/features.go | 44 ++++++ record/blinded_data.go | 304 ++++++++++++++++++++++++++++++++++++ record/blinded_data_test.go | 197 +++++++++++++++++++++++ 3 files changed, 545 insertions(+) create mode 100644 record/blinded_data.go create mode 100644 record/blinded_data_test.go diff --git a/lnwire/features.go b/lnwire/features.go index 7179a906f..8d603731f 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -678,6 +678,50 @@ func EmptyFeatureVector() *FeatureVector { return NewFeatureVector(nil, Features) } +// Record implements the RecordProducer interface for FeatureVector. Note that +// it uses a zero-value type is used to produce the record, as we expect this +// type value to be overwritten when used in generic TLV record production. +// This allows a single Record function to serve in the many different contexts +// in which feature vectors are encoded. This record wraps the encoding/ +// decoding for our raw feature vectors so that we can directly parse fully +// formed feature vector types. +func (fv *FeatureVector) Record() tlv.Record { + return tlv.MakeDynamicRecord(0, fv, fv.sizeFunc, + func(w io.Writer, val interface{}, buf *[8]byte) error { + if f, ok := val.(*FeatureVector); ok { + return rawFeatureEncoder( + w, f.RawFeatureVector, buf, + ) + } + + return tlv.NewTypeForEncodingErr( + val, "*lnwire.FeatureVector", + ) + }, + func(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if f, ok := val.(*FeatureVector); ok { + features := NewFeatureVector(nil, Features) + err := rawFeatureDecoder( + r, features.RawFeatureVector, buf, l, + ) + if err != nil { + return err + } + + *f = *features + + return nil + } + + return tlv.NewTypeForDecodingErr( + val, "*lnwire.FeatureVector", l, l, + ) + }, + ) +} + // HasFeature returns whether a particular feature is included in the set. The // feature can be seen as set either if the bit is set directly OR the queried // bit has the same meaning as its corresponding even/odd bit, which is set diff --git a/record/blinded_data.go b/record/blinded_data.go new file mode 100644 index 000000000..7990fa738 --- /dev/null +++ b/record/blinded_data.go @@ -0,0 +1,304 @@ +package record + +import ( + "bytes" + "encoding/binary" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// BlindedRouteData contains the information that is included in a blinded +// route encrypted data blob that is created by the recipient to provide +// forwarding information. +type BlindedRouteData struct { + // ShortChannelID is the channel ID of the next hop. + ShortChannelID tlv.RecordT[tlv.TlvType2, lnwire.ShortChannelID] + + // NextBlindingOverride is a blinding point that should be switched + // in for the next hop. This is used to combine two blinded paths into + // one (which primarily is used in onion messaging, but in theory + // could be used for payments as well). + NextBlindingOverride tlv.OptionalRecordT[tlv.TlvType8, *btcec.PublicKey] + + // RelayInfo provides the relay parameters for the hop. + RelayInfo tlv.RecordT[tlv.TlvType10, PaymentRelayInfo] + + // Constraints provides the payment relay constraints for the hop. + Constraints tlv.OptionalRecordT[tlv.TlvType12, PaymentConstraints] + + // Features is the set of features the payment requires. + Features tlv.OptionalRecordT[tlv.TlvType14, lnwire.FeatureVector] +} + +// NewBlindedRouteData creates the data that's provided for hops within a +// blinded route. +func NewBlindedRouteData(chanID lnwire.ShortChannelID, + blindingOverride *btcec.PublicKey, relayInfo PaymentRelayInfo, + constraints *PaymentConstraints, + features *lnwire.FeatureVector) *BlindedRouteData { + + info := &BlindedRouteData{ + ShortChannelID: tlv.NewRecordT[tlv.TlvType2](chanID), + RelayInfo: tlv.NewRecordT[tlv.TlvType10](relayInfo), + } + + if blindingOverride != nil { + info.NextBlindingOverride = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType8](blindingOverride)) + } + + if constraints != nil { + info.Constraints = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType12](*constraints)) + } + + if features != nil { + info.Features = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType14](*features), + ) + } + + return info +} + +// DecodeBlindedRouteData decodes the data provided within a blinded route. +func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) { + var ( + d BlindedRouteData + + blindingOverride = d.NextBlindingOverride.Zero() + constraints = d.Constraints.Zero() + features = d.Features.Zero() + ) + + var tlvRecords lnwire.ExtraOpaqueData + if err := lnwire.ReadElements(r, &tlvRecords); err != nil { + return nil, err + } + + typeMap, err := tlvRecords.ExtractRecords( + &d.ShortChannelID, + &blindingOverride, &d.RelayInfo.Val, &constraints, + &features, + ) + if err != nil { + return nil, err + } + + val, ok := typeMap[d.NextBlindingOverride.TlvType()] + if ok && val == nil { + d.NextBlindingOverride = tlv.SomeRecordT(blindingOverride) + } + + if val, ok := typeMap[d.Constraints.TlvType()]; ok && val == nil { + d.Constraints = tlv.SomeRecordT(constraints) + } + + if val, ok := typeMap[d.Features.TlvType()]; ok && val == nil { + d.Features = tlv.SomeRecordT(features) + } + + return &d, nil +} + +// EncodeBlindedRouteData encodes the blinded route data provided. +func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) { + var ( + e lnwire.ExtraOpaqueData + recordProducers = make([]tlv.RecordProducer, 0, 5) + ) + + recordProducers = append(recordProducers, &data.ShortChannelID) + + data.NextBlindingOverride.WhenSome(func(pk tlv.RecordT[tlv.TlvType8, + *btcec.PublicKey]) { + + recordProducers = append(recordProducers, &pk) + }) + + recordProducers = append(recordProducers, &data.RelayInfo.Val) + + data.Constraints.WhenSome(func(cs tlv.RecordT[tlv.TlvType12, + PaymentConstraints]) { + + recordProducers = append(recordProducers, &cs) + }) + + data.Features.WhenSome(func(f tlv.RecordT[tlv.TlvType14, + lnwire.FeatureVector]) { + + recordProducers = append(recordProducers, &f) + }) + + if err := e.PackRecords(recordProducers...); err != nil { + return nil, err + } + + return e[:], nil +} + +// PaymentRelayInfo describes the relay policy for a blinded path. +type PaymentRelayInfo struct { + // CltvExpiryDelta is the expiry delta for the payment. + CltvExpiryDelta uint16 + + // FeeRate is the fee rate that will be charged per millionth of a + // satoshi. + FeeRate uint32 + + // BaseFee is the per-htlc fee charged. + BaseFee uint32 +} + +// newPaymentRelayRecord creates a tlv.Record that encodes the payment relay +// (type 10) type for an encrypted blob payload. +func (i *PaymentRelayInfo) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 10, &i, func() uint64 { + // uint16 + uint32 + tuint32 + return 2 + 4 + tlv.SizeTUint32(i.BaseFee) + }, encodePaymentRelay, decodePaymentRelay, + ) +} + +func encodePaymentRelay(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(**PaymentRelayInfo); ok { + relayInfo := *t + + // Just write our first 6 bytes directly. + binary.BigEndian.PutUint16(buf[:2], relayInfo.CltvExpiryDelta) + binary.BigEndian.PutUint32(buf[2:6], relayInfo.FeeRate) + if _, err := w.Write(buf[0:6]); err != nil { + return err + } + + // We can safely reuse buf here because we overwrite its + // contents. + return tlv.ETUint32(w, &relayInfo.BaseFee, buf) + } + + return tlv.NewTypeForEncodingErr(val, "**hop.PaymentRelayInfo") +} + +func decodePaymentRelay(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if t, ok := val.(**PaymentRelayInfo); ok && l <= 10 { + scratch := make([]byte, l) + + n, err := io.ReadFull(r, scratch) + if err != nil { + return err + } + + // We expect at least 6 bytes, because we have 2 bytes for + // cltv delta and 4 bytes for fee rate. + if n < 6 { + return tlv.NewTypeForDecodingErr(val, + "*hop.paymentRelayInfo", uint64(n), 6) + } + + relayInfo := *t + + relayInfo.CltvExpiryDelta = binary.BigEndian.Uint16( + scratch[0:2], + ) + relayInfo.FeeRate = binary.BigEndian.Uint32(scratch[2:6]) + + // To be able to re-use the DTUint32 function we create a + // buffer with just the bytes holding the variable length u32. + // If the base fee is zero, this will be an empty buffer, which + // is okay. + b := bytes.NewBuffer(scratch[6:]) + + return tlv.DTUint32(b, &relayInfo.BaseFee, buf, l-6) + } + + return tlv.NewTypeForDecodingErr(val, "*hop.paymentRelayInfo", l, 10) +} + +// PaymentConstraints is a set of restrictions on a payment. +type PaymentConstraints struct { + // MaxCltvExpiry is the maximum expiry height for the payment. + MaxCltvExpiry uint32 + + // HtlcMinimumMsat is the minimum htlc size for the payment. + HtlcMinimumMsat lnwire.MilliSatoshi +} + +func (p *PaymentConstraints) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 12, &p, func() uint64 { + // uint32 + tuint64. + return 4 + tlv.SizeTUint64(uint64( + p.HtlcMinimumMsat, + )) + }, + encodePaymentConstraints, decodePaymentConstraints, + ) +} + +func encodePaymentConstraints(w io.Writer, val interface{}, + buf *[8]byte) error { + + if c, ok := val.(**PaymentConstraints); ok { + constraints := *c + + binary.BigEndian.PutUint32(buf[:4], constraints.MaxCltvExpiry) + if _, err := w.Write(buf[:4]); err != nil { + return err + } + + // We can safely re-use buf here because we overwrite its + // contents. + htlcMsat := uint64(constraints.HtlcMinimumMsat) + + return tlv.ETUint64(w, &htlcMsat, buf) + } + + return tlv.NewTypeForEncodingErr(val, "**PaymentConstraints") +} + +func decodePaymentConstraints(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if c, ok := val.(**PaymentConstraints); ok && l <= 12 { + scratch := make([]byte, l) + + n, err := io.ReadFull(r, scratch) + if err != nil { + return err + } + + // We expect at least 4 bytes for our uint32. + if n < 4 { + return tlv.NewTypeForDecodingErr(val, + "*paymentConstraints", uint64(n), 4) + } + + payConstraints := *c + + payConstraints.MaxCltvExpiry = binary.BigEndian.Uint32( + scratch[:4], + ) + + // This could be empty if our minimum is zero, that's okay. + var ( + b = bytes.NewBuffer(scratch[4:]) + minHtlc uint64 + ) + + err = tlv.DTUint64(b, &minHtlc, buf, l-4) + if err != nil { + return err + } + payConstraints.HtlcMinimumMsat = lnwire.MilliSatoshi(minHtlc) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "**PaymentConstraints", l, l) +} diff --git a/record/blinded_data_test.go b/record/blinded_data_test.go new file mode 100644 index 000000000..f8e95cdcc --- /dev/null +++ b/record/blinded_data_test.go @@ -0,0 +1,197 @@ +package record + +import ( + "bytes" + "encoding/hex" + "fmt" + "math" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +//nolint:lll +const pubkeyStr = "02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619" + +func pubkey(t *testing.T) *btcec.PublicKey { + t.Helper() + + nodeBytes, err := hex.DecodeString(pubkeyStr) + require.NoError(t, err) + + nodePk, err := btcec.ParsePubKey(nodeBytes) + require.NoError(t, err) + + return nodePk +} + +// TestBlindedDataEncoding tests encoding and decoding of blinded data blobs. +// These tests specifically cover cases where the variable length encoded +// integers values have different numbers of leading zeros trimmed because +// these TLVs are the first composite records with variable length tlvs +// (previously, a variable length integer would take up the whole record). +func TestBlindedDataEncoding(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + baseFee uint32 + htlcMin lnwire.MilliSatoshi + features *lnwire.FeatureVector + constraints bool + }{ + { + name: "zero variable values", + baseFee: 0, + htlcMin: 0, + }, + { + name: "zeros trimmed", + baseFee: math.MaxUint32 / 2, + htlcMin: math.MaxUint64 / 2, + }, + { + name: "no zeros trimmed", + baseFee: math.MaxUint32, + htlcMin: math.MaxUint64, + }, + { + name: "nil feature vector", + features: nil, + }, + { + name: "non-nil, but empty feature vector", + features: lnwire.EmptyFeatureVector(), + }, + { + name: "populated feature vector", + features: lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(lnwire.AMPOptional), + lnwire.Features, + ), + }, + { + name: "no payment constraints", + constraints: true, + }, + } + + for _, testCase := range tests { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + // Create a standard set of blinded route data, using + // the values from our test case for the variable + // length encoded values. + channelID := lnwire.NewShortChanIDFromInt(1) + info := PaymentRelayInfo{ + FeeRate: 2, + CltvExpiryDelta: 3, + BaseFee: testCase.baseFee, + } + + var constraints *PaymentConstraints + if testCase.constraints { + constraints = &PaymentConstraints{ + MaxCltvExpiry: 4, + HtlcMinimumMsat: testCase.htlcMin, + } + } + + encodedData := NewBlindedRouteData( + channelID, pubkey(t), info, constraints, + testCase.features, + ) + + encoded, err := EncodeBlindedRouteData(encodedData) + require.NoError(t, err) + + b := bytes.NewBuffer(encoded) + decodedData, err := DecodeBlindedRouteData(b) + require.NoError(t, err) + + require.Equal(t, encodedData, decodedData) + }) + } +} + +// TestBlindedRouteVectors tests encoding/decoding of the test vectors for +// blinded route data provided in the specification. +// +//nolint:lll +func TestBlindingSpecTestVectors(t *testing.T) { + nextBlindingOverrideStr, err := hex.DecodeString("031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f") + require.NoError(t, err) + nextBlindingOverride, err := btcec.ParsePubKey(nextBlindingOverrideStr) + require.NoError(t, err) + + tests := []struct { + encoded string + expectedPaymentData *BlindedRouteData + }{ + { + encoded: "011a0000000000000000000000000000000000000000000000000000020800000000000006c10a0800240000009627100c06000b69e505dc0e00fd023103123456", + expectedPaymentData: NewBlindedRouteData( + lnwire.ShortChannelID{ + BlockHeight: 0, + TxIndex: 0, + TxPosition: 1729, + }, + nil, + PaymentRelayInfo{ + CltvExpiryDelta: 36, + FeeRate: 150, + BaseFee: 10000, + }, + &PaymentConstraints{ + MaxCltvExpiry: 748005, + HtlcMinimumMsat: 1500, + }, + lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(), + lnwire.Features, + ), + ), + }, + { + encoded: "020800000000000004510821031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f0a0800300000006401f40c06000b69c105dc0e00", + expectedPaymentData: NewBlindedRouteData( + lnwire.ShortChannelID{ + TxPosition: 1105, + }, + nextBlindingOverride, + PaymentRelayInfo{ + CltvExpiryDelta: 48, + FeeRate: 100, + BaseFee: 500, + }, + &PaymentConstraints{ + MaxCltvExpiry: 747969, + HtlcMinimumMsat: 1500, + }, + lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(), + lnwire.Features, + )), + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + route, err := hex.DecodeString(test.encoded) + require.NoError(t, err) + + buff := bytes.NewBuffer(route) + + decodedRoute, err := DecodeBlindedRouteData(buff) + require.NoError(t, err) + + require.Equal( + t, test.expectedPaymentData, decodedRoute, + ) + }) + } +}