lnwire: add method ParseAndExtractExtraData

Similar to `ParseAndExtractCustomRecords`, we now add this helper method
to make sure the extra data is parsed correctly.
This commit is contained in:
yyforyongyu
2025-07-02 02:42:14 +08:00
parent e94ca84449
commit 5961f7a1bd
3 changed files with 56 additions and 35 deletions

View File

@@ -119,9 +119,9 @@ func (dp *DynPropose) Decode(r io.Reader, _ uint32) error {
maxHtlcs := dp.MaxAcceptedHTLCs.Zero()
chanType := dp.ChannelType.Zero()
typeMap, err := tlvRecords.ExtractRecords(
&dustLimit, &maxValue, &htlcMin, &reserve, &csvDelay, &maxHtlcs,
&chanType,
knownRecords, extraData, err := ParseAndExtractExtraData(
tlvRecords, &dustLimit, &maxValue, &htlcMin, &reserve,
&csvDelay, &maxHtlcs, &chanType,
)
if err != nil {
return err
@@ -129,57 +129,43 @@ func (dp *DynPropose) Decode(r io.Reader, _ uint32) error {
// Check the results of the TLV Stream decoding and appropriately set
// message fields.
if val, ok := typeMap[dp.DustLimit.TlvType()]; ok && val == nil {
if _, ok := knownRecords[dp.DustLimit.TlvType()]; ok {
var rec tlv.RecordT[tlv.TlvType0, tlv.BigSizeT[btcutil.Amount]]
rec.Val = dustLimit.Val
dp.DustLimit = tlv.SomeRecordT(rec)
delete(typeMap, dp.DustLimit.TlvType())
}
if val, ok := typeMap[dp.MaxValueInFlight.TlvType()]; ok && val == nil {
if _, ok := knownRecords[dp.MaxValueInFlight.TlvType()]; ok {
var rec tlv.RecordT[tlv.TlvType2, MilliSatoshi]
rec.Val = maxValue.Val
dp.MaxValueInFlight = tlv.SomeRecordT(rec)
delete(typeMap, dp.MaxValueInFlight.TlvType())
}
if val, ok := typeMap[dp.HtlcMinimum.TlvType()]; ok && val == nil {
if _, ok := knownRecords[dp.HtlcMinimum.TlvType()]; ok {
var rec tlv.RecordT[tlv.TlvType4, MilliSatoshi]
rec.Val = htlcMin.Val
dp.HtlcMinimum = tlv.SomeRecordT(rec)
delete(typeMap, dp.HtlcMinimum.TlvType())
}
if val, ok := typeMap[dp.ChannelReserve.TlvType()]; ok && val == nil {
if _, ok := knownRecords[dp.ChannelReserve.TlvType()]; ok {
var rec tlv.RecordT[tlv.TlvType6, tlv.BigSizeT[btcutil.Amount]]
rec.Val = reserve.Val
dp.ChannelReserve = tlv.SomeRecordT(rec)
delete(typeMap, dp.ChannelReserve.TlvType())
}
if val, ok := typeMap[dp.CsvDelay.TlvType()]; ok && val == nil {
if _, ok := knownRecords[dp.CsvDelay.TlvType()]; ok {
dp.CsvDelay = tlv.SomeRecordT(csvDelay)
delete(typeMap, dp.CsvDelay.TlvType())
}
if val, ok := typeMap[dp.MaxAcceptedHTLCs.TlvType()]; ok && val == nil {
if _, ok := knownRecords[dp.MaxAcceptedHTLCs.TlvType()]; ok {
dp.MaxAcceptedHTLCs = tlv.SomeRecordT(maxHtlcs)
delete(typeMap, dp.MaxAcceptedHTLCs.TlvType())
}
if val, ok := typeMap[dp.ChannelType.TlvType()]; ok && val == nil {
if _, ok := knownRecords[dp.ChannelType.TlvType()]; ok {
dp.ChannelType = tlv.SomeRecordT(chanType)
delete(typeMap, dp.ChannelType.TlvType())
}
if len(typeMap) != 0 {
extraData, err := NewExtraOpaqueData(typeMap)
if err != nil {
return err
}
dp.ExtraData = extraData
}
dp.ExtraData = extraData
return nil
}

View File

@@ -269,3 +269,41 @@ func MergeAndEncode(knownRecords []tlv.RecordProducer,
return EncodeRecords(sortedRecords)
}
// ParseAndExtractExtraData parses the given extra data into the passed-in
// records, then returns any remaining records as extra data.
func ParseAndExtractExtraData(allTlvData ExtraOpaqueData,
knownRecords ...tlv.RecordProducer) (fn.Set[tlv.Type],
ExtraOpaqueData, error) {
extraDataTlvMap, err := allTlvData.ExtractRecords(knownRecords...)
if err != nil {
return nil, nil, err
}
// Remove the known and now extracted records from the leftover extra
// data map.
parsedKnownRecords := make(fn.Set[tlv.Type], len(knownRecords))
for _, producer := range knownRecords {
r := producer.Record()
// Only remove the records if it was parsed (remainder is nil).
// We'll just store the type so we can tell the caller which
// records were actually parsed fully.
val, ok := extraDataTlvMap[r.Type()]
if ok && val == nil {
parsedKnownRecords.Add(r.Type())
delete(extraDataTlvMap, r.Type())
}
}
// Encode the remaining records back into the extra data field. These
// records are not in the custom records TLV type range and do not
// have associated fields in the struct that produced the records.
extraData, err := NewExtraOpaqueData(extraDataTlvMap)
if err != nil {
return nil, nil, err
}
return parsedKnownRecords, extraData, nil
}

View File

@@ -881,10 +881,7 @@ func (dp *DynPropose) RandTestMessage(t *rapid.T) Message {
}
}
extraData := RandExtraOpaqueData(t, ignoreRecords)
if len(extraData) > 0 {
msg.ExtraData = extraData
}
msg.ExtraData = RandExtraOpaqueData(t, ignoreRecords)
return msg
}