mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-29 23:21:12 +02:00
lnwire: add ExtraOpaqueData
helper functions and methods
Introduces a couple of new helper functions for both the ExtraOpaqueData and CustomRecords types along with new methods on the ExtraOpaqueData.
This commit is contained in:
@@ -3,8 +3,10 @@ package lnwire
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/fn"
|
||||||
"github.com/lightningnetwork/lnd/tlv"
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,12 +46,12 @@ func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) {
|
|||||||
|
|
||||||
// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob.
|
// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob.
|
||||||
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {
|
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {
|
||||||
stream, err := tlv.NewStream()
|
return ParseCustomRecordsFrom(bytes.NewReader(b))
|
||||||
if err != nil {
|
}
|
||||||
return nil, fmt.Errorf("error creating stream: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
typeMap, err := stream.DecodeWithParsedTypes(bytes.NewReader(b))
|
// ParseCustomRecordsFrom creates a new CustomRecords instance from a reader.
|
||||||
|
func ParseCustomRecordsFrom(r io.Reader) (CustomRecords, error) {
|
||||||
|
typeMap, err := DecodeRecords(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
|
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
|
||||||
}
|
}
|
||||||
@@ -121,21 +123,14 @@ func (c CustomRecords) ExtendRecordProducers(
|
|||||||
|
|
||||||
// Convert the custom records map to a TLV record producer slice and
|
// Convert the custom records map to a TLV record producer slice and
|
||||||
// append them to the exiting records slice.
|
// append them to the exiting records slice.
|
||||||
crRecords := tlv.MapToRecords(c)
|
customRecordProducers := RecordsAsProducers(tlv.MapToRecords(c))
|
||||||
for _, record := range crRecords {
|
producers = append(producers, customRecordProducers...)
|
||||||
r := recordProducer{record}
|
|
||||||
producers = append(producers, &r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the records slice which was given as an argument included TLV
|
// If the records slice which was given as an argument included TLV
|
||||||
// values greater than or equal to the minimum custom records TLV type
|
// values greater than or equal to the minimum custom records TLV type
|
||||||
// we will sort the extended records slice to ensure that it is ordered
|
// we will sort the extended records slice to ensure that it is ordered
|
||||||
// correctly.
|
// correctly.
|
||||||
sort.Slice(producers, func(i, j int) bool {
|
SortProducers(producers)
|
||||||
recordI := producers[i].Record()
|
|
||||||
recordJ := producers[j].Record()
|
|
||||||
return recordI.Type() < recordJ.Type()
|
|
||||||
})
|
|
||||||
|
|
||||||
return producers, nil
|
return producers, nil
|
||||||
}
|
}
|
||||||
@@ -150,27 +145,119 @@ func (c CustomRecords) RecordProducers() []tlv.RecordProducer {
|
|||||||
// Convert the custom records map to a TLV record producer slice.
|
// Convert the custom records map to a TLV record producer slice.
|
||||||
records := tlv.MapToRecords(c)
|
records := tlv.MapToRecords(c)
|
||||||
|
|
||||||
// Convert the records to record producers.
|
return RecordsAsProducers(records)
|
||||||
producers := make([]tlv.RecordProducer, len(records))
|
|
||||||
for i, record := range records {
|
|
||||||
producers[i] = &recordProducer{record}
|
|
||||||
}
|
|
||||||
|
|
||||||
return producers
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialize serializes the custom records into a byte slice.
|
// Serialize serializes the custom records into a byte slice.
|
||||||
func (c CustomRecords) Serialize() ([]byte, error) {
|
func (c CustomRecords) Serialize() ([]byte, error) {
|
||||||
records := tlv.MapToRecords(c)
|
records := tlv.MapToRecords(c)
|
||||||
stream, err := tlv.NewStream(records...)
|
return EncodeRecords(records)
|
||||||
if err != nil {
|
}
|
||||||
return nil, fmt.Errorf("error creating stream: %w", err)
|
|
||||||
}
|
// SerializeTo serializes the custom records into the given writer.
|
||||||
|
func (c CustomRecords) SerializeTo(w io.Writer) error {
|
||||||
var b bytes.Buffer
|
records := tlv.MapToRecords(c)
|
||||||
if err := stream.Encode(&b); err != nil {
|
return EncodeRecordsTo(w, records)
|
||||||
return nil, fmt.Errorf("error encoding custom records: %w", err)
|
}
|
||||||
}
|
|
||||||
|
// ProduceRecordsSorted converts a slice of record producers into a slice of
|
||||||
return b.Bytes(), nil
|
// records and then sorts it by type.
|
||||||
|
func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record {
|
||||||
|
records := fn.Map(func(producer tlv.RecordProducer) tlv.Record {
|
||||||
|
return producer.Record()
|
||||||
|
}, recordProducers)
|
||||||
|
|
||||||
|
// Ensure that the set of records are sorted before we attempt to
|
||||||
|
// decode from the stream, to ensure they're canonical.
|
||||||
|
tlv.SortRecords(records)
|
||||||
|
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
|
||||||
|
// SortProducers sorts the given record producers by their type.
|
||||||
|
func SortProducers(producers []tlv.RecordProducer) {
|
||||||
|
sort.Slice(producers, func(i, j int) bool {
|
||||||
|
recordI := producers[i].Record()
|
||||||
|
recordJ := producers[j].Record()
|
||||||
|
return recordI.Type() < recordJ.Type()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TlvMapToRecords converts a TLV map into a slice of records.
|
||||||
|
func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
|
||||||
|
tlvMapGeneric := make(map[uint64][]byte)
|
||||||
|
for k, v := range tlvMap {
|
||||||
|
tlvMapGeneric[uint64(k)] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlv.MapToRecords(tlvMapGeneric)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordsAsProducers converts a slice of records into a slice of record
|
||||||
|
// producers.
|
||||||
|
func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer {
|
||||||
|
return fn.Map(func(record tlv.Record) tlv.RecordProducer {
|
||||||
|
return &record
|
||||||
|
}, records)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeRecords encodes the given records into a byte slice.
|
||||||
|
func EncodeRecords(records []tlv.Record) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := EncodeRecordsTo(&buf, records); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeRecordsTo encodes the given records into the given writer.
|
||||||
|
func EncodeRecordsTo(w io.Writer, records []tlv.Record) error {
|
||||||
|
tlvStream, err := tlv.NewStream(records...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlvStream.Encode(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeRecords decodes the given byte slice into the given records and returns
|
||||||
|
// the rest as a TLV type map.
|
||||||
|
func DecodeRecords(r io.Reader,
|
||||||
|
records ...tlv.Record) (tlv.TypeMap, error) {
|
||||||
|
|
||||||
|
tlvStream, err := tlv.NewStream(records...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlvStream.DecodeWithParsedTypes(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeRecordsP2P decodes the given byte slice into the given records and
|
||||||
|
// returns the rest as a TLV type map. This function is identical to
|
||||||
|
// DecodeRecords except that the record size is capped at 65535.
|
||||||
|
func DecodeRecordsP2P(r *bytes.Reader,
|
||||||
|
records ...tlv.Record) (tlv.TypeMap, error) {
|
||||||
|
|
||||||
|
tlvStream, err := tlv.NewStream(records...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlvStream.DecodeWithParsedTypesP2P(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertUniqueTypes asserts that the given records have unique types.
|
||||||
|
func AssertUniqueTypes(r []tlv.Record) error {
|
||||||
|
seen := make(fn.Set[tlv.Type], len(r))
|
||||||
|
for _, record := range r {
|
||||||
|
t := record.Type()
|
||||||
|
if seen.Contains(t) {
|
||||||
|
return fmt.Errorf("duplicate record type: %d", t)
|
||||||
|
}
|
||||||
|
seen.Add(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -144,10 +144,8 @@ func TestCustomRecordsExtendRecordProducers(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
nonCustomRecords := tlv.MapToRecords(tc.existingTypes)
|
nonCustomRecords := tlv.MapToRecords(tc.existingTypes)
|
||||||
nonCustomProducers := fn.Map(
|
nonCustomProducers := RecordsAsProducers(
|
||||||
func(r tlv.Record) tlv.RecordProducer {
|
nonCustomRecords,
|
||||||
return &recordProducer{r}
|
|
||||||
}, nonCustomRecords,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
combined, err := tc.customRecords.ExtendRecordProducers(
|
combined, err := tc.customRecords.ExtendRecordProducers(
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/fn"
|
||||||
"github.com/lightningnetwork/lnd/tlv"
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,6 +16,21 @@ import (
|
|||||||
// upgrades to the network in a forwards compatible manner.
|
// upgrades to the network in a forwards compatible manner.
|
||||||
type ExtraOpaqueData []byte
|
type ExtraOpaqueData []byte
|
||||||
|
|
||||||
|
// NewExtraOpaqueData creates a new ExtraOpaqueData instance from a tlv.TypeMap.
|
||||||
|
func NewExtraOpaqueData(tlvMap tlv.TypeMap) (ExtraOpaqueData, error) {
|
||||||
|
// If the tlv map is empty, we'll want to mirror the behavior of
|
||||||
|
// decoding an empty extra opaque data field (see Decode method).
|
||||||
|
if len(tlvMap) == 0 {
|
||||||
|
return make([]byte, 0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the TLV map into a slice of records.
|
||||||
|
records := TlvMapToRecords(tlvMap)
|
||||||
|
|
||||||
|
// Encode the records into the extra data byte slice.
|
||||||
|
return EncodeRecords(records)
|
||||||
|
}
|
||||||
|
|
||||||
// Encode attempts to encode the raw extra bytes into the passed io.Writer.
|
// Encode attempts to encode the raw extra bytes into the passed io.Writer.
|
||||||
func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error {
|
func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error {
|
||||||
eBytes := []byte((*e)[:])
|
eBytes := []byte((*e)[:])
|
||||||
@@ -25,8 +41,8 @@ func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode attempts to unpack the raw bytes encoded in the passed io.Reader as a
|
// Decode attempts to unpack the raw bytes encoded in the passed-in io.Reader as
|
||||||
// set of extra opaque data.
|
// a set of extra opaque data.
|
||||||
func (e *ExtraOpaqueData) Decode(r io.Reader) error {
|
func (e *ExtraOpaqueData) Decode(r io.Reader) error {
|
||||||
// First, we'll attempt to read a set of bytes contained within the
|
// First, we'll attempt to read a set of bytes contained within the
|
||||||
// passed io.Reader (if any exist).
|
// passed io.Reader (if any exist).
|
||||||
@@ -39,7 +55,7 @@ func (e *ExtraOpaqueData) Decode(r io.Reader) error {
|
|||||||
// This ensures that any struct that embeds this type will properly
|
// This ensures that any struct that embeds this type will properly
|
||||||
// store the bytes once this method exits.
|
// store the bytes once this method exits.
|
||||||
if len(rawBytes) > 0 {
|
if len(rawBytes) > 0 {
|
||||||
*e = ExtraOpaqueData(rawBytes)
|
*e = rawBytes
|
||||||
} else {
|
} else {
|
||||||
*e = make([]byte, 0)
|
*e = make([]byte, 0)
|
||||||
}
|
}
|
||||||
@@ -50,28 +66,17 @@ func (e *ExtraOpaqueData) Decode(r io.Reader) error {
|
|||||||
// PackRecords attempts to encode the set of tlv records into the target
|
// PackRecords attempts to encode the set of tlv records into the target
|
||||||
// ExtraOpaqueData instance. The records will be encoded as a raw TLV stream
|
// ExtraOpaqueData instance. The records will be encoded as a raw TLV stream
|
||||||
// and stored within the backing slice pointer.
|
// and stored within the backing slice pointer.
|
||||||
func (e *ExtraOpaqueData) PackRecords(recordProducers ...tlv.RecordProducer) error {
|
func (e *ExtraOpaqueData) PackRecords(
|
||||||
// First, assemble all the records passed in in series.
|
recordProducers ...tlv.RecordProducer) error {
|
||||||
records := make([]tlv.Record, 0, len(recordProducers))
|
|
||||||
for _, producer := range recordProducers {
|
|
||||||
records = append(records, producer.Record())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that the set of records are sorted before we encode them into
|
// Assemble all the records passed in series, then encode them.
|
||||||
// the stream, to ensure they're canonical.
|
records := ProduceRecordsSorted(recordProducers...)
|
||||||
tlv.SortRecords(records)
|
encoded, err := EncodeRecords(records)
|
||||||
|
|
||||||
tlvStream, err := tlv.NewStream(records...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var extraBytesWriter bytes.Buffer
|
*e = encoded
|
||||||
if err := tlvStream.Encode(&extraBytesWriter); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
*e = ExtraOpaqueData(extraBytesWriter.Bytes())
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -80,29 +85,38 @@ func (e *ExtraOpaqueData) PackRecords(recordProducers ...tlv.RecordProducer) err
|
|||||||
// it were a tlv stream. The set of raw parsed types is returned, and any
|
// it were a tlv stream. The set of raw parsed types is returned, and any
|
||||||
// passed records (if found in the stream) will be parsed into the proper
|
// passed records (if found in the stream) will be parsed into the proper
|
||||||
// tlv.Record.
|
// tlv.Record.
|
||||||
func (e *ExtraOpaqueData) ExtractRecords(recordProducers ...tlv.RecordProducer) (
|
func (e *ExtraOpaqueData) ExtractRecords(
|
||||||
tlv.TypeMap, error) {
|
recordProducers ...tlv.RecordProducer) (tlv.TypeMap, error) {
|
||||||
|
|
||||||
// First, assemble all the records passed in in series.
|
|
||||||
records := make([]tlv.Record, 0, len(recordProducers))
|
|
||||||
for _, producer := range recordProducers {
|
|
||||||
records = append(records, producer.Record())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that the set of records are sorted before we attempt to
|
|
||||||
// decode from the stream, to ensure they're canonical.
|
|
||||||
tlv.SortRecords(records)
|
|
||||||
|
|
||||||
|
// First, assemble all the records passed in series.
|
||||||
|
records := ProduceRecordsSorted(recordProducers...)
|
||||||
extraBytesReader := bytes.NewReader(*e)
|
extraBytesReader := bytes.NewReader(*e)
|
||||||
|
|
||||||
tlvStream, err := tlv.NewStream(records...)
|
// Since ExtraOpaqueData is provided by a potentially malicious peer,
|
||||||
|
// pass it into the P2P decoding variant.
|
||||||
|
return DecodeRecordsP2P(extraBytesReader, records...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordProducers parses ExtraOpaqueData into a slice of TLV record producers
|
||||||
|
// by interpreting it as a TLV map.
|
||||||
|
func (e *ExtraOpaqueData) RecordProducers() ([]tlv.RecordProducer, error) {
|
||||||
|
var recordProducers []tlv.RecordProducer
|
||||||
|
|
||||||
|
// If the instance is nil or empty, return an empty slice.
|
||||||
|
if e == nil || len(*e) == 0 {
|
||||||
|
return recordProducers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the extra opaque data as a TLV map.
|
||||||
|
tlvMap, err := e.ExtractRecords()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Since ExtraOpaqueData is provided by a potentially malicious peer,
|
// Convert the TLV map into a slice of record producers.
|
||||||
// pass it into the P2P decoding variant.
|
records := TlvMapToRecords(tlvMap)
|
||||||
return tlvStream.DecodeWithParsedTypesP2P(extraBytesReader)
|
|
||||||
|
return RecordsAsProducers(records), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncodeMessageExtraData encodes the given recordProducers into the given
|
// EncodeMessageExtraData encodes the given recordProducers into the given
|
||||||
@@ -120,3 +134,116 @@ func EncodeMessageExtraData(extraData *ExtraOpaqueData,
|
|||||||
// are all properly sorted.
|
// are all properly sorted.
|
||||||
return extraData.PackRecords(recordProducers...)
|
return extraData.PackRecords(recordProducers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseAndExtractCustomRecords parses the given extra data into the passed-in
|
||||||
|
// records, then returns any remaining records split into custom records and
|
||||||
|
// extra data.
|
||||||
|
func ParseAndExtractCustomRecords(allExtraData ExtraOpaqueData,
|
||||||
|
knownRecords ...tlv.RecordProducer) (CustomRecords,
|
||||||
|
fn.Set[tlv.Type], ExtraOpaqueData, error) {
|
||||||
|
|
||||||
|
extraDataTlvMap, err := allExtraData.ExtractRecords(knownRecords...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the known and now extracted records from the leftover extra
|
||||||
|
// data map.
|
||||||
|
parsedKnownRecords := make(fn.Set[tlv.Type], len(knownRecords))
|
||||||
|
for _, producer := range knownRecords {
|
||||||
|
r := producer.Record()
|
||||||
|
|
||||||
|
// Only remove the records if it was parsed (remainder is nil).
|
||||||
|
// We'll just store the type so we can tell the caller which
|
||||||
|
// records were actually parsed fully.
|
||||||
|
val, ok := extraDataTlvMap[r.Type()]
|
||||||
|
if ok && val == nil {
|
||||||
|
parsedKnownRecords.Add(r.Type())
|
||||||
|
delete(extraDataTlvMap, r.Type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Any records from the extra data TLV map which are in the custom
|
||||||
|
// records TLV type range will be included in the custom records field
|
||||||
|
// and removed from the extra data field.
|
||||||
|
customRecordsTlvMap := make(tlv.TypeMap, len(extraDataTlvMap))
|
||||||
|
for k, v := range extraDataTlvMap {
|
||||||
|
// Skip records that are not in the custom records TLV type
|
||||||
|
// range.
|
||||||
|
if k < MinCustomRecordsTlvType {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include the record in the custom records map.
|
||||||
|
customRecordsTlvMap[k] = v
|
||||||
|
|
||||||
|
// Now that the record is included in the custom records map,
|
||||||
|
// we can remove it from the extra data TLV map.
|
||||||
|
delete(extraDataTlvMap, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the custom records field to the custom records specific TLV
|
||||||
|
// record map.
|
||||||
|
customRecords, err := NewCustomRecords(customRecordsTlvMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode the remaining records back into the extra data field. These
|
||||||
|
// records are not in the custom records TLV type range and do not
|
||||||
|
// have associated fields in the struct that produced the records.
|
||||||
|
extraData, err := NewExtraOpaqueData(extraDataTlvMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Help with unit testing where we might have the empty value (nil) for
|
||||||
|
// the extra data instead of the default that's returned by the
|
||||||
|
// constructor (empty slice).
|
||||||
|
if len(extraData) == 0 {
|
||||||
|
extraData = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return customRecords, parsedKnownRecords, extraData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeAndEncode merges the known records with the extra data and custom
|
||||||
|
// records, then encodes the merged records into raw bytes.
|
||||||
|
func MergeAndEncode(knownRecords []tlv.RecordProducer,
|
||||||
|
extraData ExtraOpaqueData, customRecords CustomRecords) ([]byte,
|
||||||
|
error) {
|
||||||
|
|
||||||
|
// Construct a slice of all the records that we should include in the
|
||||||
|
// message extra data field. We will start by including any records from
|
||||||
|
// the extra data field.
|
||||||
|
mergedRecords, err := extraData.RecordProducers()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge the known and extra data records.
|
||||||
|
mergedRecords = append(mergedRecords, knownRecords...)
|
||||||
|
|
||||||
|
// Include custom records in the extra data wire field if they are
|
||||||
|
// present. Ensure that the custom records are validated before encoding
|
||||||
|
// them.
|
||||||
|
if err := customRecords.Validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("custom records validation error: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend the message extra data records slice with TLV records from the
|
||||||
|
// custom records field.
|
||||||
|
mergedRecords = append(
|
||||||
|
mergedRecords, customRecords.RecordProducers()...,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Now we can sort the records and make sure there are no records with
|
||||||
|
// the same type that would collide when encoding.
|
||||||
|
sortedRecords := ProduceRecordsSorted(mergedRecords...)
|
||||||
|
if err := AssertUniqueTypes(sortedRecords); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return EncodeRecords(sortedRecords)
|
||||||
|
}
|
||||||
|
@@ -11,6 +11,12 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
tlvType1 tlv.TlvType1
|
||||||
|
tlvType2 tlv.TlvType2
|
||||||
|
tlvType3 tlv.TlvType3
|
||||||
|
)
|
||||||
|
|
||||||
// TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode
|
// TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode
|
||||||
// arbitrary payloads.
|
// arbitrary payloads.
|
||||||
func TestExtraOpaqueDataEncodeDecode(t *testing.T) {
|
func TestExtraOpaqueDataEncodeDecode(t *testing.T) {
|
||||||
@@ -153,21 +159,18 @@ func TestPackRecords(t *testing.T) {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
// Record type 1.
|
// Record type 1.
|
||||||
tlvType1 tlv.TlvType1
|
|
||||||
recordBytes1 = []byte("recordBytes1")
|
recordBytes1 = []byte("recordBytes1")
|
||||||
tlvRecord1 = tlv.NewPrimitiveRecord[tlv.TlvType1](
|
tlvRecord1 = tlv.NewPrimitiveRecord[tlv.TlvType1](
|
||||||
recordBytes1,
|
recordBytes1,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Record type 2.
|
// Record type 2.
|
||||||
tlvType2 tlv.TlvType2
|
|
||||||
recordBytes2 = []byte("recordBytes2")
|
recordBytes2 = []byte("recordBytes2")
|
||||||
tlvRecord2 = tlv.NewPrimitiveRecord[tlv.TlvType2](
|
tlvRecord2 = tlv.NewPrimitiveRecord[tlv.TlvType2](
|
||||||
recordBytes2,
|
recordBytes2,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Record type 3.
|
// Record type 3.
|
||||||
tlvType3 tlv.TlvType3
|
|
||||||
recordBytes3 = []byte("recordBytes3")
|
recordBytes3 = []byte("recordBytes3")
|
||||||
tlvRecord3 = tlv.NewPrimitiveRecord[tlv.TlvType3](
|
tlvRecord3 = tlv.NewPrimitiveRecord[tlv.TlvType3](
|
||||||
recordBytes3,
|
recordBytes3,
|
||||||
@@ -203,3 +206,308 @@ func TestPackRecords(t *testing.T) {
|
|||||||
require.Equal(t, recordBytes2, extractedRecords[tlvType2.TypeVal()])
|
require.Equal(t, recordBytes2, extractedRecords[tlvType2.TypeVal()])
|
||||||
require.Equal(t, recordBytes3, extractedRecords[tlvType3.TypeVal()])
|
require.Equal(t, recordBytes3, extractedRecords[tlvType3.TypeVal()])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dummyRecordProducer struct {
|
||||||
|
typ tlv.Type
|
||||||
|
scratchValue []byte
|
||||||
|
expectedValue []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dummyRecordProducer) Record() tlv.Record {
|
||||||
|
return tlv.MakePrimitiveRecord(d.typ, &d.scratchValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtraOpaqueData tests that we're able to properly encode/decode an
|
||||||
|
// ExtraOpaqueData instance.
|
||||||
|
func TestExtraOpaqueData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
types tlv.TypeMap
|
||||||
|
expectedData ExtraOpaqueData
|
||||||
|
expectedTypes tlv.TypeMap
|
||||||
|
decoders []tlv.RecordProducer
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty map",
|
||||||
|
expectedTypes: tlv.TypeMap{},
|
||||||
|
expectedData: make([]byte, 0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single record",
|
||||||
|
types: tlv.TypeMap{
|
||||||
|
tlvType1.TypeVal(): []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
expectedData: ExtraOpaqueData{
|
||||||
|
0x01, 0x03, 1, 2, 3,
|
||||||
|
},
|
||||||
|
expectedTypes: tlv.TypeMap{
|
||||||
|
tlvType1.TypeVal(): []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
decoders: []tlv.RecordProducer{
|
||||||
|
&dummyRecordProducer{
|
||||||
|
typ: tlvType1.TypeVal(),
|
||||||
|
expectedValue: []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple records",
|
||||||
|
types: tlv.TypeMap{
|
||||||
|
tlvType2.TypeVal(): []byte{4, 5, 6},
|
||||||
|
tlvType1.TypeVal(): []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
expectedData: ExtraOpaqueData{
|
||||||
|
0x01, 0x03, 1, 2, 3,
|
||||||
|
0x02, 0x03, 4, 5, 6,
|
||||||
|
},
|
||||||
|
expectedTypes: tlv.TypeMap{
|
||||||
|
tlvType1.TypeVal(): []byte{1, 2, 3},
|
||||||
|
tlvType2.TypeVal(): []byte{4, 5, 6},
|
||||||
|
},
|
||||||
|
decoders: []tlv.RecordProducer{
|
||||||
|
&dummyRecordProducer{
|
||||||
|
typ: tlvType1.TypeVal(),
|
||||||
|
expectedValue: []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
&dummyRecordProducer{
|
||||||
|
typ: tlvType2.TypeVal(),
|
||||||
|
expectedValue: []byte{4, 5, 6},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// First, test the constructor.
|
||||||
|
opaqueData, err := NewExtraOpaqueData(tc.types)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, tc.expectedData, opaqueData)
|
||||||
|
|
||||||
|
// Now encode/decode.
|
||||||
|
var b bytes.Buffer
|
||||||
|
err = opaqueData.Encode(&b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var decoded ExtraOpaqueData
|
||||||
|
err = decoded.Decode(&b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, opaqueData, decoded)
|
||||||
|
|
||||||
|
// Now RecordProducers/PackRecords.
|
||||||
|
producers, err := opaqueData.RecordProducers()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var packed ExtraOpaqueData
|
||||||
|
err = packed.PackRecords(producers...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// PackRecords returns nil vs. an empty slice if there
|
||||||
|
// are no records. We need to handle this case
|
||||||
|
// separately.
|
||||||
|
if len(producers) == 0 {
|
||||||
|
// Make sure the packed data is empty.
|
||||||
|
require.Empty(t, packed)
|
||||||
|
|
||||||
|
// Now change it to an empty slice for the
|
||||||
|
// comparison below.
|
||||||
|
packed = make([]byte, 0)
|
||||||
|
}
|
||||||
|
require.Equal(t, opaqueData, packed)
|
||||||
|
|
||||||
|
// ExtractRecords with an empty set of record producers
|
||||||
|
// should return the original type map.
|
||||||
|
extracted, err := opaqueData.ExtractRecords()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, tc.expectedTypes, extracted)
|
||||||
|
|
||||||
|
if len(tc.decoders) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractRecords with a set of record producers should
|
||||||
|
// only return the types that weren't in the passed-in
|
||||||
|
// set of producers.
|
||||||
|
extracted, err = opaqueData.ExtractRecords(
|
||||||
|
tc.decoders...,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for parsedType := range tc.expectedTypes {
|
||||||
|
remainder, ok := extracted[parsedType]
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Nil(t, remainder)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range tc.decoders {
|
||||||
|
//nolint:forcetypeassert
|
||||||
|
dec := dec.(*dummyRecordProducer)
|
||||||
|
require.Equal(
|
||||||
|
t, dec.expectedValue, dec.scratchValue,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractAndMerge tests that the ParseAndExtractCustomRecords and
|
||||||
|
// MergeAndEncode functions work as expected.
|
||||||
|
func TestExtractAndMerge(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
knownRecords []tlv.RecordProducer
|
||||||
|
extraData ExtraOpaqueData
|
||||||
|
customRecords CustomRecords
|
||||||
|
expectedErr string
|
||||||
|
expectEncoded []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid custom record",
|
||||||
|
customRecords: CustomRecords{
|
||||||
|
123: []byte("invalid"),
|
||||||
|
},
|
||||||
|
expectedErr: "custom records validation error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty everything",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "just extra data",
|
||||||
|
extraData: ExtraOpaqueData{
|
||||||
|
0x01, 0x03, 1, 2, 3,
|
||||||
|
0x02, 0x03, 4, 5, 6,
|
||||||
|
},
|
||||||
|
expectEncoded: []byte{
|
||||||
|
0x01, 0x03, 1, 2, 3,
|
||||||
|
0x02, 0x03, 4, 5, 6,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra data with known record",
|
||||||
|
extraData: ExtraOpaqueData{
|
||||||
|
0x04, 0x03, 4, 4, 4,
|
||||||
|
0x05, 0x03, 5, 5, 5,
|
||||||
|
},
|
||||||
|
knownRecords: []tlv.RecordProducer{
|
||||||
|
&dummyRecordProducer{
|
||||||
|
typ: tlvType1.TypeVal(),
|
||||||
|
scratchValue: []byte{1, 2, 3},
|
||||||
|
expectedValue: []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
&dummyRecordProducer{
|
||||||
|
typ: tlvType2.TypeVal(),
|
||||||
|
scratchValue: []byte{4, 5, 6},
|
||||||
|
expectedValue: []byte{4, 5, 6},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectEncoded: []byte{
|
||||||
|
0x01, 0x03, 1, 2, 3,
|
||||||
|
0x02, 0x03, 4, 5, 6,
|
||||||
|
0x04, 0x03, 4, 4, 4,
|
||||||
|
0x05, 0x03, 5, 5, 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra data and custom records with known record",
|
||||||
|
extraData: ExtraOpaqueData{
|
||||||
|
0x04, 0x03, 4, 4, 4,
|
||||||
|
0x05, 0x03, 5, 5, 5,
|
||||||
|
},
|
||||||
|
customRecords: CustomRecords{
|
||||||
|
MinCustomRecordsTlvType + 1: []byte{99, 99, 99},
|
||||||
|
},
|
||||||
|
knownRecords: []tlv.RecordProducer{
|
||||||
|
&dummyRecordProducer{
|
||||||
|
typ: tlvType1.TypeVal(),
|
||||||
|
scratchValue: []byte{1, 2, 3},
|
||||||
|
expectedValue: []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
&dummyRecordProducer{
|
||||||
|
typ: tlvType2.TypeVal(),
|
||||||
|
scratchValue: []byte{4, 5, 6},
|
||||||
|
expectedValue: []byte{4, 5, 6},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectEncoded: []byte{
|
||||||
|
0x01, 0x03, 1, 2, 3,
|
||||||
|
0x02, 0x03, 4, 5, 6,
|
||||||
|
0x04, 0x03, 4, 4, 4,
|
||||||
|
0x05, 0x03, 5, 5, 5,
|
||||||
|
0xfe, 0x0, 0x1, 0x0, 0x1, 0x3, 0x63, 0x63, 0x63,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicate records",
|
||||||
|
extraData: ExtraOpaqueData{
|
||||||
|
0x01, 0x03, 4, 4, 4,
|
||||||
|
0x05, 0x03, 5, 5, 5,
|
||||||
|
},
|
||||||
|
customRecords: CustomRecords{
|
||||||
|
MinCustomRecordsTlvType + 1: []byte{99, 99, 99},
|
||||||
|
},
|
||||||
|
knownRecords: []tlv.RecordProducer{
|
||||||
|
&dummyRecordProducer{
|
||||||
|
typ: tlvType1.TypeVal(),
|
||||||
|
scratchValue: []byte{1, 2, 3},
|
||||||
|
expectedValue: []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
&dummyRecordProducer{
|
||||||
|
typ: tlvType2.TypeVal(),
|
||||||
|
scratchValue: []byte{4, 5, 6},
|
||||||
|
expectedValue: []byte{4, 5, 6},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: "duplicate record type: 1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
encoded, err := MergeAndEncode(
|
||||||
|
tc.knownRecords, tc.extraData, tc.customRecords,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
require.ErrorContains(t, err, tc.expectedErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tc.expectEncoded, encoded)
|
||||||
|
|
||||||
|
// Clear all the scratch values, to make sure they're
|
||||||
|
// decoded from the data again.
|
||||||
|
for _, dec := range tc.knownRecords {
|
||||||
|
//nolint:forcetypeassert
|
||||||
|
dec := dec.(*dummyRecordProducer)
|
||||||
|
dec.scratchValue = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pCR, pKR, pED, err := ParseAndExtractCustomRecords(
|
||||||
|
encoded, tc.knownRecords...,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, tc.customRecords, pCR)
|
||||||
|
require.Equal(t, tc.extraData, pED)
|
||||||
|
|
||||||
|
for _, dec := range tc.knownRecords {
|
||||||
|
//nolint:forcetypeassert
|
||||||
|
dec := dec.(*dummyRecordProducer)
|
||||||
|
require.Equal(
|
||||||
|
t, dec.expectedValue, dec.scratchValue,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Contains(t, pKR, dec.typ)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user