routing: add outgoingFromIncoming amount calc

Adds a utility function to be able to compute the outgoing routing
amount from the incoming amount by taking inbound and outbound fees into
account. The discussion was contributed by user feelancer21, see
f6f05fa930.
This commit is contained in:
bitromortac 2024-07-01 13:34:45 +02:00
parent 2c79bf9635
commit 36cd03669b
No known key found for this signature in database
GPG Key ID: 1965063FC13BEBE2
2 changed files with 373 additions and 0 deletions

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"fmt" "fmt"
"math" "math"
"math/big"
"sort" "sort"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -836,6 +837,7 @@ func generateSphinxPacket(rt *route.Route, paymentHash []byte,
hopCopy := sphinxPath[i] hopCopy := sphinxPath[i]
path[i] = hopCopy path[i] = hopCopy
} }
return spew.Sdump(path) return spew.Sdump(path)
}), }),
) )
@ -1670,3 +1672,224 @@ func receiverAmtForwardPass(runningAmt lnwire.MilliSatoshi,
return runningAmt, nil return runningAmt, nil
} }
// incomingFromOutgoing computes the incoming amount based on the outgoing
// amount by adding fees to the outgoing amount, replicating the path finding
// and routing process, see also CheckHtlcForward.
func incomingFromOutgoing(outgoingAmt lnwire.MilliSatoshi,
incoming, outgoing *unifiedEdge) lnwire.MilliSatoshi {
outgoingFee := outgoing.policy.ComputeFee(outgoingAmt)
// Net amount is the amount the inbound fees are calculated with.
netAmount := outgoingAmt + outgoingFee
inboundFee := incoming.inboundFees.CalcFee(netAmount)
// The inbound fee is not allowed to reduce the incoming amount below
// the outgoing amount.
if int64(outgoingFee)+inboundFee < 0 {
return outgoingAmt
}
return netAmount + lnwire.MilliSatoshi(inboundFee)
}
// outgoingFromIncoming computes the outgoing amount based on the incoming
// amount by subtracting fees from the incoming amount. Note that this is not
// exactly the inverse of incomingFromOutgoing, because of some rounding.
func outgoingFromIncoming(incomingAmt lnwire.MilliSatoshi,
incoming, outgoing *unifiedEdge) lnwire.MilliSatoshi {
// Convert all quantities to big.Int to be able to hande negative
// values. The formulas to compute the outgoing amount involve terms
// with PPM*PPM*A, which can easily overflow an int64.
A := big.NewInt(int64(incomingAmt))
Ro := big.NewInt(int64(outgoing.policy.FeeProportionalMillionths))
Bo := big.NewInt(int64(outgoing.policy.FeeBaseMSat))
Ri := big.NewInt(int64(incoming.inboundFees.Rate))
Bi := big.NewInt(int64(incoming.inboundFees.Base))
PPM := big.NewInt(1_000_000)
// The following discussion was contributed by user feelancer21, see
//nolint:lll
// https://github.com/feelancer21/lnd/commit/f6f05fa930985aac0d27c3f6681aada1b599162a.
// The incoming amount Ai based on the outgoing amount Ao is computed by
// Ai = max(Ai(Ao), Ao), which caps the incoming amount such that the
// total node fee (Ai - Ao) is non-negative. This is commonly enforced
// by routing nodes.
// The function Ai(Ao) is given by:
// Ai(Ao) = (Ao + Bo + Ro/PPM) + (Bi + (Ao + Ro/PPM + Bo)*Ri/PPM), where
// the first term is the net amount (the outgoing amount plus the
// outbound fee), and the second is the inbound fee computed based on
// the net amount.
// Ai(Ao) can potentially become more negative in absolute value than
// Ao, which is why the above mentioned capping is needed. We can
// abbreviate Ai(Ao) with Ai(Ao) = m*Ao + n, where m and n are:
// m := (1 + Ro/PPM) * (1 + Ri/PPM)
// n := Bi + Bo*(1 + Ri/PPM)
// If we know that m > 0, this is equivalent of Ri/PPM > -1, because Ri
// is the only factor that can become negative. A value or Ri/PPM = -1,
// means that the routing node is willing to give up on 100% of the
// net amount (based on the fee rate), which is likely to not happen in
// practice. This condition will be important for a later trick.
// If we want to compute the incoming amount based on the outgoing
// amount, which is the reverse problem, we need to solve Ai =
// max(Ai(Ao), Ao) for Ao(Ai). Given an incoming amount A,
// we look for an Ao such that A = max(Ai(Ao), Ao).
// The max function separates this into two cases. The case to take is
// not clear yet, because we don't know Ao, but later we see a trick
// how to determine which case is the one to take.
// first case: Ai(Ao) <= Ao:
// Therefore, A = max(Ai(Ao), Ao) = Ao, we find Ao = A.
// This also leads to Ai(A) <= A by substitution into the condition.
// second case: Ai(Ao) > Ao:
// Therefore, A = max(Ai(Ao), Ao) = Ai(Ao) = m*Ao + n. Solving for Ao
// gives Ao = (A - n)/m.
//
// We know
// Ai(Ao) > Ao <=> A = Ai(Ao) > Ao = (A - n)/m,
// so A > (A - n)/m.
//
// **Assuming m > 0**, by multiplying with m, we can transform this to
// A * m + n > A.
//
// We know Ai(A) = A*m + n, therefore Ai(A) > A.
//
// This means that if we apply the incoming amount calculation to the
// **incoming** amount, and this condition holds, then we know that we
// deal with the second case, being able to compute the outgoing amount
// based off the formula Ao = (A - n)/m, otherwise we will just return
// the incoming amount.
// In case the inbound fee rate is less than -1 (-100%), we fail to
// compute the outbound amount and return the incoming amount. This also
// protects against zero division later.
// We compute m in terms of big.Int to be safe from overflows and to be
// consistent with later calculations.
// m := (PPM*PPM + Ri*PPM + Ro*PPM + Ro*Ri)/(PPM*PPM)
// Compute terms in (PPM*PPM + Ri*PPM + Ro*PPM + Ro*Ri).
m1 := new(big.Int).Mul(PPM, PPM)
m2 := new(big.Int).Mul(Ri, PPM)
m3 := new(big.Int).Mul(Ro, PPM)
m4 := new(big.Int).Mul(Ro, Ri)
// Add up terms m1..m4.
m := big.NewInt(0)
m.Add(m, m1)
m.Add(m, m2)
m.Add(m, m3)
m.Add(m, m4)
// Since we compare to 0, we can multiply by PPM*PPM to avoid the
// division.
if m.Int64() <= 0 {
return incomingAmt
}
// In order to decide if the total fee is negative, we apply the fee
// to the *incoming* amount as mentioned before.
// We compute the test amount in terms of big.Int to be safe from
// overflows and to be consistent later calculations.
// testAmtF := A*m + n =
// = A + Bo + Bi + (PPM*(A*Ri + A*Ro + Ro*Ri) + A*Ri*Ro)/(PPM*PPM)
// Compute terms in (A*Ri + A*Ro + Ro*Ri).
t1 := new(big.Int).Mul(A, Ri)
t2 := new(big.Int).Mul(A, Ro)
t3 := new(big.Int).Mul(Ro, Ri)
// Sum up terms t1-t3.
t4 := big.NewInt(0)
t4.Add(t4, t1)
t4.Add(t4, t2)
t4.Add(t4, t3)
// Compute PPM*(A*Ri + A*Ro + Ro*Ri).
t6 := new(big.Int).Mul(PPM, t4)
// Compute A*Ri*Ro.
t7 := new(big.Int).Mul(A, Ri)
t7.Mul(t7, Ro)
// Compute (PPM*(A*Ri + A*Ro + Ro*Ri) + A*Ri*Ro)/(PPM*PPM).
num := new(big.Int).Add(t6, t7)
denom := new(big.Int).Mul(PPM, PPM)
fraction := new(big.Int).Div(num, denom)
// Sum up all terms.
testAmt := big.NewInt(0)
testAmt.Add(testAmt, A)
testAmt.Add(testAmt, Bo)
testAmt.Add(testAmt, Bi)
testAmt.Add(testAmt, fraction)
// Protect against negative values for the integer cast to Msat.
if testAmt.Int64() < 0 {
return incomingAmt
}
// If the second case holds, we have to compute the outgoing amount.
if lnwire.MilliSatoshi(testAmt.Int64()) > incomingAmt {
// Compute the outgoing amount by integer ceiling division. This
// precision is needed because PPM*PPM*A and other terms can
// easily overflow with int64, which happens with about
// A = 10_000 sat.
// out := (A - n) / m = numerator / denominator
// numerator := PPM*(PPM*(A - Bo - Bi) - Bo*Ri)
// denominator := PPM*(PPM + Ri + Ro) + Ri*Ro
var numerator big.Int
// Compute (A - Bo - Bi).
temp1 := new(big.Int).Sub(A, Bo)
temp2 := new(big.Int).Sub(temp1, Bi)
// Compute terms in (PPM*(A - Bo - Bi) - Bo*Ri).
temp3 := new(big.Int).Mul(PPM, temp2)
temp4 := new(big.Int).Mul(Bo, Ri)
// Compute PPM*(PPM*(A - Bo - Bi) - Bo*Ri)
temp5 := new(big.Int).Sub(temp3, temp4)
numerator.Mul(PPM, temp5)
var denominator big.Int
// Compute (PPM + Ri + Ro).
temp1 = new(big.Int).Add(PPM, Ri)
temp2 = new(big.Int).Add(temp1, Ro)
// Compute PPM*(PPM + Ri + Ro) + Ri*Ro.
temp3 = new(big.Int).Mul(PPM, temp2)
temp4 = new(big.Int).Mul(Ri, Ro)
denominator.Add(temp3, temp4)
// We overestimate the outgoing amount by taking the ceiling of
// the division. This means that we may round slightly up by a
// MilliSatoshi, but this helps to ensure that we don't hit min
// HTLC constrains in the context of finding the minimum amount
// of a route.
// ceil = floor((numerator + denominator - 1) / denominator)
ceil := new(big.Int).Add(&numerator, &denominator)
ceil.Sub(ceil, big.NewInt(1))
ceil.Div(ceil, &denominator)
return lnwire.MilliSatoshi(ceil.Int64())
}
// Otherwise the inbound fee made up for the outbound fee, which is why
// we just return the incoming amount.
return incomingAmt
}

