lnwire: add type CustomRecords

This commit introduces the `CustomRecords` type in the `lnwire` package,
designed to hold arbitrary byte slices. Each entry in this map can
associate with TLV type values that are greater than or equal to 65536.
This commit is contained in:
ffranr 2024-04-29 11:28:22 +01:00
parent 8b1d9c9248
commit 71c32511dd
No known key found for this signature in database
GPG Key ID: B1F8848557AA29D2
3 changed files with 185 additions and 8 deletions

169
lnwire/custom_records.go Normal file
View File

@ -0,0 +1,169 @@
package lnwire
import (
"bytes"
"fmt"
"sort"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// MinCustomRecordsTlvType is the minimum custom records TLV type as
// defined in BOLT 01.
MinCustomRecordsTlvType = 65536
)
// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types
// which must be greater than or equal to MinCustomRecordsTlvType.
type CustomRecords map[uint64][]byte
// NewCustomRecordsFromTlvTypeMap creates a new CustomRecords instance from a
// tlv.TypeMap.
func NewCustomRecordsFromTlvTypeMap(tlvMap tlv.TypeMap) (CustomRecords,
error) {
customRecords := make(CustomRecords, len(tlvMap))
for k, v := range tlvMap {
customRecords[uint64(k)] = v
}
// Validate the custom records.
err := customRecords.Validate()
if err != nil {
return nil, fmt.Errorf("custom records from tlv map "+
"validation error: %v", err)
}
return customRecords, nil
}
// NewCustomRecordsFromTlvBlob creates a new CustomRecords instance from a
// tlv.Blob.
func NewCustomRecordsFromTlvBlob(b tlv.Blob) (CustomRecords, error) {
stream, err := tlv.NewStream()
if err != nil {
return nil, fmt.Errorf("error creating stream: %w", err)
}
typeMap, err := stream.DecodeWithParsedTypes(bytes.NewReader(b))
if err != nil {
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
}
return NewCustomRecordsFromTlvTypeMap(typeMap)
}
// Validate checks that all custom records are in the custom type range.
func (c CustomRecords) Validate() error {
if c == nil {
return nil
}
for key := range c {
if key < MinCustomRecordsTlvType {
return fmt.Errorf("custom records entry with TLV "+
"type below min: %d", MinCustomRecordsTlvType)
}
}
return nil
}
// Copy returns a copy of the custom records.
func (c CustomRecords) Copy() CustomRecords {
customRecords := make(CustomRecords, len(c))
for k, v := range c {
customRecords[k] = v
}
return customRecords
}
// ExtendRecordProducers extends the given records slice with the custom
// records. The resultant records slice will be sorted if the given records
// slice contains TLV types greater than or equal to MinCustomRecordsTlvType.
func (c CustomRecords) ExtendRecordProducers(
producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) {
// If the custom records are nil or empty, there is nothing to do.
if len(c) == 0 {
return producers, nil
}
// Validate the custom records.
err := c.Validate()
if err != nil {
return nil, err
}
// Ensure that the existing records slice TLV types are not also present
// in the custom records. If they are, the resultant extended records
// slice would erroneously contain duplicate TLV types.
for _, rp := range producers {
record := rp.Record()
recordTlvType := uint64(record.Type())
_, foundDuplicateTlvType := c[recordTlvType]
if foundDuplicateTlvType {
return nil, fmt.Errorf("custom records contains a TLV "+
"type that is already present in the "+
"existing records: %d", recordTlvType)
}
}
// Convert the custom records map to a TLV record producer slice and
// append them to the exiting records slice.
crRecords := tlv.MapToRecords(c)
for _, record := range crRecords {
r := recordProducer{record}
producers = append(producers, &r)
}
// If the records slice which was given as an argument included TLV
// values greater than or equal to the minimum custom records TLV type
// we will sort the extended records slice to ensure that it is ordered
// correctly.
sort.Slice(producers, func(i, j int) bool {
recordI := producers[i].Record()
recordJ := producers[j].Record()
return recordI.Type() < recordJ.Type()
})
return producers, nil
}
// RecordProducers returns a slice of record producers for the custom records.
func (c CustomRecords) RecordProducers() []tlv.RecordProducer {
// If the custom records are nil or empty, return an empty slice.
if len(c) == 0 {
return nil
}
// Convert the custom records map to a TLV record producer slice.
records := tlv.MapToRecords(c)
// Convert the records to record producers.
producers := make([]tlv.RecordProducer, len(records))
for i, record := range records {
producers[i] = &recordProducer{record}
}
return producers
}
// Serialize serializes the custom records into a byte slice.
func (c CustomRecords) Serialize() ([]byte, error) {
records := tlv.MapToRecords(c)
stream, err := tlv.NewStream(records...)
if err != nil {
return nil, fmt.Errorf("error creating stream: %w", err)
}
var b bytes.Buffer
if err := stream.Encode(&b); err != nil {
return nil, fmt.Errorf("error encoding custom records: %w", err)
}
return b.Bytes(), nil
}

View File

@ -1,5 +1,7 @@
package lnwire
import "github.com/lightningnetwork/lnd/tlv"
// QueryEncoding is an enum-like type that represents exactly how a set data is
// encoded on the wire.
type QueryEncoding uint8
@ -15,3 +17,17 @@ const (
// NOTE: this should no longer be used or accepted.
EncodingSortedZlib QueryEncoding = 1
)
// recordProducer is a simple helper struct that implements the
// tlv.RecordProducer interface.
type recordProducer struct {
record tlv.Record
}
// Record returns the underlying record.
func (r *recordProducer) Record() tlv.Record {
return r.record
}
// Ensure that recordProducer implements the tlv.RecordProducer interface.
var _ tlv.RecordProducer = (*recordProducer)(nil)

View File

@ -86,14 +86,6 @@ func TestExtraOpaqueDataEncodeDecode(t *testing.T) {
}
}
type recordProducer struct {
record tlv.Record
}
func (r *recordProducer) Record() tlv.Record {
return r.record
}
// TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of
// tlv.Records into a stream, and unpack them on the other side to obtain the
// same set of records.