From 309a1564d09861c294925d8c0503c73ff2f94df5 Mon Sep 17 00:00:00 2001 From: George Tsagkarelis Date: Thu, 2 May 2024 18:48:28 +0200 Subject: [PATCH] routing: use first hop records on path finding --- routing/pathfind.go | 54 +++++++++++++++++++++++++++++++++--- routing/payment_lifecycle.go | 1 + routing/payment_session.go | 28 +++++++++++-------- 3 files changed, 67 insertions(+), 16 deletions(-) diff --git a/routing/pathfind.go b/routing/pathfind.go index add833ecc..c22020682 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -13,9 +13,11 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/feature" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -433,6 +435,10 @@ type RestrictParams struct { // BlindedPayment is necessary to determine the hop size of the // last/exit hop. BlindedPayment *BlindedPayment + + // FirstHopCustomRecords includes any records that should be included in + // the update_add_htlc message towards our peer. + FirstHopCustomRecords record.CustomSet } // PathFindingConfig defines global parameters that control the trade-off in @@ -459,9 +465,10 @@ type PathFindingConfig struct { // available balance. func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, bandwidthHints bandwidthHints, - g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { + g routingGraph, htlcBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi + cb := func(channel *channeldb.DirectedChannel) error { if !channel.OutPolicySet { return nil @@ -477,7 +484,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, } bandwidth, ok := bandwidthHints.availableChanBandwidth( - chanID, 0, + chanID, 0, htlcBlob, ) // If the bandwidth is not available, use the channel capacity. @@ -491,7 +498,9 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, max = bandwidth } - total += bandwidth + total = lnwire.MilliSatoshi( + safeAdd(uint64(total), uint64(bandwidth)), + ) return nil } @@ -599,8 +608,23 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, self := g.graph.sourceNode() if source == self { + + firstHopTLVs := tlv.MapToRecords(r.FirstHopCustomRecords) + wireRecords := fn.Map(func(r tlv.Record) tlv.RecordProducer { + return &r + }, firstHopTLVs) + + firstHopData := lnwire.ExtraOpaqueData{} + + err := firstHopData.PackRecords(wireRecords...) + if err != nil { + return nil, 0, err + } + + tlvOption := fn.Some[tlv.Blob](firstHopData) max, total, err := getOutgoingBalance( self, outgoingChanMap, g.bandwidthHints, g.graph, + tlvOption, ) if err != nil { return nil, 0, err @@ -1029,9 +1053,23 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, continue } + firstHopTLVs := tlv.MapToRecords(r.FirstHopCustomRecords) + wireRecords := fn.Map(func(r tlv.Record) tlv.RecordProducer { + return &r + }, firstHopTLVs) + + firstHopData := lnwire.ExtraOpaqueData{} + + err := firstHopData.PackRecords(wireRecords...) + if err != nil { + return nil, 0, err + } + + tlvOption := fn.Some[tlv.Blob](firstHopData) + edge := edgeUnifier.getEdge( netAmountReceived, g.bandwidthHints, - partialPath.outboundFee, + partialPath.outboundFee, tlvOption, ) if edge == nil { @@ -1223,3 +1261,11 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, // The final hop does not have a short chanID set. return finalHop.PayloadSize(0) } + +func safeAdd(x, y uint64) uint64 { + if y > math.MaxUint64-x { + // Overflow would occur, return maximum uint64 value + return math.MaxUint64 + } + return x + y +} diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 57b77e5f8..9757afebd 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -362,6 +362,7 @@ func (p *paymentLifecycle) requestRoute( rt, err := p.paySession.RequestRoute( ps.RemainingAmt, remainingFees, uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), + p.firstHopTLVs, ) // Exit early if there's no error. diff --git a/routing/payment_session.go b/routing/payment_session.go index 2d174244c..e435b47d1 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -10,6 +10,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" ) @@ -138,7 +139,8 @@ type PaymentSession interface { // A noRouteError is returned if a non-critical error is encountered // during path finding. RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32) (*route.Route, error) + activeShards, height uint32, + firstHopTLVs record.CustomSet) (*route.Route, error) // UpdateAdditionalEdge takes an additional channel edge policy // (private channels) and applies the update from the message. Returns @@ -228,7 +230,8 @@ func newPaymentSession(p *LightningPayment, // NOTE: This function is safe for concurrent access. // NOTE: Part of the PaymentSession interface. func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32) (*route.Route, error) { + activeShards, height uint32, + firstHopTLVs record.CustomSet) (*route.Route, error) { if p.empty { return nil, errEmptyPaySession @@ -250,16 +253,17 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // to our destination, respecting the recommendations from // MissionControl. restrictions := &RestrictParams{ - ProbabilitySource: p.missionControl.GetProbability, - FeeLimit: feeLimit, - OutgoingChannelIDs: p.payment.OutgoingChannelIDs, - LastHop: p.payment.LastHop, - CltvLimit: cltvLimit, - DestCustomRecords: p.payment.DestCustomRecords, - DestFeatures: p.payment.DestFeatures, - PaymentAddr: p.payment.PaymentAddr, - Amp: p.payment.amp, - Metadata: p.payment.Metadata, + ProbabilitySource: p.missionControl.GetProbability, + FeeLimit: feeLimit, + OutgoingChannelIDs: p.payment.OutgoingChannelIDs, + LastHop: p.payment.LastHop, + CltvLimit: cltvLimit, + DestCustomRecords: p.payment.DestCustomRecords, + DestFeatures: p.payment.DestFeatures, + PaymentAddr: p.payment.PaymentAddr, + Amp: p.payment.amp, + Metadata: p.payment.Metadata, + FirstHopCustomRecords: firstHopTLVs, } finalHtlcExpiry := int32(height) + int32(finalCltvDelta)