diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 8a40710e8..3669f81e8 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -2,6 +2,7 @@ package lnwire import ( "bytes" + "fmt" "io" "github.com/btcsuite/btcd/btcec/v2" @@ -72,6 +73,11 @@ type UpdateAddHTLC struct { // next hop for this htlc. BlindingPoint BlindingPointRecord + // CustomRecords maps TLV types to byte slices, storing arbitrary data + // intended for inclusion in the ExtraData field of the UpdateAddHTLC + // message. + CustomRecords CustomRecords + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -92,6 +98,10 @@ var _ Message = (*UpdateAddHTLC)(nil) // // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { + // msgExtraData is a temporary variable used to read the message extra + // data field from the reader. + var msgExtraData ExtraOpaqueData + if err := ReadElements(r, &c.ChanID, &c.ID, @@ -99,25 +109,76 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { c.PaymentHash[:], &c.Expiry, c.OnionBlob[:], - &c.ExtraData, + &msgExtraData, ); err != nil { return err } + // Extract TLV records from the extra data field. blindingRecord := c.BlindingPoint.Zero() - tlvMap, err := c.ExtraData.ExtractRecords(&blindingRecord) + + extraDataTlvMap, err := msgExtraData.ExtractRecords(&blindingRecord) if err != nil { return err } - if val, ok := tlvMap[c.BlindingPoint.TlvType()]; ok && val == nil { + val, ok := extraDataTlvMap[c.BlindingPoint.TlvType()] + if ok && val == nil { c.BlindingPoint = tlv.SomeRecordT(blindingRecord) + + // Remove the entry from the TLV map. Anything left in the map + // will be included in the custom records field. + delete(extraDataTlvMap, c.BlindingPoint.TlvType()) + } + + // Any records from the extra data TLV map which are in the custom + // records TLV type range will be included in the custom records field + // and removed from the extra data field. + customRecordsTlvMap := make(tlv.TypeMap, len(extraDataTlvMap)) + for k, v := range extraDataTlvMap { + // Skip records that are not in the custom records TLV type + // range. + if k < MinCustomRecordsTlvType { + continue + } + + // Include the record in the custom records map. + customRecordsTlvMap[k] = v + + // Now that the record is included in the custom records map, + // we can remove it from the extra data TLV map. + delete(extraDataTlvMap, k) + } + + // Set the custom records field to the custom records specific TLV + // record map. + customRecords, err := NewCustomRecordsFromTlvTypeMap( + customRecordsTlvMap, + ) + if err != nil { + return err + } + c.CustomRecords = customRecords + + // Set custom records to nil if we didn't parse anything out of it so + // that we can use assert.Equal in tests. + if len(customRecordsTlvMap) == 0 { + c.CustomRecords = nil } // Set extra data to nil if we didn't parse anything out of it so that // we can use assert.Equal in tests. - if len(tlvMap) == 0 { + if len(extraDataTlvMap) == 0 { c.ExtraData = nil + return nil + } + + // Encode the remaining records back into the extra data field. These + // records are not in the custom records TLV type range and do not + // have associated fields in the UpdateAddHTLC struct. + c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap(extraDataTlvMap) + if err != nil { + return err } return nil @@ -152,21 +213,41 @@ func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error { return err } - // Only include blinding point in extra data if present. - var records []tlv.RecordProducer - - c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType, - *btcec.PublicKey]) { - - records = append(records, &b) - }) - - err := EncodeMessageExtraData(&c.ExtraData, records...) + // Construct a slice of all the records that we should include in the + // message extra data field. We will start by including any records from + // the extra data field. + msgExtraDataRecords, err := c.ExtraData.RecordProducers() if err != nil { return err } - return WriteBytes(w, c.ExtraData) + // Include blinding point in extra data if specified. + c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType, + *btcec.PublicKey]) { + + msgExtraDataRecords = append(msgExtraDataRecords, &b) + }) + + // Include custom records in the extra data wire field if they are + // present. Ensure that the custom records are validated before encoding + // them. + if err := c.CustomRecords.Validate(); err != nil { + return fmt.Errorf("custom records validation error: %w", err) + } + + // Extend the message extra data records slice with TLV records from the + // custom records field. + customTlvRecords := c.CustomRecords.RecordProducers() + msgExtraDataRecords = append(msgExtraDataRecords, customTlvRecords...) + + // We will now construct the message extra data field that will be + // encoded into the byte writer. + var msgExtraData ExtraOpaqueData + if err := msgExtraData.PackRecords(msgExtraDataRecords...); err != nil { + return err + } + + return WriteBytes(w, msgExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/peer/brontide.go b/peer/brontide.go index b9a9f68ca..e084bb88a 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -2024,9 +2024,9 @@ func messageSummary(msg lnwire.Message) string { ) return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, "+ - "hash=%x, blinding_point=%x", msg.ChanID, msg.ID, - msg.Amount, msg.Expiry, msg.PaymentHash[:], - blindingPoint) + "hash=%x, blinding_point=%x, custom_records=%v", + msg.ChanID, msg.ID, msg.Amount, msg.Expiry, + msg.PaymentHash[:], blindingPoint, msg.CustomRecords) case *lnwire.UpdateFailHTLC: return fmt.Sprintf("chan_id=%v, id=%v, reason=%x", msg.ChanID,