lnwire: add custom records field to type UpdateAddHtlc

- Introduce the field `CustomRecords` to the type `UpdateAddHtlc`.
- Encode and decode the new field into the `ExtraData` field of
  the `update_add_htlc` wire message.
This commit is contained in:
ffranr 2024-04-13 12:29:41 +01:00 committed by Oliver Gugger
parent 7c2d6586b8
commit ba043fa1d1
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
3 changed files with 33 additions and 13 deletions

View File

@ -86,14 +86,6 @@ func TestExtraOpaqueDataEncodeDecode(t *testing.T) {
}
}
type recordProducer struct {
record tlv.Record
}
func (r *recordProducer) Record() tlv.Record {
return r.record
}
// TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of
// tlv.Records into a stream, and unpack them on the other side to obtain the
// same set of records.

View File

@ -72,6 +72,11 @@ type UpdateAddHTLC struct {
// next hop for this htlc.
BlindingPoint BlindingPointRecord
// CustomRecords maps TLV types to byte slices, storing arbitrary data
// intended for inclusion in the ExtraData field of the UpdateAddHTLC
// message.
CustomRecords CustomRecords
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
@ -104,7 +109,9 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
return err
}
// Extract TLV records from the extra data field.
blindingRecord := c.BlindingPoint.Zero()
tlvMap, err := c.ExtraData.ExtractRecords(&blindingRecord)
if err != nil {
return err
@ -112,8 +119,19 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
if val, ok := tlvMap[c.BlindingPoint.TlvType()]; ok && val == nil {
c.BlindingPoint = tlv.SomeRecordT(blindingRecord)
// Remove the entry from the TLV map. Anything left in the map
// will be included in the custom records field.
delete(tlvMap, c.BlindingPoint.TlvType())
}
// Set the custom records field to the remaining TLV records.
customRecords, err := NewCustomRecordsFromTlvTypeMap(tlvMap)
if err != nil {
return err
}
c.CustomRecords = *customRecords
// Set extra data to nil if we didn't parse anything out of it so that
// we can use assert.Equal in tests.
if len(tlvMap) == 0 {
@ -152,16 +170,26 @@ func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error {
return err
}
// Only include blinding point in extra data if present.
// Construct a slice of all the records that we should include in the
// extra data field.
var records []tlv.RecordProducer
// Only include blinding point in extra data if present.
c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType,
*btcec.PublicKey]) {
records = append(records, &b)
})
err := EncodeMessageExtraData(&c.ExtraData, records...)
// Extend the 'records' slice with TLV records from the custom records
// field.
records, err := c.CustomRecords.ExtendRecordProducers(records)
if err != nil {
return err
}
// Encode the records into the extra data field.
err = EncodeMessageExtraData(&c.ExtraData, records...)
if err != nil {
return err
}

View File

@ -2074,9 +2074,9 @@ func messageSummary(msg lnwire.Message) string {
)
return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, "+
"hash=%x, blinding_point=%x", msg.ChanID, msg.ID,
msg.Amount, msg.Expiry, msg.PaymentHash[:],
blindingPoint)
"hash=%x, blinding_point=%x, custom_records=%v",
msg.ChanID, msg.ID, msg.Amount, msg.Expiry,
msg.PaymentHash[:], blindingPoint, msg.CustomRecords)
case *lnwire.UpdateFailHTLC:
return fmt.Sprintf("chan_id=%v, id=%v, reason=%x", msg.ChanID,