From ba043fa1d1deec2a0691958c0c63d7c3fa323658 Mon Sep 17 00:00:00 2001 From: ffranr Date: Sat, 13 Apr 2024 12:29:41 +0100 Subject: [PATCH] lnwire: add custom records field to type `UpdateAddHtlc` - Introduce the field `CustomRecords` to the type `UpdateAddHtlc`. - Encode and decode the new field into the `ExtraData` field of the `update_add_htlc` wire message. --- lnwire/extra_bytes_test.go | 8 -------- lnwire/update_add_htlc.go | 32 ++++++++++++++++++++++++++++++-- peer/brontide.go | 6 +++--- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go index bc1de8c57..69cf549bf 100644 --- a/lnwire/extra_bytes_test.go +++ b/lnwire/extra_bytes_test.go @@ -86,14 +86,6 @@ func TestExtraOpaqueDataEncodeDecode(t *testing.T) { } } -type recordProducer struct { - record tlv.Record -} - -func (r *recordProducer) Record() tlv.Record { - return r.record -} - // TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of // tlv.Records into a stream, and unpack them on the other side to obtain the // same set of records. diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 8a40710e8..c12b07e80 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -72,6 +72,11 @@ type UpdateAddHTLC struct { // next hop for this htlc. BlindingPoint BlindingPointRecord + // CustomRecords maps TLV types to byte slices, storing arbitrary data + // intended for inclusion in the ExtraData field of the UpdateAddHTLC + // message. + CustomRecords CustomRecords + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -104,7 +109,9 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { return err } + // Extract TLV records from the extra data field. blindingRecord := c.BlindingPoint.Zero() + tlvMap, err := c.ExtraData.ExtractRecords(&blindingRecord) if err != nil { return err @@ -112,8 +119,19 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { if val, ok := tlvMap[c.BlindingPoint.TlvType()]; ok && val == nil { c.BlindingPoint = tlv.SomeRecordT(blindingRecord) + + // Remove the entry from the TLV map. Anything left in the map + // will be included in the custom records field. + delete(tlvMap, c.BlindingPoint.TlvType()) } + // Set the custom records field to the remaining TLV records. + customRecords, err := NewCustomRecordsFromTlvTypeMap(tlvMap) + if err != nil { + return err + } + c.CustomRecords = *customRecords + // Set extra data to nil if we didn't parse anything out of it so that // we can use assert.Equal in tests. if len(tlvMap) == 0 { @@ -152,16 +170,26 @@ func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error { return err } - // Only include blinding point in extra data if present. + // Construct a slice of all the records that we should include in the + // extra data field. var records []tlv.RecordProducer + // Only include blinding point in extra data if present. c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType, *btcec.PublicKey]) { records = append(records, &b) }) - err := EncodeMessageExtraData(&c.ExtraData, records...) + // Extend the 'records' slice with TLV records from the custom records + // field. + records, err := c.CustomRecords.ExtendRecordProducers(records) + if err != nil { + return err + } + + // Encode the records into the extra data field. + err = EncodeMessageExtraData(&c.ExtraData, records...) if err != nil { return err } diff --git a/peer/brontide.go b/peer/brontide.go index 90057feda..b63ac8517 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -2074,9 +2074,9 @@ func messageSummary(msg lnwire.Message) string { ) return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, "+ - "hash=%x, blinding_point=%x", msg.ChanID, msg.ID, - msg.Amount, msg.Expiry, msg.PaymentHash[:], - blindingPoint) + "hash=%x, blinding_point=%x, custom_records=%v", + msg.ChanID, msg.ID, msg.Amount, msg.Expiry, + msg.PaymentHash[:], blindingPoint, msg.CustomRecords) case *lnwire.UpdateFailHTLC: return fmt.Sprintf("chan_id=%v, id=%v, reason=%x", msg.ChanID,