From 7db072e0207c86ed33b63ee07189ca41618d8226 Mon Sep 17 00:00:00 2001 From: Carla Kirk-Cohen Date: Wed, 1 Nov 2023 09:54:17 -0400 Subject: [PATCH] routing: add additional validation to hop payload creation --- routing/route/route.go | 98 ++++++++++++++++++++++++++++++++++--- routing/route/route_test.go | 2 +- 2 files changed, 93 insertions(+), 7 deletions(-) diff --git a/routing/route/route.go b/routing/route/route.go index 47c305ba0..2ed9751ae 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -38,6 +38,13 @@ var ( // ErrAMPMissingMPP is returned when the caller tries to attach an AMP // record but no MPP record is presented for the final hop. ErrAMPMissingMPP = errors.New("cannot send AMP without MPP record") + + // ErrMissingField is returned if a required TLV is missing. + ErrMissingField = errors.New("required tlv missing") + + // ErrIncorrectField is returned if a tlv field is included when it + // should not be. + ErrIncorrectField = errors.New("incorrect tlv included") ) // Vertex is a simple alias for the serialization of a compressed Bitcoin @@ -193,8 +200,28 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64, var records []tlv.Record // Hops that are not part of a blinded path will have an amount and - // a CLTV expiry field. Zero values indicate that the hop is inside of - // a blinded route, so the TLV should not be included. + // a CLTV expiry field. In a blinded route (where encrypted data is + // non-nil), these values may be omitted for intermediate nodes. + // Validate these fields against the structure of the payload so that + // we know they're included (or excluded) correctly. + isBlinded := h.EncryptedData != nil + + if err := optionalBlindedField( + h.AmtToForward == 0, isBlinded, finalHop, + ); err != nil { + return fmt.Errorf("%w: amount to forward: %v", err, + h.AmtToForward) + } + + if err := optionalBlindedField( + h.OutgoingTimeLock == 0, isBlinded, finalHop, + ); err != nil { + return fmt.Errorf("%w: outgoing timelock: %v", err, + h.OutgoingTimeLock) + } + + // Once we've validated that these TLVs are set as we expect, we can + // go ahead and include them if non-zero. amt := uint64(h.AmtToForward) if amt != 0 { records = append( @@ -208,10 +235,13 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64, ) } - // BOLT 04 says the next_hop_id should be omitted for the final hop, - // but present for all others. - // - // TODO(conner): test using hop.Exit once available + // Validate channel TLV is present as expected based on location in + // route and whether this hop is blinded. + err := validateNextChanID(nextChanID != 0, isBlinded, finalHop) + if err != nil { + return fmt.Errorf("%w: channel id: %v", err, nextChanID) + } + if nextChanID != 0 { records = append(records, record.NewNextHopIDRecord(&nextChanID), @@ -284,6 +314,62 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64, return tlvStream.Encode(w) } +// optionalBlindedField validates fields that we expect to be non-zero for all +// hops in a regular route, but may be zero for intermediate nodes in a blinded +// route. It will validate the following cases: +// - Not blinded: require non-zero values. +// - Intermediate blinded node: require zero values. +// - Final blinded node: require non-zero values. +func optionalBlindedField(isZero, blindedHop, finalHop bool) error { + switch { + // We are not in a blinded route and the TLV is not set when it should + // be. + case !blindedHop && isZero: + return ErrMissingField + + // We are not in a blinded route and the TLV is set as expected. + case !blindedHop: + return nil + + // In a blinded route the final hop is expected to have TLV values set. + case finalHop && isZero: + return ErrMissingField + + // In an intermediate hop in a blinded route and the field is not zero. + case !finalHop && !isZero: + return ErrIncorrectField + } + + return nil +} + +// validateNextChanID validates the presence of the nextChanID TLV field in +// a payload. For regular payments, it is expected to be present for all hops +// except the final hop. For blinded paths, it is not expected to be included +// at all (as this value is provided in encrypted data). +func validateNextChanID(nextChanIDIsSet, isBlinded, finalHop bool) error { + switch { + // Hops in a blinded route should not have a next channel ID set. + case isBlinded && nextChanIDIsSet: + return ErrIncorrectField + + // Otherwise, blinded hops are allowed to have a zero value. + case isBlinded: + return nil + + // The final hop in a regular route is expected to have a zero value. + case finalHop && nextChanIDIsSet: + return ErrIncorrectField + + // Intermediate hops in regular routes require non-zero value. + case !finalHop && !nextChanIDIsSet: + return ErrMissingField + + default: + return nil + } +} + // Size returns the total size this hop's payload would take up in the onion // packet. func (h *Hop) PayloadSize(nextChanID uint64) uint64 { diff --git a/routing/route/route_test.go b/routing/route/route_test.go index d73233e00..e5202d56a 100644 --- a/routing/route/route_test.go +++ b/routing/route/route_test.go @@ -170,7 +170,7 @@ func TestNoForwardingParams(t *testing.T) { } var b bytes.Buffer - err := hop.PackHopPayload(&b, 2, false) + err := hop.PackHopPayload(&b, 0, false) require.NoError(t, err) }