View File

@ -1886,6 +1886,156 @@ func TestSenderAmtBackwardPass(t *testing.T) {
require.Equal(t, testReceiverAmt, receiverAmt) require.Equal(t, testReceiverAmt, receiverAmt)
} }
// TestInboundOutbound tests the functions that computes the incoming and
// outgoing amounts based on the fees of the incoming and outgoing channels.
func TestInboundOutbound(t *testing.T) {
var outgoingAmt uint64 = 10_000_000
tests := []struct {
name string
incomingBase int32
incomingRate int32
outgoingBase uint64
outgoingRate uint64
}{
{
name: "only outbound fee",
incomingBase: 0,
incomingRate: 0,
outgoingBase: 20,
outgoingRate: 100,
},
{
name: "positive inbound and outbound fee",
incomingBase: 20,
incomingRate: 100,
outgoingBase: 20,
outgoingRate: 100,
},
{
name: "small negative inbound and outbound fee",
incomingBase: -10,
incomingRate: -50,
outgoingBase: 20,
outgoingRate: 100,
},
{
name: "equal negative inbound and outbound fee",
incomingBase: -20,
incomingRate: -100,
outgoingBase: 20,
outgoingRate: 100,
},
{
name: "large negative inbound and outbound fee",
incomingBase: -30,
incomingRate: -200,
outgoingBase: 20,
outgoingRate: 100,
},
{
name: "order of PPM negative inbound and " +
"outbound fee (m=0)",
incomingBase: -30,
incomingRate: -1_000_000,
outgoingBase: 20,
outgoingRate: 100,
},
{
name: "huge negative inbound and " +
"outbound fee (m<0)",
incomingBase: -30,
incomingRate: -2_000_000,
outgoingBase: 20,
outgoingRate: 100,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(tt *testing.T) {
testInboundOutboundFee(
tt, outgoingAmt, tc.incomingBase,
tc.incomingRate, tc.outgoingBase,
tc.outgoingRate,
)
})
}
}
// testInboundOutboundFee is a helper function that tests the outgoing and
// incoming amount relationship.
func testInboundOutboundFee(t *testing.T, outgoingAmt uint64, inBase,
inRate int32, outBase, outRate uint64) {
debugStr := fmt.Sprintf(
"outAmt=%d, inBase=%d, inRate=%d, outBase=%d, outRate=%d",
outgoingAmt, inBase, inRate, outBase, outRate,
)
incomingEdge := &unifiedEdge{
policy: &models.CachedEdgePolicy{},
inboundFees: models.InboundFee{
Base: inBase,
Rate: inRate,
},
}
outgoingEdge := &unifiedEdge{
policy: &models.CachedEdgePolicy{
FeeBaseMSat: lnwire.MilliSatoshi(
outBase,
),
FeeProportionalMillionths: lnwire.MilliSatoshi(
outRate,
),
},
}
// We compute the incoming amount based on the outgoing amount, which
// mimicks the path finding process.
incomingAmt := incomingFromOutgoing(
lnwire.MilliSatoshi(outgoingAmt), incomingEdge,
outgoingEdge,
)
// We do the reverse and compute the outgoing amount based on the
// incoming amount.
outgoingAmtNew := outgoingFromIncoming(
incomingAmt, incomingEdge, outgoingEdge,
)
// We require that the incoming amount is always larger than or equal to
// the outgoing amount, because total fees (=incoming-outgoing) should
// not become negative.
require.GreaterOrEqual(
t, int64(incomingAmt), int64(outgoingAmtNew), debugStr,
"expected incomingAmt >= outgoingAmtNew",
)
// We check that up to rounding the amounts are equal.
require.InDelta(
t, int64(outgoingAmt), int64(outgoingAmtNew), 1.0, debugStr,
"expected |outgoingAmt - outgoingAmtNew | <= 1",
)
// If we round, the computed outgoing amount should be larger than the
// exact outgoing amount, to not hit any min HTLC limits.
require.GreaterOrEqual(
t, int64(outgoingAmtNew), int64(outgoingAmt), debugStr,
"expected outgoingAmtNew >= outgoingAmt",
)
}
// FuzzInboundOutbound tests the incoming and outgoing amount calculation
// functions with fuzzing.
func FuzzInboundOutboundFee(f *testing.F) {
f.Add(uint64(0), int32(0), int32(0), uint64(0), uint64(0))
f.Fuzz(testInboundOutboundFee)
}
// TestSendToRouteSkipTempErrSuccess validates a successful payment send. // TestSendToRouteSkipTempErrSuccess validates a successful payment send.
func TestSendToRouteSkipTempErrSuccess(t *testing.T) { func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
t.Parallel() t.Parallel()