tlv: Added bool to primitive

Signed-off-by: Ononiwu Maureen <amaka013@gmail.com>
This commit is contained in:
Ononiwu Maureen
2023-10-05 06:12:54 +01:00
parent ad5cd9c8bb
commit 206f773a9b
4 changed files with 102 additions and 32 deletions

View File

@@ -97,6 +97,13 @@ func FuzzVarBytes(f *testing.F) {
})
}
func FuzzBool(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val bool
harness(t, data, EBool, DBool, &val, 1)
})
}
// bigSizeHarness works the same as harness, except that it compares decoded
// values instead of encoded values. We do this because DBigSize may leave some
// bytes unparsed from data, causing the encoded data to be shorter than the
@@ -224,6 +231,7 @@ func FuzzStream(f *testing.F) {
tu16 uint16
tu32 uint32
tu64 uint64
boolean bool
)
sizeTU16 := func() uint64 {
@@ -260,6 +268,7 @@ func FuzzStream(f *testing.F) {
MakeDynamicRecord(
13, &tu64, sizeTU64, ETUint64, DTUint64,
),
MakePrimitiveRecord(14, &boolean),
}
decodeStream := MustNewStream(decodeRecords...)

View File

@@ -2,6 +2,7 @@ package tlv
import (
"encoding/binary"
"errors"
"fmt"
"io"
@@ -143,6 +144,33 @@ func EUint64T(w io.Writer, val uint64, buf *[8]byte) error {
return err
}
// EBool encodes a boolean. An error is returned if val is not a boolean.
func EBool(w io.Writer, val interface{}, buf *[8]byte) error {
if i, ok := val.(*bool); ok {
if *i {
buf[0] = 1
} else {
buf[0] = 0
}
_, err := w.Write(buf[:1])
return err
}
return NewTypeForEncodingErr(val, "bool")
}
// EBoolT encodes a bool val to the provided io.Writer. This method is exposed
// so that encodings for custom bool-like types can be created without
// incurring an extra heap allocation.
func EBoolT(w io.Writer, val bool, buf *[8]byte) error {
if val {
buf[0] = 1
} else {
buf[0] = 0
}
_, err := w.Write(buf[:1])
return err
}
// DUint8 is a Decoder for uint8 values. An error is returned if val is not a
// *uint8.
func DUint8(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
@@ -195,6 +223,21 @@ func DUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
return NewTypeForDecodingErr(val, "uint64", l, 8)
}
// DBool decodes a boolean. An error is returned if val is not a boolean.
func DBool(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if i, ok := val.(*bool); ok && l == 1 {
if _, err := io.ReadFull(r, buf[:1]); err != nil {
return err
}
if buf[0] != 0 && buf[0] != 1 {
return errors.New("corrupted data")
}
*i = buf[0] != 0
return nil
}
return NewTypeForDecodingErr(val, "bool", l, 1)
}
// EBytes32 is an Encoder for 32-byte arrays. An error is returned if val is not
// a *[32]byte.
func EBytes32(w io.Writer, val interface{}, _ *[8]byte) error {

View File

@@ -26,6 +26,7 @@ type primitive struct {
b64 [64]byte
pk *btcec.PublicKey
bytes []byte
boolean bool
}
// TestWrongEncodingType asserts that all primitives encoders will fail with a
@@ -41,6 +42,7 @@ func TestWrongEncodingType(t *testing.T) {
tlv.EBytes64,
tlv.EPubKey,
tlv.EVarBytes,
tlv.EBool,
}
// We'll use an int32 since it is not a primitive type, which should
@@ -73,6 +75,7 @@ func TestWrongDecodingType(t *testing.T) {
tlv.DBytes64,
tlv.DPubKey,
tlv.DVarBytes,
tlv.DBool,
}
// We'll use an int32 since it is not a primitive type, which should
@@ -117,6 +120,7 @@ func TestPrimitiveEncodings(t *testing.T) {
b64: [64]byte{0x02, 0x01},
pk: testPK,
bytes: []byte{0xaa, 0xbb},
boolean: true,
}
encoders := []fieldEncoder{
@@ -156,6 +160,10 @@ func TestPrimitiveEncodings(t *testing.T) {
val: &prim.bytes,
encoder: tlv.EVarBytes,
},
{
val: &prim.boolean,
encoder: tlv.EBool,
},
}
// First we'll encode the primitive fields into a buffer.
@@ -222,6 +230,11 @@ func TestPrimitiveEncodings(t *testing.T) {
decoder: tlv.DVarBytes,
size: 2,
},
{
val: &prim2.boolean,
decoder: tlv.DBool,
size: 1,
},
}
for _, field := range decoders {

View File

@@ -104,6 +104,11 @@ func MakePrimitiveRecord(typ Type, val interface{}) Record {
decoder Decoder
)
switch e := val.(type) {
case *bool:
staticSize = 1
encoder = EBool
decoder = DBool
case *uint8:
staticSize = 1
encoder = EUint8