diff --git a/routing/router.go b/routing/router.go index 1a5c85397..d67fc2ebd 100644 --- a/routing/router.go +++ b/routing/router.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "math" + "math/big" "sort" "sync" "sync/atomic" @@ -836,6 +837,7 @@ func generateSphinxPacket(rt *route.Route, paymentHash []byte, hopCopy := sphinxPath[i] path[i] = hopCopy } + return spew.Sdump(path) }), ) @@ -1670,3 +1672,224 @@ func receiverAmtForwardPass(runningAmt lnwire.MilliSatoshi, return runningAmt, nil } + +// incomingFromOutgoing computes the incoming amount based on the outgoing +// amount by adding fees to the outgoing amount, replicating the path finding +// and routing process, see also CheckHtlcForward. +func incomingFromOutgoing(outgoingAmt lnwire.MilliSatoshi, + incoming, outgoing *unifiedEdge) lnwire.MilliSatoshi { + + outgoingFee := outgoing.policy.ComputeFee(outgoingAmt) + + // Net amount is the amount the inbound fees are calculated with. + netAmount := outgoingAmt + outgoingFee + + inboundFee := incoming.inboundFees.CalcFee(netAmount) + + // The inbound fee is not allowed to reduce the incoming amount below + // the outgoing amount. + if int64(outgoingFee)+inboundFee < 0 { + return outgoingAmt + } + + return netAmount + lnwire.MilliSatoshi(inboundFee) +} + +// outgoingFromIncoming computes the outgoing amount based on the incoming +// amount by subtracting fees from the incoming amount. Note that this is not +// exactly the inverse of incomingFromOutgoing, because of some rounding. +func outgoingFromIncoming(incomingAmt lnwire.MilliSatoshi, + incoming, outgoing *unifiedEdge) lnwire.MilliSatoshi { + + // Convert all quantities to big.Int to be able to hande negative + // values. The formulas to compute the outgoing amount involve terms + // with PPM*PPM*A, which can easily overflow an int64. + A := big.NewInt(int64(incomingAmt)) + Ro := big.NewInt(int64(outgoing.policy.FeeProportionalMillionths)) + Bo := big.NewInt(int64(outgoing.policy.FeeBaseMSat)) + Ri := big.NewInt(int64(incoming.inboundFees.Rate)) + Bi := big.NewInt(int64(incoming.inboundFees.Base)) + PPM := big.NewInt(1_000_000) + + // The following discussion was contributed by user feelancer21, see + //nolint:lll + // https://github.com/feelancer21/lnd/commit/f6f05fa930985aac0d27c3f6681aada1b599162a. + + // The incoming amount Ai based on the outgoing amount Ao is computed by + // Ai = max(Ai(Ao), Ao), which caps the incoming amount such that the + // total node fee (Ai - Ao) is non-negative. This is commonly enforced + // by routing nodes. + + // The function Ai(Ao) is given by: + // Ai(Ao) = (Ao + Bo + Ro/PPM) + (Bi + (Ao + Ro/PPM + Bo)*Ri/PPM), where + // the first term is the net amount (the outgoing amount plus the + // outbound fee), and the second is the inbound fee computed based on + // the net amount. + + // Ai(Ao) can potentially become more negative in absolute value than + // Ao, which is why the above mentioned capping is needed. We can + // abbreviate Ai(Ao) with Ai(Ao) = m*Ao + n, where m and n are: + // m := (1 + Ro/PPM) * (1 + Ri/PPM) + // n := Bi + Bo*(1 + Ri/PPM) + + // If we know that m > 0, this is equivalent of Ri/PPM > -1, because Ri + // is the only factor that can become negative. A value or Ri/PPM = -1, + // means that the routing node is willing to give up on 100% of the + // net amount (based on the fee rate), which is likely to not happen in + // practice. This condition will be important for a later trick. + + // If we want to compute the incoming amount based on the outgoing + // amount, which is the reverse problem, we need to solve Ai = + // max(Ai(Ao), Ao) for Ao(Ai). Given an incoming amount A, + // we look for an Ao such that A = max(Ai(Ao), Ao). + + // The max function separates this into two cases. The case to take is + // not clear yet, because we don't know Ao, but later we see a trick + // how to determine which case is the one to take. + + // first case: Ai(Ao) <= Ao: + // Therefore, A = max(Ai(Ao), Ao) = Ao, we find Ao = A. + // This also leads to Ai(A) <= A by substitution into the condition. + + // second case: Ai(Ao) > Ao: + // Therefore, A = max(Ai(Ao), Ao) = Ai(Ao) = m*Ao + n. Solving for Ao + // gives Ao = (A - n)/m. + // + // We know + // Ai(Ao) > Ao <=> A = Ai(Ao) > Ao = (A - n)/m, + // so A > (A - n)/m. + // + // **Assuming m > 0**, by multiplying with m, we can transform this to + // A * m + n > A. + // + // We know Ai(A) = A*m + n, therefore Ai(A) > A. + // + // This means that if we apply the incoming amount calculation to the + // **incoming** amount, and this condition holds, then we know that we + // deal with the second case, being able to compute the outgoing amount + // based off the formula Ao = (A - n)/m, otherwise we will just return + // the incoming amount. + + // In case the inbound fee rate is less than -1 (-100%), we fail to + // compute the outbound amount and return the incoming amount. This also + // protects against zero division later. + + // We compute m in terms of big.Int to be safe from overflows and to be + // consistent with later calculations. + // m := (PPM*PPM + Ri*PPM + Ro*PPM + Ro*Ri)/(PPM*PPM) + + // Compute terms in (PPM*PPM + Ri*PPM + Ro*PPM + Ro*Ri). + m1 := new(big.Int).Mul(PPM, PPM) + m2 := new(big.Int).Mul(Ri, PPM) + m3 := new(big.Int).Mul(Ro, PPM) + m4 := new(big.Int).Mul(Ro, Ri) + + // Add up terms m1..m4. + m := big.NewInt(0) + m.Add(m, m1) + m.Add(m, m2) + m.Add(m, m3) + m.Add(m, m4) + + // Since we compare to 0, we can multiply by PPM*PPM to avoid the + // division. + if m.Int64() <= 0 { + return incomingAmt + } + + // In order to decide if the total fee is negative, we apply the fee + // to the *incoming* amount as mentioned before. + + // We compute the test amount in terms of big.Int to be safe from + // overflows and to be consistent later calculations. + // testAmtF := A*m + n = + // = A + Bo + Bi + (PPM*(A*Ri + A*Ro + Ro*Ri) + A*Ri*Ro)/(PPM*PPM) + + // Compute terms in (A*Ri + A*Ro + Ro*Ri). + t1 := new(big.Int).Mul(A, Ri) + t2 := new(big.Int).Mul(A, Ro) + t3 := new(big.Int).Mul(Ro, Ri) + + // Sum up terms t1-t3. + t4 := big.NewInt(0) + t4.Add(t4, t1) + t4.Add(t4, t2) + t4.Add(t4, t3) + + // Compute PPM*(A*Ri + A*Ro + Ro*Ri). + t6 := new(big.Int).Mul(PPM, t4) + + // Compute A*Ri*Ro. + t7 := new(big.Int).Mul(A, Ri) + t7.Mul(t7, Ro) + + // Compute (PPM*(A*Ri + A*Ro + Ro*Ri) + A*Ri*Ro)/(PPM*PPM). + num := new(big.Int).Add(t6, t7) + denom := new(big.Int).Mul(PPM, PPM) + fraction := new(big.Int).Div(num, denom) + + // Sum up all terms. + testAmt := big.NewInt(0) + testAmt.Add(testAmt, A) + testAmt.Add(testAmt, Bo) + testAmt.Add(testAmt, Bi) + testAmt.Add(testAmt, fraction) + + // Protect against negative values for the integer cast to Msat. + if testAmt.Int64() < 0 { + return incomingAmt + } + + // If the second case holds, we have to compute the outgoing amount. + if lnwire.MilliSatoshi(testAmt.Int64()) > incomingAmt { + // Compute the outgoing amount by integer ceiling division. This + // precision is needed because PPM*PPM*A and other terms can + // easily overflow with int64, which happens with about + // A = 10_000 sat. + + // out := (A - n) / m = numerator / denominator + // numerator := PPM*(PPM*(A - Bo - Bi) - Bo*Ri) + // denominator := PPM*(PPM + Ri + Ro) + Ri*Ro + + var numerator big.Int + + // Compute (A - Bo - Bi). + temp1 := new(big.Int).Sub(A, Bo) + temp2 := new(big.Int).Sub(temp1, Bi) + + // Compute terms in (PPM*(A - Bo - Bi) - Bo*Ri). + temp3 := new(big.Int).Mul(PPM, temp2) + temp4 := new(big.Int).Mul(Bo, Ri) + + // Compute PPM*(PPM*(A - Bo - Bi) - Bo*Ri) + temp5 := new(big.Int).Sub(temp3, temp4) + numerator.Mul(PPM, temp5) + + var denominator big.Int + + // Compute (PPM + Ri + Ro). + temp1 = new(big.Int).Add(PPM, Ri) + temp2 = new(big.Int).Add(temp1, Ro) + + // Compute PPM*(PPM + Ri + Ro) + Ri*Ro. + temp3 = new(big.Int).Mul(PPM, temp2) + temp4 = new(big.Int).Mul(Ri, Ro) + denominator.Add(temp3, temp4) + + // We overestimate the outgoing amount by taking the ceiling of + // the division. This means that we may round slightly up by a + // MilliSatoshi, but this helps to ensure that we don't hit min + // HTLC constrains in the context of finding the minimum amount + // of a route. + // ceil = floor((numerator + denominator - 1) / denominator) + ceil := new(big.Int).Add(&numerator, &denominator) + ceil.Sub(ceil, big.NewInt(1)) + ceil.Div(ceil, &denominator) + + return lnwire.MilliSatoshi(ceil.Int64()) + } + + // Otherwise the inbound fee made up for the outbound fee, which is why + // we just return the incoming amount. + return incomingAmt +} diff --git a/routing/router_test.go b/routing/router_test.go index 7ee5817b4..ec095d787 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1886,6 +1886,156 @@ func TestSenderAmtBackwardPass(t *testing.T) { require.Equal(t, testReceiverAmt, receiverAmt) } +// TestInboundOutbound tests the functions that computes the incoming and +// outgoing amounts based on the fees of the incoming and outgoing channels. +func TestInboundOutbound(t *testing.T) { + var outgoingAmt uint64 = 10_000_000 + + tests := []struct { + name string + incomingBase int32 + incomingRate int32 + outgoingBase uint64 + outgoingRate uint64 + }{ + { + name: "only outbound fee", + incomingBase: 0, + incomingRate: 0, + outgoingBase: 20, + outgoingRate: 100, + }, + { + name: "positive inbound and outbound fee", + incomingBase: 20, + incomingRate: 100, + outgoingBase: 20, + outgoingRate: 100, + }, + { + name: "small negative inbound and outbound fee", + incomingBase: -10, + incomingRate: -50, + outgoingBase: 20, + outgoingRate: 100, + }, + { + name: "equal negative inbound and outbound fee", + incomingBase: -20, + incomingRate: -100, + outgoingBase: 20, + outgoingRate: 100, + }, + { + name: "large negative inbound and outbound fee", + incomingBase: -30, + incomingRate: -200, + outgoingBase: 20, + outgoingRate: 100, + }, + { + name: "order of PPM negative inbound and " + + "outbound fee (m=0)", + incomingBase: -30, + incomingRate: -1_000_000, + outgoingBase: 20, + outgoingRate: 100, + }, + { + name: "huge negative inbound and " + + "outbound fee (m<0)", + incomingBase: -30, + incomingRate: -2_000_000, + outgoingBase: 20, + outgoingRate: 100, + }, + } + + for _, tc := range tests { + tc := tc + + t.Run(tc.name, func(tt *testing.T) { + testInboundOutboundFee( + tt, outgoingAmt, tc.incomingBase, + tc.incomingRate, tc.outgoingBase, + tc.outgoingRate, + ) + }) + } +} + +// testInboundOutboundFee is a helper function that tests the outgoing and +// incoming amount relationship. +func testInboundOutboundFee(t *testing.T, outgoingAmt uint64, inBase, + inRate int32, outBase, outRate uint64) { + + debugStr := fmt.Sprintf( + "outAmt=%d, inBase=%d, inRate=%d, outBase=%d, outRate=%d", + outgoingAmt, inBase, inRate, outBase, outRate, + ) + + incomingEdge := &unifiedEdge{ + policy: &models.CachedEdgePolicy{}, + inboundFees: models.InboundFee{ + Base: inBase, + Rate: inRate, + }, + } + + outgoingEdge := &unifiedEdge{ + policy: &models.CachedEdgePolicy{ + FeeBaseMSat: lnwire.MilliSatoshi( + outBase, + ), + FeeProportionalMillionths: lnwire.MilliSatoshi( + outRate, + ), + }, + } + + // We compute the incoming amount based on the outgoing amount, which + // mimicks the path finding process. + incomingAmt := incomingFromOutgoing( + lnwire.MilliSatoshi(outgoingAmt), incomingEdge, + outgoingEdge, + ) + + // We do the reverse and compute the outgoing amount based on the + // incoming amount. + outgoingAmtNew := outgoingFromIncoming( + incomingAmt, incomingEdge, outgoingEdge, + ) + + // We require that the incoming amount is always larger than or equal to + // the outgoing amount, because total fees (=incoming-outgoing) should + // not become negative. + require.GreaterOrEqual( + t, int64(incomingAmt), int64(outgoingAmtNew), debugStr, + "expected incomingAmt >= outgoingAmtNew", + ) + + // We check that up to rounding the amounts are equal. + require.InDelta( + t, int64(outgoingAmt), int64(outgoingAmtNew), 1.0, debugStr, + "expected |outgoingAmt - outgoingAmtNew | <= 1", + ) + + // If we round, the computed outgoing amount should be larger than the + // exact outgoing amount, to not hit any min HTLC limits. + require.GreaterOrEqual( + t, int64(outgoingAmtNew), int64(outgoingAmt), debugStr, + "expected outgoingAmtNew >= outgoingAmt", + ) +} + +// FuzzInboundOutbound tests the incoming and outgoing amount calculation +// functions with fuzzing. +func FuzzInboundOutboundFee(f *testing.F) { + f.Add(uint64(0), int32(0), int32(0), uint64(0), uint64(0)) + + f.Fuzz(testInboundOutboundFee) +} + // TestSendToRouteSkipTempErrSuccess validates a successful payment send. func TestSendToRouteSkipTempErrSuccess(t *testing.T) { t.Parallel()