mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-31 02:01:46 +02:00
lnwire: add custom records field to type UpdateAddHtlc
- Introduce the field `CustomRecords` to the type `UpdateAddHtlc`. - Encode and decode the new field into the `ExtraData` field of the `update_add_htlc` wire message.
This commit is contained in:
parent
41a5b9abf8
commit
2b3618c14d
@ -2,6 +2,7 @@ package lnwire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
@ -72,6 +73,11 @@ type UpdateAddHTLC struct {
|
||||
// next hop for this htlc.
|
||||
BlindingPoint BlindingPointRecord
|
||||
|
||||
// CustomRecords maps TLV types to byte slices, storing arbitrary data
|
||||
// intended for inclusion in the ExtraData field of the UpdateAddHTLC
|
||||
// message.
|
||||
CustomRecords CustomRecords
|
||||
|
||||
// ExtraData is the set of data that was appended to this message to
|
||||
// fill out the full maximum transport message size. These fields can
|
||||
// be used to specify optional data such as custom TLV fields.
|
||||
@ -92,6 +98,10 @@ var _ Message = (*UpdateAddHTLC)(nil)
|
||||
//
|
||||
// This is part of the lnwire.Message interface.
|
||||
func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
|
||||
// msgExtraData is a temporary variable used to read the message extra
|
||||
// data field from the reader.
|
||||
var msgExtraData ExtraOpaqueData
|
||||
|
||||
if err := ReadElements(r,
|
||||
&c.ChanID,
|
||||
&c.ID,
|
||||
@ -99,25 +109,76 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
|
||||
c.PaymentHash[:],
|
||||
&c.Expiry,
|
||||
c.OnionBlob[:],
|
||||
&c.ExtraData,
|
||||
&msgExtraData,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract TLV records from the extra data field.
|
||||
blindingRecord := c.BlindingPoint.Zero()
|
||||
tlvMap, err := c.ExtraData.ExtractRecords(&blindingRecord)
|
||||
|
||||
extraDataTlvMap, err := msgExtraData.ExtractRecords(&blindingRecord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val, ok := tlvMap[c.BlindingPoint.TlvType()]; ok && val == nil {
|
||||
val, ok := extraDataTlvMap[c.BlindingPoint.TlvType()]
|
||||
if ok && val == nil {
|
||||
c.BlindingPoint = tlv.SomeRecordT(blindingRecord)
|
||||
|
||||
// Remove the entry from the TLV map. Anything left in the map
|
||||
// will be included in the custom records field.
|
||||
delete(extraDataTlvMap, c.BlindingPoint.TlvType())
|
||||
}
|
||||
|
||||
// Any records from the extra data TLV map which are in the custom
|
||||
// records TLV type range will be included in the custom records field
|
||||
// and removed from the extra data field.
|
||||
customRecordsTlvMap := make(tlv.TypeMap, len(extraDataTlvMap))
|
||||
for k, v := range extraDataTlvMap {
|
||||
// Skip records that are not in the custom records TLV type
|
||||
// range.
|
||||
if k < MinCustomRecordsTlvType {
|
||||
continue
|
||||
}
|
||||
|
||||
// Include the record in the custom records map.
|
||||
customRecordsTlvMap[k] = v
|
||||
|
||||
// Now that the record is included in the custom records map,
|
||||
// we can remove it from the extra data TLV map.
|
||||
delete(extraDataTlvMap, k)
|
||||
}
|
||||
|
||||
// Set the custom records field to the custom records specific TLV
|
||||
// record map.
|
||||
customRecords, err := NewCustomRecordsFromTlvTypeMap(
|
||||
customRecordsTlvMap,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.CustomRecords = customRecords
|
||||
|
||||
// Set custom records to nil if we didn't parse anything out of it so
|
||||
// that we can use assert.Equal in tests.
|
||||
if len(customRecordsTlvMap) == 0 {
|
||||
c.CustomRecords = nil
|
||||
}
|
||||
|
||||
// Set extra data to nil if we didn't parse anything out of it so that
|
||||
// we can use assert.Equal in tests.
|
||||
if len(tlvMap) == 0 {
|
||||
if len(extraDataTlvMap) == 0 {
|
||||
c.ExtraData = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode the remaining records back into the extra data field. These
|
||||
// records are not in the custom records TLV type range and do not
|
||||
// have associated fields in the UpdateAddHTLC struct.
|
||||
c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap(extraDataTlvMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -152,21 +213,41 @@ func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only include blinding point in extra data if present.
|
||||
var records []tlv.RecordProducer
|
||||
|
||||
c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType,
|
||||
*btcec.PublicKey]) {
|
||||
|
||||
records = append(records, &b)
|
||||
})
|
||||
|
||||
err := EncodeMessageExtraData(&c.ExtraData, records...)
|
||||
// Construct a slice of all the records that we should include in the
|
||||
// message extra data field. We will start by including any records from
|
||||
// the extra data field.
|
||||
msgExtraDataRecords, err := c.ExtraData.RecordProducers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteBytes(w, c.ExtraData)
|
||||
// Include blinding point in extra data if specified.
|
||||
c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType,
|
||||
*btcec.PublicKey]) {
|
||||
|
||||
msgExtraDataRecords = append(msgExtraDataRecords, &b)
|
||||
})
|
||||
|
||||
// Include custom records in the extra data wire field if they are
|
||||
// present. Ensure that the custom records are validated before encoding
|
||||
// them.
|
||||
if err := c.CustomRecords.Validate(); err != nil {
|
||||
return fmt.Errorf("custom records validation error: %w", err)
|
||||
}
|
||||
|
||||
// Extend the message extra data records slice with TLV records from the
|
||||
// custom records field.
|
||||
customTlvRecords := c.CustomRecords.RecordProducers()
|
||||
msgExtraDataRecords = append(msgExtraDataRecords, customTlvRecords...)
|
||||
|
||||
// We will now construct the message extra data field that will be
|
||||
// encoded into the byte writer.
|
||||
var msgExtraData ExtraOpaqueData
|
||||
if err := msgExtraData.PackRecords(msgExtraDataRecords...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteBytes(w, msgExtraData)
|
||||
}
|
||||
|
||||
// MsgType returns the integer uniquely identifying this message type on the
|
||||
|
@ -2024,9 +2024,9 @@ func messageSummary(msg lnwire.Message) string {
|
||||
)
|
||||
|
||||
return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, "+
|
||||
"hash=%x, blinding_point=%x", msg.ChanID, msg.ID,
|
||||
msg.Amount, msg.Expiry, msg.PaymentHash[:],
|
||||
blindingPoint)
|
||||
"hash=%x, blinding_point=%x, custom_records=%v",
|
||||
msg.ChanID, msg.ID, msg.Amount, msg.Expiry,
|
||||
msg.PaymentHash[:], blindingPoint, msg.CustomRecords)
|
||||
|
||||
case *lnwire.UpdateFailHTLC:
|
||||
return fmt.Sprintf("chan_id=%v, id=%v, reason=%x", msg.ChanID,
|
||||
|
Loading…
x
Reference in New Issue
Block a user