routing: make msg and index optional

This is later used to handle their nil values.
This commit is contained in:
bitromortac
2025-04-30 14:36:57 +02:00
parent bc4229b32e
commit e5c541407f
2 changed files with 61 additions and 10 deletions

View File

@@ -845,10 +845,15 @@ func newPaymentFailure(sourceIdx *int,
}
info := paymentFailureInfo{
sourceIdx: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint8(*sourceIdx),
sourceIdx: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType0](
uint8(*sourceIdx),
)),
msg: tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType1](
failureMessage{failureMsg},
),
),
msg: tlv.NewRecordT[tlv.TlvType1](failureMessage{failureMsg}),
}
return &paymentFailure{
@@ -921,8 +926,15 @@ func decodePaymentFailure(r io.Reader, val interface{}, _ *[8]byte,
// paymentFailureInfo holds additional information about a payment failure.
type paymentFailureInfo struct {
sourceIdx tlv.RecordT[tlv.TlvType0, uint8]
msg tlv.RecordT[tlv.TlvType1, failureMessage]
// sourceIdx is the hop the error was reported from. In order to be able
// to decrypt the error message, we need to know the source, which is
// why an error message can only be present if the source is known.
sourceIdx tlv.OptionalRecordT[tlv.TlvType0, uint8]
// msg is the error why a payment failed. If we identify the failure of
// a certain hop at the above index, but aren't able to decode the
// failure message we indicate this by not setting this field.
msg tlv.OptionalRecordT[tlv.TlvType1, failureMessage]
}
// Record returns a TLV record that can be used to encode/decode a
@@ -948,9 +960,27 @@ func (r *paymentFailureInfo) Record() tlv.Record {
func encodePaymentFailureInfo(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*paymentFailureInfo); ok {
var recordProducers []tlv.RecordProducer
v.sourceIdx.WhenSome(
func(r tlv.RecordT[tlv.TlvType0, uint8]) {
recordProducers = append(
recordProducers, &r,
)
},
)
v.msg.WhenSome(
func(r tlv.RecordT[tlv.TlvType1, failureMessage]) {
recordProducers = append(
recordProducers, &r,
)
},
)
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(
&v.sourceIdx, &v.msg,
recordProducers...,
),
)
}
@@ -964,14 +994,26 @@ func decodePaymentFailureInfo(r io.Reader, val interface{}, _ *[8]byte,
if v, ok := val.(*paymentFailureInfo); ok {
var h paymentFailureInfo
_, err := lnwire.DecodeRecords(
sourceIdx := tlv.ZeroRecordT[tlv.TlvType0, uint8]()
msg := tlv.ZeroRecordT[tlv.TlvType1, failureMessage]()
typeMap, err := lnwire.DecodeRecords(
r,
lnwire.ProduceRecordsSorted(&h.sourceIdx, &h.msg)...,
lnwire.ProduceRecordsSorted(&sourceIdx, &msg)...,
)
if err != nil {
return err
}
if _, ok := typeMap[h.sourceIdx.TlvType()]; ok {
h.sourceIdx = tlv.SomeRecordT(sourceIdx)
}
if _, ok := typeMap[h.msg.TlvType()]; ok {
h.msg = tlv.SomeRecordT(msg)
}
*v = h
return nil

View File

@@ -138,8 +138,17 @@ func (i *interpretedResult) processFail(rt *mcRoute, failure paymentFailure) {
failure.info.WhenSome(
func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) {
idx = int(r.Val.sourceIdx.Val)
failMsg = r.Val.msg.Val.FailureMessage
r.Val.sourceIdx.WhenSome(
func(r tlv.RecordT[tlv.TlvType0, uint8]) {
idx = int(r.Val)
},
)
r.Val.msg.WhenSome(
func(r tlv.RecordT[tlv.TlvType1, failureMessage]) {
failMsg = r.Val.FailureMessage
},
)
},
)