mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-26 13:42:49 +02:00
lnwire: add type CustomRecords
This commit introduces the `CustomRecords` type in the `lnwire` package, designed to hold arbitrary byte slices. Each entry in this map can associate with TLV type values that are greater than or equal to 65536.
This commit is contained in:
176
lnwire/custom_records.go
Normal file
176
lnwire/custom_records.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package lnwire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MinCustomRecordsTlvType is the minimum custom records TLV type as
|
||||||
|
// defined in BOLT 01.
|
||||||
|
MinCustomRecordsTlvType = 65536
|
||||||
|
)
|
||||||
|
|
||||||
|
// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types
|
||||||
|
// which must be greater than or equal to MinCustomRecordsTlvType.
|
||||||
|
type CustomRecords map[uint64][]byte
|
||||||
|
|
||||||
|
// NewCustomRecords creates a new CustomRecords instance from a
|
||||||
|
// tlv.TypeMap.
|
||||||
|
func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) {
|
||||||
|
// Make comparisons in unit tests easy by returning nil if the map is
|
||||||
|
// empty.
|
||||||
|
if len(tlvMap) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
customRecords := make(CustomRecords, len(tlvMap))
|
||||||
|
for k, v := range tlvMap {
|
||||||
|
customRecords[uint64(k)] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the custom records.
|
||||||
|
err := customRecords.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("custom records from tlv map "+
|
||||||
|
"validation error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return customRecords, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob.
|
||||||
|
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {
|
||||||
|
stream, err := tlv.NewStream()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
typeMap, err := stream.DecodeWithParsedTypes(bytes.NewReader(b))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewCustomRecords(typeMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks that all custom records are in the custom type range.
|
||||||
|
func (c CustomRecords) Validate() error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range c {
|
||||||
|
if key < MinCustomRecordsTlvType {
|
||||||
|
return fmt.Errorf("custom records entry with TLV "+
|
||||||
|
"type below min: %d", MinCustomRecordsTlvType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy returns a copy of the custom records.
|
||||||
|
func (c CustomRecords) Copy() CustomRecords {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
customRecords := make(CustomRecords, len(c))
|
||||||
|
for k, v := range c {
|
||||||
|
customRecords[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return customRecords
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtendRecordProducers extends the given records slice with the custom
|
||||||
|
// records. The resultant records slice will be sorted if the given records
|
||||||
|
// slice contains TLV types greater than or equal to MinCustomRecordsTlvType.
|
||||||
|
func (c CustomRecords) ExtendRecordProducers(
|
||||||
|
producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) {
|
||||||
|
|
||||||
|
// If the custom records are nil or empty, there is nothing to do.
|
||||||
|
if len(c) == 0 {
|
||||||
|
return producers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the custom records.
|
||||||
|
err := c.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that the existing records slice TLV types are not also present
|
||||||
|
// in the custom records. If they are, the resultant extended records
|
||||||
|
// slice would erroneously contain duplicate TLV types.
|
||||||
|
for _, rp := range producers {
|
||||||
|
record := rp.Record()
|
||||||
|
recordTlvType := uint64(record.Type())
|
||||||
|
|
||||||
|
_, foundDuplicateTlvType := c[recordTlvType]
|
||||||
|
if foundDuplicateTlvType {
|
||||||
|
return nil, fmt.Errorf("custom records contains a TLV "+
|
||||||
|
"type that is already present in the "+
|
||||||
|
"existing records: %d", recordTlvType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the custom records map to a TLV record producer slice and
|
||||||
|
// append them to the exiting records slice.
|
||||||
|
crRecords := tlv.MapToRecords(c)
|
||||||
|
for _, record := range crRecords {
|
||||||
|
r := recordProducer{record}
|
||||||
|
producers = append(producers, &r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the records slice which was given as an argument included TLV
|
||||||
|
// values greater than or equal to the minimum custom records TLV type
|
||||||
|
// we will sort the extended records slice to ensure that it is ordered
|
||||||
|
// correctly.
|
||||||
|
sort.Slice(producers, func(i, j int) bool {
|
||||||
|
recordI := producers[i].Record()
|
||||||
|
recordJ := producers[j].Record()
|
||||||
|
return recordI.Type() < recordJ.Type()
|
||||||
|
})
|
||||||
|
|
||||||
|
return producers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordProducers returns a slice of record producers for the custom records.
|
||||||
|
func (c CustomRecords) RecordProducers() []tlv.RecordProducer {
|
||||||
|
// If the custom records are nil or empty, return an empty slice.
|
||||||
|
if len(c) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the custom records map to a TLV record producer slice.
|
||||||
|
records := tlv.MapToRecords(c)
|
||||||
|
|
||||||
|
// Convert the records to record producers.
|
||||||
|
producers := make([]tlv.RecordProducer, len(records))
|
||||||
|
for i, record := range records {
|
||||||
|
producers[i] = &recordProducer{record}
|
||||||
|
}
|
||||||
|
|
||||||
|
return producers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize serializes the custom records into a byte slice.
|
||||||
|
func (c CustomRecords) Serialize() ([]byte, error) {
|
||||||
|
records := tlv.MapToRecords(c)
|
||||||
|
stream, err := tlv.NewStream(records...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := stream.Encode(&b); err != nil {
|
||||||
|
return nil, fmt.Errorf("error encoding custom records: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.Bytes(), nil
|
||||||
|
}
|
198
lnwire/custom_records_test.go
Normal file
198
lnwire/custom_records_test.go
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
package lnwire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/fn"
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestCustomRecords tests the custom records serialization and deserialization,
|
||||||
|
// as well as copying and producing records.
|
||||||
|
func TestCustomRecords(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
customTypes tlv.TypeMap
|
||||||
|
expectedRecords CustomRecords
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty custom records",
|
||||||
|
customTypes: tlv.TypeMap{},
|
||||||
|
expectedRecords: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom record with invalid type",
|
||||||
|
customTypes: tlv.TypeMap{
|
||||||
|
123: []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
expectedErr: "TLV type below min: 65536",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid custom record",
|
||||||
|
customTypes: tlv.TypeMap{
|
||||||
|
65536: []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
expectedRecords: map[uint64][]byte{
|
||||||
|
65536: {1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid custom records, wrong order",
|
||||||
|
customTypes: tlv.TypeMap{
|
||||||
|
65537: []byte{3, 4, 5},
|
||||||
|
65536: []byte{1, 2, 3},
|
||||||
|
},
|
||||||
|
expectedRecords: map[uint64][]byte{
|
||||||
|
65536: {1, 2, 3},
|
||||||
|
65537: {3, 4, 5},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
records, err := NewCustomRecords(tc.customTypes)
|
||||||
|
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
require.ErrorContains(t, err, tc.expectedErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tc.expectedRecords, records)
|
||||||
|
|
||||||
|
// Serialize, then parse the records again.
|
||||||
|
blob, err := records.Serialize()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedRecords, err := ParseCustomRecords(blob)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, tc.expectedRecords, parsedRecords)
|
||||||
|
|
||||||
|
// Copy() should also return the same records.
|
||||||
|
require.Equal(
|
||||||
|
t, tc.expectedRecords, parsedRecords.Copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecordProducers() should also allow us to serialize
|
||||||
|
// the records again.
|
||||||
|
serializedProducers := serializeRecordProducers(
|
||||||
|
t, parsedRecords.RecordProducers(),
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Equal(t, blob, serializedProducers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCustomRecordsExtendRecordProducers tests that we can extend a slice of
|
||||||
|
// record producers with custom records.
|
||||||
|
func TestCustomRecordsExtendRecordProducers(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
existingTypes map[uint64][]byte
|
||||||
|
customRecords CustomRecords
|
||||||
|
expectedResult tlv.TypeMap
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal merge",
|
||||||
|
existingTypes: map[uint64][]byte{
|
||||||
|
123: {3, 4, 5},
|
||||||
|
345: {1, 2, 3},
|
||||||
|
},
|
||||||
|
customRecords: CustomRecords{
|
||||||
|
65536: {1, 2, 3},
|
||||||
|
},
|
||||||
|
expectedResult: tlv.TypeMap{
|
||||||
|
123: {3, 4, 5},
|
||||||
|
345: {1, 2, 3},
|
||||||
|
65536: {1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicates",
|
||||||
|
existingTypes: map[uint64][]byte{
|
||||||
|
123: {3, 4, 5},
|
||||||
|
345: {1, 2, 3},
|
||||||
|
65536: {1, 2, 3},
|
||||||
|
},
|
||||||
|
customRecords: CustomRecords{
|
||||||
|
65536: {1, 2, 3},
|
||||||
|
},
|
||||||
|
expectedErr: "contains a TLV type that is already " +
|
||||||
|
"present in the existing records: 65536",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non custom type in custom records",
|
||||||
|
existingTypes: map[uint64][]byte{
|
||||||
|
123: {3, 4, 5},
|
||||||
|
345: {1, 2, 3},
|
||||||
|
65536: {1, 2, 3},
|
||||||
|
},
|
||||||
|
customRecords: CustomRecords{
|
||||||
|
123: {1, 2, 3},
|
||||||
|
},
|
||||||
|
expectedErr: "TLV type below min: 65536",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
nonCustomRecords := tlv.MapToRecords(tc.existingTypes)
|
||||||
|
nonCustomProducers := fn.Map(
|
||||||
|
func(r tlv.Record) tlv.RecordProducer {
|
||||||
|
return &recordProducer{r}
|
||||||
|
}, nonCustomRecords,
|
||||||
|
)
|
||||||
|
|
||||||
|
combined, err := tc.customRecords.ExtendRecordProducers(
|
||||||
|
nonCustomProducers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
require.ErrorContains(t, err, tc.expectedErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serializedProducers := serializeRecordProducers(
|
||||||
|
t, combined,
|
||||||
|
)
|
||||||
|
|
||||||
|
stream, err := tlv.NewStream()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedMap, err := stream.DecodeWithParsedTypes(
|
||||||
|
bytes.NewReader(serializedProducers),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, tc.expectedResult, parsedMap)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// serializeRecordProducers is a helper function that serializes a slice of
|
||||||
|
// record producers into a byte slice.
|
||||||
|
func serializeRecordProducers(t *testing.T,
|
||||||
|
producers []tlv.RecordProducer) []byte {
|
||||||
|
|
||||||
|
tlvRecords := fn.Map(func(p tlv.RecordProducer) tlv.Record {
|
||||||
|
return p.Record()
|
||||||
|
}, producers)
|
||||||
|
|
||||||
|
stream, err := tlv.NewStream(tlvRecords...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
err = stream.Encode(&b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
@@ -1,5 +1,7 @@
|
|||||||
package lnwire
|
package lnwire
|
||||||
|
|
||||||
|
import "github.com/lightningnetwork/lnd/tlv"
|
||||||
|
|
||||||
// QueryEncoding is an enum-like type that represents exactly how a set data is
|
// QueryEncoding is an enum-like type that represents exactly how a set data is
|
||||||
// encoded on the wire.
|
// encoded on the wire.
|
||||||
type QueryEncoding uint8
|
type QueryEncoding uint8
|
||||||
@@ -15,3 +17,17 @@ const (
|
|||||||
// NOTE: this should no longer be used or accepted.
|
// NOTE: this should no longer be used or accepted.
|
||||||
EncodingSortedZlib QueryEncoding = 1
|
EncodingSortedZlib QueryEncoding = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// recordProducer is a simple helper struct that implements the
|
||||||
|
// tlv.RecordProducer interface.
|
||||||
|
type recordProducer struct {
|
||||||
|
record tlv.Record
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record returns the underlying record.
|
||||||
|
func (r *recordProducer) Record() tlv.Record {
|
||||||
|
return r.record
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that recordProducer implements the tlv.RecordProducer interface.
|
||||||
|
var _ tlv.RecordProducer = (*recordProducer)(nil)
|
||||||
|
@@ -86,14 +86,6 @@ func TestExtraOpaqueDataEncodeDecode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type recordProducer struct {
|
|
||||||
record tlv.Record
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *recordProducer) Record() tlv.Record {
|
|
||||||
return r.record
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of
|
// TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of
|
||||||
// tlv.Records into a stream, and unpack them on the other side to obtain the
|
// tlv.Records into a stream, and unpack them on the other side to obtain the
|
||||||
// same set of records.
|
// same set of records.
|
||||||
|
Reference in New Issue
Block a user