diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 46f0f42cc..524231c17 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -77,6 +77,9 @@ * [Add Taproot witness types to rpc](https://github.com/lightningnetwork/lnd/pull/8431) + +* [Fixed](https://github.com/lightningnetwork/lnd/pull/7852) the payload size + calculation in our pathfinder because blinded hops introduced new tlv records. # New Features ## Functional Enhancements diff --git a/itest/lnd_amp_test.go b/itest/lnd_amp_test.go index ba17a30b6..61b6d194a 100644 --- a/itest/lnd_amp_test.go +++ b/itest/lnd_amp_test.go @@ -104,6 +104,10 @@ func testSendPaymentAMPInvoiceCase(ht *lntest.HarnessTest, // expect an extra invoice to appear in the ListInvoices response, since // a new invoice will be JIT inserted under a different payment address // than the one in the invoice. + // + // NOTE: This will only work when the peer has spontaneous AMP payments + // enabled otherwise no invoice under a different payment_addr will be + // found. var ( expNumInvoices = 1 externalPayAddr []byte diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 420259de2..4d82c72e4 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -15,7 +15,6 @@ import ( "github.com/btcsuite/btcd/wire" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnrpc" @@ -280,7 +279,7 @@ func (r *RouterBackend) parseQueryRoutesRequest(in *lnrpc.QueryRoutesRequest) ( // inside of the path rather than the request's fields. var ( targetPubKey *route.Vertex - routeHintEdges map[route.Vertex][]*models.CachedEdgePolicy + routeHintEdges map[route.Vertex][]routing.AdditionalEdge blindedPmt *routing.BlindedPayment // finalCLTVDelta varies depending on whether we're sending to @@ -391,6 +390,7 @@ func (r *RouterBackend) parseQueryRoutesRequest(in *lnrpc.QueryRoutesRequest) ( DestCustomRecords: record.CustomSet(in.DestCustomRecords), CltvLimit: cltvLimit, DestFeatures: destinationFeatures, + BlindedPayment: blindedPmt, } // Pass along an outgoing channel restriction if specified. @@ -967,6 +967,9 @@ func (r *RouterBackend) extractIntentFromSendRequest( // pseudo-reusable, e.g. the invoice parameters are // reused (amt, cltv, hop hints, etc) even though the // payments will share different payment hashes. + // + // NOTE: This will only work when the peer has + // spontaneous AMP payments enabled. if len(rpcPayReq.PaymentAddr) > 0 { var addr [32]byte copy(addr[:], rpcPayReq.PaymentAddr) diff --git a/record/amp.go b/record/amp.go index eb431fec3..f63c7a141 100644 --- a/record/amp.go +++ b/record/amp.go @@ -19,6 +19,16 @@ type AMP struct { childIndex uint32 } +// MaxAmpPayLoadSize is an AMP Record which when serialized to a tlv record uses +// the maximum payload size. The `childIndex` is created randomly and is a +// 4 byte `varint` type so we make sure we use an index which will be encoded in +// 4 bytes. +var MaxAmpPayLoadSize = AMP{ + rootShare: [32]byte{}, + setID: [32]byte{}, + childIndex: 0x80000000, +} + // NewAMP generate a new AMP record with the given root_share, set_id, and // child_index. func NewAMP(rootShare, setID [32]byte, childIndex uint32) *AMP { diff --git a/routing/additional_edge.go b/routing/additional_edge.go new file mode 100644 index 000000000..22cf3032b --- /dev/null +++ b/routing/additional_edge.go @@ -0,0 +1,107 @@ +package routing + +import ( + "errors" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +var ( + // ErrNoPayLoadSizeFunc is returned when no payload size function is + // definied. + ErrNoPayLoadSizeFunc = errors.New("no payloadSizeFunc defined for " + + "additional edge") +) + +// AdditionalEdge is an interface which specifies additional edges which can +// be appended to an existing route. Compared to normal edges of a route they +// provide an explicit payload size function and are introduced because blinded +// paths differ in their payload structure. +type AdditionalEdge interface { + // IntermediatePayloadSize returns the size of the payload for the + // additional edge when being an intermediate hop in a route NOT the + // final hop. + IntermediatePayloadSize(amount lnwire.MilliSatoshi, expiry uint32, + legacy bool, channelID uint64) uint64 + + // EdgePolicy returns the policy of the additional edge. + EdgePolicy() *models.CachedEdgePolicy +} + +// PayloadSizeFunc defines the interface for the payload size function. +type PayloadSizeFunc func(amount lnwire.MilliSatoshi, expiry uint32, + legacy bool, channelID uint64) uint64 + +// PrivateEdge implements the AdditionalEdge interface. As the name implies it +// is used for private route hints that the receiver adds for example to an +// invoice. +type PrivateEdge struct { + policy *models.CachedEdgePolicy +} + +// EdgePolicy return the policy of the PrivateEdge. +func (p *PrivateEdge) EdgePolicy() *models.CachedEdgePolicy { + return p.policy +} + +// IntermediatePayloadSize returns the sphinx payload size defined in BOLT04 if +// this edge were to be included in a route. +func (p *PrivateEdge) IntermediatePayloadSize(amount lnwire.MilliSatoshi, + expiry uint32, legacy bool, channelID uint64) uint64 { + + hop := route.Hop{ + AmtToForward: amount, + OutgoingTimeLock: expiry, + LegacyPayload: legacy, + } + + return hop.PayloadSize(channelID) +} + +// BlindedEdge implements the AdditionalEdge interface. Blinded hops are viewed +// as additional edges because they are appened at the end of a normal route. +type BlindedEdge struct { + policy *models.CachedEdgePolicy + cipherText []byte + blindingPoint *btcec.PublicKey +} + +// EdgePolicy return the policy of the BlindedEdge. +func (b *BlindedEdge) EdgePolicy() *models.CachedEdgePolicy { + return b.policy +} + +// IntermediatePayloadSize returns the sphinx payload size defined in BOLT04 if +// this edge were to be included in a route. +func (b *BlindedEdge) IntermediatePayloadSize(_ lnwire.MilliSatoshi, _ uint32, + _ bool, _ uint64) uint64 { + + hop := route.Hop{ + BlindingPoint: b.blindingPoint, + LegacyPayload: false, + EncryptedData: b.cipherText, + } + + // For blinded paths the next chanID is in the encrypted data tlv. + return hop.PayloadSize(0) +} + +// Compile-time constraints to ensure the PrivateEdge and the BlindedEdge +// implement the AdditionalEdge interface. +var _ AdditionalEdge = (*PrivateEdge)(nil) +var _ AdditionalEdge = (*BlindedEdge)(nil) + +// defaultHopPayloadSize is the default payload size of a normal (not-blinded) +// hop in the route. +func defaultHopPayloadSize(amount lnwire.MilliSatoshi, expiry uint32, + legacy bool, channelID uint64) uint64 { + + // The payload size of a cleartext intermediate hop is equal to the + // payload size of a private edge therefore we reuse its size function. + edge := PrivateEdge{} + + return edge.IntermediatePayloadSize(amount, expiry, legacy, channelID) +} diff --git a/routing/additional_edge_test.go b/routing/additional_edge_test.go new file mode 100644 index 000000000..b3ea3b501 --- /dev/null +++ b/routing/additional_edge_test.go @@ -0,0 +1,136 @@ +package routing + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" +) + +// TestIntermediatePayloadSize tests the payload size functions of the +// PrivateEdge and the BlindedEdge. +func TestIntermediatePayloadSize(t *testing.T) { + t.Parallel() + + testPrivKeyBytes, _ := hex.DecodeString("e126f68f7eafcc8b74f54d269fe" + + "206be715000f94dac067d1c04a8ca3b2db734") + _, blindedPoint := btcec.PrivKeyFromBytes(testPrivKeyBytes) + + testCases := []struct { + name string + hop route.Hop + nextHop uint64 + edge AdditionalEdge + }{ + { + name: "Legacy payload private edge", + hop: route.Hop{ + AmtToForward: 1000, + OutgoingTimeLock: 600000, + ChannelID: 3432483437438, + LegacyPayload: true, + }, + nextHop: 1, + edge: &PrivateEdge{}, + }, + { + name: "Tlv payload private edge", + hop: route.Hop{ + AmtToForward: 1000, + OutgoingTimeLock: 600000, + ChannelID: 3432483437438, + LegacyPayload: false, + }, + nextHop: 1, + edge: &PrivateEdge{}, + }, + { + name: "Blinded edge", + hop: route.Hop{ + EncryptedData: []byte{12, 13}, + }, + edge: &BlindedEdge{ + cipherText: []byte{12, 13}, + }, + }, + { + name: "Blinded edge - introduction point", + hop: route.Hop{ + EncryptedData: []byte{12, 13}, + BlindingPoint: blindedPoint, + }, + edge: &BlindedEdge{ + cipherText: []byte{12, 13}, + blindingPoint: blindedPoint, + }, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payLoad, err := createHopPayload( + testCase.hop, testCase.nextHop, false, + ) + require.NoErrorf(t, err, "failed to create hop payload") + + expectedPayloadSize := testCase.edge. + IntermediatePayloadSize( + testCase.hop.AmtToForward, + testCase.hop.OutgoingTimeLock, + testCase.hop.LegacyPayload, + testCase.nextHop, + ) + + require.Equal( + t, expectedPayloadSize, + uint64(payLoad.NumBytes()), + ) + }) + } +} + +// createHopPayload creates the hop payload of the sphinx package to facilitate +// the testing of the payload size. +func createHopPayload(hop route.Hop, nextHop uint64, + finalHop bool) (sphinx.HopPayload, error) { + + // If this is the legacy payload, then we can just include the + // hop data as normal. + if hop.LegacyPayload { + // Before we encode this value, we'll pack the next hop + // into the NextAddress field of the hop info to ensure + // we point to the right now. + hopData := sphinx.HopData{ + ForwardAmount: uint64(hop.AmtToForward), + OutgoingCltv: hop.OutgoingTimeLock, + } + binary.BigEndian.PutUint64( + hopData.NextAddress[:], nextHop, + ) + + return sphinx.NewLegacyHopPayload(&hopData) + } + + // For non-legacy payloads, we'll need to pack the + // routing information, along with any extra TLV + // information into the new per-hop payload format. + // We'll also pass in the chan ID of the hop this + // channel should be forwarded to so we can construct a + // valid payload. + var b bytes.Buffer + err := hop.PackHopPayload(&b, nextHop, finalHop) + if err != nil { + return sphinx.HopPayload{}, err + } + + return sphinx.NewTLVHopPayload(b.Bytes()) +} diff --git a/routing/blinding.go b/routing/blinding.go index 50d3b6f79..61d303a0c 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -99,7 +99,7 @@ func (b *BlindedPayment) toRouteHints() RouteHints { hintCount := len(b.BlindedPath.BlindedHops) - 1 hints := make( - map[route.Vertex][]*models.CachedEdgePolicy, hintCount, + RouteHints, hintCount, ) // Start at the unblinded introduction node, because our pathfinding @@ -116,25 +116,31 @@ func (b *BlindedPayment) toRouteHints() RouteHints { // will ensure that pathfinding provides sufficient fees/delay for the // blinded portion to the introduction node. firstBlindedHop := b.BlindedPath.BlindedHops[1].BlindedNodePub - hints[fromNode] = []*models.CachedEdgePolicy{ - { - TimeLockDelta: b.CltvExpiryDelta, - MinHTLC: lnwire.MilliSatoshi(b.HtlcMinimum), - MaxHTLC: lnwire.MilliSatoshi(b.HtlcMaximum), - FeeBaseMSat: lnwire.MilliSatoshi(b.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi( - b.ProportionalFee, - ), - ToNodePubKey: func() route.Vertex { - return route.NewVertex( - // The first node in this slice is - // the introduction node, so we start - // at index 1 to get the first blinded - // relaying node. - firstBlindedHop, - ) - }, - ToNodeFeatures: features, + edgePolicy := &models.CachedEdgePolicy{ + TimeLockDelta: b.CltvExpiryDelta, + MinHTLC: lnwire.MilliSatoshi(b.HtlcMinimum), + MaxHTLC: lnwire.MilliSatoshi(b.HtlcMaximum), + FeeBaseMSat: lnwire.MilliSatoshi(b.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi( + b.ProportionalFee, + ), + ToNodePubKey: func() route.Vertex { + return route.NewVertex( + // The first node in this slice is + // the introduction node, so we start + // at index 1 to get the first blinded + // relaying node. + firstBlindedHop, + ) + }, + ToNodeFeatures: features, + } + + hints[fromNode] = []AdditionalEdge{ + &BlindedEdge{ + policy: edgePolicy, + cipherText: b.BlindedPath.BlindedHops[0].CipherText, + blindingPoint: b.BlindedPath.BlindingPoint, }, } @@ -156,15 +162,19 @@ func (b *BlindedPayment) toRouteHints() RouteHints { b.BlindedPath.BlindedHops[nextHopIdx].BlindedNodePub, ) - hint := &models.CachedEdgePolicy{ + edgePolicy := &models.CachedEdgePolicy{ ToNodePubKey: func() route.Vertex { return nextNode }, ToNodeFeatures: features, } - hints[fromNode] = []*models.CachedEdgePolicy{ - hint, + hints[fromNode] = []AdditionalEdge{ + &BlindedEdge{ + policy: edgePolicy, + cipherText: b.BlindedPath.BlindedHops[i]. + CipherText, + }, } } diff --git a/routing/blinding_test.go b/routing/blinding_test.go index 5dc71354f..561ace6fc 100644 --- a/routing/blinding_test.go +++ b/routing/blinding_test.go @@ -1,10 +1,12 @@ package routing import ( + "bytes" "testing" "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" @@ -94,6 +96,12 @@ func TestBlindedPaymentToHints(t *testing.T) { htlcMin uint64 = 100 htlcMax uint64 = 100_000_000 + sizeEncryptedData = 100 + cipherText = bytes.Repeat( + []byte{1}, sizeEncryptedData, + ) + _, blindedPoint = btcec.PrivKeyFromBytes([]byte{5}) + rawFeatures = lnwire.NewRawFeatureVector( lnwire.AMPOptional, ) @@ -108,6 +116,7 @@ func TestBlindedPaymentToHints(t *testing.T) { blindedPayment := &BlindedPayment{ BlindedPath: &sphinx.BlindedPath{ IntroductionPoint: pk1, + BlindingPoint: blindedPoint, BlindedHops: []*sphinx.BlindedHopInfo{ {}, }, @@ -125,40 +134,52 @@ func TestBlindedPaymentToHints(t *testing.T) { blindedPayment.BlindedPath.BlindedHops = []*sphinx.BlindedHopInfo{ { BlindedNodePub: pkb1, + CipherText: cipherText, }, { BlindedNodePub: pkb2, + CipherText: cipherText, }, { BlindedNodePub: pkb3, + CipherText: cipherText, }, } expected := RouteHints{ v1: { - { - TimeLockDelta: cltvDelta, - MinHTLC: lnwire.MilliSatoshi(htlcMin), - MaxHTLC: lnwire.MilliSatoshi(htlcMax), - FeeBaseMSat: lnwire.MilliSatoshi(baseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi( - ppmFee, - ), - ToNodePubKey: func() route.Vertex { - return vb2 + //nolint:lll + &BlindedEdge{ + policy: &models.CachedEdgePolicy{ + TimeLockDelta: cltvDelta, + MinHTLC: lnwire.MilliSatoshi(htlcMin), + MaxHTLC: lnwire.MilliSatoshi(htlcMax), + FeeBaseMSat: lnwire.MilliSatoshi(baseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi( + ppmFee, + ), + ToNodePubKey: func() route.Vertex { + return vb2 + }, + ToNodeFeatures: features, }, - ToNodeFeatures: features, + blindingPoint: blindedPoint, + cipherText: cipherText, }, }, vb2: { - { - ToNodePubKey: func() route.Vertex { - return vb3 + &BlindedEdge{ + policy: &models.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return vb3 + }, + ToNodeFeatures: features, }, - ToNodeFeatures: features, + cipherText: cipherText, }, }, } + actual := blindedPayment.toRouteHints() require.Equal(t, len(expected), len(actual)) @@ -170,13 +191,24 @@ func TestBlindedPaymentToHints(t *testing.T) { require.Len(t, actualHint, 1) // We can't assert that our functions are equal, so we check - // their output and then mark as nil so that we can use + // their output and then mark them as nil so that we can use // require.Equal for all our other fields. - require.Equal(t, expectedHint[0].ToNodePubKey(), - actualHint[0].ToNodePubKey()) + require.Equal(t, expectedHint[0].EdgePolicy().ToNodePubKey(), + actualHint[0].EdgePolicy().ToNodePubKey()) - actualHint[0].ToNodePubKey = nil - expectedHint[0].ToNodePubKey = nil + actualHint[0].EdgePolicy().ToNodePubKey = nil + expectedHint[0].EdgePolicy().ToNodePubKey = nil + + // The arguments we use for the payload do not matter as long as + // both functions return the same payload. + expectedPayloadSize := expectedHint[0].IntermediatePayloadSize( + 0, 0, false, 0, + ) + actualPayloadSize := actualHint[0].IntermediatePayloadSize( + 0, 0, false, 0, + ) + + require.Equal(t, expectedPayloadSize, actualPayloadSize) require.Equal(t, expectedHint[0], actualHint[0]) } diff --git a/routing/mocks.go b/routing/mocks.go new file mode 100644 index 000000000..9c019abc9 --- /dev/null +++ b/routing/mocks.go @@ -0,0 +1,32 @@ +package routing + +import ( + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/mock" +) + +// mockAdditionalEdge is a mock of the AdditionalEdge interface. +type mockAdditionalEdge struct{ mock.Mock } + +// IntermediatePayloadSize returns the sphinx payload size defined in BOLT04 if +// this edge were to be included in a route. +func (m *mockAdditionalEdge) IntermediatePayloadSize(amount lnwire.MilliSatoshi, + expiry uint32, legacy bool, channelID uint64) uint64 { + + args := m.Called(amount, expiry, legacy, channelID) + + return args.Get(0).(uint64) +} + +// EdgePolicy return the policy of the mockAdditionalEdge. +func (m *mockAdditionalEdge) EdgePolicy() *models.CachedEdgePolicy { + args := m.Called() + + edgePolicy := args.Get(0) + if edgePolicy == nil { + return nil + } + + return edgePolicy.(*models.CachedEdgePolicy) +} diff --git a/routing/pathfind.go b/routing/pathfind.go index 642dab90b..53eb0dfb4 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -88,7 +88,7 @@ var ( // of the edge. type edgePolicyWithSource struct { sourceNode route.Vertex - edge *models.CachedEdgePolicy + edge AdditionalEdge } // finalHopParams encapsulates various parameters for route construction that @@ -355,8 +355,9 @@ type graphParams struct { // additionalEdges is an optional set of edges that should be // considered during path finding, that is not already found in the - // channel graph. - additionalEdges map[route.Vertex][]*models.CachedEdgePolicy + // channel graph. These can either be private edges for bolt 11 invoices + // or blinded edges when a payment to a blinded path is made. + additionalEdges map[route.Vertex][]AdditionalEdge // bandwidthHints is an interface that provides bandwidth hints that // can provide a better estimate of the current channel bandwidth than @@ -407,9 +408,18 @@ type RestrictParams struct { // invoices. PaymentAddr *[32]byte + // Amp signals to the pathfinder that this payment is an AMP payment + // and therefore it needs to account for additional AMP data in the + // final hop payload size calculation. + Amp *AMPOptions + // Metadata is additional data that is sent along with the payment to // the payee. Metadata []byte + + // BlindedPayment is necessary to determine the hop size of the + // last/exit hop. + BlindedPayment *BlindedPayment } // PathFindingConfig defines global parameters that control the trade-off in @@ -609,7 +619,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, distance := make(map[route.Vertex]*nodeWithDist, estimatedNodeCount) additionalEdgesWithSrc := make(map[route.Vertex][]*edgePolicyWithSource) - for vertex, outgoingEdgePolicies := range g.additionalEdges { + for vertex, additionalEdges := range g.additionalEdges { // Edges connected to self are always included in the graph, // therefore can be skipped. This prevents us from trying // routes to malformed hop hints. @@ -619,12 +629,13 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Build reverse lookup to find incoming edges. Needed because // search is taken place from target to source. - for _, outgoingEdgePolicy := range outgoingEdgePolicies { + for _, additionalEdge := range additionalEdges { + outgoingEdgePolicy := additionalEdge.EdgePolicy() toVertex := outgoingEdgePolicy.ToNodePubKey() incomingEdgePolicy := &edgePolicyWithSource{ sourceNode: vertex, - edge: outgoingEdgePolicy, + edge: additionalEdge, } additionalEdgesWithSrc[toVertex] = @@ -633,23 +644,10 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } } - // Build a preliminary destination hop structure to obtain the payload - // size. - var mpp *record.MPP - if r.PaymentAddr != nil { - mpp = record.NewMPP(amt, *r.PaymentAddr) - } - - finalHop := route.Hop{ - AmtToForward: amt, - OutgoingTimeLock: uint32(finalHtlcExpiry), - CustomRecords: r.DestCustomRecords, - LegacyPayload: !features.HasFeature( - lnwire.TLVOnionPayloadOptional, - ), - MPP: mpp, - Metadata: r.Metadata, - } + // The payload size of the final hop differ from intermediate hops + // and depends on whether the destination is blinded or not. + lastHopPayloadSize := lastHopPayloadSize(r, finalHtlcExpiry, amt, + !features.HasFeature(lnwire.TLVOnionPayloadOptional)) // We can't always assume that the end destination is publicly // advertised to the network so we'll manually include the target node. @@ -667,7 +665,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, amountToReceive: amt, incomingCltv: finalHtlcExpiry, probability: 1, - routingInfoSize: finalHop.PayloadSize(0), + routingInfoSize: lastHopPayloadSize, } // Calculate the absolute cltv limit. Use uint64 to prevent an overflow @@ -821,23 +819,30 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // blob. var payloadSize uint64 if fromVertex != source { + // In case the unifiedEdge does not have a payload size + // function supplied we request a graceful shutdown + // because this should never happen. + if edge.hopPayloadSizeFn == nil { + log.Criticalf("No payload size function "+ + "available for edge=%v unable to "+ + "determine payload size: %v", edge, + ErrNoPayLoadSizeFunc) + + return + } + supportsTlv := fromFeatures.HasFeature( lnwire.TLVOnionPayloadOptional, ) - hop := route.Hop{ - AmtToForward: amountToSend, - OutgoingTimeLock: uint32( - toNodeDist.incomingCltv, - ), - LegacyPayload: !supportsTlv, - } - - payloadSize = hop.PayloadSize(edge.policy.ChannelID) + payloadSize = edge.hopPayloadSizeFn( + amountToSend, + uint32(toNodeDist.incomingCltv), + !supportsTlv, edge.policy.ChannelID, + ) } routingInfoSize := toNodeDist.routingInfoSize + payloadSize - // Skip paths that would exceed the maximum routing info size. if routingInfoSize > sphinx.MaxPayloadSize { return @@ -930,9 +935,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // calculations. We set a high capacity to act as if // there is enough liquidity, otherwise the hint would // not have been added by a wallet. + // We also pass the payload size function to the + // graph data so that we calculate the exact payload + // size when evaluating this hop for a route. u.addPolicy( - reverseEdge.sourceNode, reverseEdge.edge, + reverseEdge.sourceNode, + reverseEdge.edge.EdgePolicy(), fakeHopHintCapacity, + reverseEdge.edge.IntermediatePayloadSize, ) } @@ -1098,3 +1108,55 @@ func getProbabilityBasedDist(weight int64, probability float64, return int64(dist) } + +// lastHopPayloadSize calculates the payload size of the final hop in a route. +// It depends on the tlv types which are present and also whether the hop is +// part of a blinded route or not. +func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, + amount lnwire.MilliSatoshi, legacy bool) uint64 { + + if r.BlindedPayment != nil { + blindedPath := r.BlindedPayment.BlindedPath.BlindedHops + blindedPoint := r.BlindedPayment.BlindedPath.BlindingPoint + + encryptedData := blindedPath[len(blindedPath)-1].CipherText + finalHop := route.Hop{ + AmtToForward: amount, + OutgoingTimeLock: uint32(finalHtlcExpiry), + LegacyPayload: false, + EncryptedData: encryptedData, + } + if len(blindedPath) == 1 { + finalHop.BlindingPoint = blindedPoint + } + + // The final hop does not have a short chanID set. + return finalHop.PayloadSize(0) + } + + var mpp *record.MPP + if r.PaymentAddr != nil { + mpp = record.NewMPP(amount, *r.PaymentAddr) + } + + var amp *record.AMP + if r.Amp != nil { + // The AMP payload is not easy accessible at this point but we + // are only interested in the size of the payload so we just use + // the AMP record dummy. + amp = &record.MaxAmpPayLoadSize + } + + finalHop := route.Hop{ + AmtToForward: amount, + OutgoingTimeLock: uint32(finalHtlcExpiry), + CustomRecords: r.DestCustomRecords, + LegacyPayload: legacy, + MPP: mpp, + AMP: amp, + Metadata: r.Metadata, + } + + // The final hop does not have a short chanID set. + return finalHop.PayloadSize(0) +} diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 7c4f635df..6bf754c82 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -746,6 +746,9 @@ func TestPathFinding(t *testing.T) { }, { name: "path finding with additional edges", fn: runPathFindingWithAdditionalEdges, + }, { + name: "path finding max payload restriction", + fn: runPathFindingMaxPayloadRestriction, }, { name: "path finding with redundant additional edges", fn: runPathFindingWithRedundantAdditionalEdges, @@ -1204,7 +1207,7 @@ func runPathFindingWithAdditionalEdges(t *testing.T, useCache bool) { // Create the channel edge going from songoku to doge and include it in // our map of additional edges. - songokuToDoge := &models.CachedEdgePolicy{ + songokuToDogePolicy := &models.CachedEdgePolicy{ ToNodePubKey: func() route.Vertex { return doge.PubKeyBytes }, @@ -1215,8 +1218,10 @@ func runPathFindingWithAdditionalEdges(t *testing.T, useCache bool) { TimeLockDelta: 9, } - additionalEdges := map[route.Vertex][]*models.CachedEdgePolicy{ - graph.aliasMap["songoku"]: {songokuToDoge}, + additionalEdges := map[route.Vertex][]AdditionalEdge{ + graph.aliasMap["songoku"]: {&PrivateEdge{ + policy: songokuToDogePolicy, + }}, } find := func(r *RestrictParams) ( @@ -1266,6 +1271,122 @@ func runPathFindingWithAdditionalEdges(t *testing.T, useCache bool) { assertExpectedPath(t, graph.aliasMap, path, "songoku", "doge") } +// runPathFindingMaxPayloadRestriction tests the maximum size of a sphinx +// package when creating a route. So we make sure the pathfinder does not return +// a route which is greater than the maximum sphinx package size of 1300 bytes +// defined in BOLT04. +func runPathFindingMaxPayloadRestriction(t *testing.T, useCache bool) { + graph, err := parseTestGraph(t, useCache, basicGraphFilePath) + require.NoError(t, err, "unable to create graph") + + sourceNode, err := graph.graph.SourceNode() + require.NoError(t, err, "unable to fetch source node") + + paymentAmt := lnwire.NewMSatFromSatoshis(100) + + // Create a node doge which is not visible in the graph. + dogePubKeyHex := "03dd46ff29a6941b4a2607525b043ec9b020b3f318a1bf281" + + "536fd7011ec59c882" + dogePubKeyBytes, err := hex.DecodeString(dogePubKeyHex) + require.NoError(t, err, "unable to decode public key") + dogePubKey, err := btcec.ParsePubKey(dogePubKeyBytes) + require.NoError(t, err, "unable to parse public key from bytes") + + doge := &channeldb.LightningNode{} + doge.AddPubKey(dogePubKey) + doge.Alias = "doge" + copy(doge.PubKeyBytes[:], dogePubKeyBytes) + graph.aliasMap["doge"] = doge.PubKeyBytes + + const ( + chanID uint64 = 1337 + finalHtlcExpiry int32 = 0 + ) + + // Create the channel edge going from songoku to doge and later add it + // with the mocked size function to the graph data. + songokuToDogePolicy := &models.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return doge.PubKeyBytes + }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), + ChannelID: chanID, + FeeBaseMSat: 1, + FeeProportionalMillionths: 1000, + TimeLockDelta: 9, + } + + // The route has 2 hops. The exit hop (doge) and the hop + // (songoku -> doge). The desired path looks like this: + // source -> songoku -> doge + tests := []struct { + name string + mockedPayloadSize uint64 + err error + }{ + { + // The final hop payload size needs to be considered + // as well and because it's treated differently than the + // intermediate hops the following tests choose to use + // the legacy payload format to have a constant final + // hop payload size. + name: "route max payload size (1300)", + mockedPayloadSize: 1300 - sphinx.LegacyHopDataSize, + }, + { + // We increase the enrypted data size by one byte. + name: "route 1 bytes bigger than max " + + "payload", + mockedPayloadSize: 1300 - sphinx.LegacyHopDataSize + 1, + err: errNoPathFound, + }, + } + + for _, testCase := range tests { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + restrictions := *noRestrictions + // No tlv payload, this makes sure the final hop uses + // the legacy payload. + restrictions.DestFeatures = lnwire.EmptyFeatureVector() + + // Create the mocked AdditionalEdge and mock the + // corresponding calls. + mockedEdge := &mockAdditionalEdge{} + + mockedEdge.On("EdgePolicy").Return(songokuToDogePolicy) + + mockedEdge.On("IntermediatePayloadSize", + paymentAmt, uint32(finalHtlcExpiry), true, + chanID).Once(). + Return(testCase.mockedPayloadSize) + + additionalEdges := map[route.Vertex][]AdditionalEdge{ + graph.aliasMap["songoku"]: {mockedEdge}, + } + + path, err := dbFindPath( + graph.graph, additionalEdges, + &mockBandwidthHints{}, &restrictions, + testPathFindingConfig, sourceNode.PubKeyBytes, + doge.PubKeyBytes, paymentAmt, 0, + finalHtlcExpiry, + ) + require.ErrorIs(t, err, testCase.err) + + if err == nil { + assertExpectedPath(t, graph.aliasMap, path, + "songoku", "doge") + } + + mockedEdge.AssertExpectations(t) + }) + } +} + // runPathFindingWithRedundantAdditionalEdges asserts that we are able to find // paths to nodes ignoring additional edges that are already known by self node. func runPathFindingWithRedundantAdditionalEdges(t *testing.T, useCache bool) { @@ -1290,7 +1411,7 @@ func runPathFindingWithRedundantAdditionalEdges(t *testing.T, useCache bool) { // Create the channel edge going from alice to bob and include it in // our map of additional edges. - aliceToBob := &models.CachedEdgePolicy{ + aliceToBobPolicy := &models.CachedEdgePolicy{ ToNodePubKey: func() route.Vertex { return target }, @@ -1301,8 +1422,10 @@ func runPathFindingWithRedundantAdditionalEdges(t *testing.T, useCache bool) { TimeLockDelta: 9, } - additionalEdges := map[route.Vertex][]*models.CachedEdgePolicy{ - ctx.source: {aliceToBob}, + additionalEdges := map[route.Vertex][]AdditionalEdge{ + ctx.source: {&PrivateEdge{ + policy: aliceToBobPolicy, + }}, } path, err := dbFindPath( @@ -2402,7 +2525,8 @@ func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, path []*models.CachedEdgePolicy, nodeAliases ...string) { if len(path) != len(nodeAliases) { - t.Fatal("number of hops and number of aliases do not match") + t.Fatalf("number of hops=(%v) and number of aliases=(%v) do "+ + "not match", len(path), len(nodeAliases)) } for i, hop := range path { @@ -3072,7 +3196,7 @@ func (c *pathFindingTestContext) assertPath(path []*models.CachedEdgePolicy, // dbFindPath calls findPath after getting a db transaction from the database // graph. func dbFindPath(graph *channeldb.ChannelGraph, - additionalEdges map[route.Vertex][]*models.CachedEdgePolicy, + additionalEdges map[route.Vertex][]AdditionalEdge, bandwidthHints bandwidthHints, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, @@ -3230,8 +3354,8 @@ func TestBlindedRouteConstruction(t *testing.T) { edges := []*models.CachedEdgePolicy{ aliceBobEdge, bobCarolEdge, - carolDaveEdge, - daveEveEdge, + carolDaveEdge.EdgePolicy(), + daveEveEdge.EdgePolicy(), } // Total timelock for the route should include: @@ -3327,3 +3451,168 @@ func TestBlindedRouteConstruction(t *testing.T) { require.NoError(t, err) require.Equal(t, expectedRoute, route) } + +// TestLastHopPayloadSize tests the final hop payload size. The final hop +// payload structure differes from the intermediate hop payload for both the +// non-blinded and blinded case. +func TestLastHopPayloadSize(t *testing.T) { + t.Parallel() + + var ( + metadata = []byte{21, 22} + customRecords = map[uint64][]byte{ + record.CustomTypeStart: {1, 2, 3}, + } + sizeEncryptedData = 100 + encrypedData = bytes.Repeat( + []byte{1}, sizeEncryptedData, + ) + _, blindedPoint = btcec.PrivKeyFromBytes([]byte{5}) + paymentAddr = &[32]byte{1} + ampOptions = &Options{} + amtToForward = lnwire.MilliSatoshi(10000) + finalHopExpiry int32 = 144 + + oneHopBlindedPayment = &BlindedPayment{ + BlindedPath: &sphinx.BlindedPath{ + BlindedHops: []*sphinx.BlindedHopInfo{ + { + CipherText: encrypedData, + }, + }, + BlindingPoint: blindedPoint, + }, + } + twoHopBlindedPayment = &BlindedPayment{ + BlindedPath: &sphinx.BlindedPath{ + BlindedHops: []*sphinx.BlindedHopInfo{ + { + CipherText: encrypedData, + }, + { + CipherText: encrypedData, + }, + }, + BlindingPoint: blindedPoint, + }, + } + ) + + testCases := []struct { + name string + restrictions *RestrictParams + finalHopExpiry int32 + amount lnwire.MilliSatoshi + legacy bool + }{ + { + name: "Non blinded final hop", + restrictions: &RestrictParams{ + PaymentAddr: paymentAddr, + DestCustomRecords: customRecords, + Metadata: metadata, + Amp: ampOptions, + }, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + legacy: false, + }, + { + name: "Non blinded final hop legacy", + restrictions: &RestrictParams{ + // The legacy encoding has no ability to include + // those extra data we expect that this data is + // ignored. + PaymentAddr: paymentAddr, + DestCustomRecords: customRecords, + Metadata: metadata, + Amp: ampOptions, + }, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + legacy: true, + }, + { + name: "Blinded final hop introduction point", + restrictions: &RestrictParams{ + BlindedPayment: oneHopBlindedPayment, + }, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + }, + { + name: "Blinded final hop of a two hop payment", + restrictions: &RestrictParams{ + BlindedPayment: twoHopBlindedPayment, + }, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var mpp *record.MPP + if tc.restrictions.PaymentAddr != nil { + mpp = record.NewMPP( + tc.amount, *tc.restrictions.PaymentAddr, + ) + } + + // In case it's an AMP payment we use the max AMP record + // size to estimate the final hop size. + var amp *record.AMP + if tc.restrictions.Amp != nil { + amp = &record.MaxAmpPayLoadSize + } + + var finalHop route.Hop + if tc.restrictions.BlindedPayment != nil { + blindedPath := tc.restrictions.BlindedPayment. + BlindedPath.BlindedHops + + blindedPoint := tc.restrictions.BlindedPayment. + BlindedPath.BlindingPoint + + //nolint:lll + finalHop = route.Hop{ + AmtToForward: tc.amount, + OutgoingTimeLock: uint32(tc.finalHopExpiry), + LegacyPayload: false, + EncryptedData: blindedPath[len(blindedPath)-1].CipherText, + } + if len(blindedPath) == 1 { + finalHop.BlindingPoint = blindedPoint + } + } else { + //nolint:lll + finalHop = route.Hop{ + LegacyPayload: tc.legacy, + AmtToForward: tc.amount, + OutgoingTimeLock: uint32(tc.finalHopExpiry), + Metadata: tc.restrictions.Metadata, + MPP: mpp, + AMP: amp, + CustomRecords: tc.restrictions.DestCustomRecords, + } + } + + payLoad, err := createHopPayload(finalHop, 0, true) + require.NoErrorf(t, err, "failed to create hop payload") + + expectedPayloadSize := lastHopPayloadSize( + tc.restrictions, tc.finalHopExpiry, + tc.amount, tc.legacy, + ) + + require.Equal( + t, expectedPayloadSize, + uint64(payLoad.NumBytes()), + ) + }) + } +} diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 9853f8fd9..e3bf77170 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -628,7 +628,7 @@ func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, return nil, err } - // It this shard carries MPP or AMP options, add them to the last hop + // If this shard carries MPP or AMP options, add them to the last hop // on the route. hop := rt.Hops[len(rt.Hops)-1] if shard.MPP() != nil { diff --git a/routing/payment_session.go b/routing/payment_session.go index a04a4de55..2d174244c 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -163,7 +163,7 @@ type PaymentSession interface { // loop if payment attempts take long enough. An additional set of edges can // also be provided to assist in reaching the payment's destination. type paymentSession struct { - additionalEdges map[route.Vertex][]*models.CachedEdgePolicy + additionalEdges map[route.Vertex][]AdditionalEdge getBandwidthHints func(routingGraph) (bandwidthHints, error) @@ -258,6 +258,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, DestCustomRecords: p.payment.DestCustomRecords, DestFeatures: p.payment.DestFeatures, PaymentAddr: p.payment.PaymentAddr, + Amp: p.payment.amp, Metadata: p.payment.Metadata, } @@ -441,11 +442,12 @@ func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, } for _, edge := range edges { - if edge.ChannelID != channelID { + policy := edge.EdgePolicy() + if policy.ChannelID != channelID { continue } - return edge + return policy } return nil diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 229c932f3..b96a2294b 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -95,9 +95,9 @@ func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession { // RouteHintsToEdges converts a list of invoice route hints to an edge map that // can be passed into pathfinding. func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( - map[route.Vertex][]*models.CachedEdgePolicy, error) { + map[route.Vertex][]AdditionalEdge, error) { - edges := make(map[route.Vertex][]*models.CachedEdgePolicy) + edges := make(map[route.Vertex][]AdditionalEdge) // Traverse through all of the available hop hints and include them in // our edges map, indexed by the public key of the channel's starting @@ -127,7 +127,7 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( // Finally, create the channel edge from the hop hint // and add it to list of edges corresponding to the node // at the start of the channel. - edge := &models.CachedEdgePolicy{ + edgePolicy := &models.CachedEdgePolicy{ ToNodePubKey: func() route.Vertex { return endNode.PubKeyBytes }, @@ -142,6 +142,10 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( TimeLockDelta: hopHint.CLTVExpiryDelta, } + edge := &PrivateEdge{ + policy: edgePolicy, + } + v := route.NewVertex(hopHint.NodeID) edges[v] = append(edges[v], edge) } diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 1c199ff40..67a285159 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -138,7 +138,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { require.Equal(t, 1, len(policies), "should have 1 edge policy") // Check that the policy has been created as expected. - policy := policies[0] + policy := policies[0].EdgePolicy() require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch") require.Equal(t, oldExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch", diff --git a/routing/router.go b/routing/router.go index c602573eb..00c0ded57 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1954,7 +1954,7 @@ type RouteRequest struct { // RouteHints is an alias type for a set of route hints, with the source node // as the map's key and the details of the hint(s) in the edge policy. -type RouteHints map[route.Vertex][]*models.CachedEdgePolicy +type RouteHints map[route.Vertex][]AdditionalEdge // NewRouteRequest produces a new route request for a regular payment or one // to a blinded route, validating that the target, routeHints and finalExpiry diff --git a/routing/router_test.go b/routing/router_test.go index 227899c83..12c8ff729 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3954,7 +3954,7 @@ func TestNewRouteRequest(t *testing.T) { name: "hints and blinded", blindedPayment: blindedMultiHop, routeHints: make( - map[route.Vertex][]*models.CachedEdgePolicy, + map[route.Vertex][]AdditionalEdge, ), err: ErrHintsAndBlinded, }, diff --git a/routing/unified_edges.go b/routing/unified_edges.go index aee168348..c828e9a6e 100644 --- a/routing/unified_edges.go +++ b/routing/unified_edges.go @@ -40,9 +40,12 @@ func newNodeEdgeUnifier(sourceNode, toNode route.Vertex, } // addPolicy adds a single channel policy. Capacity may be zero if unknown -// (light clients). +// (light clients). We expect a non-nil payload size function and will request a +// graceful shutdown if it is not provided as this indicates that edges are +// incorrectly specified. func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex, - edge *models.CachedEdgePolicy, capacity btcutil.Amount) { + edge *models.CachedEdgePolicy, capacity btcutil.Amount, + hopPayloadSizeFn PayloadSizeFunc) { localChan := fromNode == u.sourceNode @@ -62,9 +65,20 @@ func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex, u.edgeUnifiers[fromNode] = unifier } + // In case no payload size function was provided a graceful shutdown + // is requested, because this function is not used as intended. + if hopPayloadSizeFn == nil { + log.Criticalf("No payloadsize function was provided for the "+ + "edge (chanid=%v) when adding it to the edge unifier "+ + "of node: %v", edge.ChannelID, fromNode) + + return + } + unifier.edges = append(unifier.edges, &unifiedEdge{ - policy: edge, - capacity: capacity, + policy: edge, + capacity: capacity, + hopPayloadSizeFn: hopPayloadSizeFn, }) } @@ -79,9 +93,13 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error { return nil } - // Add this policy to the corresponding edgeUnifier. + // Add this policy to the corresponding edgeUnifier. We default + // to the clear hop payload size function because + // `addGraphPolicies` is only used for cleartext intermediate + // hops in a route. u.addPolicy( channel.OtherNode, channel.InPolicy, channel.Capacity, + defaultHopPayloadSize, ) return nil @@ -96,6 +114,12 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error { type unifiedEdge struct { policy *models.CachedEdgePolicy capacity btcutil.Amount + + // hopPayloadSize supplies an edge with the ability to calculate the + // exact payload size if this edge would be included in a route. This + // is needed because hops of a blinded path differ in their payload + // structure compared to cleartext hops. + hopPayloadSizeFn PayloadSizeFunc } // amtInRange checks whether an amount falls within the valid range for a @@ -202,6 +226,7 @@ func (u *edgeUnifier) getEdgeLocal(amt lnwire.MilliSatoshi, log.Debugf("Skipped edge %v: not enough bandwidth, "+ "bandwidth=%v, amt=%v", edge.policy.ChannelID, bandwidth, amt) + continue } @@ -214,14 +239,16 @@ func (u *edgeUnifier) getEdgeLocal(amt lnwire.MilliSatoshi, log.Debugf("Skipped edge %v: not max bandwidth, "+ "bandwidth=%v, maxBandwidth=%v", bandwidth, maxBandwidth) + continue } maxBandwidth = bandwidth // Update best edge. bestEdge = &unifiedEdge{ - policy: edge.policy, - capacity: edge.capacity, + policy: edge.policy, + capacity: edge.capacity, + hopPayloadSizeFn: edge.hopPayloadSizeFn, } } @@ -234,10 +261,11 @@ func (u *edgeUnifier) getEdgeLocal(amt lnwire.MilliSatoshi, // forwarding context. func (u *edgeUnifier) getEdgeNetwork(amt lnwire.MilliSatoshi) *unifiedEdge { var ( - bestPolicy *models.CachedEdgePolicy - maxFee lnwire.MilliSatoshi - maxTimelock uint16 - maxCapMsat lnwire.MilliSatoshi + bestPolicy *models.CachedEdgePolicy + maxFee lnwire.MilliSatoshi + maxTimelock uint16 + maxCapMsat lnwire.MilliSatoshi + hopPayloadSizeFn PayloadSizeFunc ) for _, edge := range u.edges { @@ -274,7 +302,6 @@ func (u *edgeUnifier) getEdgeNetwork(amt lnwire.MilliSatoshi) *unifiedEdge { maxTimelock = lntypes.Max( maxTimelock, edge.policy.TimeLockDelta, ) - // Use the policy that results in the highest fee for this // specific amount. fee := edge.policy.ComputeFee(amt) @@ -282,11 +309,17 @@ func (u *edgeUnifier) getEdgeNetwork(amt lnwire.MilliSatoshi) *unifiedEdge { log.Debugf("Skipped edge %v due to it produces less "+ "fee: fee=%v, maxFee=%v", edge.policy.ChannelID, fee, maxFee) + continue } maxFee = fee bestPolicy = edge.policy + // The payload size function for edges to a connected peer is + // always the same hence there is not need to find the maximum. + // This also counts for blinded edges where we only have one + // edge to a blinded peer. + hopPayloadSizeFn = edge.hopPayloadSizeFn } // Return early if no channel matches. @@ -308,6 +341,7 @@ func (u *edgeUnifier) getEdgeNetwork(amt lnwire.MilliSatoshi) *unifiedEdge { modifiedEdge := unifiedEdge{policy: &policyCopy} modifiedEdge.policy.TimeLockDelta = maxTimelock modifiedEdge.capacity = maxCapMsat.ToSatoshis() + modifiedEdge.hopPayloadSizeFn = hopPayloadSizeFn return &modifiedEdge } diff --git a/routing/unified_edges_test.go b/routing/unified_edges_test.go index 9b603d78c..043447f52 100644 --- a/routing/unified_edges_test.go +++ b/routing/unified_edges_test.go @@ -41,15 +41,17 @@ func TestNodeEdgeUnifier(t *testing.T) { c2 := btcutil.Amount(8) unifierFilled := newNodeEdgeUnifier(source, toNode, nil) - unifierFilled.addPolicy(fromNode, &p1, c1) - unifierFilled.addPolicy(fromNode, &p2, c2) + unifierFilled.addPolicy(fromNode, &p1, c1, defaultHopPayloadSize) + unifierFilled.addPolicy(fromNode, &p2, c2, defaultHopPayloadSize) unifierNoCapacity := newNodeEdgeUnifier(source, toNode, nil) - unifierNoCapacity.addPolicy(fromNode, &p1, 0) - unifierNoCapacity.addPolicy(fromNode, &p2, 0) + unifierNoCapacity.addPolicy(fromNode, &p1, 0, defaultHopPayloadSize) + unifierNoCapacity.addPolicy(fromNode, &p2, 0, defaultHopPayloadSize) unifierNoInfo := newNodeEdgeUnifier(source, toNode, nil) - unifierNoInfo.addPolicy(fromNode, &models.CachedEdgePolicy{}, 0) + unifierNoInfo.addPolicy( + fromNode, &models.CachedEdgePolicy{}, 0, defaultHopPayloadSize, + ) tests := []struct { name string