lnwire: use require package for fuzz tests

Simplify code by using the require package instead of t.Fatal().
This commit is contained in:
Matt Morehouse 2023-05-19 11:59:45 -05:00
parent b95faaba45
commit 460ba4ad82
No known key found for this signature in database
GPG Key ID: CC8ECA224831C982

View File

@ -4,8 +4,9 @@ import (
"bytes" "bytes"
"compress/zlib" "compress/zlib"
"encoding/binary" "encoding/binary"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
// prefixWithMsgType takes []byte and adds a wire protocol prefix // prefixWithMsgType takes []byte and adds a wire protocol prefix
@ -41,26 +42,15 @@ func harness(t *testing.T, data []byte) {
// We will serialize the message into a new bytes buffer. // We will serialize the message into a new bytes buffer.
var b bytes.Buffer var b bytes.Buffer
if _, err := WriteMessage(&b, msg, 0); err != nil { _, err = WriteMessage(&b, msg, 0)
// Could not serialize message into bytes buffer, panic require.NoError(t, err)
t.Fatal(err)
}
// Deserialize the message from the serialized bytes buffer, and then // Deserialize the message from the serialized bytes buffer, and then
// assert that the original message is equal to the newly deserialized // assert that the original message is equal to the newly deserialized
// message. // message.
newMsg, err := ReadMessage(&b, 0) newMsg, err := ReadMessage(&b, 0)
if err != nil { require.NoError(t, err)
// Could not deserialize message from bytes buffer, panic require.Equal(t, msg, newMsg)
t.Fatal(err)
}
if !reflect.DeepEqual(msg, newMsg) {
// Deserialized message and original message are not deeply
// equal.
t.Fatal("original message and deserialized message are not " +
"deeply equal")
}
} }
func FuzzAcceptChannel(f *testing.F) { func FuzzAcceptChannel(f *testing.F) {
@ -83,107 +73,32 @@ func FuzzAcceptChannel(f *testing.F) {
// We will serialize the message into a new bytes buffer. // We will serialize the message into a new bytes buffer.
var b bytes.Buffer var b bytes.Buffer
if _, err := WriteMessage(&b, msg, 0); err != nil { _, err = WriteMessage(&b, msg, 0)
// Could not serialize message into bytes buffer, panic require.NoError(t, err)
t.Fatal(err)
}
// Deserialize the message from the serialized bytes buffer, and // Deserialize the message from the serialized bytes buffer, and
// then assert that the original message is equal to the newly // then assert that the original message is equal to the newly
// deserialized message. // deserialized message.
newMsg, err := ReadMessage(&b, 0) newMsg, err := ReadMessage(&b, 0)
if err != nil { require.NoError(t, err)
// Could not deserialize message from bytes buffer,
// panic
t.Fatal(err)
}
// Now compare every field instead of using reflect.DeepEqual. require.IsType(t, &AcceptChannel{}, msg)
// For UpfrontShutdownScript, we only compare bytes. This first, _ := msg.(*AcceptChannel)
// probably takes up more branches than necessary, but that's require.IsType(t, &AcceptChannel{}, newMsg)
// fine for now. second, _ := newMsg.(*AcceptChannel)
var shouldPanic bool
first, ok := msg.(*AcceptChannel)
if !ok {
t.Fatal("message was not AcceptChannel")
}
second, ok := newMsg.(*AcceptChannel)
if !ok {
t.Fatal("new message was not AcceptChannel")
}
if !bytes.Equal(first.PendingChannelID[:], // We can't use require.Equal for UpfrontShutdownScript, since
second.PendingChannelID[:]) { // we consider the empty slice and nil to be equivalent.
require.True(
t, bytes.Equal(
first.UpfrontShutdownScript,
second.UpfrontShutdownScript,
),
)
first.UpfrontShutdownScript = nil
second.UpfrontShutdownScript = nil
shouldPanic = true require.Equal(t, first, second)
}
if first.DustLimit != second.DustLimit {
shouldPanic = true
}
if first.MaxValueInFlight != second.MaxValueInFlight {
shouldPanic = true
}
if first.ChannelReserve != second.ChannelReserve {
shouldPanic = true
}
if first.HtlcMinimum != second.HtlcMinimum {
shouldPanic = true
}
if first.MinAcceptDepth != second.MinAcceptDepth {
shouldPanic = true
}
if first.CsvDelay != second.CsvDelay {
shouldPanic = true
}
if first.MaxAcceptedHTLCs != second.MaxAcceptedHTLCs {
shouldPanic = true
}
if !first.FundingKey.IsEqual(second.FundingKey) {
shouldPanic = true
}
if !first.RevocationPoint.IsEqual(second.RevocationPoint) {
shouldPanic = true
}
if !first.PaymentPoint.IsEqual(second.PaymentPoint) {
shouldPanic = true
}
if !first.DelayedPaymentPoint.IsEqual(
second.DelayedPaymentPoint) {
shouldPanic = true
}
if !first.HtlcPoint.IsEqual(second.HtlcPoint) {
shouldPanic = true
}
if !first.FirstCommitmentPoint.IsEqual(
second.FirstCommitmentPoint) {
shouldPanic = true
}
if !bytes.Equal(first.UpfrontShutdownScript,
second.UpfrontShutdownScript) {
shouldPanic = true
}
if shouldPanic {
t.Fatal("original message and deseralized message " +
"are not equal")
}
}) })
} }
@ -356,80 +271,34 @@ func FuzzNodeAnnouncement(f *testing.F) {
// We will serialize the message into a new bytes buffer. // We will serialize the message into a new bytes buffer.
var b bytes.Buffer var b bytes.Buffer
if _, err := WriteMessage(&b, msg, 0); err != nil { _, err = WriteMessage(&b, msg, 0)
// Could not serialize message into bytes buffer, panic require.NoError(t, err)
t.Fatal(err)
}
// Deserialize the message from the serialized bytes buffer, and // Deserialize the message from the serialized bytes buffer, and
// then assert that the original message is equal to the newly // then assert that the original message is equal to the newly
// deserialized message. // deserialized message.
newMsg, err := ReadMessage(&b, 0) newMsg, err := ReadMessage(&b, 0)
if err != nil { require.NoError(t, err)
// Could not deserialize message from bytes buffer,
// panic
t.Fatal(err)
}
// Now compare every field instead of using reflect.DeepEqual require.IsType(t, &NodeAnnouncement{}, msg)
// for the Addresses field. first, _ := msg.(*NodeAnnouncement)
var shouldPanic bool require.IsType(t, &NodeAnnouncement{}, newMsg)
first, ok := msg.(*NodeAnnouncement) second, _ := newMsg.(*NodeAnnouncement)
if !ok {
t.Fatal("message was not NodeAnnouncement")
}
second, ok := newMsg.(*NodeAnnouncement)
if !ok {
t.Fatal("new message was not NodeAnnouncement")
}
if !bytes.Equal(first.Signature[:], second.Signature[:]) {
shouldPanic = true
}
if !reflect.DeepEqual(first.Features, second.Features) {
shouldPanic = true
}
if first.Timestamp != second.Timestamp {
shouldPanic = true
}
if !bytes.Equal(first.NodeID[:], second.NodeID[:]) {
shouldPanic = true
}
if !reflect.DeepEqual(first.RGBColor, second.RGBColor) {
shouldPanic = true
}
if !bytes.Equal(first.Alias[:], second.Alias[:]) {
shouldPanic = true
}
if len(first.Addresses) != len(second.Addresses) {
shouldPanic = true
}
// We can't use require.Equal for Addresses, since the same IP
// can be represented by different underlying bytes. Instead, we
// compare the normalized string representation of each address.
require.Equal(t, len(first.Addresses), len(second.Addresses))
for i := range first.Addresses { for i := range first.Addresses {
if first.Addresses[i].String() != require.Equal(
second.Addresses[i].String() { t, first.Addresses[i].String(),
second.Addresses[i].String(),
shouldPanic = true )
break
}
} }
first.Addresses = nil
second.Addresses = nil
if !reflect.DeepEqual(first.ExtraOpaqueData, require.Equal(t, first, second)
second.ExtraOpaqueData) {
shouldPanic = true
}
if shouldPanic {
t.Fatal("original message and deserialized message " +
"are not equal")
}
}) })
} }
@ -461,123 +330,32 @@ func FuzzOpenChannel(f *testing.F) {
// We will serialize the message into a new bytes buffer. // We will serialize the message into a new bytes buffer.
var b bytes.Buffer var b bytes.Buffer
if _, err := WriteMessage(&b, msg, 0); err != nil { _, err = WriteMessage(&b, msg, 0)
// Could not serialize message into bytes buffer, panic require.NoError(t, err)
t.Fatal(err)
}
// Deserialize the message from the serialized bytes buffer, and // Deserialize the message from the serialized bytes buffer, and
// then assert that the original message is equal to the newly // then assert that the original message is equal to the newly
// deserialized message. // deserialized message.
newMsg, err := ReadMessage(&b, 0) newMsg, err := ReadMessage(&b, 0)
if err != nil { require.NoError(t, err)
// Could not deserialize message from bytes buffer,
// panic
t.Fatal(err)
}
// Now compare every field instead of using reflect.DeepEqual. require.IsType(t, &OpenChannel{}, msg)
// For UpfrontShutdownScript, we only compare bytes. This first, _ := msg.(*OpenChannel)
// probably takes up more branches than necessary, but that's require.IsType(t, &OpenChannel{}, newMsg)
// fine for now. second, _ := newMsg.(*OpenChannel)
var shouldPanic bool
first, ok := msg.(*OpenChannel)
if !ok {
t.Fatal("message was not OpenChannel")
}
second, ok := newMsg.(*OpenChannel)
if !ok {
t.Fatal("new message was not OpenChannel")
}
if !first.ChainHash.IsEqual(&second.ChainHash) { // We can't use require.Equal for UpfrontShutdownScript, since
shouldPanic = true // we consider the empty slice and nil to be equivalent.
} require.True(
t, bytes.Equal(
first.UpfrontShutdownScript,
second.UpfrontShutdownScript,
),
)
first.UpfrontShutdownScript = nil
second.UpfrontShutdownScript = nil
if !bytes.Equal(first.PendingChannelID[:], require.Equal(t, first, second)
second.PendingChannelID[:]) {
shouldPanic = true
}
if first.FundingAmount != second.FundingAmount {
shouldPanic = true
}
if first.PushAmount != second.PushAmount {
shouldPanic = true
}
if first.DustLimit != second.DustLimit {
shouldPanic = true
}
if first.MaxValueInFlight != second.MaxValueInFlight {
shouldPanic = true
}
if first.ChannelReserve != second.ChannelReserve {
shouldPanic = true
}
if first.HtlcMinimum != second.HtlcMinimum {
shouldPanic = true
}
if first.FeePerKiloWeight != second.FeePerKiloWeight {
shouldPanic = true
}
if first.CsvDelay != second.CsvDelay {
shouldPanic = true
}
if first.MaxAcceptedHTLCs != second.MaxAcceptedHTLCs {
shouldPanic = true
}
if !first.FundingKey.IsEqual(second.FundingKey) {
shouldPanic = true
}
if !first.RevocationPoint.IsEqual(second.RevocationPoint) {
shouldPanic = true
}
if !first.PaymentPoint.IsEqual(second.PaymentPoint) {
shouldPanic = true
}
if !first.DelayedPaymentPoint.IsEqual(
second.DelayedPaymentPoint) {
shouldPanic = true
}
if !first.HtlcPoint.IsEqual(second.HtlcPoint) {
shouldPanic = true
}
if !first.FirstCommitmentPoint.IsEqual(
second.FirstCommitmentPoint) {
shouldPanic = true
}
if first.ChannelFlags != second.ChannelFlags {
shouldPanic = true
}
if !bytes.Equal(first.UpfrontShutdownScript,
second.UpfrontShutdownScript) {
shouldPanic = true
}
if shouldPanic {
t.Fatal("original message and deserialized message " +
"are not equal")
}
}) })
} }
@ -619,15 +397,10 @@ func FuzzZlibQueryShortChanIDs(f *testing.F) {
var buf bytes.Buffer var buf bytes.Buffer
zlibWriter := zlib.NewWriter(&buf) zlibWriter := zlib.NewWriter(&buf)
_, err := zlibWriter.Write(data) _, err := zlibWriter.Write(data)
if err != nil { require.NoError(t, err) // Zlib bug?
// Zlib bug?
t.Fatal(err)
}
if err := zlibWriter.Close(); err != nil { err = zlibWriter.Close()
// Zlib bug? require.NoError(t, err) // Zlib bug?
t.Fatal(err)
}
compressedPayload := buf.Bytes() compressedPayload := buf.Bytes()
@ -668,15 +441,10 @@ func FuzzZlibReplyChannelRange(f *testing.F) {
var buf bytes.Buffer var buf bytes.Buffer
zlibWriter := zlib.NewWriter(&buf) zlibWriter := zlib.NewWriter(&buf)
_, err := zlibWriter.Write(data) _, err := zlibWriter.Write(data)
if err != nil { require.NoError(t, err) // Zlib bug?
// Zlib bug?
t.Fatal(err)
}
if err := zlibWriter.Close(); err != nil { err = zlibWriter.Close()
// Zlib bug? require.NoError(t, err) // Zlib bug?
t.Fatal(err)
}
compressedPayload := buf.Bytes() compressedPayload := buf.Bytes()
@ -834,13 +602,9 @@ func FuzzParseRawSignature(f *testing.F) {
} }
sig2, err := NewSigFromRawSignature(sig.ToSignatureBytes()) sig2, err := NewSigFromRawSignature(sig.ToSignatureBytes())
if err != nil { require.NoError(t, err, "failed to reparse signature")
t.Fatalf("failed to reparse signature: %v", err)
}
if !reflect.DeepEqual(sig, sig2) { require.Equal(t, sig, sig2, "signature mismatch")
t.Fatalf("signature mismatch: %v != %v", sig, sig2)
}
}) })
} }
@ -861,21 +625,13 @@ func FuzzConvertFixedSignature(f *testing.F) {
} }
sig2, err := NewSigFromSignature(derSig) sig2, err := NewSigFromSignature(derSig)
if err != nil { require.NoError(t, err, "failed to parse signature")
t.Fatalf("failed to parse signature: %v", err)
}
derSig2, err := sig2.ToSignature() derSig2, err := sig2.ToSignature()
if err != nil { require.NoError(t, err, "failed to reconvert signature to DER")
t.Fatalf("failed to reconvert signature to DER: %v",
err)
}
derBytes := derSig.Serialize() derBytes := derSig.Serialize()
derBytes2 := derSig2.Serialize() derBytes2 := derSig2.Serialize()
if !bytes.Equal(derBytes, derBytes2) { require.Equal(t, derBytes, derBytes2, "signature mismatch")
t.Fatalf("signature mismatch: %v != %v", derBytes,
derBytes2)
}
}) })
} }