diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index aca562b03..0a6065c82 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -202,20 +202,135 @@ func extractTLVPayload(r *sphinxHopIterator) (*Payload, RouteRole, error) { return payload, routeRole, nil } - // If we had an encrypted data payload present, pull out our forwarding - // info from the blob. - fwdInfo, err := r.blindingKit.DecryptAndValidateFwdInfo( - payload, isFinal, + return parseAndValidateRecipientData(r, payload, isFinal, routeRole) +} + +// parseAndValidateRecipientData decrypts the payload from the recipient and +// then continues handling and validation based on if we are a forwarding node +// in this blinded path or the final destination node. +func parseAndValidateRecipientData(r *sphinxHopIterator, payload *Payload, + isFinal bool, routeRole RouteRole) (*Payload, RouteRole, error) { + + // Decrypt and validate the blinded route data + routeData, blindingPoint, err := decryptAndValidateBlindedRouteData( + r, payload, ) if err != nil { return nil, routeRole, err } - payload.FwdInfo = *fwdInfo + // Exit early if this onion is for the exit hop of the route since + // route blinding receives are not yet supported. + if isFinal { + return nil, routeRole, fmt.Errorf("being the final hop in a " + + "blinded path is not yet supported") + } + + // Else, we are a forwarding node in this blinded path. + return deriveBlindedRouteForwardingInfo( + r, routeData, payload, routeRole, blindingPoint, + ) +} + +// deriveBlindedRouteForwardingInfo uses the parsed BlindedRouteData from the +// recipient to derive the ForwardingInfo for the payment. +func deriveBlindedRouteForwardingInfo(r *sphinxHopIterator, + routeData *record.BlindedRouteData, payload *Payload, + routeRole RouteRole, blindingPoint *btcec.PublicKey) (*Payload, + RouteRole, error) { + + relayInfo, err := routeData.RelayInfo.UnwrapOrErr( + fmt.Errorf("relay info not set for non-final blinded hop"), + ) + if err != nil { + return nil, routeRole, err + } + + nextSCID, err := routeData.ShortChannelID.UnwrapOrErr( + fmt.Errorf("next SCID not set for non-final blinded hop"), + ) + if err != nil { + return nil, routeRole, err + } + + fwdAmt, err := calculateForwardingAmount( + r.blindingKit.IncomingAmount, relayInfo.Val.BaseFee, + relayInfo.Val.FeeRate, + ) + if err != nil { + return nil, routeRole, err + } + + nextEph, err := routeData.NextBlindingOverride.UnwrapOrFuncErr( + func() (tlv.RecordT[tlv.TlvType8, *btcec.PublicKey], error) { + next, err := r.blindingKit.Processor.NextEphemeral( + blindingPoint, + ) + if err != nil { + return routeData.NextBlindingOverride.Zero(), + err + } + + return tlv.NewPrimitiveRecord[tlv.TlvType8](next), nil + }) + if err != nil { + return nil, routeRole, err + } + + payload.FwdInfo = ForwardingInfo{ + NextHop: nextSCID.Val, + AmountToForward: fwdAmt, + OutgoingCTLV: r.blindingKit.IncomingCltv - uint32( + relayInfo.Val.CltvExpiryDelta, + ), + // Remap from blinding override type to blinding point type. + NextBlinding: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + nextEph.Val, + ), + ), + } return payload, routeRole, nil } +// decryptAndValidateBlindedRouteData decrypts the encrypted payload from the +// payment recipient using a blinding key. The incoming HTLC amount and CLTV +// values are then verified against the policy values from the recipient. +func decryptAndValidateBlindedRouteData(r *sphinxHopIterator, + payload *Payload) (*record.BlindedRouteData, *btcec.PublicKey, error) { + + blindingPoint, err := r.blindingKit.getBlindingPoint( + payload.blindingPoint, + ) + if err != nil { + return nil, nil, err + } + + decrypted, err := r.blindingKit.Processor.DecryptBlindedHopData( + blindingPoint, payload.encryptedData, + ) + if err != nil { + return nil, nil, fmt.Errorf("decrypt blinded data: %w", err) + } + + buf := bytes.NewBuffer(decrypted) + routeData, err := record.DecodeBlindedRouteData(buf) + if err != nil { + return nil, nil, fmt.Errorf("%w: %w", ErrDecodeFailed, err) + } + + err = ValidateBlindedRouteData( + routeData, r.blindingKit.IncomingAmount, + r.blindingKit.IncomingCltv, + ) + if err != nil { + return nil, nil, err + } + + return routeData, blindingPoint, nil +} + // parseAndValidateSenderPayload parses the payload bytes received from the // onion constructor (the sender) and validates that various fields have been // set. It also uses the presence of a blinding key in either the @@ -367,110 +482,6 @@ func (b *BlindingKit) getBlindingPoint(payloadBlinding *btcec.PublicKey) ( } } -// DecryptAndValidateFwdInfo performs all operations required to decrypt and -// validate a blinded route. -func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload, - isFinalHop bool) (*ForwardingInfo, error) { - - // We expect this function to be called when we have encrypted data - // present, and expect validation to already have ensured that a - // blinding key is set either in the payload or the - // update_add_htlc message. - blindingPoint, err := b.getBlindingPoint(payload.blindingPoint) - if err != nil { - return nil, err - } - - decrypted, err := b.Processor.DecryptBlindedHopData( - blindingPoint, payload.encryptedData, - ) - if err != nil { - return nil, fmt.Errorf("decrypt blinded "+ - "data: %w", err) - } - - buf := bytes.NewBuffer(decrypted) - routeData, err := record.DecodeBlindedRouteData(buf) - if err != nil { - return nil, fmt.Errorf("%w: %w", - ErrDecodeFailed, err) - } - - // Validate the data in the blinded route against our incoming htlc's - // information. - if err := ValidateBlindedRouteData( - routeData, b.IncomingAmount, b.IncomingCltv, - ); err != nil { - return nil, err - } - - // Exit early if this onion is for the exit hop of the route since - // route blinding receives are not yet supported. - if isFinalHop { - return nil, fmt.Errorf("being the final hop in a blinded " + - "path is not yet supported") - } - - // At this point, we know we are a forwarding node for this onion - // and so we expect the relay info and next SCID fields to be set. - relayInfo, err := routeData.RelayInfo.UnwrapOrErr( - fmt.Errorf("relay info not set for non-final blinded hop"), - ) - if err != nil { - return nil, err - } - - nextSCID, err := routeData.ShortChannelID.UnwrapOrErr( - fmt.Errorf("next SCID not set for non-final blinded hop"), - ) - if err != nil { - return nil, err - } - - fwdAmt, err := calculateForwardingAmount( - b.IncomingAmount, relayInfo.Val.BaseFee, relayInfo.Val.FeeRate, - ) - if err != nil { - return nil, err - } - - // If we have an override for the blinding point for the next node, - // we'll just use it without tweaking (the sender intended to switch - // out directly for this blinding point). Otherwise, we'll tweak our - // blinding point to get the next ephemeral key. - nextEph, err := routeData.NextBlindingOverride.UnwrapOrFuncErr( - func() (tlv.RecordT[tlv.TlvType8, - *btcec.PublicKey], error) { - - next, err := b.Processor.NextEphemeral(blindingPoint) - if err != nil { - // Return a zero record because we expect the - // error to be checked. - return routeData.NextBlindingOverride.Zero(), - err - } - - return tlv.NewPrimitiveRecord[tlv.TlvType8](next), nil - }, - ) - if err != nil { - return nil, err - } - - return &ForwardingInfo{ - NextHop: nextSCID.Val, - AmountToForward: fwdAmt, - OutgoingCTLV: b.IncomingCltv - uint32( - relayInfo.Val.CltvExpiryDelta, - ), - // Remap from blinding override type to blinding point type. - NextBlinding: tlv.SomeRecordT( - tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( - nextEph.Val), - ), - }, nil -} - // calculateForwardingAmount calculates the amount to forward for a blinded // hop based on the incoming amount and forwarding parameters. // diff --git a/htlcswitch/hop/iterator_test.go b/htlcswitch/hop/iterator_test.go index 62df54412..74be6b190 100644 --- a/htlcswitch/hop/iterator_test.go +++ b/htlcswitch/hop/iterator_test.go @@ -177,11 +177,11 @@ func (m *mockProcessor) NextEphemeral(*btcec.PublicKey) (*btcec.PublicKey, return nil, nil } -// TestDecryptAndValidateFwdInfo tests deriving forwarding info using a +// TestParseAndValidateRecipientData tests deriving forwarding info using a // blinding kit. This test does not cover assertions on the calculations of // forwarding information, because this is covered in a test dedicated to those // calculations. -func TestDecryptAndValidateFwdInfo(t *testing.T) { +func TestParseAndValidateRecipientData(t *testing.T) { t.Parallel() // Encode valid blinding data that we'll fake decrypting for our test. @@ -282,11 +282,15 @@ func TestDecryptAndValidateFwdInfo(t *testing.T) { tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](testCase.updateAddBlinding), ) } - _, err := kit.DecryptAndValidateFwdInfo( - &Payload{ + iterator := &sphinxHopIterator{ + blindingKit: kit, + } + + _, _, err = parseAndValidateRecipientData( + iterator, &Payload{ encryptedData: testCase.data, blindingPoint: testCase.payloadBlinding, - }, false, + }, false, RouteRoleCleartext, ) require.ErrorIs(t, err, testCase.expectedErr) })