diff --git a/routing/additional_edge.go b/routing/additional_edge.go index eee17cce1..a1e5fc856 100644 --- a/routing/additional_edge.go +++ b/routing/additional_edge.go @@ -2,8 +2,8 @@ package routing import ( "errors" + "fmt" - "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -61,11 +61,42 @@ func (p *PrivateEdge) IntermediatePayloadSize(amount lnwire.MilliSatoshi, } // BlindedEdge implements the AdditionalEdge interface. Blinded hops are viewed -// as additional edges because they are appened at the end of a normal route. +// as additional edges because they are appended at the end of a normal route. type BlindedEdge struct { - policy *models.CachedEdgePolicy - cipherText []byte - blindingPoint *btcec.PublicKey + policy *models.CachedEdgePolicy + + // blindedPayment is the BlindedPayment that this blinded edge was + // derived from. + blindedPayment *BlindedPayment + + // hopIndex is the index of the hop in the blinded payment path that + // this edge is associated with. + hopIndex int +} + +// NewBlindedEdge constructs a new BlindedEdge which packages the policy info +// for a specific hop within the given blinded payment path. The hop index +// should correspond to the hop within the blinded payment that this edge is +// associated with. +func NewBlindedEdge(policy *models.CachedEdgePolicy, payment *BlindedPayment, + hopIndex int) (*BlindedEdge, error) { + + if payment == nil { + return nil, fmt.Errorf("blinded payment cannot be nil for " + + "blinded edge") + } + + if hopIndex < 0 || hopIndex >= len(payment.BlindedPath.BlindedHops) { + return nil, fmt.Errorf("the hop index %d is outside the "+ + "valid range between 0 and %d", hopIndex, + len(payment.BlindedPath.BlindedHops)-1) + } + + return &BlindedEdge{ + policy: policy, + hopIndex: hopIndex, + blindedPayment: payment, + }, nil } // EdgePolicy return the policy of the BlindedEdge. @@ -78,9 +109,11 @@ func (b *BlindedEdge) EdgePolicy() *models.CachedEdgePolicy { func (b *BlindedEdge) IntermediatePayloadSize(_ lnwire.MilliSatoshi, _ uint32, _ uint64) uint64 { + blindedPath := b.blindedPayment.BlindedPath + hop := route.Hop{ - BlindingPoint: b.blindingPoint, - EncryptedData: b.cipherText, + BlindingPoint: blindedPath.BlindingPoint, + EncryptedData: blindedPath.BlindedHops[b.hopIndex].CipherText, } // For blinded paths the next chanID is in the encrypted data tlv. diff --git a/routing/additional_edge_test.go b/routing/additional_edge_test.go index b5628c6b7..0324e2e10 100644 --- a/routing/additional_edge_test.go +++ b/routing/additional_edge_test.go @@ -42,9 +42,13 @@ func TestIntermediatePayloadSize(t *testing.T) { hop: route.Hop{ EncryptedData: []byte{12, 13}, }, - edge: &BlindedEdge{ - cipherText: []byte{12, 13}, - }, + edge: &BlindedEdge{blindedPayment: &BlindedPayment{ + BlindedPath: &sphinx.BlindedPath{ + BlindedHops: []*sphinx.BlindedHopInfo{ + {CipherText: []byte{12, 13}}, + }, + }, + }}, }, { name: "Blinded edge - introduction point", @@ -52,10 +56,14 @@ func TestIntermediatePayloadSize(t *testing.T) { EncryptedData: []byte{12, 13}, BlindingPoint: blindedPoint, }, - edge: &BlindedEdge{ - cipherText: []byte{12, 13}, - blindingPoint: blindedPoint, - }, + edge: &BlindedEdge{blindedPayment: &BlindedPayment{ + BlindedPath: &sphinx.BlindedPath{ + BlindingPoint: blindedPoint, + BlindedHops: []*sphinx.BlindedHopInfo{ + {CipherText: []byte{12, 13}}, + }, + }, + }}, }, } diff --git a/routing/blinding.go b/routing/blinding.go index d2d64aa5d..788fb7b77 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -88,13 +88,13 @@ func (b *BlindedPayment) Validate() error { // the case of multiple blinded hops, CLTV delta is fully accounted for in the // hints (both for intermediate hops and the final_cltv_delta for the receiving // node). -func (b *BlindedPayment) toRouteHints() RouteHints { +func (b *BlindedPayment) toRouteHints() (RouteHints, error) { // If we just have a single hop in our blinded route, it just contains // an introduction node (this is a valid path according to the spec). // Since we have the un-blinded node ID for the introduction node, we // don't need to add any route hints. if len(b.BlindedPath.BlindedHops) == 1 { - return nil + return nil, nil } hintCount := len(b.BlindedPath.BlindedHops) - 1 @@ -136,14 +136,13 @@ func (b *BlindedPayment) toRouteHints() RouteHints { ToNodeFeatures: features, } - hints[fromNode] = []AdditionalEdge{ - &BlindedEdge{ - policy: edgePolicy, - cipherText: b.BlindedPath.BlindedHops[0].CipherText, - blindingPoint: b.BlindedPath.BlindingPoint, - }, + edge, err := NewBlindedEdge(edgePolicy, b, 0) + if err != nil { + return nil, err } + hints[fromNode] = []AdditionalEdge{edge} + // Start at an offset of 1 because the first node in our blinded hops // is the introduction node and terminate at the second-last node // because we're dealing with hops as pairs. @@ -169,14 +168,13 @@ func (b *BlindedPayment) toRouteHints() RouteHints { ToNodeFeatures: features, } - hints[fromNode] = []AdditionalEdge{ - &BlindedEdge{ - policy: edgePolicy, - cipherText: b.BlindedPath.BlindedHops[i]. - CipherText, - }, + edge, err := NewBlindedEdge(edgePolicy, b, i) + if err != nil { + return nil, err } + + hints[fromNode] = []AdditionalEdge{edge} } - return hints + return hints, nil } diff --git a/routing/blinding_test.go b/routing/blinding_test.go index f6327ecd5..58ad56594 100644 --- a/routing/blinding_test.go +++ b/routing/blinding_test.go @@ -128,7 +128,9 @@ func TestBlindedPaymentToHints(t *testing.T) { HtlcMaximum: htlcMax, Features: features, } - require.Nil(t, blindedPayment.toRouteHints()) + hints, err := blindedPayment.toRouteHints() + require.NoError(t, err) + require.Nil(t, hints) // Populate the blinded payment with hops. blindedPayment.BlindedPath.BlindedHops = []*sphinx.BlindedHopInfo{ @@ -146,41 +148,43 @@ func TestBlindedPaymentToHints(t *testing.T) { }, } + policy1 := &models.CachedEdgePolicy{ + TimeLockDelta: cltvDelta, + MinHTLC: lnwire.MilliSatoshi(htlcMin), + MaxHTLC: lnwire.MilliSatoshi(htlcMax), + FeeBaseMSat: lnwire.MilliSatoshi(baseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi( + ppmFee, + ), + ToNodePubKey: func() route.Vertex { + return vb2 + }, + ToNodeFeatures: features, + } + policy2 := &models.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return vb3 + }, + ToNodeFeatures: features, + } + + blindedEdge1, err := NewBlindedEdge(policy1, blindedPayment, 0) + require.NoError(t, err) + + blindedEdge2, err := NewBlindedEdge(policy2, blindedPayment, 1) + require.NoError(t, err) + expected := RouteHints{ v1: { - //nolint:lll - &BlindedEdge{ - policy: &models.CachedEdgePolicy{ - TimeLockDelta: cltvDelta, - MinHTLC: lnwire.MilliSatoshi(htlcMin), - MaxHTLC: lnwire.MilliSatoshi(htlcMax), - FeeBaseMSat: lnwire.MilliSatoshi(baseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi( - ppmFee, - ), - ToNodePubKey: func() route.Vertex { - return vb2 - }, - ToNodeFeatures: features, - }, - blindingPoint: blindedPoint, - cipherText: cipherText, - }, + blindedEdge1, }, vb2: { - &BlindedEdge{ - policy: &models.CachedEdgePolicy{ - ToNodePubKey: func() route.Vertex { - return vb3 - }, - ToNodeFeatures: features, - }, - cipherText: cipherText, - }, + blindedEdge2, }, } - actual := blindedPayment.toRouteHints() + actual, err := blindedPayment.toRouteHints() + require.NoError(t, err) require.Equal(t, len(expected), len(actual)) for vertex, expectedHint := range expected { diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 06005716f..fd9839dba 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3341,7 +3341,9 @@ func TestBlindedRouteConstruction(t *testing.T) { // that make up the graph we'll give to route construction. The hints // map is keyed by source node, so we can retrieve our blinded edges // accordingly. - blindedEdges := blindedPayment.toRouteHints() + blindedEdges, err := blindedPayment.toRouteHints() + require.NoError(t, err) + carolDaveEdge := blindedEdges[carolVertex][0] daveEveEdge := blindedEdges[daveBlindedVertex][0] diff --git a/routing/router.go b/routing/router.go index 851db4af0..149cd3415 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2001,6 +2001,7 @@ func NewRouteRequest(source route.Vertex, target *route.Vertex, // Assume that we're starting off with a regular payment. requestHints = routeHints requestExpiry = finalExpiry + err error ) if blindedPayment != nil { @@ -2038,7 +2039,10 @@ func NewRouteRequest(source route.Vertex, target *route.Vertex, requestExpiry = blindedPayment.CltvExpiryDelta } - requestHints = blindedPayment.toRouteHints() + requestHints, err = blindedPayment.toRouteHints() + if err != nil { + return nil, err + } } requestTarget, err := getTargetNode(target, blindedPayment)