htlcswitch: split parsing and validation of TLV payloads

When handling blinded errors, we need to know whether there was a
blinding key in our payload when we successfully parsed our payload
but then found an invalid set of fields. The combination of
parsing and validation in NewPayloadFromReader means that we don't know
whether a blinding point was available to us by the time the error is
returned.

This commit splits parsing and validation into two functions so that
we can take a look at what we actually pulled of the payload in between
parsing and TLV validation.
This commit is contained in:
Carla Kirk-Cohen
2024-04-23 11:27:14 -04:00
parent 4d051b4170
commit b81a6f3d2f
4 changed files with 63 additions and 42 deletions

View File

@@ -133,14 +133,10 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
}
}
// NewPayloadFromReader builds a new Hop from the passed io.Reader and returns
// a map of all the types that were found in the payload. The reader
// should correspond to the bytes encapsulated in a TLV onion payload. The
// final hop bool signals that this payload was the final packet parsed by
// sphinx.
func NewPayloadFromReader(r io.Reader, finalHop,
updateAddBlinding bool) (*Payload, map[tlv.Type][]byte, error) {
// ParseTLVPayload builds a new Hop from the passed io.Reader and returns
// a map of all the types that were found in the payload. This function
// does not perform validation of TLV types included in the payload.
func ParseTLVPayload(r io.Reader) (*Payload, map[tlv.Type][]byte, error) {
var (
cid uint64
amt uint64
@@ -175,25 +171,6 @@ func NewPayloadFromReader(r io.Reader, finalHop,
return nil, nil, err
}
// Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04.
err = ValidateParsedPayloadTypes(
parsedTypes, finalHop, updateAddBlinding,
)
if err != nil {
return nil, nil, err
}
// Check for violation of the rules for mandatory fields.
violatingType := getMinRequiredViolation(parsedTypes)
if violatingType != nil {
return nil, nil, ErrInvalidPayload{
Type: *violatingType,
Violation: RequiredViolation,
FinalHop: finalHop,
}
}
// If no MPP field was parsed, set the MPP field on the resulting
// payload to nil.
if _, ok := parsedTypes[record.MPPOnionType]; !ok {
@@ -234,7 +211,34 @@ func NewPayloadFromReader(r io.Reader, finalHop,
blindingPoint: blindingPoint,
customRecords: customRecords,
totalAmtMsat: lnwire.MilliSatoshi(totalAmtMsat),
}, nil, nil
}, parsedTypes, nil
}
// ValidateTLVPayload validates the TLV fields that were included in a TLV
// payload.
func ValidateTLVPayload(parsedTypes map[tlv.Type][]byte,
finalHop bool, updateAddBlinding bool) error {
// Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04.
err := ValidateParsedPayloadTypes(
parsedTypes, finalHop, updateAddBlinding,
)
if err != nil {
return err
}
// Check for violation of the rules for mandatory fields.
violatingType := getMinRequiredViolation(parsedTypes)
if violatingType != nil {
return ErrInvalidPayload{
Type: *violatingType,
Violation: RequiredViolation,
FinalHop: finalHop,
}
}
return nil
}
// ForwardingInfo returns the basic parameters required for HTLC forwarding,