Merge pull request #8334 from lightningnetwork/tlv-record-enchancements

tlv: various enhancements to the new RecordT type
This commit is contained in:
Olaoluwa Osuntokun 2024-01-05 13:57:55 -08:00 committed by GitHub
commit 2f04ce7c6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 381 additions and 105 deletions

View File

@ -20,10 +20,12 @@ package tlv
type tlvType{{ $index }} struct{} type tlvType{{ $index }} struct{}
func (t *tlvType{{ $index }}) typeVal() Type { func (t *tlvType{{ $index }}) TypeVal() Type {
return {{ $index }} return {{ $index }}
} }
func (t *tlvType{{ $index }}) tlv() {}
type TlvType{{ $index }} = *tlvType{{ $index }} type TlvType{{ $index }} = *tlvType{{ $index }}
{{- end }} {{- end }}
` `

View File

@ -62,11 +62,29 @@ func (t *RecordT[T, V]) Record() Record {
tlvRecord, ok := any(&t.Val).(RecordProducer) tlvRecord, ok := any(&t.Val).(RecordProducer)
if !ok { if !ok {
return MakePrimitiveRecord( return MakePrimitiveRecord(
t.recordType.typeVal(), &t.Val, t.recordType.TypeVal(), &t.Val,
) )
} }
return tlvRecord.Record() // To enforce proper usage of the RecordT type, we'll make a wrapper
// record that uses the proper internal type value.
ogRecord := tlvRecord.Record()
return Record{
value: ogRecord.value,
typ: t.recordType.TypeVal(),
staticSize: ogRecord.staticSize,
sizeFunc: ogRecord.sizeFunc,
encoder: ogRecord.encoder,
decoder: ogRecord.decoder,
}
}
// TlvType returns the type of the record. This is the value used to identify
// this type on the wire. This value is bound to the specified TlvType type
// param.
func (t *RecordT[T, V]) TlvType() Type {
return t.recordType.TypeVal()
} }
// OptionalRecordT is a high-order type that represents an optional TLV record. // OptionalRecordT is a high-order type that represents an optional TLV record.
@ -76,6 +94,29 @@ type OptionalRecordT[T TlvType, V any] struct {
fn.Option[RecordT[T, V]] fn.Option[RecordT[T, V]]
} }
// TlvType returns the type of the record. This is the value used to identify
// this type on the wire. This value is bound to the specified TlvType type
// param.
func (t *OptionalRecordT[T, V]) TlvType() Type {
zeroRecord := ZeroRecordT[T, V]()
return zeroRecord.TlvType()
}
// WhenSomeV executes the given function if the optional record is present.
// This operates on the inner most type, V, which is the value of the record.
func (t *OptionalRecordT[T, V]) WhenSomeV(f func(V)) {
t.Option.WhenSome(func(r RecordT[T, V]) {
f(r.Val)
})
}
// SomeRecordT creates a new OptionalRecordT type from a given RecordT type.
func SomeRecordT[T TlvType, V any](record RecordT[T, V]) OptionalRecordT[T, V] {
return OptionalRecordT[T, V]{
Option: fn.Some(record),
}
}
// ZeroRecordT returns a zero value of the RecordT type. // ZeroRecordT returns a zero value of the RecordT type.
func ZeroRecordT[T TlvType, V any]() RecordT[T, V] { func ZeroRecordT[T TlvType, V any]() RecordT[T, V] {
var v V var v V

View File

@ -63,6 +63,10 @@ type coolWireMsg struct {
CsvDelay RecordT[TlvType1, wireCsv] CsvDelay RecordT[TlvType1, wireCsv]
} }
type coolWireMsgDiffContext struct {
CsvDelay RecordT[TlvType3, wireCsv]
}
// TestRecordTFromRecord tests that we can create a RecordT type from an // TestRecordTFromRecord tests that we can create a RecordT type from an
// existing record type and encode/decode as normal. // existing record type and encode/decode as normal.
func TestRecordTFromRecord(t *testing.T) { func TestRecordTFromRecord(t *testing.T) {
@ -91,3 +95,24 @@ func TestRecordTFromRecord(t *testing.T) {
require.Equal(t, wireMsg, wireMsg2) require.Equal(t, wireMsg, wireMsg2)
} }
// TestRecordTFromRecordTypeOverride tests that we can create a RecordT type
// from an existing record type and encode/decode as normal. In this variant,
// we make sure that we can use the type system to override the type of an
// original record.
func TestRecordTFromRecordTypeOverride(t *testing.T) {
t.Parallel()
// First, we'll make a new wire message. Instead of using the TLV type
// of 1 (hard coded in the Record() method defined above), we'll
// instead use TLvType3, as we want to use the same encode/decode, but
// in a context with a different integer type.
val := wireCsv(5)
wireMsg := coolWireMsgDiffContext{
CsvDelay: NewRecordT[TlvType3](val),
}
// If we extract the record, we should see that the type is now 3.
tlvRecord := wireMsg.CsvDelay.Record()
require.Equal(t, tlvRecord.Type(), Type(3))
}

View File

@ -5,7 +5,13 @@ import "fmt"
// TlvType is an interface used to enable binding the integer type of a TLV // TlvType is an interface used to enable binding the integer type of a TLV
// record to the type at compile time. // record to the type at compile time.
type TlvType interface { type TlvType interface {
typeVal() Type // TypeVal returns the integer TLV type that this TlvType struct
// instance maps to.
TypeVal() Type
// tlv is an internal method to make this a "sealed" interface, meaning
// only this package can declare new instances.
tlv()
} }
//go:generate go run internal/gen/gen_tlv_types.go -o tlv_types_generated.go //go:generate go run internal/gen/gen_tlv_types.go -o tlv_types_generated.go

File diff suppressed because it is too large Load Diff