lnwire: add new TestMessage interface for property tests

In this commit, we add a new `TestMessage` interface for use in property
tests. With this, we'll be able to generate a random instance of a given
message, using the rapid byte stream. This can also eventually be useful
for fuzzing.
This commit is contained in:
Olaoluwa Osuntokun 2025-03-19 15:10:35 -05:00
parent 56a100123b
commit eb877db2ff
No known key found for this signature in database
GPG Key ID: 90525F7DEEE0AD86
25 changed files with 2090 additions and 32 deletions

2
.gitignore vendored
View File

@ -80,3 +80,5 @@ coverage.txt
# Release build directory (to avoid build.vcs.modified Golang build tag to be
# set to true by having untracked files in the working directory).
/lnd-*/
.aider*

View File

@ -128,8 +128,8 @@ type AcceptChannel struct {
// interface.
var _ Message = (*AcceptChannel)(nil)
// A compile time check to ensure AcceptChannel implements the lnwire.SizeableMessage
// interface.
// A compile time check to ensure AcceptChannel implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*AcceptChannel)(nil)
// Encode serializes the target AcceptChannel into the passed io.Writer

View File

@ -124,7 +124,8 @@ type ChannelUpdate1 struct {
// interface.
var _ Message = (*ChannelUpdate1)(nil)
// A compile time check to ensure ChannelUpdate1 implements the lnwire.SizeableMessage interface.
// A compile time check to ensure ChannelUpdate1 implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ChannelUpdate1)(nil)
// Decode deserializes a serialized ChannelUpdate stored in the passed

View File

@ -180,5 +180,6 @@ func (c *ClosingComplete) SerializedSize() (uint32, error) {
// interface.
var _ Message = (*ClosingComplete)(nil)
// A compile time check to ensure ClosingComplete implements the lnwire.SizeableMessage interface.
// A compile time check to ensure ClosingComplete implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ClosingComplete)(nil)

View File

@ -118,6 +118,6 @@ func (c *ClosingSig) SerializedSize() (uint32, error) {
// interface.
var _ Message = (*ClosingSig)(nil)
// A compile time check to ensure ClosingSig implements the lnwire.SizeableMessage
// interface.
// A compile time check to ensure ClosingSig implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ClosingSig)(nil)

View File

@ -59,7 +59,8 @@ func NewClosingSigned(cid ChannelID, fs btcutil.Amount,
// interface.
var _ Message = (*ClosingSigned)(nil)
// A compile time check to ensure ClosingSigned implements the lnwire.SizeableMessage interface.
// A compile time check to ensure ClosingSigned implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ClosingSigned)(nil)
// Decode deserializes a serialized ClosingSigned message stored in the passed

View File

@ -64,7 +64,8 @@ func NewCommitSig() *CommitSig {
// interface.
var _ Message = (*CommitSig)(nil)
// A compile time check to ensure CommitSig implements the lnwire.SizeableMessage interface.
// A compile time check to ensure CommitSig implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*CommitSig)(nil)
// Decode deserializes a serialized CommitSig message stored in the

View File

@ -69,7 +69,7 @@ type Custom struct {
Data []byte
}
// A compile time check to ensure FundingCreated implements the lnwire.Message
// A compile time check to ensure Custom implements the lnwire.Message
// interface.
var _ Message = (*Custom)(nil)

View File

@ -36,8 +36,8 @@ type FundingSigned struct {
// interface.
var _ Message = (*FundingSigned)(nil)
// A compile time check to ensure FundingSigned implements the lnwire.SizeableMessage
// interface.
// A compile time check to ensure FundingSigned implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*FundingSigned)(nil)
// Encode serializes the target FundingSigned into the passed io.Writer

View File

@ -58,7 +58,8 @@ func NewGossipTimestampRange() *GossipTimestampRange {
// lnwire.Message interface.
var _ Message = (*GossipTimestampRange)(nil)
// A compile time check to ensure GossipTimestampRange implements the lnwire.SizeableMessage interface.
// A compile time check to ensure GossipTimestampRange implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*GossipTimestampRange)(nil)
// Decode deserializes a serialized GossipTimestampRange message stored in the

View File

@ -27,6 +27,10 @@ type KickoffSig struct {
// interface.
var _ Message = (*KickoffSig)(nil)
// A compile time check to ensure KickoffSig implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*KickoffSig)(nil)
// Encode serializes the target KickoffSig into the passed bytes.Buffer
// observing the specified protocol version.
//
@ -54,3 +58,10 @@ func (ks *KickoffSig) Decode(r io.Reader, _ uint32) error {
//
// This is part of the lnwire.Message interface.
func (ks *KickoffSig) MsgType() MessageType { return MsgKickoffSig }
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (ks *KickoffSig) SerializedSize() (uint32, error) {
return MessageSerializedSize(ks)
}

View File

@ -65,6 +65,11 @@ const (
MsgChannelAnnouncement2 = 267
MsgChannelUpdate2 = 271
MsgKickoffSig = 777
// MsgEnd defines the end of the official message range of the protocol.
// If a new message is added beyond this message, then this should be
// modified.
MsgEnd = 778
)
// IsChannelUpdate is a filter function that discerns channel update messages

View File

@ -104,7 +104,8 @@ type NodeAnnouncement struct {
// lnwire.Message interface.
var _ Message = (*NodeAnnouncement)(nil)
// A compile time check to ensure NodeAnnouncement implements the lnwire.SizeableMessage interface.
// A compile time check to ensure NodeAnnouncement implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*NodeAnnouncement)(nil)
// Decode deserializes a serialized NodeAnnouncement stored in the passed

View File

@ -164,8 +164,8 @@ type OpenChannel struct {
// interface.
var _ Message = (*OpenChannel)(nil)
// A compile time check to ensure OpenChannel implements the lnwire.SizeableMessage
// interface.
// A compile time check to ensure OpenChannel implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*OpenChannel)(nil)
// Encode serializes the target OpenChannel into the passed io.Writer

View File

@ -49,8 +49,8 @@ func NewQueryChannelRange() *QueryChannelRange {
// lnwire.Message interface.
var _ Message = (*QueryChannelRange)(nil)
// A compile time check to ensure QueryChannelRange implements the lnwire.SizeableMessage
// interface.
// A compile time check to ensure QueryChannelRange implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*QueryChannelRange)(nil)
// Decode deserializes a serialized QueryChannelRange message stored in the

View File

@ -91,8 +91,8 @@ func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding,
// lnwire.Message interface.
var _ Message = (*QueryShortChanIDs)(nil)
// A compile time check to ensure QueryShortChanIDs implements the lnwire.SizeableMessage
// interface.
// A compile time check to ensure QueryShortChanIDs implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*QueryShortChanIDs)(nil)
// Decode deserializes a serialized QueryShortChanIDs message stored in the

View File

@ -70,7 +70,8 @@ func NewReplyChannelRange() *ReplyChannelRange {
// lnwire.Message interface.
var _ Message = (*ReplyChannelRange)(nil)
// A compile time check to ensure ReplyChannelRange implements the lnwire.SizeableMessage interface.
// A compile time check to ensure ReplyChannelRange implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ReplyChannelRange)(nil)
// Decode deserializes a serialized ReplyChannelRange message stored in the

View File

@ -39,7 +39,8 @@ func NewReplyShortChanIDsEnd() *ReplyShortChanIDsEnd {
// lnwire.Message interface.
var _ Message = (*ReplyShortChanIDsEnd)(nil)
// A compile time check to ensure ReplyShortChanIDsEnd implements the lnwire.SizeableMessage interface.
// A compile time check to ensure ReplyShortChanIDsEnd implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ReplyShortChanIDsEnd)(nil)
// Decode deserializes a serialized ReplyShortChanIDsEnd message stored in the

View File

@ -55,7 +55,8 @@ func NewRevokeAndAck() *RevokeAndAck {
// interface.
var _ Message = (*RevokeAndAck)(nil)
// A compile time check to ensure RevokeAndAck implements the lnwire.SizeableMessage interface.
// A compile time check to ensure RevokeAndAck implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*RevokeAndAck)(nil)
// Decode deserializes a serialized RevokeAndAck message stored in the

View File

@ -61,7 +61,8 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown {
// interface.
var _ Message = (*Shutdown)(nil)
// A compile-time check to ensure Shutdown implements the lnwire.SizeableMessage interface.
// A compile-time check to ensure Shutdown implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Shutdown)(nil)
// Decode deserializes a serialized Shutdown from the passed io.Reader,

View File

@ -24,7 +24,8 @@ type Stfu struct {
// A compile time check to ensure Stfu implements the lnwire.Message interface.
var _ Message = (*Stfu)(nil)
// A compile time check to ensure Stfu implements the lnwire.SizeableMessage interface.
// A compile time check to ensure Stfu implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Stfu)(nil)
// Encode serializes the target Stfu into the passed io.Writer.

1669
lnwire/test_message.go Normal file

File diff suppressed because it is too large Load Diff

360
lnwire/test_utils.go Normal file
View File

@ -0,0 +1,360 @@
package lnwire
import (
"crypto/sha256"
"fmt"
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/stretchr/testify/require"
"pgregory.net/rapid"
)
// RandChannelUpdate generates a random ChannelUpdate message using rapid's
// generators.
func RandPartialSig(t *rapid.T) *PartialSig {
// Generate random private key bytes
sigBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "privKeyBytes")
var s btcec.ModNScalar
s.SetByteSlice(sigBytes)
return &PartialSig{
Sig: s,
}
}
// RandPartialSigWithNonce generates a random PartialSigWithNonce using rapid
// generators.
func RandPartialSigWithNonce(t *rapid.T) *PartialSigWithNonce {
sigLen := rapid.IntRange(1, 65).Draw(t, "partialSigLen")
sigBytes := rapid.SliceOfN(
rapid.Byte(), sigLen, sigLen,
).Draw(t, "partialSig")
sigScalar := new(btcec.ModNScalar)
sigScalar.SetByteSlice(sigBytes)
return NewPartialSigWithNonce(
RandMusig2Nonce(t), *sigScalar,
)
}
// RandPubKey generates a random public key using rapid's generators.
func RandPubKey(t *rapid.T) *btcec.PublicKey {
privKeyBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(
t, "privKeyBytes",
)
_, pub := btcec.PrivKeyFromBytes(privKeyBytes)
return pub
}
// RandChannelID generates a random channel ID.
func RandChannelID(t *rapid.T) ChannelID {
var c ChannelID
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "channelID")
copy(c[:], bytes)
return c
}
// RandShortChannelID generates a random short channel ID.
func RandShortChannelID(t *rapid.T) ShortChannelID {
return NewShortChanIDFromInt(
uint64(rapid.IntRange(1, 100000).Draw(t, "shortChanID")),
)
}
// RandFeatureVector generates a random feature vector.
func RandFeatureVector(t *rapid.T) *RawFeatureVector {
featureVec := NewRawFeatureVector()
// Add a random number of random feature bits
numFeatures := rapid.IntRange(0, 20).Draw(t, "numFeatures")
for i := 0; i < numFeatures; i++ {
bit := FeatureBit(rapid.IntRange(0, 100).Draw(
t, fmt.Sprintf("featureBit-%d", i)),
)
featureVec.Set(bit)
}
return featureVec
}
// RandSignature generates a signature for testing.
func RandSignature(t *rapid.T) Sig {
testRScalar := new(btcec.ModNScalar)
testSScalar := new(btcec.ModNScalar)
// Generate random bytes for R and S
rBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "rBytes")
sBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "sBytes")
_ = testRScalar.SetByteSlice(rBytes)
_ = testSScalar.SetByteSlice(sBytes)
testSig := ecdsa.NewSignature(testRScalar, testSScalar)
sig, err := NewSigFromSignature(testSig)
if err != nil {
panic(fmt.Sprintf("unable to create signature: %v", err))
}
return sig
}
// RandPaymentHash generates a random payment hash.
func RandPaymentHash(t *rapid.T) [32]byte {
var hash [32]byte
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "paymentHash")
copy(hash[:], bytes)
return hash
}
// RandPaymentPreimage generates a random payment preimage.
func RandPaymentPreimage(t *rapid.T) [32]byte {
var preimage [32]byte
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "preimage")
copy(preimage[:], bytes)
return preimage
}
// RandChainHash generates a random chain hash.
func RandChainHash(t *rapid.T) chainhash.Hash {
var hash [32]byte
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash")
copy(hash[:], bytes)
return hash
}
// RandNodeAlias generates a random node alias.
func RandNodeAlias(t *rapid.T) NodeAlias {
var alias NodeAlias
aliasLength := rapid.IntRange(0, 32).Draw(t, "aliasLength")
aliasBytes := rapid.StringN(
0, aliasLength, aliasLength,
).Draw(t, "alias")
copy(alias[:], aliasBytes)
return alias
}
// RandNetAddrs generates random network addresses.
func RandNetAddrs(t *rapid.T) []net.Addr {
numAddresses := rapid.IntRange(0, 5).Draw(t, "numAddresses")
if numAddresses == 0 {
return nil
}
addresses := make([]net.Addr, numAddresses)
for i := 0; i < numAddresses; i++ {
addressType := rapid.IntRange(0, 1).Draw(
t, fmt.Sprintf("addressType-%d", i),
)
switch addressType {
// IPv4.
case 0:
ipBytes := rapid.SliceOfN(rapid.Byte(), 4, 4).Draw(
t, fmt.Sprintf("ipv4-%d", i),
)
port := rapid.IntRange(1, 65535).Draw(
t, fmt.Sprintf("port-%d", i),
)
addresses[i] = &net.TCPAddr{
IP: ipBytes,
Port: port,
}
// IPv6.
case 1:
ipBytes := rapid.SliceOfN(rapid.Byte(), 16, 16).Draw(
t, fmt.Sprintf("ipv6-%d", i),
)
port := rapid.IntRange(1, 65535).Draw(
t, fmt.Sprintf("port-%d", i),
)
addresses[i] = &net.TCPAddr{
IP: ipBytes,
Port: port,
}
}
}
return addresses
}
// RandCustomRecords generates random custom TLV records.
func RandCustomRecords(t *rapid.T,
ignoreRecords fn.Set[uint64],
custom bool) (CustomRecords, fn.Set[uint64]) {
numRecords := rapid.IntRange(0, 5).Draw(t, "numCustomRecords")
customRecords := make(CustomRecords)
if numRecords == 0 {
return nil, nil
}
rangeStart := 0
rangeStop := int(CustomTypeStart)
if custom {
rangeStart = 70_000
rangeStop = 100_000
}
ignoreSet := fn.NewSet[uint64]()
for i := 0; i < numRecords; i++ {
recordType := uint64(
rapid.IntRange(rangeStart, rangeStop).
Filter(func(i int) bool {
return !ignoreRecords.Contains(
uint64(i),
)
}).
Draw(
t, fmt.Sprintf("recordType-%d", i),
),
)
recordLen := rapid.IntRange(4, 64).Draw(
t, fmt.Sprintf("recordLen-%d", i),
)
record := rapid.SliceOfN(
rapid.Byte(), recordLen, recordLen,
).Draw(t, fmt.Sprintf("record-%d", i))
customRecords[recordType] = record
ignoreSet.Add(recordType)
}
return customRecords, ignoreSet
}
// RandMusig2Nonce generates a random musig2 nonce.
func RandMusig2Nonce(t *rapid.T) Musig2Nonce {
var nonce Musig2Nonce
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "nonce")
copy(nonce[:], bytes)
return nonce
}
// RandExtraOpaqueData generates random extra opaque data.
func RandExtraOpaqueData(t *rapid.T,
ignoreRecords fn.Set[uint64]) ExtraOpaqueData {
// Make some random records.
cRecords, _ := RandCustomRecords(t, ignoreRecords, false)
if cRecords == nil {
return ExtraOpaqueData{}
}
// Encode those records as opaque data.
recordBytes, err := cRecords.Serialize()
require.NoError(t, err)
return ExtraOpaqueData(recordBytes)
}
// RandOpaqueReason generates a random opaque reason for HTLC failures.
func RandOpaqueReason(t *rapid.T) OpaqueReason {
reasonLen := rapid.IntRange(32, 300).Draw(t, "reasonLen")
return rapid.SliceOfN(rapid.Byte(), reasonLen, reasonLen).Draw(
t, "opaqueReason",
)
}
// RandFailCode generates a random HTLC failure code.
func RandFailCode(t *rapid.T) FailCode {
// List of known failure codes to choose from Using only the documented
// codes.
validCodes := []FailCode{
CodeInvalidRealm,
CodeTemporaryNodeFailure,
CodePermanentNodeFailure,
CodeRequiredNodeFeatureMissing,
CodePermanentChannelFailure,
CodeRequiredChannelFeatureMissing,
CodeUnknownNextPeer,
CodeIncorrectOrUnknownPaymentDetails,
CodeIncorrectPaymentAmount,
CodeFinalExpiryTooSoon,
CodeInvalidOnionVersion,
CodeInvalidOnionHmac,
CodeInvalidOnionKey,
CodeTemporaryChannelFailure,
CodeChannelDisabled,
CodeExpiryTooSoon,
CodeMPPTimeout,
CodeInvalidOnionPayload,
CodeFeeInsufficient,
}
// Choose a random code from the list.
idx := rapid.IntRange(0, len(validCodes)-1).Draw(t, "failCodeIndex")
return validCodes[idx]
}
// RandSHA256Hash generates a random SHA256 hash.
func RandSHA256Hash(t *rapid.T) [sha256.Size]byte {
var hash [sha256.Size]byte
bytes := rapid.SliceOfN(rapid.Byte(), sha256.Size, sha256.Size).Draw(
t, "sha256Hash",
)
copy(hash[:], bytes)
return hash
}
// RandDeliveryAddress generates a random delivery address (script).
func RandDeliveryAddress(t *rapid.T) DeliveryAddress {
addrLen := rapid.IntRange(1, 34).Draw(t, "addrLen")
return rapid.SliceOfN(rapid.Byte(), addrLen, addrLen).Draw(
t, "deliveryAddress",
)
}
// RandChannelType generates a random channel type.
func RandChannelType(t *rapid.T) *ChannelType {
vec := RandFeatureVector(t)
chanType := ChannelType(*vec)
return &chanType
}
// RandLeaseExpiry generates a random lease expiry.
func RandLeaseExpiry(t *rapid.T) *LeaseExpiry {
exp := LeaseExpiry(
uint32(rapid.IntRange(1000, 1000000).Draw(t, "leaseExpiry")),
)
return &exp
}
// RandOutPoint generates a random transaction outpoint.
func RandOutPoint(t *rapid.T) wire.OutPoint {
// Generate a random transaction ID
var txid chainhash.Hash
txidBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "txid")
copy(txid[:], txidBytes)
// Generate a random output index
vout := uint32(rapid.IntRange(0, 10).Draw(t, "vout"))
return wire.OutPoint{
Hash: txid,
Index: vout,
}
}

View File

@ -110,10 +110,6 @@ func NewUpdateAddHTLC() *UpdateAddHTLC {
// interface.
var _ Message = (*UpdateAddHTLC)(nil)
// A compile time check to ensure UpdateAddHTLC implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*UpdateAddHTLC)(nil)
// Decode deserializes a serialized UpdateAddHTLC message stored in the passed
// io.Reader observing the specified protocol version.
//
@ -223,3 +219,7 @@ func (c *UpdateAddHTLC) TargetChanID() ChannelID {
func (c *UpdateAddHTLC) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// A compile time check to ensure UpdateAddHTLC implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*UpdateAddHTLC)(nil)

View File

@ -38,8 +38,8 @@ type UpdateFailHTLC struct {
// interface.
var _ Message = (*UpdateFailHTLC)(nil)
// A compile time check to ensure UpdateFailHTLC implements the lnwire.SizeableMessage
// interface.
// A compile time check to ensure UpdateFailHTLC implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*UpdateFailHTLC)(nil)
// Decode deserializes a serialized UpdateFailHTLC message stored in the passed
@ -55,8 +55,8 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error {
)
}
// Encode serializes the target UpdateFailHTLC into the passed io.Writer observing
// the protocol version specified.
// Encode serializes the target UpdateFailHTLC into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error {