mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-27 18:22:24 +01:00
htlcswitch: split parsing and validation of TLV payloads
When handling blinded errors, we need to know whether there was a blinding key in our payload when we successfully parsed our payload but then found an invalid set of fields. The combination of parsing and validation in NewPayloadFromReader means that we don't know whether a blinding point was available to us by the time the error is returned. This commit splits parsing and validation into two functions so that we can take a look at what we actually pulled of the payload in between parsing and TLV validation.
This commit is contained in:
parent
4d051b4170
commit
b81a6f3d2f
@ -135,19 +135,23 @@ func fuzzPayload(f *testing.F, finalPayload, updateAddBlinded bool) {
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
payload1, _, err := NewPayloadFromReader(
|
||||
r, finalPayload, updateAddBlinded,
|
||||
)
|
||||
payload1, parsed, err := ParseTLVPayload(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = ValidateParsedPayloadTypes(
|
||||
parsed, finalPayload, updateAddBlinded,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
hop, nextChanID := hopFromPayload(payload1)
|
||||
err = hop.PackHopPayload(&b, nextChanID, finalPayload)
|
||||
switch {
|
||||
// PackHopPayload refuses to encode an AMP record
|
||||
// without an MPP record. However, NewPayloadFromReader
|
||||
// without an MPP record. However, ValidateParsedPayloadTypes
|
||||
// does allow decoding an AMP record without an MPP
|
||||
// record, since validation is done at a later stage. Do
|
||||
// not report a bug for this case.
|
||||
@ -156,9 +160,9 @@ func fuzzPayload(f *testing.F, finalPayload, updateAddBlinded bool) {
|
||||
|
||||
// PackHopPayload will not encode regular payloads or final
|
||||
// hops in blinded routes that do not have an amount or expiry
|
||||
// TLV set. However, NewPayloadFromReader will allow creation
|
||||
// of payloads where these TLVs are present, but they have
|
||||
// zero values because validation is done at a later stage.
|
||||
// TLV set. However, ValidateParsedPayloadTypes will allow
|
||||
// creation of payloads where these TLVs are present, but they
|
||||
// have zero values because validation is done at a later stage.
|
||||
case errors.Is(err, route.ErrMissingField):
|
||||
return
|
||||
|
||||
@ -166,8 +170,11 @@ func fuzzPayload(f *testing.F, finalPayload, updateAddBlinded bool) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
payload2, _, err := NewPayloadFromReader(
|
||||
&b, finalPayload, updateAddBlinded,
|
||||
payload2, parsed, err := ParseTLVPayload(&b)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ValidateParsedPayloadTypes(
|
||||
parsed, finalPayload, updateAddBlinded,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -112,14 +112,20 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) {
|
||||
// to decode only what we need to make routing decisions.
|
||||
case sphinx.PayloadTLV:
|
||||
isFinal := r.processedPacket.Action == sphinx.ExitNode
|
||||
payload, parsed, err := NewPayloadFromReader(
|
||||
payload, parsed, err := ParseTLVPayload(
|
||||
bytes.NewReader(r.processedPacket.Payload.Payload),
|
||||
isFinal, r.blindingKit.UpdateAddBlinding.IsSome(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := ValidateTLVPayload(
|
||||
parsed, isFinal,
|
||||
r.blindingKit.UpdateAddBlinding.IsSome(),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If we had an encrypted data payload present, pull out our
|
||||
// forwarding info from the blob.
|
||||
if payload.encryptedData != nil {
|
||||
|
@ -133,14 +133,10 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
|
||||
}
|
||||
}
|
||||
|
||||
// NewPayloadFromReader builds a new Hop from the passed io.Reader and returns
|
||||
// a map of all the types that were found in the payload. The reader
|
||||
// should correspond to the bytes encapsulated in a TLV onion payload. The
|
||||
// final hop bool signals that this payload was the final packet parsed by
|
||||
// sphinx.
|
||||
func NewPayloadFromReader(r io.Reader, finalHop,
|
||||
updateAddBlinding bool) (*Payload, map[tlv.Type][]byte, error) {
|
||||
|
||||
// ParseTLVPayload builds a new Hop from the passed io.Reader and returns
|
||||
// a map of all the types that were found in the payload. This function
|
||||
// does not perform validation of TLV types included in the payload.
|
||||
func ParseTLVPayload(r io.Reader) (*Payload, map[tlv.Type][]byte, error) {
|
||||
var (
|
||||
cid uint64
|
||||
amt uint64
|
||||
@ -175,25 +171,6 @@ func NewPayloadFromReader(r io.Reader, finalHop,
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Validate whether the sender properly included or omitted tlv records
|
||||
// in accordance with BOLT 04.
|
||||
err = ValidateParsedPayloadTypes(
|
||||
parsedTypes, finalHop, updateAddBlinding,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Check for violation of the rules for mandatory fields.
|
||||
violatingType := getMinRequiredViolation(parsedTypes)
|
||||
if violatingType != nil {
|
||||
return nil, nil, ErrInvalidPayload{
|
||||
Type: *violatingType,
|
||||
Violation: RequiredViolation,
|
||||
FinalHop: finalHop,
|
||||
}
|
||||
}
|
||||
|
||||
// If no MPP field was parsed, set the MPP field on the resulting
|
||||
// payload to nil.
|
||||
if _, ok := parsedTypes[record.MPPOnionType]; !ok {
|
||||
@ -234,7 +211,34 @@ func NewPayloadFromReader(r io.Reader, finalHop,
|
||||
blindingPoint: blindingPoint,
|
||||
customRecords: customRecords,
|
||||
totalAmtMsat: lnwire.MilliSatoshi(totalAmtMsat),
|
||||
}, nil, nil
|
||||
}, parsedTypes, nil
|
||||
}
|
||||
|
||||
// ValidateTLVPayload validates the TLV fields that were included in a TLV
|
||||
// payload.
|
||||
func ValidateTLVPayload(parsedTypes map[tlv.Type][]byte,
|
||||
finalHop bool, updateAddBlinding bool) error {
|
||||
|
||||
// Validate whether the sender properly included or omitted tlv records
|
||||
// in accordance with BOLT 04.
|
||||
err := ValidateParsedPayloadTypes(
|
||||
parsedTypes, finalHop, updateAddBlinding,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for violation of the rules for mandatory fields.
|
||||
violatingType := getMinRequiredViolation(parsedTypes)
|
||||
if violatingType != nil {
|
||||
return ErrInvalidPayload{
|
||||
Type: *violatingType,
|
||||
Violation: RequiredViolation,
|
||||
FinalHop: finalHop,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForwardingInfo returns the basic parameters required for HTLC forwarding,
|
||||
|
@ -547,9 +547,13 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
||||
testChildIndex = uint32(9)
|
||||
)
|
||||
|
||||
p, _, err := hop.NewPayloadFromReader(
|
||||
bytes.NewReader(test.payload), test.isFinalHop,
|
||||
test.updateAddBlinded,
|
||||
p, parsedTypes, err := hop.ParseTLVPayload(
|
||||
bytes.NewReader(test.payload),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = hop.ValidateTLVPayload(
|
||||
parsedTypes, test.isFinalHop, test.updateAddBlinded,
|
||||
)
|
||||
if !reflect.DeepEqual(test.expErr, err) {
|
||||
t.Fatalf("expected error mismatch, want: %v, got: %v",
|
||||
|
Loading…
x
Reference in New Issue
Block a user