record+htlcswitch: convert BlindedRouteData fields to optional

For the final hop in a blinded route, the SCID and RelayInfo fields will
_not_ be set. So these fields need to be converted to optional records.

The existing BlindedRouteData constructor is also renamed to
`NewNonFinalBlindedRouteData` in preparation for a
`NewFinalBlindedRouteData` constructor which will be used to construct
the blinded data for the final hop which will contain a much smaller set
of data. The SCID and RelayInfo parameters of the constructor are left
as non-pointers in order to force the caller to set them in the case
that the constructor is called for non-final nodes. The other option
would be to create a single constructor where all parameters are
optional but I think this makes it easier for the caller to make a
mistake.
This commit is contained in:
Elle Mouton
2024-05-02 14:22:34 +02:00
parent 925b68c1ed
commit ad0905f10e
6 changed files with 71 additions and 28 deletions

View File

@@ -360,6 +360,7 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
if err != nil {
return nil, err
}
// Validate the data in the blinded route against our incoming htlc's
// information.
if err := ValidateBlindedRouteData(
@@ -368,9 +369,31 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
return nil, err
}
// Exit early if this onion is for the exit hop of the route since
// route blinding receives are not yet supported.
if isFinalHop {
return nil, fmt.Errorf("being the final hop in a blinded " +
"path is not yet supported")
}
// At this point, we know we are a forwarding node for this onion
// and so we expect the relay info and next SCID fields to be set.
relayInfo, err := routeData.RelayInfo.UnwrapOrErr(
fmt.Errorf("relay info not set for non-final blinded hop"),
)
if err != nil {
return nil, err
}
nextSCID, err := routeData.ShortChannelID.UnwrapOrErr(
fmt.Errorf("next SCID not set for non-final blinded hop"),
)
if err != nil {
return nil, err
}
fwdAmt, err := calculateForwardingAmount(
b.IncomingAmount, routeData.RelayInfo.Val.BaseFee,
routeData.RelayInfo.Val.FeeRate,
b.IncomingAmount, relayInfo.Val.BaseFee, relayInfo.Val.FeeRate,
)
if err != nil {
return nil, err
@@ -400,10 +423,10 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
}
return &ForwardingInfo{
NextHop: routeData.ShortChannelID.Val,
NextHop: nextSCID.Val,
AmountToForward: fwdAmt,
OutgoingCTLV: b.IncomingCltv - uint32(
routeData.RelayInfo.Val.CltvExpiryDelta,
relayInfo.Val.CltvExpiryDelta,
),
// Remap from blinding override type to blinding point type.
NextBlinding: tlv.SomeRecordT(

View File

@@ -186,7 +186,7 @@ func TestDecryptAndValidateFwdInfo(t *testing.T) {
// Encode valid blinding data that we'll fake decrypting for our test.
maxCltv := 1000
blindedData := record.NewBlindedRouteData(
blindedData := record.NewNonFinalBlindedRouteData(
lnwire.NewShortChanIDFromInt(1500), nil,
record.PaymentRelayInfo{
CltvExpiryDelta: 10,

View File

@@ -646,7 +646,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
}{
{
name: "max cltv expired",
data: record.NewBlindedRouteData(
data: record.NewNonFinalBlindedRouteData(
scid,
nil,
record.PaymentRelayInfo{},
@@ -663,7 +663,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
},
{
name: "zero max cltv",
data: record.NewBlindedRouteData(
data: record.NewNonFinalBlindedRouteData(
scid,
nil,
record.PaymentRelayInfo{},
@@ -682,7 +682,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
},
{
name: "amount below minimum",
data: record.NewBlindedRouteData(
data: record.NewNonFinalBlindedRouteData(
scid,
nil,
record.PaymentRelayInfo{},
@@ -699,7 +699,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
},
{
name: "valid, no features",
data: record.NewBlindedRouteData(
data: record.NewNonFinalBlindedRouteData(
scid,
nil,
record.PaymentRelayInfo{},
@@ -714,7 +714,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
},
{
name: "unknown features",
data: record.NewBlindedRouteData(
data: record.NewNonFinalBlindedRouteData(
scid,
nil,
record.PaymentRelayInfo{},
@@ -738,7 +738,7 @@ func TestValidateBlindedRouteData(t *testing.T) {
},
{
name: "valid data",
data: record.NewBlindedRouteData(
data: record.NewNonFinalBlindedRouteData(
scid,
nil,
record.PaymentRelayInfo{

View File

@@ -676,7 +676,7 @@ func (b *blindedForwardTest) createBlindedRoute(hops []*forwardingEdge,
// Encode the route's blinded data and include it in the
// blinded hop.
payload := record.NewBlindedRouteData(
payload := record.NewNonFinalBlindedRouteData(
scid, nil, *relayInfo, constraints, nil,
)
payloadBytes, err := record.EncodeBlindedRouteData(payload)
@@ -739,7 +739,7 @@ func (b *blindedForwardTest) createBlindedRoute(hops []*forwardingEdge,
// node ID here so that it _looks like_ a valid
// forwarding hop (though in reality it's the last
// hop).
record.NewBlindedRouteData(
record.NewNonFinalBlindedRouteData(
lnwire.NewShortChanIDFromInt(100), nil,
record.PaymentRelayInfo{}, nil, nil,
),

View File

@@ -15,7 +15,7 @@ import (
// forwarding information.
type BlindedRouteData struct {
// ShortChannelID is the channel ID of the next hop.
ShortChannelID tlv.RecordT[tlv.TlvType2, lnwire.ShortChannelID]
ShortChannelID tlv.OptionalRecordT[tlv.TlvType2, lnwire.ShortChannelID]
// NextBlindingOverride is a blinding point that should be switched
// in for the next hop. This is used to combine two blinded paths into
@@ -24,7 +24,7 @@ type BlindedRouteData struct {
NextBlindingOverride tlv.OptionalRecordT[tlv.TlvType8, *btcec.PublicKey]
// RelayInfo provides the relay parameters for the hop.
RelayInfo tlv.RecordT[tlv.TlvType10, PaymentRelayInfo]
RelayInfo tlv.OptionalRecordT[tlv.TlvType10, PaymentRelayInfo]
// Constraints provides the payment relay constraints for the hop.
Constraints tlv.OptionalRecordT[tlv.TlvType12, PaymentConstraints]
@@ -33,16 +33,20 @@ type BlindedRouteData struct {
Features tlv.OptionalRecordT[tlv.TlvType14, lnwire.FeatureVector]
}
// NewBlindedRouteData creates the data that's provided for hops within a
// blinded route.
func NewBlindedRouteData(chanID lnwire.ShortChannelID,
// NewNonFinalBlindedRouteData creates the data that's provided for hops within
// a blinded route.
func NewNonFinalBlindedRouteData(chanID lnwire.ShortChannelID,
blindingOverride *btcec.PublicKey, relayInfo PaymentRelayInfo,
constraints *PaymentConstraints,
features *lnwire.FeatureVector) *BlindedRouteData {
info := &BlindedRouteData{
ShortChannelID: tlv.NewRecordT[tlv.TlvType2](chanID),
RelayInfo: tlv.NewRecordT[tlv.TlvType10](relayInfo),
ShortChannelID: tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType2](chanID),
),
RelayInfo: tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType10](relayInfo),
),
}
if blindingOverride != nil {
@@ -69,7 +73,9 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) {
var (
d BlindedRouteData
scid = d.ShortChannelID.Zero()
blindingOverride = d.NextBlindingOverride.Zero()
relayInfo = d.RelayInfo.Zero()
constraints = d.Constraints.Zero()
features = d.Features.Zero()
)
@@ -80,19 +86,25 @@ func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) {
}
typeMap, err := tlvRecords.ExtractRecords(
&d.ShortChannelID,
&blindingOverride, &d.RelayInfo.Val, &constraints,
&features,
&scid, &blindingOverride, &relayInfo, &constraints, &features,
)
if err != nil {
return nil, err
}
if val, ok := typeMap[d.ShortChannelID.TlvType()]; ok && val == nil {
d.ShortChannelID = tlv.SomeRecordT(scid)
}
val, ok := typeMap[d.NextBlindingOverride.TlvType()]
if ok && val == nil {
d.NextBlindingOverride = tlv.SomeRecordT(blindingOverride)
}
if val, ok := typeMap[d.RelayInfo.TlvType()]; ok && val == nil {
d.RelayInfo = tlv.SomeRecordT(relayInfo)
}
if val, ok := typeMap[d.Constraints.TlvType()]; ok && val == nil {
d.Constraints = tlv.SomeRecordT(constraints)
}
@@ -111,7 +123,11 @@ func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) {
recordProducers = make([]tlv.RecordProducer, 0, 5)
)
recordProducers = append(recordProducers, &data.ShortChannelID)
data.ShortChannelID.WhenSome(func(scid tlv.RecordT[tlv.TlvType2,
lnwire.ShortChannelID]) {
recordProducers = append(recordProducers, &scid)
})
data.NextBlindingOverride.WhenSome(func(pk tlv.RecordT[tlv.TlvType8,
*btcec.PublicKey]) {
@@ -119,7 +135,11 @@ func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) {
recordProducers = append(recordProducers, &pk)
})
recordProducers = append(recordProducers, &data.RelayInfo.Val)
data.RelayInfo.WhenSome(func(r tlv.RecordT[tlv.TlvType10,
PaymentRelayInfo]) {
recordProducers = append(recordProducers, &r)
})
data.Constraints.WhenSome(func(cs tlv.RecordT[tlv.TlvType12,
PaymentConstraints]) {

View File

@@ -101,7 +101,7 @@ func TestBlindedDataEncoding(t *testing.T) {
}
}
encodedData := NewBlindedRouteData(
encodedData := NewNonFinalBlindedRouteData(
channelID, pubkey(t), info, constraints,
testCase.features,
)
@@ -134,7 +134,7 @@ func TestBlindingSpecTestVectors(t *testing.T) {
}{
{
encoded: "011a0000000000000000000000000000000000000000000000000000020800000000000006c10a0800240000009627100c06000b69e505dc0e00fd023103123456",
expectedPaymentData: NewBlindedRouteData(
expectedPaymentData: NewNonFinalBlindedRouteData(
lnwire.ShortChannelID{
BlockHeight: 0,
TxIndex: 0,
@@ -158,7 +158,7 @@ func TestBlindingSpecTestVectors(t *testing.T) {
},
{
encoded: "020800000000000004510821031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f0a0800300000006401f40c06000b69c105dc0e00",
expectedPaymentData: NewBlindedRouteData(
expectedPaymentData: NewNonFinalBlindedRouteData(
lnwire.ShortChannelID{
TxPosition: 1105,
},