htlcswitch: split parsing and validation of TLV payloads

When handling blinded errors, we need to know whether there was a
blinding key in our payload when we successfully parsed our payload
but then found an invalid set of fields. The combination of
parsing and validation in NewPayloadFromReader means that we don't know
whether a blinding point was available to us by the time the error is
returned.

This commit splits parsing and validation into two functions so that
we can take a look at what we actually pulled of the payload in between
parsing and TLV validation.
This commit is contained in:
Carla Kirk-Cohen
2024-04-23 11:27:14 -04:00
parent 4d051b4170
commit b81a6f3d2f
4 changed files with 63 additions and 42 deletions

View File

@@ -135,19 +135,23 @@ func fuzzPayload(f *testing.F, finalPayload, updateAddBlinded bool) {
r := bytes.NewReader(data) r := bytes.NewReader(data)
payload1, _, err := NewPayloadFromReader( payload1, parsed, err := ParseTLVPayload(r)
r, finalPayload, updateAddBlinded,
)
if err != nil { if err != nil {
return return
} }
if err = ValidateParsedPayloadTypes(
parsed, finalPayload, updateAddBlinded,
); err != nil {
return
}
var b bytes.Buffer var b bytes.Buffer
hop, nextChanID := hopFromPayload(payload1) hop, nextChanID := hopFromPayload(payload1)
err = hop.PackHopPayload(&b, nextChanID, finalPayload) err = hop.PackHopPayload(&b, nextChanID, finalPayload)
switch { switch {
// PackHopPayload refuses to encode an AMP record // PackHopPayload refuses to encode an AMP record
// without an MPP record. However, NewPayloadFromReader // without an MPP record. However, ValidateParsedPayloadTypes
// does allow decoding an AMP record without an MPP // does allow decoding an AMP record without an MPP
// record, since validation is done at a later stage. Do // record, since validation is done at a later stage. Do
// not report a bug for this case. // not report a bug for this case.
@@ -156,9 +160,9 @@ func fuzzPayload(f *testing.F, finalPayload, updateAddBlinded bool) {
// PackHopPayload will not encode regular payloads or final // PackHopPayload will not encode regular payloads or final
// hops in blinded routes that do not have an amount or expiry // hops in blinded routes that do not have an amount or expiry
// TLV set. However, NewPayloadFromReader will allow creation // TLV set. However, ValidateParsedPayloadTypes will allow
// of payloads where these TLVs are present, but they have // creation of payloads where these TLVs are present, but they
// zero values because validation is done at a later stage. // have zero values because validation is done at a later stage.
case errors.Is(err, route.ErrMissingField): case errors.Is(err, route.ErrMissingField):
return return
@@ -166,8 +170,11 @@ func fuzzPayload(f *testing.F, finalPayload, updateAddBlinded bool) {
require.NoError(t, err) require.NoError(t, err)
} }
payload2, _, err := NewPayloadFromReader( payload2, parsed, err := ParseTLVPayload(&b)
&b, finalPayload, updateAddBlinded, require.NoError(t, err)
err = ValidateParsedPayloadTypes(
parsed, finalPayload, updateAddBlinded,
) )
require.NoError(t, err) require.NoError(t, err)

View File

@@ -112,14 +112,20 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) {
// to decode only what we need to make routing decisions. // to decode only what we need to make routing decisions.
case sphinx.PayloadTLV: case sphinx.PayloadTLV:
isFinal := r.processedPacket.Action == sphinx.ExitNode isFinal := r.processedPacket.Action == sphinx.ExitNode
payload, parsed, err := NewPayloadFromReader( payload, parsed, err := ParseTLVPayload(
bytes.NewReader(r.processedPacket.Payload.Payload), bytes.NewReader(r.processedPacket.Payload.Payload),
isFinal, r.blindingKit.UpdateAddBlinding.IsSome(),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := ValidateTLVPayload(
parsed, isFinal,
r.blindingKit.UpdateAddBlinding.IsSome(),
); err != nil {
return nil, err
}
// If we had an encrypted data payload present, pull out our // If we had an encrypted data payload present, pull out our
// forwarding info from the blob. // forwarding info from the blob.
if payload.encryptedData != nil { if payload.encryptedData != nil {

View File

@@ -133,14 +133,10 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
} }
} }
// NewPayloadFromReader builds a new Hop from the passed io.Reader and returns // ParseTLVPayload builds a new Hop from the passed io.Reader and returns
// a map of all the types that were found in the payload. The reader // a map of all the types that were found in the payload. This function
// should correspond to the bytes encapsulated in a TLV onion payload. The // does not perform validation of TLV types included in the payload.
// final hop bool signals that this payload was the final packet parsed by func ParseTLVPayload(r io.Reader) (*Payload, map[tlv.Type][]byte, error) {
// sphinx.
func NewPayloadFromReader(r io.Reader, finalHop,
updateAddBlinding bool) (*Payload, map[tlv.Type][]byte, error) {
var ( var (
cid uint64 cid uint64
amt uint64 amt uint64
@@ -175,25 +171,6 @@ func NewPayloadFromReader(r io.Reader, finalHop,
return nil, nil, err return nil, nil, err
} }
// Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04.
err = ValidateParsedPayloadTypes(
parsedTypes, finalHop, updateAddBlinding,
)
if err != nil {
return nil, nil, err
}
// Check for violation of the rules for mandatory fields.
violatingType := getMinRequiredViolation(parsedTypes)
if violatingType != nil {
return nil, nil, ErrInvalidPayload{
Type: *violatingType,
Violation: RequiredViolation,
FinalHop: finalHop,
}
}
// If no MPP field was parsed, set the MPP field on the resulting // If no MPP field was parsed, set the MPP field on the resulting
// payload to nil. // payload to nil.
if _, ok := parsedTypes[record.MPPOnionType]; !ok { if _, ok := parsedTypes[record.MPPOnionType]; !ok {
@@ -234,7 +211,34 @@ func NewPayloadFromReader(r io.Reader, finalHop,
blindingPoint: blindingPoint, blindingPoint: blindingPoint,
customRecords: customRecords, customRecords: customRecords,
totalAmtMsat: lnwire.MilliSatoshi(totalAmtMsat), totalAmtMsat: lnwire.MilliSatoshi(totalAmtMsat),
}, nil, nil }, parsedTypes, nil
}
// ValidateTLVPayload validates the TLV fields that were included in a TLV
// payload.
func ValidateTLVPayload(parsedTypes map[tlv.Type][]byte,
finalHop bool, updateAddBlinding bool) error {
// Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04.
err := ValidateParsedPayloadTypes(
parsedTypes, finalHop, updateAddBlinding,
)
if err != nil {
return err
}
// Check for violation of the rules for mandatory fields.
violatingType := getMinRequiredViolation(parsedTypes)
if violatingType != nil {
return ErrInvalidPayload{
Type: *violatingType,
Violation: RequiredViolation,
FinalHop: finalHop,
}
}
return nil
} }
// ForwardingInfo returns the basic parameters required for HTLC forwarding, // ForwardingInfo returns the basic parameters required for HTLC forwarding,

View File

@@ -547,9 +547,13 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
testChildIndex = uint32(9) testChildIndex = uint32(9)
) )
p, _, err := hop.NewPayloadFromReader( p, parsedTypes, err := hop.ParseTLVPayload(
bytes.NewReader(test.payload), test.isFinalHop, bytes.NewReader(test.payload),
test.updateAddBlinded, )
require.NoError(t, err)
err = hop.ValidateTLVPayload(
parsedTypes, test.isFinalHop, test.updateAddBlinded,
) )
if !reflect.DeepEqual(test.expErr, err) { if !reflect.DeepEqual(test.expErr, err) {
t.Fatalf("expected error mismatch, want: %v, got: %v", t.Fatalf("expected error mismatch, want: %v, got: %v",