diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index 5ef870874..4d9ee95b0 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -173,53 +173,7 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, RouteRole, error) { // Otherwise, if this is the TLV payload, then we'll make a new stream // to decode only what we need to make routing decisions. case sphinx.PayloadTLV: - isFinal := r.processedPacket.Action == sphinx.ExitNode - payload, parsed, err := ParseTLVPayload( - bytes.NewReader(r.processedPacket.Payload.Payload), - ) - if err != nil { - // If we couldn't even parse our payload then we do - // a best-effort of determining our role in a blinded - // route, accepting that we can't know whether we - // were the introduction node (as the payload - // is not parseable). - routeRole := RouteRoleCleartext - if r.blindingKit.UpdateAddBlinding.IsSome() { - routeRole = RouteRoleRelaying - } - - return nil, routeRole, err - } - - // Now that we've parsed our payload we can determine which - // role we're playing in the route. - _, payloadBlinding := parsed[record.BlindingPointOnionType] - routeRole := NewRouteRole( - r.blindingKit.UpdateAddBlinding.IsSome(), - payloadBlinding, - ) - - if err := ValidateTLVPayload( - parsed, isFinal, - r.blindingKit.UpdateAddBlinding.IsSome(), - ); err != nil { - return nil, routeRole, err - } - - // If we had an encrypted data payload present, pull out our - // forwarding info from the blob. - if payload.encryptedData != nil { - fwdInfo, err := r.blindingKit.DecryptAndValidateFwdInfo( - payload, isFinal, parsed, - ) - if err != nil { - return nil, routeRole, err - } - - payload.FwdInfo = *fwdInfo - } - - return payload, routeRole, nil + return extractTLVPayload(r) default: return nil, RouteRoleCleartext, @@ -228,6 +182,73 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, RouteRole, error) { } } +// extractTLVPayload parses the hop payload and assumes that it uses the TLV +// format. It returns the parsed payload along with the RouteRole that this hop +// plays given the contents of the payload. +func extractTLVPayload(r *sphinxHopIterator) (*Payload, RouteRole, error) { + isFinal := r.processedPacket.Action == sphinx.ExitNode + + // Extract TLVs from the packet constructor (the sender). + payload, parsed, err := ParseTLVPayload( + bytes.NewReader(r.processedPacket.Payload.Payload), + ) + if err != nil { + // If we couldn't even parse our payload then we do a + // best-effort of determining our role in a blinded route, + // accepting that we can't know whether we were the introduction + // node (as the payload is not parseable). + routeRole := RouteRoleCleartext + if r.blindingKit.UpdateAddBlinding.IsSome() { + routeRole = RouteRoleRelaying + } + + return nil, routeRole, err + } + + // Now that we've parsed our payload we can determine which role we're + // playing in the route. + _, payloadBlinding := parsed[record.BlindingPointOnionType] + routeRole := NewRouteRole( + r.blindingKit.UpdateAddBlinding.IsSome(), payloadBlinding, + ) + + // Validate the presence of the various payload fields we received from + // the sender. + if err := ValidateTLVPayload( + parsed, isFinal, r.blindingKit.UpdateAddBlinding.IsSome(), + ); err != nil { + return nil, routeRole, err + } + + // If there is no encrypted data from the receiver then return the + // payload as is since the forwarding info would have been received + // from the sender. + if payload.encryptedData != nil { + return payload, routeRole, nil + } + + // Validate the presence of various fields in the sender payload given + // that we now know that this is a hop with instructions from the + // recipient. + err = ValidatePayloadWithBlinded(isFinal, parsed) + if err != nil { + return payload, routeRole, err + } + + // If we had an encrypted data payload present, pull out our forwarding + // info from the blob. + fwdInfo, err := r.blindingKit.DecryptAndValidateFwdInfo( + payload, isFinal, + ) + if err != nil { + return nil, routeRole, err + } + + payload.FwdInfo = *fwdInfo + + return payload, routeRole, nil +} + // ExtractErrorEncrypter decodes and returns the ErrorEncrypter for this hop, // along with a failure code to signal if the decoding was successful. The // ErrorEncrypter is used to encrypt errors back to the sender in the event that @@ -327,8 +348,7 @@ func (b *BlindingKit) getBlindingPoint(payloadBlinding *btcec.PublicKey) ( // DecryptAndValidateFwdInfo performs all operations required to decrypt and // validate a blinded route. func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload, - isFinalHop bool, payloadParsed map[tlv.Type][]byte) ( - *ForwardingInfo, error) { + isFinalHop bool) (*ForwardingInfo, error) { // We expect this function to be called when we have encrypted data // present, and expect validation to already have ensured that a @@ -354,13 +374,6 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload, ErrDecodeFailed, err) } - // Validate the contents of the payload against the values we've - // just pulled out of the encrypted data blob. - err = ValidatePayloadWithBlinded(isFinalHop, payloadParsed) - if err != nil { - return nil, err - } - // Validate the data in the blinded route against our incoming htlc's // information. if err := ValidateBlindedRouteData( diff --git a/htlcswitch/hop/iterator_test.go b/htlcswitch/hop/iterator_test.go index d99f1de24..62df54412 100644 --- a/htlcswitch/hop/iterator_test.go +++ b/htlcswitch/hop/iterator_test.go @@ -287,7 +287,6 @@ func TestDecryptAndValidateFwdInfo(t *testing.T) { encryptedData: testCase.data, blindingPoint: testCase.payloadBlinding, }, false, - make(map[tlv.Type][]byte), ) require.ErrorIs(t, err, testCase.expectedErr) })