mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-27 14:11:04 +02:00
record+htlcswitch: convert BlindedRouteData fields to optional
For the final hop in a blinded route, the SCID and RelayInfo fields will _not_ be set. So these fields need to be converted to optional records. The existing BlindedRouteData constructor is also renamed to `NewNonFinalBlindedRouteData` in preparation for a `NewFinalBlindedRouteData` constructor which will be used to construct the blinded data for the final hop which will contain a much smaller set of data. The SCID and RelayInfo parameters of the constructor are left as non-pointers in order to force the caller to set them in the case that the constructor is called for non-final nodes. The other option would be to create a single constructor where all parameters are optional but I think this makes it easier for the caller to make a mistake.
This commit is contained in:
@@ -360,6 +360,7 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate the data in the blinded route against our incoming htlc's
|
||||
// information.
|
||||
if err := ValidateBlindedRouteData(
|
||||
@@ -368,9 +369,31 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exit early if this onion is for the exit hop of the route since
|
||||
// route blinding receives are not yet supported.
|
||||
if isFinalHop {
|
||||
return nil, fmt.Errorf("being the final hop in a blinded " +
|
||||
"path is not yet supported")
|
||||
}
|
||||
|
||||
// At this point, we know we are a forwarding node for this onion
|
||||
// and so we expect the relay info and next SCID fields to be set.
|
||||
relayInfo, err := routeData.RelayInfo.UnwrapOrErr(
|
||||
fmt.Errorf("relay info not set for non-final blinded hop"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nextSCID, err := routeData.ShortChannelID.UnwrapOrErr(
|
||||
fmt.Errorf("next SCID not set for non-final blinded hop"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fwdAmt, err := calculateForwardingAmount(
|
||||
b.IncomingAmount, routeData.RelayInfo.Val.BaseFee,
|
||||
routeData.RelayInfo.Val.FeeRate,
|
||||
b.IncomingAmount, relayInfo.Val.BaseFee, relayInfo.Val.FeeRate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -400,10 +423,10 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
|
||||
}
|
||||
|
||||
return &ForwardingInfo{
|
||||
NextHop: routeData.ShortChannelID.Val,
|
||||
NextHop: nextSCID.Val,
|
||||
AmountToForward: fwdAmt,
|
||||
OutgoingCTLV: b.IncomingCltv - uint32(
|
||||
routeData.RelayInfo.Val.CltvExpiryDelta,
|
||||
relayInfo.Val.CltvExpiryDelta,
|
||||
),
|
||||
// Remap from blinding override type to blinding point type.
|
||||
NextBlinding: tlv.SomeRecordT(
|
||||
|
@@ -186,7 +186,7 @@ func TestDecryptAndValidateFwdInfo(t *testing.T) {
|
||||
|
||||
// Encode valid blinding data that we'll fake decrypting for our test.
|
||||
maxCltv := 1000
|
||||
blindedData := record.NewBlindedRouteData(
|
||||
blindedData := record.NewNonFinalBlindedRouteData(
|
||||
lnwire.NewShortChanIDFromInt(1500), nil,
|
||||
record.PaymentRelayInfo{
|
||||
CltvExpiryDelta: 10,
|
||||
|
@@ -646,7 +646,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "max cltv expired",
|
||||
data: record.NewBlindedRouteData(
|
||||
data: record.NewNonFinalBlindedRouteData(
|
||||
scid,
|
||||
nil,
|
||||
record.PaymentRelayInfo{},
|
||||
@@ -663,7 +663,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "zero max cltv",
|
||||
data: record.NewBlindedRouteData(
|
||||
data: record.NewNonFinalBlindedRouteData(
|
||||
scid,
|
||||
nil,
|
||||
record.PaymentRelayInfo{},
|
||||
@@ -682,7 +682,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "amount below minimum",
|
||||
data: record.NewBlindedRouteData(
|
||||
data: record.NewNonFinalBlindedRouteData(
|
||||
scid,
|
||||
nil,
|
||||
record.PaymentRelayInfo{},
|
||||
@@ -699,7 +699,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "valid, no features",
|
||||
data: record.NewBlindedRouteData(
|
||||
data: record.NewNonFinalBlindedRouteData(
|
||||
scid,
|
||||
nil,
|
||||
record.PaymentRelayInfo{},
|
||||
@@ -714,7 +714,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "unknown features",
|
||||
data: record.NewBlindedRouteData(
|
||||
data: record.NewNonFinalBlindedRouteData(
|
||||
scid,
|
||||
nil,
|
||||
record.PaymentRelayInfo{},
|
||||
@@ -738,7 +738,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "valid data",
|
||||
data: record.NewBlindedRouteData(
|
||||
data: record.NewNonFinalBlindedRouteData(
|
||||
scid,
|
||||
nil,
|
||||
record.PaymentRelayInfo{
|
||||
|
@@ -676,7 +676,7 @@ func (b *blindedForwardTest) createBlindedRoute(hops []*forwardingEdge,
|
||||
|
||||
// Encode the route's blinded data and include it in the
|
||||
// blinded hop.
|
||||
payload := record.NewBlindedRouteData(
|
||||
payload := record.NewNonFinalBlindedRouteData(
|
||||
scid, nil, *relayInfo, constraints, nil,
|
||||
)
|
||||
payloadBytes, err := record.EncodeBlindedRouteData(payload)
|
||||
@@ -739,7 +739,7 @@ func (b *blindedForwardTest) createBlindedRoute(hops []*forwardingEdge,
|
||||
// node ID here so that it _looks like_ a valid
|
||||
// forwarding hop (though in reality it's the last
|
||||
// hop).
|
||||
record.NewBlindedRouteData(
|
||||
record.NewNonFinalBlindedRouteData(
|
||||
lnwire.NewShortChanIDFromInt(100), nil,
|
||||
record.PaymentRelayInfo{}, nil, nil,
|
||||
),
|
||||
|
@@ -15,7 +15,7 @@ import (
|
||||
// forwarding information.
|
||||
type BlindedRouteData struct {
|
||||
// ShortChannelID is the channel ID of the next hop.
|
||||
ShortChannelID tlv.RecordT[tlv.TlvType2, lnwire.ShortChannelID]
|
||||
ShortChannelID tlv.OptionalRecordT[tlv.TlvType2, lnwire.ShortChannelID]
|
||||
|
||||
// NextBlindingOverride is a blinding point that should be switched
|
||||
// in for the next hop. This is used to combine two blinded paths into
|
||||
@@ -24,7 +24,7 @@ type BlindedRouteData struct {
|
||||
NextBlindingOverride tlv.OptionalRecordT[tlv.TlvType8, *btcec.PublicKey]
|
||||
|
||||
// RelayInfo provides the relay parameters for the hop.
|
||||
RelayInfo tlv.RecordT[tlv.TlvType10, PaymentRelayInfo]
|
||||
RelayInfo tlv.OptionalRecordT[tlv.TlvType10, PaymentRelayInfo]
|
||||
|
||||
// Constraints provides the payment relay constraints for the hop.
|
||||
Constraints tlv.OptionalRecordT[tlv.TlvType12, PaymentConstraints]
|
||||
@@ -33,16 +33,20 @@ type BlindedRouteData struct {
|
||||
Features tlv.OptionalRecordT[tlv.TlvType14, lnwire.FeatureVector]
|
||||
}
|
||||
|
||||
// NewBlindedRouteData creates the data that's provided for hops within a
|
||||
// blinded route.
|
||||
func NewBlindedRouteData(chanID lnwire.ShortChannelID,
|
||||
// NewNonFinalBlindedRouteData creates the data that's provided for hops within
|
||||
// a blinded route.
|
||||
func NewNonFinalBlindedRouteData(chanID lnwire.ShortChannelID,
|
||||
blindingOverride *btcec.PublicKey, relayInfo PaymentRelayInfo,
|
||||
constraints *PaymentConstraints,
|
||||
features *lnwire.FeatureVector) *BlindedRouteData {
|
||||
|
||||
info := &BlindedRouteData{
|
||||
ShortChannelID: tlv.NewRecordT[tlv.TlvType2](chanID),
|
||||
RelayInfo: tlv.NewRecordT[tlv.TlvType10](relayInfo),
|
||||
ShortChannelID: tlv.SomeRecordT(
|
||||
tlv.NewRecordT[tlv.TlvType2](chanID),
|
||||
),
|
||||
RelayInfo: tlv.SomeRecordT(
|
||||
tlv.NewRecordT[tlv.TlvType10](relayInfo),
|
||||
),
|
||||
}
|
||||
|
||||
if blindingOverride != nil {
|
||||
@@ -69,7 +73,9 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) {
|
||||
var (
|
||||
d BlindedRouteData
|
||||
|
||||
scid = d.ShortChannelID.Zero()
|
||||
blindingOverride = d.NextBlindingOverride.Zero()
|
||||
relayInfo = d.RelayInfo.Zero()
|
||||
constraints = d.Constraints.Zero()
|
||||
features = d.Features.Zero()
|
||||
)
|
||||
@@ -80,19 +86,25 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) {
|
||||
}
|
||||
|
||||
typeMap, err := tlvRecords.ExtractRecords(
|
||||
&d.ShortChannelID,
|
||||
&blindingOverride, &d.RelayInfo.Val, &constraints,
|
||||
&features,
|
||||
&scid, &blindingOverride, &relayInfo, &constraints, &features,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if val, ok := typeMap[d.ShortChannelID.TlvType()]; ok && val == nil {
|
||||
d.ShortChannelID = tlv.SomeRecordT(scid)
|
||||
}
|
||||
|
||||
val, ok := typeMap[d.NextBlindingOverride.TlvType()]
|
||||
if ok && val == nil {
|
||||
d.NextBlindingOverride = tlv.SomeRecordT(blindingOverride)
|
||||
}
|
||||
|
||||
if val, ok := typeMap[d.RelayInfo.TlvType()]; ok && val == nil {
|
||||
d.RelayInfo = tlv.SomeRecordT(relayInfo)
|
||||
}
|
||||
|
||||
if val, ok := typeMap[d.Constraints.TlvType()]; ok && val == nil {
|
||||
d.Constraints = tlv.SomeRecordT(constraints)
|
||||
}
|
||||
@@ -111,7 +123,11 @@ func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) {
|
||||
recordProducers = make([]tlv.RecordProducer, 0, 5)
|
||||
)
|
||||
|
||||
recordProducers = append(recordProducers, &data.ShortChannelID)
|
||||
data.ShortChannelID.WhenSome(func(scid tlv.RecordT[tlv.TlvType2,
|
||||
lnwire.ShortChannelID]) {
|
||||
|
||||
recordProducers = append(recordProducers, &scid)
|
||||
})
|
||||
|
||||
data.NextBlindingOverride.WhenSome(func(pk tlv.RecordT[tlv.TlvType8,
|
||||
*btcec.PublicKey]) {
|
||||
@@ -119,7 +135,11 @@ func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) {
|
||||
recordProducers = append(recordProducers, &pk)
|
||||
})
|
||||
|
||||
recordProducers = append(recordProducers, &data.RelayInfo.Val)
|
||||
data.RelayInfo.WhenSome(func(r tlv.RecordT[tlv.TlvType10,
|
||||
PaymentRelayInfo]) {
|
||||
|
||||
recordProducers = append(recordProducers, &r)
|
||||
})
|
||||
|
||||
data.Constraints.WhenSome(func(cs tlv.RecordT[tlv.TlvType12,
|
||||
PaymentConstraints]) {
|
||||
|
@@ -101,7 +101,7 @@ func TestBlindedDataEncoding(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
encodedData := NewBlindedRouteData(
|
||||
encodedData := NewNonFinalBlindedRouteData(
|
||||
channelID, pubkey(t), info, constraints,
|
||||
testCase.features,
|
||||
)
|
||||
@@ -134,7 +134,7 @@ func TestBlindingSpecTestVectors(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
encoded: "011a0000000000000000000000000000000000000000000000000000020800000000000006c10a0800240000009627100c06000b69e505dc0e00fd023103123456",
|
||||
expectedPaymentData: NewBlindedRouteData(
|
||||
expectedPaymentData: NewNonFinalBlindedRouteData(
|
||||
lnwire.ShortChannelID{
|
||||
BlockHeight: 0,
|
||||
TxIndex: 0,
|
||||
@@ -158,7 +158,7 @@ func TestBlindingSpecTestVectors(t *testing.T) {
|
||||
},
|
||||
{
|
||||
encoded: "020800000000000004510821031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f0a0800300000006401f40c06000b69c105dc0e00",
|
||||
expectedPaymentData: NewBlindedRouteData(
|
||||
expectedPaymentData: NewNonFinalBlindedRouteData(
|
||||
lnwire.ShortChannelID{
|
||||
TxPosition: 1105,
|
||||
},
|
||||
|
Reference in New Issue
Block a user