From 58ed8e751d2c855918c1a0c15f70aab75ef2f6c6 Mon Sep 17 00:00:00 2001 From: ffranr Date: Mon, 29 Apr 2024 11:28:22 +0100 Subject: [PATCH] lnwire: add type `CustomRecords` This commit introduces the `CustomRecords` type in the `lnwire` package, designed to hold arbitrary byte slices. Each entry in this map can associate with TLV type values that are greater than or equal to 65536. --- lnwire/custom_records.go | 176 ++++++++++++++++++++++++++++++ lnwire/custom_records_test.go | 198 ++++++++++++++++++++++++++++++++++ lnwire/encoding.go | 16 +++ lnwire/extra_bytes_test.go | 8 -- 4 files changed, 390 insertions(+), 8 deletions(-) create mode 100644 lnwire/custom_records.go create mode 100644 lnwire/custom_records_test.go diff --git a/lnwire/custom_records.go b/lnwire/custom_records.go new file mode 100644 index 000000000..1e1988842 --- /dev/null +++ b/lnwire/custom_records.go @@ -0,0 +1,176 @@ +package lnwire + +import ( + "bytes" + "fmt" + "sort" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // MinCustomRecordsTlvType is the minimum custom records TLV type as + // defined in BOLT 01. + MinCustomRecordsTlvType = 65536 +) + +// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types +// which must be greater than or equal to MinCustomRecordsTlvType. +type CustomRecords map[uint64][]byte + +// NewCustomRecords creates a new CustomRecords instance from a +// tlv.TypeMap. +func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) { + // Make comparisons in unit tests easy by returning nil if the map is + // empty. + if len(tlvMap) == 0 { + return nil, nil + } + + customRecords := make(CustomRecords, len(tlvMap)) + for k, v := range tlvMap { + customRecords[uint64(k)] = v + } + + // Validate the custom records. + err := customRecords.Validate() + if err != nil { + return nil, fmt.Errorf("custom records from tlv map "+ + "validation error: %w", err) + } + + return customRecords, nil +} + +// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob. +func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) { + stream, err := tlv.NewStream() + if err != nil { + return nil, fmt.Errorf("error creating stream: %w", err) + } + + typeMap, err := stream.DecodeWithParsedTypes(bytes.NewReader(b)) + if err != nil { + return nil, fmt.Errorf("error decoding HTLC record: %w", err) + } + + return NewCustomRecords(typeMap) +} + +// Validate checks that all custom records are in the custom type range. +func (c CustomRecords) Validate() error { + if c == nil { + return nil + } + + for key := range c { + if key < MinCustomRecordsTlvType { + return fmt.Errorf("custom records entry with TLV "+ + "type below min: %d", MinCustomRecordsTlvType) + } + } + + return nil +} + +// Copy returns a copy of the custom records. +func (c CustomRecords) Copy() CustomRecords { + if c == nil { + return nil + } + + customRecords := make(CustomRecords, len(c)) + for k, v := range c { + customRecords[k] = v + } + + return customRecords +} + +// ExtendRecordProducers extends the given records slice with the custom +// records. The resultant records slice will be sorted if the given records +// slice contains TLV types greater than or equal to MinCustomRecordsTlvType. +func (c CustomRecords) ExtendRecordProducers( + producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) { + + // If the custom records are nil or empty, there is nothing to do. + if len(c) == 0 { + return producers, nil + } + + // Validate the custom records. + err := c.Validate() + if err != nil { + return nil, err + } + + // Ensure that the existing records slice TLV types are not also present + // in the custom records. If they are, the resultant extended records + // slice would erroneously contain duplicate TLV types. + for _, rp := range producers { + record := rp.Record() + recordTlvType := uint64(record.Type()) + + _, foundDuplicateTlvType := c[recordTlvType] + if foundDuplicateTlvType { + return nil, fmt.Errorf("custom records contains a TLV "+ + "type that is already present in the "+ + "existing records: %d", recordTlvType) + } + } + + // Convert the custom records map to a TLV record producer slice and + // append them to the exiting records slice. + crRecords := tlv.MapToRecords(c) + for _, record := range crRecords { + r := recordProducer{record} + producers = append(producers, &r) + } + + // If the records slice which was given as an argument included TLV + // values greater than or equal to the minimum custom records TLV type + // we will sort the extended records slice to ensure that it is ordered + // correctly. + sort.Slice(producers, func(i, j int) bool { + recordI := producers[i].Record() + recordJ := producers[j].Record() + return recordI.Type() < recordJ.Type() + }) + + return producers, nil +} + +// RecordProducers returns a slice of record producers for the custom records. +func (c CustomRecords) RecordProducers() []tlv.RecordProducer { + // If the custom records are nil or empty, return an empty slice. + if len(c) == 0 { + return nil + } + + // Convert the custom records map to a TLV record producer slice. + records := tlv.MapToRecords(c) + + // Convert the records to record producers. + producers := make([]tlv.RecordProducer, len(records)) + for i, record := range records { + producers[i] = &recordProducer{record} + } + + return producers +} + +// Serialize serializes the custom records into a byte slice. +func (c CustomRecords) Serialize() ([]byte, error) { + records := tlv.MapToRecords(c) + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, fmt.Errorf("error creating stream: %w", err) + } + + var b bytes.Buffer + if err := stream.Encode(&b); err != nil { + return nil, fmt.Errorf("error encoding custom records: %w", err) + } + + return b.Bytes(), nil +} diff --git a/lnwire/custom_records_test.go b/lnwire/custom_records_test.go new file mode 100644 index 000000000..0338d0159 --- /dev/null +++ b/lnwire/custom_records_test.go @@ -0,0 +1,198 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestCustomRecords tests the custom records serialization and deserialization, +// as well as copying and producing records. +func TestCustomRecords(t *testing.T) { + testCases := []struct { + name string + customTypes tlv.TypeMap + expectedRecords CustomRecords + expectedErr string + }{ + { + name: "empty custom records", + customTypes: tlv.TypeMap{}, + expectedRecords: nil, + }, + { + name: "custom record with invalid type", + customTypes: tlv.TypeMap{ + 123: []byte{1, 2, 3}, + }, + expectedErr: "TLV type below min: 65536", + }, + { + name: "valid custom record", + customTypes: tlv.TypeMap{ + 65536: []byte{1, 2, 3}, + }, + expectedRecords: map[uint64][]byte{ + 65536: {1, 2, 3}, + }, + }, + { + name: "valid custom records, wrong order", + customTypes: tlv.TypeMap{ + 65537: []byte{3, 4, 5}, + 65536: []byte{1, 2, 3}, + }, + expectedRecords: map[uint64][]byte{ + 65536: {1, 2, 3}, + 65537: {3, 4, 5}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + records, err := NewCustomRecords(tc.customTypes) + + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + return + } + + require.NoError(t, err) + require.Equal(t, tc.expectedRecords, records) + + // Serialize, then parse the records again. + blob, err := records.Serialize() + require.NoError(t, err) + + parsedRecords, err := ParseCustomRecords(blob) + require.NoError(t, err) + + require.Equal(t, tc.expectedRecords, parsedRecords) + + // Copy() should also return the same records. + require.Equal( + t, tc.expectedRecords, parsedRecords.Copy(), + ) + + // RecordProducers() should also allow us to serialize + // the records again. + serializedProducers := serializeRecordProducers( + t, parsedRecords.RecordProducers(), + ) + + require.Equal(t, blob, serializedProducers) + }) + } +} + +// TestCustomRecordsExtendRecordProducers tests that we can extend a slice of +// record producers with custom records. +func TestCustomRecordsExtendRecordProducers(t *testing.T) { + testCases := []struct { + name string + existingTypes map[uint64][]byte + customRecords CustomRecords + expectedResult tlv.TypeMap + expectedErr string + }{ + { + name: "normal merge", + existingTypes: map[uint64][]byte{ + 123: {3, 4, 5}, + 345: {1, 2, 3}, + }, + customRecords: CustomRecords{ + 65536: {1, 2, 3}, + }, + expectedResult: tlv.TypeMap{ + 123: {3, 4, 5}, + 345: {1, 2, 3}, + 65536: {1, 2, 3}, + }, + }, + { + name: "duplicates", + existingTypes: map[uint64][]byte{ + 123: {3, 4, 5}, + 345: {1, 2, 3}, + 65536: {1, 2, 3}, + }, + customRecords: CustomRecords{ + 65536: {1, 2, 3}, + }, + expectedErr: "contains a TLV type that is already " + + "present in the existing records: 65536", + }, + { + name: "non custom type in custom records", + existingTypes: map[uint64][]byte{ + 123: {3, 4, 5}, + 345: {1, 2, 3}, + 65536: {1, 2, 3}, + }, + customRecords: CustomRecords{ + 123: {1, 2, 3}, + }, + expectedErr: "TLV type below min: 65536", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + nonCustomRecords := tlv.MapToRecords(tc.existingTypes) + nonCustomProducers := fn.Map( + func(r tlv.Record) tlv.RecordProducer { + return &recordProducer{r} + }, nonCustomRecords, + ) + + combined, err := tc.customRecords.ExtendRecordProducers( + nonCustomProducers, + ) + + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + return + } + + require.NoError(t, err) + + serializedProducers := serializeRecordProducers( + t, combined, + ) + + stream, err := tlv.NewStream() + require.NoError(t, err) + + parsedMap, err := stream.DecodeWithParsedTypes( + bytes.NewReader(serializedProducers), + ) + require.NoError(t, err) + + require.Equal(t, tc.expectedResult, parsedMap) + }) + } +} + +// serializeRecordProducers is a helper function that serializes a slice of +// record producers into a byte slice. +func serializeRecordProducers(t *testing.T, + producers []tlv.RecordProducer) []byte { + + tlvRecords := fn.Map(func(p tlv.RecordProducer) tlv.Record { + return p.Record() + }, producers) + + stream, err := tlv.NewStream(tlvRecords...) + require.NoError(t, err) + + var b bytes.Buffer + err = stream.Encode(&b) + require.NoError(t, err) + + return b.Bytes() +} diff --git a/lnwire/encoding.go b/lnwire/encoding.go index e04b2b01d..72000f81d 100644 --- a/lnwire/encoding.go +++ b/lnwire/encoding.go @@ -1,5 +1,7 @@ package lnwire +import "github.com/lightningnetwork/lnd/tlv" + // QueryEncoding is an enum-like type that represents exactly how a set data is // encoded on the wire. type QueryEncoding uint8 @@ -15,3 +17,17 @@ const ( // NOTE: this should no longer be used or accepted. EncodingSortedZlib QueryEncoding = 1 ) + +// recordProducer is a simple helper struct that implements the +// tlv.RecordProducer interface. +type recordProducer struct { + record tlv.Record +} + +// Record returns the underlying record. +func (r *recordProducer) Record() tlv.Record { + return r.record +} + +// Ensure that recordProducer implements the tlv.RecordProducer interface. +var _ tlv.RecordProducer = (*recordProducer)(nil) diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go index fd9f28841..97a908bce 100644 --- a/lnwire/extra_bytes_test.go +++ b/lnwire/extra_bytes_test.go @@ -86,14 +86,6 @@ func TestExtraOpaqueDataEncodeDecode(t *testing.T) { } } -type recordProducer struct { - record tlv.Record -} - -func (r *recordProducer) Record() tlv.Record { - return r.record -} - // TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of // tlv.Records into a stream, and unpack them on the other side to obtain the // same set of records.