routing: encode first hop records in update_add_htlc message

This commit is contained in:
George Tsagkarelis
2024-04-16 12:29:15 +02:00
committed by Olaoluwa Osuntokun
parent 3a53b16549
commit cbdcdac213
2 changed files with 30 additions and 6 deletions

View File

@@ -1,6 +1,7 @@
package routing package routing
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"time" "time"
@@ -13,8 +14,10 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards" "github.com/lightningnetwork/lnd/routing/shards"
"github.com/lightningnetwork/lnd/tlv"
) )
// ErrPaymentLifecycleExiting is used when waiting for htlc attempt result, but // ErrPaymentLifecycleExiting is used when waiting for htlc attempt result, but
@@ -31,6 +34,7 @@ type paymentLifecycle struct {
shardTracker shards.ShardTracker shardTracker shards.ShardTracker
timeoutChan <-chan time.Time timeoutChan <-chan time.Time
currentHeight int32 currentHeight int32
firstHopTLVs record.CustomSet
// quit is closed to signal the sub goroutines of the payment lifecycle // quit is closed to signal the sub goroutines of the payment lifecycle
// to stop. // to stop.
@@ -53,7 +57,7 @@ type paymentLifecycle struct {
func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, paySession PaymentSession, identifier lntypes.Hash, paySession PaymentSession,
shardTracker shards.ShardTracker, timeout time.Duration, shardTracker shards.ShardTracker, timeout time.Duration,
currentHeight int32) *paymentLifecycle { currentHeight int32, firstHopTLVs record.CustomSet) *paymentLifecycle {
p := &paymentLifecycle{ p := &paymentLifecycle{
router: r, router: r,
@@ -64,6 +68,7 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
currentHeight: currentHeight, currentHeight: currentHeight,
quit: make(chan struct{}), quit: make(chan struct{}),
resultCollected: make(chan error, 1), resultCollected: make(chan error, 1),
firstHopTLVs: firstHopTLVs,
} }
// Mount the result collector. // Mount the result collector.
@@ -673,6 +678,18 @@ func (p *paymentLifecycle) sendAttempt(
PaymentHash: *attempt.Hash, PaymentHash: *attempt.Hash,
} }
buffer := new(bytes.Buffer)
if err := p.firstHopTLVs.Encode(buffer); err != nil {
return p.failAttempt(attempt.AttemptID, err)
}
recordsBytes := buffer.Bytes()
tlvRecord := tlv.NewPrimitiveRecord[lnwire.CustomRecordsBlobTlvType](
recordsBytes,
)
htlcAdd.CustomRecordsBlob = tlv.SomeRecordT(tlvRecord)
// Generate the raw encoded sphinx packet to be included along // Generate the raw encoded sphinx packet to be included along
// with the htlcAdd message that we send directly to the // with the htlcAdd message that we send directly to the
// switch. // switch.

View File

@@ -713,7 +713,7 @@ func (r *ChannelRouter) Start() error {
// be tried. // be tried.
_, _, err := r.sendPayment( _, _, err := r.sendPayment(
0, payment.Info.PaymentIdentifier, 0, 0, payment.Info.PaymentIdentifier, 0,
paySession, shardTracker, paySession, shardTracker, nil,
) )
if err != nil { if err != nil {
log.Errorf("Resuming payment %v failed: %v.", log.Errorf("Resuming payment %v failed: %v.",
@@ -2308,6 +2308,11 @@ type LightningPayment struct {
// fail. // fail.
DestCustomRecords record.CustomSet DestCustomRecords record.CustomSet
// FirstHopCustomRecords are the TLV records that are to be sent to the
// first hop of this payment. These records will be transmitted via the
// wire message and therefore do not affect the onion payload size.
FirstHopCustomRecords record.CustomSet
// MaxParts is the maximum number of partial payments that may be used // MaxParts is the maximum number of partial payments that may be used
// to complete the full amount. // to complete the full amount.
MaxParts uint32 MaxParts uint32
@@ -2393,6 +2398,7 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
return r.sendPayment( return r.sendPayment(
payment.FeeLimit, payment.Identifier(), payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker, payment.PayAttemptTimeout, paySession, shardTracker,
payment.FirstHopCustomRecords,
) )
} }
@@ -2413,6 +2419,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
_, _, err := r.sendPayment( _, _, err := r.sendPayment(
payment.FeeLimit, payment.Identifier(), payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, ps, st, payment.PayAttemptTimeout, ps, st,
payment.FirstHopCustomRecords,
) )
if err != nil { if err != nil {
log.Errorf("Payment %x failed: %v", log.Errorf("Payment %x failed: %v",
@@ -2587,7 +2594,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// - no payment timeout. // - no payment timeout.
// - no current block height. // - no current block height.
p := newPaymentLifecycle( p := newPaymentLifecycle(
r, 0, paymentIdentifier, nil, shardTracker, 0, 0, r, 0, paymentIdentifier, nil, shardTracker, 0, 0, nil,
) )
// We found a route to try, create a new HTLC attempt to try. // We found a route to try, create a new HTLC attempt to try.
@@ -2683,8 +2690,8 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// the ControlTower. // the ControlTower.
func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi, func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, timeout time.Duration, identifier lntypes.Hash, timeout time.Duration,
paySession PaymentSession, paySession PaymentSession, shardTracker shards.ShardTracker,
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) { firstHopTLVs record.CustomSet) ([32]byte, *route.Route, error) {
// We'll also fetch the current block height so we can properly // We'll also fetch the current block height so we can properly
// calculate the required HTLC time locks within the route. // calculate the required HTLC time locks within the route.
@@ -2697,7 +2704,7 @@ func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
// can resume the payment from the current state. // can resume the payment from the current state.
p := newPaymentLifecycle( p := newPaymentLifecycle(
r, feeLimit, identifier, paySession, r, feeLimit, identifier, paySession,
shardTracker, timeout, currentHeight, shardTracker, timeout, currentHeight, firstHopTLVs,
) )
return p.resumePayment() return p.resumePayment()