diff --git a/lnwire/dyn_propose.go b/lnwire/dyn_propose.go index 4d7459cf2..4744b87d2 100644 --- a/lnwire/dyn_propose.go +++ b/lnwire/dyn_propose.go @@ -119,9 +119,9 @@ func (dp *DynPropose) Decode(r io.Reader, _ uint32) error { maxHtlcs := dp.MaxAcceptedHTLCs.Zero() chanType := dp.ChannelType.Zero() - typeMap, err := tlvRecords.ExtractRecords( - &dustLimit, &maxValue, &htlcMin, &reserve, &csvDelay, &maxHtlcs, - &chanType, + knownRecords, extraData, err := ParseAndExtractExtraData( + tlvRecords, &dustLimit, &maxValue, &htlcMin, &reserve, + &csvDelay, &maxHtlcs, &chanType, ) if err != nil { return err @@ -129,57 +129,43 @@ func (dp *DynPropose) Decode(r io.Reader, _ uint32) error { // Check the results of the TLV Stream decoding and appropriately set // message fields. - if val, ok := typeMap[dp.DustLimit.TlvType()]; ok && val == nil { + if _, ok := knownRecords[dp.DustLimit.TlvType()]; ok { var rec tlv.RecordT[tlv.TlvType0, tlv.BigSizeT[btcutil.Amount]] rec.Val = dustLimit.Val dp.DustLimit = tlv.SomeRecordT(rec) - delete(typeMap, dp.DustLimit.TlvType()) } - if val, ok := typeMap[dp.MaxValueInFlight.TlvType()]; ok && val == nil { + + if _, ok := knownRecords[dp.MaxValueInFlight.TlvType()]; ok { var rec tlv.RecordT[tlv.TlvType2, MilliSatoshi] rec.Val = maxValue.Val dp.MaxValueInFlight = tlv.SomeRecordT(rec) - - delete(typeMap, dp.MaxValueInFlight.TlvType()) } - if val, ok := typeMap[dp.HtlcMinimum.TlvType()]; ok && val == nil { + + if _, ok := knownRecords[dp.HtlcMinimum.TlvType()]; ok { var rec tlv.RecordT[tlv.TlvType4, MilliSatoshi] rec.Val = htlcMin.Val dp.HtlcMinimum = tlv.SomeRecordT(rec) - - delete(typeMap, dp.HtlcMinimum.TlvType()) } - if val, ok := typeMap[dp.ChannelReserve.TlvType()]; ok && val == nil { + + if _, ok := knownRecords[dp.ChannelReserve.TlvType()]; ok { var rec tlv.RecordT[tlv.TlvType6, tlv.BigSizeT[btcutil.Amount]] rec.Val = reserve.Val dp.ChannelReserve = tlv.SomeRecordT(rec) - - delete(typeMap, dp.ChannelReserve.TlvType()) } - if val, ok := typeMap[dp.CsvDelay.TlvType()]; ok && val == nil { + + if _, ok := knownRecords[dp.CsvDelay.TlvType()]; ok { dp.CsvDelay = tlv.SomeRecordT(csvDelay) - - delete(typeMap, dp.CsvDelay.TlvType()) } - if val, ok := typeMap[dp.MaxAcceptedHTLCs.TlvType()]; ok && val == nil { + + if _, ok := knownRecords[dp.MaxAcceptedHTLCs.TlvType()]; ok { dp.MaxAcceptedHTLCs = tlv.SomeRecordT(maxHtlcs) - - delete(typeMap, dp.MaxAcceptedHTLCs.TlvType()) } - if val, ok := typeMap[dp.ChannelType.TlvType()]; ok && val == nil { + + if _, ok := knownRecords[dp.ChannelType.TlvType()]; ok { dp.ChannelType = tlv.SomeRecordT(chanType) - - delete(typeMap, dp.ChannelType.TlvType()) } - if len(typeMap) != 0 { - extraData, err := NewExtraOpaqueData(typeMap) - if err != nil { - return err - } - - dp.ExtraData = extraData - } + dp.ExtraData = extraData return nil } diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index 9530e06e8..39228c101 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -269,3 +269,41 @@ func MergeAndEncode(knownRecords []tlv.RecordProducer, return EncodeRecords(sortedRecords) } + +// ParseAndExtractExtraData parses the given extra data into the passed-in +// records, then returns any remaining records as extra data. +func ParseAndExtractExtraData(allTlvData ExtraOpaqueData, + knownRecords ...tlv.RecordProducer) (fn.Set[tlv.Type], + ExtraOpaqueData, error) { + + extraDataTlvMap, err := allTlvData.ExtractRecords(knownRecords...) + if err != nil { + return nil, nil, err + } + + // Remove the known and now extracted records from the leftover extra + // data map. + parsedKnownRecords := make(fn.Set[tlv.Type], len(knownRecords)) + for _, producer := range knownRecords { + r := producer.Record() + + // Only remove the records if it was parsed (remainder is nil). + // We'll just store the type so we can tell the caller which + // records were actually parsed fully. + val, ok := extraDataTlvMap[r.Type()] + if ok && val == nil { + parsedKnownRecords.Add(r.Type()) + delete(extraDataTlvMap, r.Type()) + } + } + + // Encode the remaining records back into the extra data field. These + // records are not in the custom records TLV type range and do not + // have associated fields in the struct that produced the records. + extraData, err := NewExtraOpaqueData(extraDataTlvMap) + if err != nil { + return nil, nil, err + } + + return parsedKnownRecords, extraData, nil +} diff --git a/lnwire/test_message.go b/lnwire/test_message.go index 28ffba6c2..33b98d03a 100644 --- a/lnwire/test_message.go +++ b/lnwire/test_message.go @@ -881,10 +881,7 @@ func (dp *DynPropose) RandTestMessage(t *rapid.T) Message { } } - extraData := RandExtraOpaqueData(t, ignoreRecords) - if len(extraData) > 0 { - msg.ExtraData = extraData - } + msg.ExtraData = RandExtraOpaqueData(t, ignoreRecords) return msg }