diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index b0b67fc34..ad7b1b097 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -14,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" ) @@ -495,8 +496,9 @@ type AuxHtlcModifier interface { // data blob of an HTLC, may produce a different blob or modify the // amount of bitcoin this htlc should carry. ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, - htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, - lnwire.CustomRecords, error) + htlcCustomRecords lnwire.CustomRecords, + peer route.Vertex) (lnwire.MilliSatoshi, lnwire.CustomRecords, + error) } // AuxTrafficShaper is an interface that allows the sender to determine if a diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index 4099b9608..6ee1f1f64 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -9,6 +9,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -161,8 +162,8 @@ func (*mockTrafficShaper) PaymentBandwidth(_, _, _ fn.Option[tlv.Blob], // data blob of an HTLC, may produce a different blob or modify the // amount of bitcoin this htlc should carry. func (*mockTrafficShaper) ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, - _ lnwire.CustomRecords) (lnwire.MilliSatoshi, lnwire.CustomRecords, - error) { + _ lnwire.CustomRecords, _ route.Vertex) (lnwire.MilliSatoshi, + lnwire.CustomRecords, error) { return totalAmount, nil, nil } diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index a64c8c87e..7a6443ab6 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -724,6 +724,13 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { // value. rt.FirstHopWireCustomRecords = p.firstHopCustomRecords + if len(rt.Hops) == 0 { + return fmt.Errorf("cannot amend first hop data, route length " + + "is zero") + } + + firstHopPK := rt.Hops[0].PubKeyBytes + // extraDataRequest is a helper struct to pass the custom records and // amount back from the traffic shaper. type extraDataRequest struct { @@ -740,6 +747,7 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { func(ts htlcswitch.AuxTrafficShaper) fn.Result[extraDataRequest] { newAmt, newRecords, err := ts.ProduceHtlcExtraData( rt.TotalAmount, p.firstHopCustomRecords, + firstHopPK, ) if err != nil { return fn.Err[extraDataRequest](err) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 6d4c2fb45..0ee751196 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -368,7 +368,11 @@ func TestRequestRouteSucceed(t *testing.T) { // Create a mock payment session and a dummy route. paySession := &mockPaymentSession{} - dummyRoute := &route.Route{} + dummyRoute := &route.Route{ + Hops: []*route.Hop{ + testHop, + }, + } // Mount the mocked payment session. p.paySession = paySession