From af506946433858fbaf47538715745e78990ab0d4 Mon Sep 17 00:00:00 2001 From: ffranr Date: Fri, 3 May 2024 18:36:00 +0100 Subject: [PATCH] lnwire: add `ExtraOpaqueData` helper functions and methods Introduces a couple of new helper functions for both the ExtraOpaqueData and CustomRecords types along with new methods on the ExtraOpaqueData. --- lnwire/custom_records.go | 153 +++++++++++++---- lnwire/custom_records_test.go | 6 +- lnwire/extra_bytes.go | 199 +++++++++++++++++---- lnwire/extra_bytes_test.go | 314 +++++++++++++++++++++++++++++++++- 4 files changed, 596 insertions(+), 76 deletions(-) diff --git a/lnwire/custom_records.go b/lnwire/custom_records.go index 1e1988842..f0f59185e 100644 --- a/lnwire/custom_records.go +++ b/lnwire/custom_records.go @@ -3,8 +3,10 @@ package lnwire import ( "bytes" "fmt" + "io" "sort" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/tlv" ) @@ -44,12 +46,12 @@ func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) { // ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob. func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) { - stream, err := tlv.NewStream() - if err != nil { - return nil, fmt.Errorf("error creating stream: %w", err) - } + return ParseCustomRecordsFrom(bytes.NewReader(b)) +} - typeMap, err := stream.DecodeWithParsedTypes(bytes.NewReader(b)) +// ParseCustomRecordsFrom creates a new CustomRecords instance from a reader. +func ParseCustomRecordsFrom(r io.Reader) (CustomRecords, error) { + typeMap, err := DecodeRecords(r) if err != nil { return nil, fmt.Errorf("error decoding HTLC record: %w", err) } @@ -121,21 +123,14 @@ func (c CustomRecords) ExtendRecordProducers( // Convert the custom records map to a TLV record producer slice and // append them to the exiting records slice. - crRecords := tlv.MapToRecords(c) - for _, record := range crRecords { - r := recordProducer{record} - producers = append(producers, &r) - } + customRecordProducers := RecordsAsProducers(tlv.MapToRecords(c)) + producers = append(producers, customRecordProducers...) // If the records slice which was given as an argument included TLV // values greater than or equal to the minimum custom records TLV type // we will sort the extended records slice to ensure that it is ordered // correctly. - sort.Slice(producers, func(i, j int) bool { - recordI := producers[i].Record() - recordJ := producers[j].Record() - return recordI.Type() < recordJ.Type() - }) + SortProducers(producers) return producers, nil } @@ -150,27 +145,119 @@ func (c CustomRecords) RecordProducers() []tlv.RecordProducer { // Convert the custom records map to a TLV record producer slice. records := tlv.MapToRecords(c) - // Convert the records to record producers. - producers := make([]tlv.RecordProducer, len(records)) - for i, record := range records { - producers[i] = &recordProducer{record} - } - - return producers + return RecordsAsProducers(records) } // Serialize serializes the custom records into a byte slice. func (c CustomRecords) Serialize() ([]byte, error) { records := tlv.MapToRecords(c) - stream, err := tlv.NewStream(records...) - if err != nil { - return nil, fmt.Errorf("error creating stream: %w", err) - } - - var b bytes.Buffer - if err := stream.Encode(&b); err != nil { - return nil, fmt.Errorf("error encoding custom records: %w", err) - } - - return b.Bytes(), nil + return EncodeRecords(records) +} + +// SerializeTo serializes the custom records into the given writer. +func (c CustomRecords) SerializeTo(w io.Writer) error { + records := tlv.MapToRecords(c) + return EncodeRecordsTo(w, records) +} + +// ProduceRecordsSorted converts a slice of record producers into a slice of +// records and then sorts it by type. +func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record { + records := fn.Map(func(producer tlv.RecordProducer) tlv.Record { + return producer.Record() + }, recordProducers) + + // Ensure that the set of records are sorted before we attempt to + // decode from the stream, to ensure they're canonical. + tlv.SortRecords(records) + + return records +} + +// SortProducers sorts the given record producers by their type. +func SortProducers(producers []tlv.RecordProducer) { + sort.Slice(producers, func(i, j int) bool { + recordI := producers[i].Record() + recordJ := producers[j].Record() + return recordI.Type() < recordJ.Type() + }) +} + +// TlvMapToRecords converts a TLV map into a slice of records. +func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record { + tlvMapGeneric := make(map[uint64][]byte) + for k, v := range tlvMap { + tlvMapGeneric[uint64(k)] = v + } + + return tlv.MapToRecords(tlvMapGeneric) +} + +// RecordsAsProducers converts a slice of records into a slice of record +// producers. +func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer { + return fn.Map(func(record tlv.Record) tlv.RecordProducer { + return &record + }, records) +} + +// EncodeRecords encodes the given records into a byte slice. +func EncodeRecords(records []tlv.Record) ([]byte, error) { + var buf bytes.Buffer + if err := EncodeRecordsTo(&buf, records); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// EncodeRecordsTo encodes the given records into the given writer. +func EncodeRecordsTo(w io.Writer, records []tlv.Record) error { + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// DecodeRecords decodes the given byte slice into the given records and returns +// the rest as a TLV type map. +func DecodeRecords(r io.Reader, + records ...tlv.Record) (tlv.TypeMap, error) { + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + return tlvStream.DecodeWithParsedTypes(r) +} + +// DecodeRecordsP2P decodes the given byte slice into the given records and +// returns the rest as a TLV type map. This function is identical to +// DecodeRecords except that the record size is capped at 65535. +func DecodeRecordsP2P(r *bytes.Reader, + records ...tlv.Record) (tlv.TypeMap, error) { + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + return tlvStream.DecodeWithParsedTypesP2P(r) +} + +// AssertUniqueTypes asserts that the given records have unique types. +func AssertUniqueTypes(r []tlv.Record) error { + seen := make(fn.Set[tlv.Type], len(r)) + for _, record := range r { + t := record.Type() + if seen.Contains(t) { + return fmt.Errorf("duplicate record type: %d", t) + } + seen.Add(t) + } + + return nil } diff --git a/lnwire/custom_records_test.go b/lnwire/custom_records_test.go index 0338d0159..1d30e2100 100644 --- a/lnwire/custom_records_test.go +++ b/lnwire/custom_records_test.go @@ -144,10 +144,8 @@ func TestCustomRecordsExtendRecordProducers(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { nonCustomRecords := tlv.MapToRecords(tc.existingTypes) - nonCustomProducers := fn.Map( - func(r tlv.Record) tlv.RecordProducer { - return &recordProducer{r} - }, nonCustomRecords, + nonCustomProducers := RecordsAsProducers( + nonCustomRecords, ) combined, err := tc.customRecords.ExtendRecordProducers( diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index 0ebf48f57..c4ca260e1 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -5,6 +5,7 @@ import ( "fmt" "io" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/tlv" ) @@ -15,6 +16,21 @@ import ( // upgrades to the network in a forwards compatible manner. type ExtraOpaqueData []byte +// NewExtraOpaqueData creates a new ExtraOpaqueData instance from a tlv.TypeMap. +func NewExtraOpaqueData(tlvMap tlv.TypeMap) (ExtraOpaqueData, error) { + // If the tlv map is empty, we'll want to mirror the behavior of + // decoding an empty extra opaque data field (see Decode method). + if len(tlvMap) == 0 { + return make([]byte, 0), nil + } + + // Convert the TLV map into a slice of records. + records := TlvMapToRecords(tlvMap) + + // Encode the records into the extra data byte slice. + return EncodeRecords(records) +} + // Encode attempts to encode the raw extra bytes into the passed io.Writer. func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error { eBytes := []byte((*e)[:]) @@ -25,8 +41,8 @@ func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error { return nil } -// Decode attempts to unpack the raw bytes encoded in the passed io.Reader as a -// set of extra opaque data. +// Decode attempts to unpack the raw bytes encoded in the passed-in io.Reader as +// a set of extra opaque data. func (e *ExtraOpaqueData) Decode(r io.Reader) error { // First, we'll attempt to read a set of bytes contained within the // passed io.Reader (if any exist). @@ -39,7 +55,7 @@ func (e *ExtraOpaqueData) Decode(r io.Reader) error { // This ensures that any struct that embeds this type will properly // store the bytes once this method exits. if len(rawBytes) > 0 { - *e = ExtraOpaqueData(rawBytes) + *e = rawBytes } else { *e = make([]byte, 0) } @@ -50,28 +66,17 @@ func (e *ExtraOpaqueData) Decode(r io.Reader) error { // PackRecords attempts to encode the set of tlv records into the target // ExtraOpaqueData instance. The records will be encoded as a raw TLV stream // and stored within the backing slice pointer. -func (e *ExtraOpaqueData) PackRecords(recordProducers ...tlv.RecordProducer) error { - // First, assemble all the records passed in in series. - records := make([]tlv.Record, 0, len(recordProducers)) - for _, producer := range recordProducers { - records = append(records, producer.Record()) - } +func (e *ExtraOpaqueData) PackRecords( + recordProducers ...tlv.RecordProducer) error { - // Ensure that the set of records are sorted before we encode them into - // the stream, to ensure they're canonical. - tlv.SortRecords(records) - - tlvStream, err := tlv.NewStream(records...) + // Assemble all the records passed in series, then encode them. + records := ProduceRecordsSorted(recordProducers...) + encoded, err := EncodeRecords(records) if err != nil { return err } - var extraBytesWriter bytes.Buffer - if err := tlvStream.Encode(&extraBytesWriter); err != nil { - return err - } - - *e = ExtraOpaqueData(extraBytesWriter.Bytes()) + *e = encoded return nil } @@ -80,29 +85,38 @@ func (e *ExtraOpaqueData) PackRecords(recordProducers ...tlv.RecordProducer) err // it were a tlv stream. The set of raw parsed types is returned, and any // passed records (if found in the stream) will be parsed into the proper // tlv.Record. -func (e *ExtraOpaqueData) ExtractRecords(recordProducers ...tlv.RecordProducer) ( - tlv.TypeMap, error) { - - // First, assemble all the records passed in in series. - records := make([]tlv.Record, 0, len(recordProducers)) - for _, producer := range recordProducers { - records = append(records, producer.Record()) - } - - // Ensure that the set of records are sorted before we attempt to - // decode from the stream, to ensure they're canonical. - tlv.SortRecords(records) +func (e *ExtraOpaqueData) ExtractRecords( + recordProducers ...tlv.RecordProducer) (tlv.TypeMap, error) { + // First, assemble all the records passed in series. + records := ProduceRecordsSorted(recordProducers...) extraBytesReader := bytes.NewReader(*e) - tlvStream, err := tlv.NewStream(records...) + // Since ExtraOpaqueData is provided by a potentially malicious peer, + // pass it into the P2P decoding variant. + return DecodeRecordsP2P(extraBytesReader, records...) +} + +// RecordProducers parses ExtraOpaqueData into a slice of TLV record producers +// by interpreting it as a TLV map. +func (e *ExtraOpaqueData) RecordProducers() ([]tlv.RecordProducer, error) { + var recordProducers []tlv.RecordProducer + + // If the instance is nil or empty, return an empty slice. + if e == nil || len(*e) == 0 { + return recordProducers, nil + } + + // Parse the extra opaque data as a TLV map. + tlvMap, err := e.ExtractRecords() if err != nil { return nil, err } - // Since ExtraOpaqueData is provided by a potentially malicious peer, - // pass it into the P2P decoding variant. - return tlvStream.DecodeWithParsedTypesP2P(extraBytesReader) + // Convert the TLV map into a slice of record producers. + records := TlvMapToRecords(tlvMap) + + return RecordsAsProducers(records), nil } // EncodeMessageExtraData encodes the given recordProducers into the given @@ -120,3 +134,116 @@ func EncodeMessageExtraData(extraData *ExtraOpaqueData, // are all properly sorted. return extraData.PackRecords(recordProducers...) } + +// ParseAndExtractCustomRecords parses the given extra data into the passed-in +// records, then returns any remaining records split into custom records and +// extra data. +func ParseAndExtractCustomRecords(allExtraData ExtraOpaqueData, + knownRecords ...tlv.RecordProducer) (CustomRecords, + fn.Set[tlv.Type], ExtraOpaqueData, error) { + + extraDataTlvMap, err := allExtraData.ExtractRecords(knownRecords...) + if err != nil { + return nil, nil, nil, err + } + + // Remove the known and now extracted records from the leftover extra + // data map. + parsedKnownRecords := make(fn.Set[tlv.Type], len(knownRecords)) + for _, producer := range knownRecords { + r := producer.Record() + + // Only remove the records if it was parsed (remainder is nil). + // We'll just store the type so we can tell the caller which + // records were actually parsed fully. + val, ok := extraDataTlvMap[r.Type()] + if ok && val == nil { + parsedKnownRecords.Add(r.Type()) + delete(extraDataTlvMap, r.Type()) + } + } + + // Any records from the extra data TLV map which are in the custom + // records TLV type range will be included in the custom records field + // and removed from the extra data field. + customRecordsTlvMap := make(tlv.TypeMap, len(extraDataTlvMap)) + for k, v := range extraDataTlvMap { + // Skip records that are not in the custom records TLV type + // range. + if k < MinCustomRecordsTlvType { + continue + } + + // Include the record in the custom records map. + customRecordsTlvMap[k] = v + + // Now that the record is included in the custom records map, + // we can remove it from the extra data TLV map. + delete(extraDataTlvMap, k) + } + + // Set the custom records field to the custom records specific TLV + // record map. + customRecords, err := NewCustomRecords(customRecordsTlvMap) + if err != nil { + return nil, nil, nil, err + } + + // Encode the remaining records back into the extra data field. These + // records are not in the custom records TLV type range and do not + // have associated fields in the struct that produced the records. + extraData, err := NewExtraOpaqueData(extraDataTlvMap) + if err != nil { + return nil, nil, nil, err + } + + // Help with unit testing where we might have the empty value (nil) for + // the extra data instead of the default that's returned by the + // constructor (empty slice). + if len(extraData) == 0 { + extraData = nil + } + + return customRecords, parsedKnownRecords, extraData, nil +} + +// MergeAndEncode merges the known records with the extra data and custom +// records, then encodes the merged records into raw bytes. +func MergeAndEncode(knownRecords []tlv.RecordProducer, + extraData ExtraOpaqueData, customRecords CustomRecords) ([]byte, + error) { + + // Construct a slice of all the records that we should include in the + // message extra data field. We will start by including any records from + // the extra data field. + mergedRecords, err := extraData.RecordProducers() + if err != nil { + return nil, err + } + + // Merge the known and extra data records. + mergedRecords = append(mergedRecords, knownRecords...) + + // Include custom records in the extra data wire field if they are + // present. Ensure that the custom records are validated before encoding + // them. + if err := customRecords.Validate(); err != nil { + return nil, fmt.Errorf("custom records validation error: %w", + err) + } + + // Extend the message extra data records slice with TLV records from the + // custom records field. + mergedRecords = append( + mergedRecords, customRecords.RecordProducers()..., + ) + + // Now we can sort the records and make sure there are no records with + // the same type that would collide when encoding. + sortedRecords := ProduceRecordsSorted(mergedRecords...) + if err := AssertUniqueTypes(sortedRecords); err != nil { + return nil, err + } + + return EncodeRecords(sortedRecords) +} diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go index 1ecc7adf5..79c913362 100644 --- a/lnwire/extra_bytes_test.go +++ b/lnwire/extra_bytes_test.go @@ -11,6 +11,12 @@ import ( "github.com/stretchr/testify/require" ) +var ( + tlvType1 tlv.TlvType1 + tlvType2 tlv.TlvType2 + tlvType3 tlv.TlvType3 +) + // TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode // arbitrary payloads. func TestExtraOpaqueDataEncodeDecode(t *testing.T) { @@ -153,21 +159,18 @@ func TestPackRecords(t *testing.T) { var ( // Record type 1. - tlvType1 tlv.TlvType1 recordBytes1 = []byte("recordBytes1") tlvRecord1 = tlv.NewPrimitiveRecord[tlv.TlvType1]( recordBytes1, ) // Record type 2. - tlvType2 tlv.TlvType2 recordBytes2 = []byte("recordBytes2") tlvRecord2 = tlv.NewPrimitiveRecord[tlv.TlvType2]( recordBytes2, ) // Record type 3. - tlvType3 tlv.TlvType3 recordBytes3 = []byte("recordBytes3") tlvRecord3 = tlv.NewPrimitiveRecord[tlv.TlvType3]( recordBytes3, @@ -203,3 +206,308 @@ func TestPackRecords(t *testing.T) { require.Equal(t, recordBytes2, extractedRecords[tlvType2.TypeVal()]) require.Equal(t, recordBytes3, extractedRecords[tlvType3.TypeVal()]) } + +type dummyRecordProducer struct { + typ tlv.Type + scratchValue []byte + expectedValue []byte +} + +func (d *dummyRecordProducer) Record() tlv.Record { + return tlv.MakePrimitiveRecord(d.typ, &d.scratchValue) +} + +// TestExtraOpaqueData tests that we're able to properly encode/decode an +// ExtraOpaqueData instance. +func TestExtraOpaqueData(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + types tlv.TypeMap + expectedData ExtraOpaqueData + expectedTypes tlv.TypeMap + decoders []tlv.RecordProducer + }{ + { + name: "empty map", + expectedTypes: tlv.TypeMap{}, + expectedData: make([]byte, 0), + }, + { + name: "single record", + types: tlv.TypeMap{ + tlvType1.TypeVal(): []byte{1, 2, 3}, + }, + expectedData: ExtraOpaqueData{ + 0x01, 0x03, 1, 2, 3, + }, + expectedTypes: tlv.TypeMap{ + tlvType1.TypeVal(): []byte{1, 2, 3}, + }, + decoders: []tlv.RecordProducer{ + &dummyRecordProducer{ + typ: tlvType1.TypeVal(), + expectedValue: []byte{1, 2, 3}, + }, + }, + }, + { + name: "multiple records", + types: tlv.TypeMap{ + tlvType2.TypeVal(): []byte{4, 5, 6}, + tlvType1.TypeVal(): []byte{1, 2, 3}, + }, + expectedData: ExtraOpaqueData{ + 0x01, 0x03, 1, 2, 3, + 0x02, 0x03, 4, 5, 6, + }, + expectedTypes: tlv.TypeMap{ + tlvType1.TypeVal(): []byte{1, 2, 3}, + tlvType2.TypeVal(): []byte{4, 5, 6}, + }, + decoders: []tlv.RecordProducer{ + &dummyRecordProducer{ + typ: tlvType1.TypeVal(), + expectedValue: []byte{1, 2, 3}, + }, + &dummyRecordProducer{ + typ: tlvType2.TypeVal(), + expectedValue: []byte{4, 5, 6}, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // First, test the constructor. + opaqueData, err := NewExtraOpaqueData(tc.types) + require.NoError(t, err) + + require.Equal(t, tc.expectedData, opaqueData) + + // Now encode/decode. + var b bytes.Buffer + err = opaqueData.Encode(&b) + require.NoError(t, err) + + var decoded ExtraOpaqueData + err = decoded.Decode(&b) + require.NoError(t, err) + + require.Equal(t, opaqueData, decoded) + + // Now RecordProducers/PackRecords. + producers, err := opaqueData.RecordProducers() + require.NoError(t, err) + + var packed ExtraOpaqueData + err = packed.PackRecords(producers...) + require.NoError(t, err) + + // PackRecords returns nil vs. an empty slice if there + // are no records. We need to handle this case + // separately. + if len(producers) == 0 { + // Make sure the packed data is empty. + require.Empty(t, packed) + + // Now change it to an empty slice for the + // comparison below. + packed = make([]byte, 0) + } + require.Equal(t, opaqueData, packed) + + // ExtractRecords with an empty set of record producers + // should return the original type map. + extracted, err := opaqueData.ExtractRecords() + require.NoError(t, err) + + require.Equal(t, tc.expectedTypes, extracted) + + if len(tc.decoders) == 0 { + return + } + + // ExtractRecords with a set of record producers should + // only return the types that weren't in the passed-in + // set of producers. + extracted, err = opaqueData.ExtractRecords( + tc.decoders..., + ) + require.NoError(t, err) + + for parsedType := range tc.expectedTypes { + remainder, ok := extracted[parsedType] + require.True(t, ok) + require.Nil(t, remainder) + } + + for _, dec := range tc.decoders { + //nolint:forcetypeassert + dec := dec.(*dummyRecordProducer) + require.Equal( + t, dec.expectedValue, dec.scratchValue, + ) + } + }) + } +} + +// TestExtractAndMerge tests that the ParseAndExtractCustomRecords and +// MergeAndEncode functions work as expected. +func TestExtractAndMerge(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + knownRecords []tlv.RecordProducer + extraData ExtraOpaqueData + customRecords CustomRecords + expectedErr string + expectEncoded []byte + }{ + { + name: "invalid custom record", + customRecords: CustomRecords{ + 123: []byte("invalid"), + }, + expectedErr: "custom records validation error", + }, + { + name: "empty everything", + }, + { + name: "just extra data", + extraData: ExtraOpaqueData{ + 0x01, 0x03, 1, 2, 3, + 0x02, 0x03, 4, 5, 6, + }, + expectEncoded: []byte{ + 0x01, 0x03, 1, 2, 3, + 0x02, 0x03, 4, 5, 6, + }, + }, + { + name: "extra data with known record", + extraData: ExtraOpaqueData{ + 0x04, 0x03, 4, 4, 4, + 0x05, 0x03, 5, 5, 5, + }, + knownRecords: []tlv.RecordProducer{ + &dummyRecordProducer{ + typ: tlvType1.TypeVal(), + scratchValue: []byte{1, 2, 3}, + expectedValue: []byte{1, 2, 3}, + }, + &dummyRecordProducer{ + typ: tlvType2.TypeVal(), + scratchValue: []byte{4, 5, 6}, + expectedValue: []byte{4, 5, 6}, + }, + }, + expectEncoded: []byte{ + 0x01, 0x03, 1, 2, 3, + 0x02, 0x03, 4, 5, 6, + 0x04, 0x03, 4, 4, 4, + 0x05, 0x03, 5, 5, 5, + }, + }, + { + name: "extra data and custom records with known record", + extraData: ExtraOpaqueData{ + 0x04, 0x03, 4, 4, 4, + 0x05, 0x03, 5, 5, 5, + }, + customRecords: CustomRecords{ + MinCustomRecordsTlvType + 1: []byte{99, 99, 99}, + }, + knownRecords: []tlv.RecordProducer{ + &dummyRecordProducer{ + typ: tlvType1.TypeVal(), + scratchValue: []byte{1, 2, 3}, + expectedValue: []byte{1, 2, 3}, + }, + &dummyRecordProducer{ + typ: tlvType2.TypeVal(), + scratchValue: []byte{4, 5, 6}, + expectedValue: []byte{4, 5, 6}, + }, + }, + expectEncoded: []byte{ + 0x01, 0x03, 1, 2, 3, + 0x02, 0x03, 4, 5, 6, + 0x04, 0x03, 4, 4, 4, + 0x05, 0x03, 5, 5, 5, + 0xfe, 0x0, 0x1, 0x0, 0x1, 0x3, 0x63, 0x63, 0x63, + }, + }, + { + name: "duplicate records", + extraData: ExtraOpaqueData{ + 0x01, 0x03, 4, 4, 4, + 0x05, 0x03, 5, 5, 5, + }, + customRecords: CustomRecords{ + MinCustomRecordsTlvType + 1: []byte{99, 99, 99}, + }, + knownRecords: []tlv.RecordProducer{ + &dummyRecordProducer{ + typ: tlvType1.TypeVal(), + scratchValue: []byte{1, 2, 3}, + expectedValue: []byte{1, 2, 3}, + }, + &dummyRecordProducer{ + typ: tlvType2.TypeVal(), + scratchValue: []byte{4, 5, 6}, + expectedValue: []byte{4, 5, 6}, + }, + }, + expectedErr: "duplicate record type: 1", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encoded, err := MergeAndEncode( + tc.knownRecords, tc.extraData, tc.customRecords, + ) + + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + + return + } + + require.NoError(t, err) + require.Equal(t, tc.expectEncoded, encoded) + + // Clear all the scratch values, to make sure they're + // decoded from the data again. + for _, dec := range tc.knownRecords { + //nolint:forcetypeassert + dec := dec.(*dummyRecordProducer) + dec.scratchValue = nil + } + + pCR, pKR, pED, err := ParseAndExtractCustomRecords( + encoded, tc.knownRecords..., + ) + require.NoError(t, err) + + require.Equal(t, tc.customRecords, pCR) + require.Equal(t, tc.extraData, pED) + + for _, dec := range tc.knownRecords { + //nolint:forcetypeassert + dec := dec.(*dummyRecordProducer) + require.Equal( + t, dec.expectedValue, dec.scratchValue, + ) + + require.Contains(t, pKR, dec.typ) + } + }) + } +}