diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index aa500e3aa..e06163d48 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" ) @@ -645,15 +646,16 @@ func (f *interceptedForward) Packet() InterceptedPacket { ChanID: f.packet.incomingChanID, HtlcID: f.packet.incomingHTLCID, }, - OutgoingChanID: f.packet.outgoingChanID, - Hash: f.htlc.PaymentHash, - OutgoingExpiry: f.htlc.Expiry, - OutgoingAmount: f.htlc.Amount, - IncomingAmount: f.packet.incomingAmount, - IncomingExpiry: f.packet.incomingTimeout, - CustomRecords: f.packet.customRecords, - OnionBlob: f.htlc.OnionBlob, - AutoFailHeight: f.autoFailHeight, + OutgoingChanID: f.packet.outgoingChanID, + Hash: f.htlc.PaymentHash, + OutgoingExpiry: f.htlc.Expiry, + OutgoingAmount: f.htlc.Amount, + IncomingAmount: f.packet.incomingAmount, + IncomingExpiry: f.packet.incomingTimeout, + InOnionCustomRecords: f.packet.inOnionCustomRecords, + OnionBlob: f.htlc.OnionBlob, + AutoFailHeight: f.autoFailHeight, + InWireCustomRecords: f.packet.inWireCustomRecords, } } @@ -723,6 +725,8 @@ func (f *interceptedForward) ResumeModified( } } + log.Tracef("Forwarding packet %v", lnutils.SpewLogClosure(f.packet)) + // Forward to the switch. A link quit channel isn't needed, because we // are on a different thread now. return f.htlcSwitch.ForwardPackets(nil, f.packet) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 5e808f42a..f5b3bbe98 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -357,13 +357,17 @@ type InterceptedPacket struct { // IncomingAmount is the amount of the accepted htlc. IncomingAmount lnwire.MilliSatoshi - // CustomRecords are user-defined records in the custom type range that - // were included in the payload. - CustomRecords record.CustomSet + // InOnionCustomRecords are user-defined records in the custom type + // range that were included in the payload. + InOnionCustomRecords record.CustomSet // OnionBlob is the onion packet for the next hop OnionBlob [lnwire.OnionPacketSize]byte + // InWireCustomRecords are user-defined p2p wire message records that + // were defined by the peer that forwarded this HTLC to us. + InWireCustomRecords lnwire.CustomRecords + // AutoFailHeight is the block height at which this intercept will be // failed back automatically. AutoFailHeight int32 diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 81c82cfa2..57d531935 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -3630,7 +3630,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, } // Otherwise, it was already processed, we can - // can collect it and continue. + // collect it and continue. addMsg := &lnwire.UpdateAddHTLC{ Expiry: fwdInfo.OutgoingCTLV, Amount: fwdInfo.AmountToForward, @@ -3650,19 +3650,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, inboundFee := l.cfg.FwrdingPolicy.InboundFee + //nolint:lll updatePacket := &htlcPacket{ - incomingChanID: l.ShortChanID(), - incomingHTLCID: pd.HtlcIndex, - outgoingChanID: fwdInfo.NextHop, - sourceRef: pd.SourceRef, - incomingAmount: pd.Amount, - amount: addMsg.Amount, - htlc: addMsg, - obfuscator: obfuscator, - incomingTimeout: pd.Timeout, - outgoingTimeout: fwdInfo.OutgoingCTLV, - customRecords: pld.CustomRecords(), - inboundFee: inboundFee, + incomingChanID: l.ShortChanID(), + incomingHTLCID: pd.HtlcIndex, + outgoingChanID: fwdInfo.NextHop, + sourceRef: pd.SourceRef, + incomingAmount: pd.Amount, + amount: addMsg.Amount, + htlc: addMsg, + obfuscator: obfuscator, + incomingTimeout: pd.Timeout, + outgoingTimeout: fwdInfo.OutgoingCTLV, + inOnionCustomRecords: pld.CustomRecords(), + inboundFee: inboundFee, + inWireCustomRecords: pd.CustomRecords.Copy(), } switchPackets = append( switchPackets, updatePacket, @@ -3718,19 +3720,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, if fwdPkg.State == channeldb.FwdStateLockedIn { inboundFee := l.cfg.FwrdingPolicy.InboundFee + //nolint:lll updatePacket := &htlcPacket{ - incomingChanID: l.ShortChanID(), - incomingHTLCID: pd.HtlcIndex, - outgoingChanID: fwdInfo.NextHop, - sourceRef: pd.SourceRef, - incomingAmount: pd.Amount, - amount: addMsg.Amount, - htlc: addMsg, - obfuscator: obfuscator, - incomingTimeout: pd.Timeout, - outgoingTimeout: fwdInfo.OutgoingCTLV, - customRecords: pld.CustomRecords(), - inboundFee: inboundFee, + incomingChanID: l.ShortChanID(), + incomingHTLCID: pd.HtlcIndex, + outgoingChanID: fwdInfo.NextHop, + sourceRef: pd.SourceRef, + incomingAmount: pd.Amount, + amount: addMsg.Amount, + htlc: addMsg, + obfuscator: obfuscator, + incomingTimeout: pd.Timeout, + outgoingTimeout: fwdInfo.OutgoingCTLV, + inOnionCustomRecords: pld.CustomRecords(), + inboundFee: inboundFee, + inWireCustomRecords: pd.CustomRecords.Copy(), } fwdPkg.FwdFilter.Set(idx) diff --git a/htlcswitch/packet.go b/htlcswitch/packet.go index 45f4e465b..31639dd5d 100644 --- a/htlcswitch/packet.go +++ b/htlcswitch/packet.go @@ -94,9 +94,13 @@ type htlcPacket struct { // link. outgoingTimeout uint32 - // customRecords are user-defined records in the custom type range that - // were included in the payload. - customRecords record.CustomSet + // inOnionCustomRecords are user-defined records in the custom type + // range that were included in the onion payload. + inOnionCustomRecords record.CustomSet + + // inWireCustomRecords are custom type range TLVs that are included + // in the incoming update_add_htlc wire message. + inWireCustomRecords lnwire.CustomRecords // originalOutgoingChanID is used when sending back failure messages. // It is only used for forwarded Adds on option_scid_alias channels. diff --git a/lnrpc/routerrpc/forward_interceptor.go b/lnrpc/routerrpc/forward_interceptor.go index 4ed497034..70b9146fd 100644 --- a/lnrpc/routerrpc/forward_interceptor.go +++ b/lnrpc/routerrpc/forward_interceptor.go @@ -89,7 +89,7 @@ func (r *forwardInterceptor) onIntercept( OutgoingExpiry: htlc.OutgoingExpiry, IncomingAmountMsat: uint64(htlc.IncomingAmount), IncomingExpiry: htlc.IncomingExpiry, - CustomRecords: htlc.CustomRecords, + CustomRecords: htlc.InOnionCustomRecords, OnionBlob: htlc.OnionBlob[:], AutoFailHeight: htlc.AutoFailHeight, } diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 2338f6c10..7af99f527 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -207,6 +207,7 @@ func PayDescsFromRemoteLogUpdates(chanID lnwire.ShortChannelID, height uint64, Index: uint16(i), }, BlindingPoint: wireMsg.BlindingPoint, + CustomRecords: wireMsg.CustomRecords.Copy(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -1154,6 +1155,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, LogIndex: logUpdate.LogIndex, addCommitHeightRemote: commitHeight, BlindingPoint: wireMsg.BlindingPoint, + CustomRecords: wireMsg.CustomRecords.Copy(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -1359,6 +1361,7 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd LogIndex: logUpdate.LogIndex, addCommitHeightLocal: commitHeight, BlindingPoint: wireMsg.BlindingPoint, + CustomRecords: wireMsg.CustomRecords.Copy(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob, wireMsg.OnionBlob[:]) @@ -3403,6 +3406,7 @@ func (lc *LightningChannel) createCommitDiff(newCommit *commitment, Expiry: pd.Timeout, PaymentHash: pd.RHash, BlindingPoint: pd.BlindingPoint, + CustomRecords: pd.CustomRecords.Copy(), } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc @@ -3543,6 +3547,7 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate { Expiry: pd.Timeout, PaymentHash: pd.RHash, BlindingPoint: pd.BlindingPoint, + CustomRecords: pd.CustomRecords.Copy(), } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc @@ -5620,6 +5625,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( Expiry: pd.Timeout, PaymentHash: pd.RHash, BlindingPoint: pd.BlindingPoint, + CustomRecords: pd.CustomRecords.Copy(), } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc @@ -5965,9 +5971,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, OnionBlob: htlc.OnionBlob[:], OpenCircuitKey: openKey, BlindingPoint: htlc.BlindingPoint, - // TODO(guggero): Add custom records from HTLC here once we have - // the custom records in the HTLC struct (later commits in this - // PR). + CustomRecords: htlc.CustomRecords.Copy(), } } @@ -6028,9 +6032,7 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, HtlcIndex: lc.updateLogs.Remote.htlcCounter, OnionBlob: htlc.OnionBlob[:], BlindingPoint: htlc.BlindingPoint, - // TODO(guggero): Add custom records from HTLC here once we have - // the custom records in the HTLC struct (later commits in this - // PR). + CustomRecords: htlc.CustomRecords.Copy(), } localACKedIndex := lc.commitChains.Remote.tail().messageIndices.Local diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 5244d4d63..2c7fd3b21 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -25,12 +25,13 @@ var ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting") // paymentLifecycle holds all information about the current state of a payment // needed to resume if from any point. type paymentLifecycle struct { - router *ChannelRouter - feeLimit lnwire.MilliSatoshi - identifier lntypes.Hash - paySession PaymentSession - shardTracker shards.ShardTracker - currentHeight int32 + router *ChannelRouter + feeLimit lnwire.MilliSatoshi + identifier lntypes.Hash + paySession PaymentSession + shardTracker shards.ShardTracker + currentHeight int32 + firstHopCustomRecords lnwire.CustomRecords // quit is closed to signal the sub goroutines of the payment lifecycle // to stop. @@ -52,18 +53,19 @@ type paymentLifecycle struct { // newPaymentLifecycle initiates a new payment lifecycle and returns it. func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash, paySession PaymentSession, - shardTracker shards.ShardTracker, - currentHeight int32) *paymentLifecycle { + shardTracker shards.ShardTracker, currentHeight int32, + firstHopCustomRecords lnwire.CustomRecords) *paymentLifecycle { p := &paymentLifecycle{ - router: r, - feeLimit: feeLimit, - identifier: identifier, - paySession: paySession, - shardTracker: shardTracker, - currentHeight: currentHeight, - quit: make(chan struct{}), - resultCollected: make(chan error, 1), + router: r, + feeLimit: feeLimit, + identifier: identifier, + paySession: paySession, + shardTracker: shardTracker, + currentHeight: currentHeight, + quit: make(chan struct{}), + resultCollected: make(chan error, 1), + firstHopCustomRecords: firstHopCustomRecords, } // Mount the result collector. @@ -677,9 +679,10 @@ func (p *paymentLifecycle) sendAttempt( // this packet will be used to route the payment through the network, // starting with the first-hop. htlcAdd := &lnwire.UpdateAddHTLC{ - Amount: rt.TotalAmount, - Expiry: rt.TotalTimeLock, - PaymentHash: *attempt.Hash, + Amount: rt.TotalAmount, + Expiry: rt.TotalTimeLock, + PaymentHash: *attempt.Hash, + CustomRecords: p.firstHopCustomRecords, } // Generate the raw encoded sphinx packet to be included along diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 34c8d6c17..4df27523f 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -89,7 +89,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { // Create a test payment lifecycle with no fee limit and no timeout. p := newPaymentLifecycle( rt, noFeeLimit, paymentHash, mockPaymentSession, - mockShardTracker, 0, + mockShardTracker, 0, nil, ) // Create a mock payment which is returned from mockControlTower. diff --git a/routing/router.go b/routing/router.go index 0b6a9beac..0ec5976ae 100644 --- a/routing/router.go +++ b/routing/router.go @@ -865,6 +865,11 @@ type LightningPayment struct { // fail. DestCustomRecords record.CustomSet + // FirstHopCustomRecords are the TLV records that are to be sent to the + // first hop of this payment. These records will be transmitted via the + // wire message and therefore do not affect the onion payload size. + FirstHopCustomRecords lnwire.CustomRecords + // MaxParts is the maximum number of partial payments that may be used // to complete the full amount. MaxParts uint32 @@ -948,6 +953,7 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte, return r.sendPayment( context.Background(), payment.FeeLimit, payment.Identifier(), payment.PayAttemptTimeout, paySession, shardTracker, + payment.FirstHopCustomRecords, ) } @@ -968,6 +974,7 @@ func (r *ChannelRouter) SendPaymentAsync(ctx context.Context, _, _, err := r.sendPayment( ctx, payment.FeeLimit, payment.Identifier(), payment.PayAttemptTimeout, ps, st, + payment.FirstHopCustomRecords, ) if err != nil { log.Errorf("Payment %x failed: %v", @@ -1141,7 +1148,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // - nil payment session (since we already have a route). // - no payment timeout. // - no current block height. - p := newPaymentLifecycle(r, 0, paymentIdentifier, nil, shardTracker, 0) + p := newPaymentLifecycle( + r, 0, paymentIdentifier, nil, shardTracker, 0, nil, + ) // We found a route to try, create a new HTLC attempt to try. // @@ -1237,7 +1246,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, func (r *ChannelRouter) sendPayment(ctx context.Context, feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash, paymentAttemptTimeout time.Duration, paySession PaymentSession, - shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) { + shardTracker shards.ShardTracker, + firstHopCustomRecords lnwire.CustomRecords) ([32]byte, *route.Route, + error) { // If the user provides a timeout, we will additionally wrap the context // in a deadline. @@ -1262,7 +1273,7 @@ func (r *ChannelRouter) sendPayment(ctx context.Context, // can resume the payment from the current state. p := newPaymentLifecycle( r, feeLimit, identifier, paySession, shardTracker, - currentHeight, + currentHeight, firstHopCustomRecords, ) return p.resumePayment(ctx) @@ -1465,7 +1476,7 @@ func (r *ChannelRouter) resumePayments() error { noTimeout := time.Duration(0) _, _, err := r.sendPayment( context.Background(), 0, payHash, noTimeout, paySession, - shardTracker, + shardTracker, nil, ) if err != nil { log.Errorf("Resuming payment %v failed: %v", payHash, diff --git a/witness_beacon.go b/witness_beacon.go index 4fd44e283..2bc3c0850 100644 --- a/witness_beacon.go +++ b/witness_beacon.go @@ -101,10 +101,11 @@ func (p *preimageBeacon) SubscribeUpdates( ChanID: chanID, HtlcID: htlc.HtlcIndex, }, - OutgoingChanID: payload.FwdInfo.NextHop, - OutgoingExpiry: payload.FwdInfo.OutgoingCTLV, - OutgoingAmount: payload.FwdInfo.AmountToForward, - CustomRecords: payload.CustomRecords(), + OutgoingChanID: payload.FwdInfo.NextHop, + OutgoingExpiry: payload.FwdInfo.OutgoingCTLV, + OutgoingAmount: payload.FwdInfo.AmountToForward, + InOnionCustomRecords: payload.CustomRecords(), + InWireCustomRecords: htlc.CustomRecords, } copy(packet.OnionBlob[:], nextHopOnionBlob)