From 82ae0220c86585b7183f3a20a493d8c9b4578a18 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 9 Sep 2024 11:18:56 +0200 Subject: [PATCH] lnwire21: add custom records parsing We add the new custom records encoding/decoding logic to the "frozen" lnwire21 package. We can do this because nothing uses this logic yet. If the custom records logic changes, the changes should _not_ be added to the lnwire21 version. --- .../migration/lnwire21/custom_records.go | 263 ++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 channeldb/migration/lnwire21/custom_records.go diff --git a/channeldb/migration/lnwire21/custom_records.go b/channeldb/migration/lnwire21/custom_records.go new file mode 100644 index 000000000..f0f59185e --- /dev/null +++ b/channeldb/migration/lnwire21/custom_records.go @@ -0,0 +1,263 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + "sort" + + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // MinCustomRecordsTlvType is the minimum custom records TLV type as + // defined in BOLT 01. + MinCustomRecordsTlvType = 65536 +) + +// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types +// which must be greater than or equal to MinCustomRecordsTlvType. +type CustomRecords map[uint64][]byte + +// NewCustomRecords creates a new CustomRecords instance from a +// tlv.TypeMap. +func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) { + // Make comparisons in unit tests easy by returning nil if the map is + // empty. + if len(tlvMap) == 0 { + return nil, nil + } + + customRecords := make(CustomRecords, len(tlvMap)) + for k, v := range tlvMap { + customRecords[uint64(k)] = v + } + + // Validate the custom records. + err := customRecords.Validate() + if err != nil { + return nil, fmt.Errorf("custom records from tlv map "+ + "validation error: %w", err) + } + + return customRecords, nil +} + +// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob. +func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) { + return ParseCustomRecordsFrom(bytes.NewReader(b)) +} + +// ParseCustomRecordsFrom creates a new CustomRecords instance from a reader. +func ParseCustomRecordsFrom(r io.Reader) (CustomRecords, error) { + typeMap, err := DecodeRecords(r) + if err != nil { + return nil, fmt.Errorf("error decoding HTLC record: %w", err) + } + + return NewCustomRecords(typeMap) +} + +// Validate checks that all custom records are in the custom type range. +func (c CustomRecords) Validate() error { + if c == nil { + return nil + } + + for key := range c { + if key < MinCustomRecordsTlvType { + return fmt.Errorf("custom records entry with TLV "+ + "type below min: %d", MinCustomRecordsTlvType) + } + } + + return nil +} + +// Copy returns a copy of the custom records. +func (c CustomRecords) Copy() CustomRecords { + if c == nil { + return nil + } + + customRecords := make(CustomRecords, len(c)) + for k, v := range c { + customRecords[k] = v + } + + return customRecords +} + +// ExtendRecordProducers extends the given records slice with the custom +// records. The resultant records slice will be sorted if the given records +// slice contains TLV types greater than or equal to MinCustomRecordsTlvType. +func (c CustomRecords) ExtendRecordProducers( + producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) { + + // If the custom records are nil or empty, there is nothing to do. + if len(c) == 0 { + return producers, nil + } + + // Validate the custom records. + err := c.Validate() + if err != nil { + return nil, err + } + + // Ensure that the existing records slice TLV types are not also present + // in the custom records. If they are, the resultant extended records + // slice would erroneously contain duplicate TLV types. + for _, rp := range producers { + record := rp.Record() + recordTlvType := uint64(record.Type()) + + _, foundDuplicateTlvType := c[recordTlvType] + if foundDuplicateTlvType { + return nil, fmt.Errorf("custom records contains a TLV "+ + "type that is already present in the "+ + "existing records: %d", recordTlvType) + } + } + + // Convert the custom records map to a TLV record producer slice and + // append them to the exiting records slice. + customRecordProducers := RecordsAsProducers(tlv.MapToRecords(c)) + producers = append(producers, customRecordProducers...) + + // If the records slice which was given as an argument included TLV + // values greater than or equal to the minimum custom records TLV type + // we will sort the extended records slice to ensure that it is ordered + // correctly. + SortProducers(producers) + + return producers, nil +} + +// RecordProducers returns a slice of record producers for the custom records. +func (c CustomRecords) RecordProducers() []tlv.RecordProducer { + // If the custom records are nil or empty, return an empty slice. + if len(c) == 0 { + return nil + } + + // Convert the custom records map to a TLV record producer slice. + records := tlv.MapToRecords(c) + + return RecordsAsProducers(records) +} + +// Serialize serializes the custom records into a byte slice. +func (c CustomRecords) Serialize() ([]byte, error) { + records := tlv.MapToRecords(c) + return EncodeRecords(records) +} + +// SerializeTo serializes the custom records into the given writer. +func (c CustomRecords) SerializeTo(w io.Writer) error { + records := tlv.MapToRecords(c) + return EncodeRecordsTo(w, records) +} + +// ProduceRecordsSorted converts a slice of record producers into a slice of +// records and then sorts it by type. +func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record { + records := fn.Map(func(producer tlv.RecordProducer) tlv.Record { + return producer.Record() + }, recordProducers) + + // Ensure that the set of records are sorted before we attempt to + // decode from the stream, to ensure they're canonical. + tlv.SortRecords(records) + + return records +} + +// SortProducers sorts the given record producers by their type. +func SortProducers(producers []tlv.RecordProducer) { + sort.Slice(producers, func(i, j int) bool { + recordI := producers[i].Record() + recordJ := producers[j].Record() + return recordI.Type() < recordJ.Type() + }) +} + +// TlvMapToRecords converts a TLV map into a slice of records. +func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record { + tlvMapGeneric := make(map[uint64][]byte) + for k, v := range tlvMap { + tlvMapGeneric[uint64(k)] = v + } + + return tlv.MapToRecords(tlvMapGeneric) +} + +// RecordsAsProducers converts a slice of records into a slice of record +// producers. +func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer { + return fn.Map(func(record tlv.Record) tlv.RecordProducer { + return &record + }, records) +} + +// EncodeRecords encodes the given records into a byte slice. +func EncodeRecords(records []tlv.Record) ([]byte, error) { + var buf bytes.Buffer + if err := EncodeRecordsTo(&buf, records); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// EncodeRecordsTo encodes the given records into the given writer. +func EncodeRecordsTo(w io.Writer, records []tlv.Record) error { + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// DecodeRecords decodes the given byte slice into the given records and returns +// the rest as a TLV type map. +func DecodeRecords(r io.Reader, + records ...tlv.Record) (tlv.TypeMap, error) { + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + return tlvStream.DecodeWithParsedTypes(r) +} + +// DecodeRecordsP2P decodes the given byte slice into the given records and +// returns the rest as a TLV type map. This function is identical to +// DecodeRecords except that the record size is capped at 65535. +func DecodeRecordsP2P(r *bytes.Reader, + records ...tlv.Record) (tlv.TypeMap, error) { + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + return tlvStream.DecodeWithParsedTypesP2P(r) +} + +// AssertUniqueTypes asserts that the given records have unique types. +func AssertUniqueTypes(r []tlv.Record) error { + seen := make(fn.Set[tlv.Type], len(r)) + for _, record := range r { + t := record.Type() + if seen.Contains(t) { + return fmt.Errorf("duplicate record type: %d", t) + } + seen.Add(t) + } + + return nil +}