lnwire: add ExtraOpaqueData helper functions and methods

Introduces a couple of new helper functions for both the ExtraOpaqueData
and CustomRecords types along with new methods on the ExtraOpaqueData.
This commit is contained in:
ffranr
2024-05-03 18:36:00 +01:00
committed by Oliver Gugger
parent 17c0a70b07
commit af50694643
4 changed files with 596 additions and 76 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/tlv"
)
@@ -15,6 +16,21 @@ import (
// upgrades to the network in a forwards compatible manner.
type ExtraOpaqueData []byte
// NewExtraOpaqueData creates a new ExtraOpaqueData instance from a tlv.TypeMap.
func NewExtraOpaqueData(tlvMap tlv.TypeMap) (ExtraOpaqueData, error) {
// If the tlv map is empty, we'll want to mirror the behavior of
// decoding an empty extra opaque data field (see Decode method).
if len(tlvMap) == 0 {
return make([]byte, 0), nil
}
// Convert the TLV map into a slice of records.
records := TlvMapToRecords(tlvMap)
// Encode the records into the extra data byte slice.
return EncodeRecords(records)
}
// Encode attempts to encode the raw extra bytes into the passed io.Writer.
func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error {
eBytes := []byte((*e)[:])
@@ -25,8 +41,8 @@ func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error {
return nil
}
// Decode attempts to unpack the raw bytes encoded in the passed io.Reader as a
// set of extra opaque data.
// Decode attempts to unpack the raw bytes encoded in the passed-in io.Reader as
// a set of extra opaque data.
func (e *ExtraOpaqueData) Decode(r io.Reader) error {
// First, we'll attempt to read a set of bytes contained within the
// passed io.Reader (if any exist).
@@ -39,7 +55,7 @@ func (e *ExtraOpaqueData) Decode(r io.Reader) error {
// This ensures that any struct that embeds this type will properly
// store the bytes once this method exits.
if len(rawBytes) > 0 {
*e = ExtraOpaqueData(rawBytes)
*e = rawBytes
} else {
*e = make([]byte, 0)
}
@@ -50,28 +66,17 @@ func (e *ExtraOpaqueData) Decode(r io.Reader) error {
// PackRecords attempts to encode the set of tlv records into the target
// ExtraOpaqueData instance. The records will be encoded as a raw TLV stream
// and stored within the backing slice pointer.
func (e *ExtraOpaqueData) PackRecords(recordProducers ...tlv.RecordProducer) error {
// First, assemble all the records passed in in series.
records := make([]tlv.Record, 0, len(recordProducers))
for _, producer := range recordProducers {
records = append(records, producer.Record())
}
func (e *ExtraOpaqueData) PackRecords(
recordProducers ...tlv.RecordProducer) error {
// Ensure that the set of records are sorted before we encode them into
// the stream, to ensure they're canonical.
tlv.SortRecords(records)
tlvStream, err := tlv.NewStream(records...)
// Assemble all the records passed in series, then encode them.
records := ProduceRecordsSorted(recordProducers...)
encoded, err := EncodeRecords(records)
if err != nil {
return err
}
var extraBytesWriter bytes.Buffer
if err := tlvStream.Encode(&extraBytesWriter); err != nil {
return err
}
*e = ExtraOpaqueData(extraBytesWriter.Bytes())
*e = encoded
return nil
}
@@ -80,29 +85,38 @@ func (e *ExtraOpaqueData) PackRecords(recordProducers ...tlv.RecordProducer) err
// it were a tlv stream. The set of raw parsed types is returned, and any
// passed records (if found in the stream) will be parsed into the proper
// tlv.Record.
func (e *ExtraOpaqueData) ExtractRecords(recordProducers ...tlv.RecordProducer) (
tlv.TypeMap, error) {
// First, assemble all the records passed in in series.
records := make([]tlv.Record, 0, len(recordProducers))
for _, producer := range recordProducers {
records = append(records, producer.Record())
}
// Ensure that the set of records are sorted before we attempt to
// decode from the stream, to ensure they're canonical.
tlv.SortRecords(records)
func (e *ExtraOpaqueData) ExtractRecords(
recordProducers ...tlv.RecordProducer) (tlv.TypeMap, error) {
// First, assemble all the records passed in series.
records := ProduceRecordsSorted(recordProducers...)
extraBytesReader := bytes.NewReader(*e)
tlvStream, err := tlv.NewStream(records...)
// Since ExtraOpaqueData is provided by a potentially malicious peer,
// pass it into the P2P decoding variant.
return DecodeRecordsP2P(extraBytesReader, records...)
}
// RecordProducers parses ExtraOpaqueData into a slice of TLV record producers
// by interpreting it as a TLV map.
func (e *ExtraOpaqueData) RecordProducers() ([]tlv.RecordProducer, error) {
var recordProducers []tlv.RecordProducer
// If the instance is nil or empty, return an empty slice.
if e == nil || len(*e) == 0 {
return recordProducers, nil
}
// Parse the extra opaque data as a TLV map.
tlvMap, err := e.ExtractRecords()
if err != nil {
return nil, err
}
// Since ExtraOpaqueData is provided by a potentially malicious peer,
// pass it into the P2P decoding variant.
return tlvStream.DecodeWithParsedTypesP2P(extraBytesReader)
// Convert the TLV map into a slice of record producers.
records := TlvMapToRecords(tlvMap)
return RecordsAsProducers(records), nil
}
// EncodeMessageExtraData encodes the given recordProducers into the given
@@ -120,3 +134,116 @@ func EncodeMessageExtraData(extraData *ExtraOpaqueData,
// are all properly sorted.
return extraData.PackRecords(recordProducers...)
}
// ParseAndExtractCustomRecords parses the given extra data into the passed-in
// records, then returns any remaining records split into custom records and
// extra data.
func ParseAndExtractCustomRecords(allExtraData ExtraOpaqueData,
knownRecords ...tlv.RecordProducer) (CustomRecords,
fn.Set[tlv.Type], ExtraOpaqueData, error) {
extraDataTlvMap, err := allExtraData.ExtractRecords(knownRecords...)
if err != nil {
return nil, nil, nil, err
}
// Remove the known and now extracted records from the leftover extra
// data map.
parsedKnownRecords := make(fn.Set[tlv.Type], len(knownRecords))
for _, producer := range knownRecords {
r := producer.Record()
// Only remove the records if it was parsed (remainder is nil).
// We'll just store the type so we can tell the caller which
// records were actually parsed fully.
val, ok := extraDataTlvMap[r.Type()]
if ok && val == nil {
parsedKnownRecords.Add(r.Type())
delete(extraDataTlvMap, r.Type())
}
}
// Any records from the extra data TLV map which are in the custom
// records TLV type range will be included in the custom records field
// and removed from the extra data field.
customRecordsTlvMap := make(tlv.TypeMap, len(extraDataTlvMap))
for k, v := range extraDataTlvMap {
// Skip records that are not in the custom records TLV type
// range.
if k < MinCustomRecordsTlvType {
continue
}
// Include the record in the custom records map.
customRecordsTlvMap[k] = v
// Now that the record is included in the custom records map,
// we can remove it from the extra data TLV map.
delete(extraDataTlvMap, k)
}
// Set the custom records field to the custom records specific TLV
// record map.
customRecords, err := NewCustomRecords(customRecordsTlvMap)
if err != nil {
return nil, nil, nil, err
}
// Encode the remaining records back into the extra data field. These
// records are not in the custom records TLV type range and do not
// have associated fields in the struct that produced the records.
extraData, err := NewExtraOpaqueData(extraDataTlvMap)
if err != nil {
return nil, nil, nil, err
}
// Help with unit testing where we might have the empty value (nil) for
// the extra data instead of the default that's returned by the
// constructor (empty slice).
if len(extraData) == 0 {
extraData = nil
}
return customRecords, parsedKnownRecords, extraData, nil
}
// MergeAndEncode merges the known records with the extra data and custom
// records, then encodes the merged records into raw bytes.
func MergeAndEncode(knownRecords []tlv.RecordProducer,
extraData ExtraOpaqueData, customRecords CustomRecords) ([]byte,
error) {
// Construct a slice of all the records that we should include in the
// message extra data field. We will start by including any records from
// the extra data field.
mergedRecords, err := extraData.RecordProducers()
if err != nil {
return nil, err
}
// Merge the known and extra data records.
mergedRecords = append(mergedRecords, knownRecords...)
// Include custom records in the extra data wire field if they are
// present. Ensure that the custom records are validated before encoding
// them.
if err := customRecords.Validate(); err != nil {
return nil, fmt.Errorf("custom records validation error: %w",
err)
}
// Extend the message extra data records slice with TLV records from the
// custom records field.
mergedRecords = append(
mergedRecords, customRecords.RecordProducers()...,
)
// Now we can sort the records and make sure there are no records with
// the same type that would collide when encoding.
sortedRecords := ProduceRecordsSorted(mergedRecords...)
if err := AssertUniqueTypes(sortedRecords); err != nil {
return nil, err
}
return EncodeRecords(sortedRecords)
}