From 019b8fa8aacd07c8b6944d787b3ced0a1784a268 Mon Sep 17 00:00:00 2001 From: Carla Kirk-Cohen Date: Wed, 14 Dec 2022 15:00:37 -0500 Subject: [PATCH] hop: add function for calculating forwarding amount Co-authored-by: Calvin Zachman --- htlcswitch/hop/iterator.go | 45 ++++++++++++++++++++++++++++ htlcswitch/hop/iterator_test.go | 53 +++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index f2f728038..095fd190d 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -114,6 +114,51 @@ func (r *sphinxHopIterator) ExtractErrorEncrypter( return extracter(r.ogPacket.EphemeralKey) } +// calculateForwardingAmount calculates the amount to forward for a blinded +// hop based on the incoming amount and forwarding parameters. +// +// When forwarding a payment, the fee we take is calculated, not on the +// incoming amount, but rather on the amount we forward. We charge fees based +// on our own liquidity we are forwarding downstream. +// +// With route blinding, we are NOT given the amount to forward. This +// unintuitive looking formula comes from the fact that without the amount to +// forward, we cannot compute the fees taken directly. +// +// The amount to be forwarded can be computed as follows: +// +// amt_to_forward = incoming_amount - total_fees +// total_fees = base_fee + amt_to_forward*(fee_rate/1000000) +// +// Solving for amount_to_forward: +// amt_to_forward = incoming_amount - base_fee - (amount_to_forward * fee_rate)/1e6 +// amt_to_forward + (amount_to_forward * fee_rate) / 1e6 = incoming_amount - base_fee +// amt_to_forward * 1e6 + (amount_to_forward * fee_rate) = (incoming_amount - base_fee) * 1e6 +// amt_to_forward * (1e6 + fee_rate) = (incoming_amount - base_fee) * 1e6 +// amt_to_forward = ((incoming_amount - base_fee) * 1e6) / (1e6 + fee_rate) +// +// From there we use a ceiling formula for integer division so that we always +// round up, otherwise the sender may receive slightly less than intended: +// +// ceil(a/b) = (a + b - 1)/(b). +// +//nolint:lll,dupword +func calculateForwardingAmount(incomingAmount lnwire.MilliSatoshi, baseFee, + proportionalFee uint32) (lnwire.MilliSatoshi, error) { + + // Sanity check to prevent overflow. + if incomingAmount < lnwire.MilliSatoshi(baseFee) { + return 0, fmt.Errorf("incoming amount: %v < base fee: %v", + incomingAmount, baseFee) + } + numerator := (uint64(incomingAmount) - uint64(baseFee)) * 1e6 + denominator := 1e6 + uint64(proportionalFee) + + ceiling := (numerator + denominator - 1) / denominator + + return lnwire.MilliSatoshi(ceiling), nil +} + // OnionProcessor is responsible for keeping all sphinx dependent parts inside // and expose only decoding function. With such approach we give freedom for // subsystems which wants to decode sphinx path to not be dependable from diff --git a/htlcswitch/hop/iterator_test.go b/htlcswitch/hop/iterator_test.go index cb2a2816f..74eb60a7a 100644 --- a/htlcswitch/hop/iterator_test.go +++ b/htlcswitch/hop/iterator_test.go @@ -100,3 +100,56 @@ func TestSphinxHopIteratorForwardingInstructions(t *testing.T) { } } } + +// TestForwardingAmountCalc tests calculation of forwarding amounts from the +// hop's forwarding parameters. +func TestForwardingAmountCalc(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + incomingAmount lnwire.MilliSatoshi + baseFee uint32 + proportional uint32 + forwardAmount lnwire.MilliSatoshi + expectErr bool + }{ + { + name: "overflow", + incomingAmount: 10, + baseFee: 100, + expectErr: true, + }, + { + name: "trivial proportional", + incomingAmount: 100_000, + baseFee: 1000, + proportional: 10, + forwardAmount: 99000, + }, + { + name: "both fees charged", + incomingAmount: 10_002_020, + baseFee: 1000, + proportional: 1, + forwardAmount: 10_001_010, + }, + } + + for _, testCase := range tests { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + actual, err := calculateForwardingAmount( + testCase.incomingAmount, testCase.baseFee, + testCase.proportional, + ) + + require.Equal(t, testCase.expectErr, err != nil) + require.Equal(t, testCase.forwardAmount.ToSatoshis(), + actual.ToSatoshis()) + }) + } +}