diff --git a/.gitignore b/.gitignore index 5ddd602d3..1e498b0cf 100644 --- a/.gitignore +++ b/.gitignore @@ -80,3 +80,5 @@ coverage.txt # Release build directory (to avoid build.vcs.modified Golang build tag to be # set to true by having untracked files in the working directory). /lnd-*/ + +.aider* diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index e45bcf9ab..afb2f1412 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -128,6 +128,10 @@ type AcceptChannel struct { // interface. var _ Message = (*AcceptChannel)(nil) +// A compile time check to ensure AcceptChannel implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*AcceptChannel)(nil) + // Encode serializes the target AcceptChannel into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -281,3 +285,10 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { func (a *AcceptChannel) MsgType() MessageType { return MsgAcceptChannel } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *AcceptChannel) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/announcement_signatures.go b/lnwire/announcement_signatures.go index a2bc21f39..cf8f68be5 100644 --- a/lnwire/announcement_signatures.go +++ b/lnwire/announcement_signatures.go @@ -47,6 +47,10 @@ type AnnounceSignatures1 struct { // lnwire.Message interface. var _ Message = (*AnnounceSignatures1)(nil) +// A compile time check to ensure AnnounceSignatures1 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*AnnounceSignatures1)(nil) + // A compile time check to ensure AnnounceSignatures1 implements the // lnwire.AnnounceSignatures interface. var _ AnnounceSignatures = (*AnnounceSignatures1)(nil) @@ -97,6 +101,13 @@ func (a *AnnounceSignatures1) MsgType() MessageType { return MsgAnnounceSignatures } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *AnnounceSignatures1) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} + // SCID returns the ShortChannelID of the channel. // // This is part of the lnwire.AnnounceSignatures interface. diff --git a/lnwire/announcement_signatures_2.go b/lnwire/announcement_signatures_2.go index a10447032..6e893dafd 100644 --- a/lnwire/announcement_signatures_2.go +++ b/lnwire/announcement_signatures_2.go @@ -41,6 +41,10 @@ type AnnounceSignatures2 struct { // lnwire.Message interface. var _ Message = (*AnnounceSignatures2)(nil) +// A compile time check to ensure AnnounceSignatures2 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*AnnounceSignatures2)(nil) + // Decode deserializes a serialized AnnounceSignatures2 stored in the passed // io.Reader observing the specified protocol version. // @@ -82,6 +86,13 @@ func (a *AnnounceSignatures2) MsgType() MessageType { return MsgAnnounceSignatures2 } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *AnnounceSignatures2) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} + // SCID returns the ShortChannelID of the channel. // // NOTE: this is part of the AnnounceSignatures interface. diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index ed4c5b97e..0a3989abb 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -62,6 +62,10 @@ type ChannelAnnouncement1 struct { // lnwire.Message interface. var _ Message = (*ChannelAnnouncement1)(nil) +// A compile time check to ensure ChannelAnnouncement1 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelAnnouncement1)(nil) + // Decode deserializes a serialized ChannelAnnouncement stored in the passed // io.Reader observing the specified protocol version. // @@ -143,6 +147,13 @@ func (a *ChannelAnnouncement1) MsgType() MessageType { return MsgChannelAnnouncement } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *ChannelAnnouncement1) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} + // DataToSign is used to retrieve part of the announcement message which should // be signed. func (a *ChannelAnnouncement1) DataToSign() ([]byte, error) { diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 074e7d084..57b3a24b8 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -194,10 +194,21 @@ func (c *ChannelAnnouncement2) MsgType() MessageType { return MsgChannelAnnouncement2 } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ChannelAnnouncement2) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // A compile time check to ensure ChannelAnnouncement2 implements the // lnwire.Message interface. var _ Message = (*ChannelAnnouncement2)(nil) +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelAnnouncement2)(nil) + // Node1KeyBytes returns the bytes representing the public key of node 1 in the // channel. // diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index 912a068bd..f388db1d1 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -63,6 +63,10 @@ func NewChannelReady(cid ChannelID, npcp *btcec.PublicKey) *ChannelReady { // interface. var _ Message = (*ChannelReady)(nil) +// A compile time check to ensure ChannelReady implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelReady)(nil) + // Decode deserializes the serialized ChannelReady message stored in the // passed io.Reader into the target ChannelReady using the deserialization // rules defined by the passed protocol version. @@ -170,3 +174,10 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { func (c *ChannelReady) MsgType() MessageType { return MsgChannelReady } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ChannelReady) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index 577379623..f26a2fc5d 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -99,6 +99,10 @@ type ChannelReestablish struct { // lnwire.Message interface. var _ Message = (*ChannelReestablish)(nil) +// A compile time check to ensure ChannelReestablish implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelReestablish)(nil) + // Encode serializes the target ChannelReestablish into the passed io.Writer // observing the protocol version specified. // @@ -234,3 +238,10 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { func (a *ChannelReestablish) MsgType() MessageType { return MsgChannelReestablish } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *ChannelReestablish) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 33cc3bff0..88f981671 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -124,6 +124,10 @@ type ChannelUpdate1 struct { // interface. var _ Message = (*ChannelUpdate1)(nil) +// A compile time check to ensure ChannelUpdate1 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelUpdate1)(nil) + // Decode deserializes a serialized ChannelUpdate stored in the passed // io.Reader observing the specified protocol version. // @@ -367,3 +371,10 @@ func (a *ChannelUpdate1) SetSCID(scid ShortChannelID) { // A compile time assertion to ensure ChannelUpdate1 implements the // ChannelUpdate interface. var _ ChannelUpdate = (*ChannelUpdate1)(nil) + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *ChannelUpdate1) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index b41f2f29f..56f7edf6b 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -241,6 +241,13 @@ func (c *ChannelUpdate2) MsgType() MessageType { return MsgChannelUpdate2 } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ChannelUpdate2) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + func (c *ChannelUpdate2) ExtraData() ExtraOpaqueData { return c.ExtraOpaqueData } diff --git a/lnwire/closing_complete.go b/lnwire/closing_complete.go index 4d390fd6f..7980ef1ee 100644 --- a/lnwire/closing_complete.go +++ b/lnwire/closing_complete.go @@ -169,6 +169,17 @@ func (c *ClosingComplete) MsgType() MessageType { return MsgClosingComplete } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ClosingComplete) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // A compile time check to ensure ClosingComplete implements the lnwire.Message // interface. var _ Message = (*ClosingComplete)(nil) + +// A compile time check to ensure ClosingComplete implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ClosingComplete)(nil) diff --git a/lnwire/closing_sig.go b/lnwire/closing_sig.go index 2c73fa720..94a356066 100644 --- a/lnwire/closing_sig.go +++ b/lnwire/closing_sig.go @@ -107,6 +107,17 @@ func (c *ClosingSig) MsgType() MessageType { return MsgClosingSig } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ClosingSig) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // A compile time check to ensure ClosingSig implements the lnwire.Message // interface. var _ Message = (*ClosingSig)(nil) + +// A compile time check to ensure ClosingSig implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ClosingSig)(nil) diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 08b5bb6a7..c247cfe0a 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -59,6 +59,10 @@ func NewClosingSigned(cid ChannelID, fs btcutil.Amount, // interface. var _ Message = (*ClosingSigned)(nil) +// A compile time check to ensure ClosingSigned implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ClosingSigned)(nil) + // Decode deserializes a serialized ClosingSigned message stored in the passed // io.Reader observing the specified protocol version. // @@ -130,3 +134,10 @@ func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error { func (c *ClosingSigned) MsgType() MessageType { return MsgClosingSigned } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ClosingSigned) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 3a475e71f..600ff81e7 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -64,6 +64,10 @@ func NewCommitSig() *CommitSig { // interface. var _ Message = (*CommitSig)(nil) +// A compile time check to ensure CommitSig implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*CommitSig)(nil) + // Decode deserializes a serialized CommitSig message stored in the // passed io.Reader observing the specified protocol version. // @@ -151,3 +155,10 @@ func (c *CommitSig) MsgType() MessageType { func (c *CommitSig) TargetChanID() ChannelID { return c.ChanID } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *CommitSig) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/custom.go b/lnwire/custom.go index 232a8be52..740c23a14 100644 --- a/lnwire/custom.go +++ b/lnwire/custom.go @@ -69,10 +69,14 @@ type Custom struct { Data []byte } -// A compile time check to ensure FundingCreated implements the lnwire.Message +// A compile time check to ensure Custom implements the lnwire.Message // interface. var _ Message = (*Custom)(nil) +// A compile time check to ensure Custom implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Custom)(nil) + // NewCustom instantiates a new custom message. func NewCustom(msgType MessageType, data []byte) (*Custom, error) { if msgType < CustomTypeStart && !IsCustomOverride(msgType) { @@ -117,3 +121,10 @@ func (c *Custom) Decode(r io.Reader, pver uint32) error { func (c *Custom) MsgType() MessageType { return c.Type } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *Custom) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/dyn_ack.go b/lnwire/dyn_ack.go index d477461e7..1cc57e955 100644 --- a/lnwire/dyn_ack.go +++ b/lnwire/dyn_ack.go @@ -41,6 +41,10 @@ type DynAck struct { // interface. var _ Message = (*DynAck)(nil) +// A compile time check to ensure DynAck implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*DynAck)(nil) + // Encode serializes the target DynAck into the passed io.Writer. Serialization // will observe the rules defined by the passed protocol version. // @@ -136,3 +140,10 @@ func (da *DynAck) Decode(r io.Reader, _ uint32) error { func (da *DynAck) MsgType() MessageType { return MsgDynAck } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (da *DynAck) SerializedSize() (uint32, error) { + return MessageSerializedSize(da) +} diff --git a/lnwire/dyn_propose.go b/lnwire/dyn_propose.go index 394fff6f3..cc19ec394 100644 --- a/lnwire/dyn_propose.go +++ b/lnwire/dyn_propose.go @@ -105,6 +105,10 @@ type DynPropose struct { // interface. var _ Message = (*DynPropose)(nil) +// A compile time check to ensure DynPropose implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*DynPropose)(nil) + // Encode serializes the target DynPropose into the passed io.Writer. // Serialization will observe the rules defined by the passed protocol version. // @@ -317,3 +321,10 @@ func (dp *DynPropose) Decode(r io.Reader, _ uint32) error { func (dp *DynPropose) MsgType() MessageType { return MsgDynPropose } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (dp *DynPropose) SerializedSize() (uint32, error) { + return MessageSerializedSize(dp) +} diff --git a/lnwire/dyn_reject.go b/lnwire/dyn_reject.go index 2c6484424..51cb8cee2 100644 --- a/lnwire/dyn_reject.go +++ b/lnwire/dyn_reject.go @@ -30,6 +30,10 @@ type DynReject struct { // interface. var _ Message = (*DynReject)(nil) +// A compile time check to ensure DynReject implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*DynReject)(nil) + // Encode serializes the target DynReject into the passed io.Writer. // Serialization will observe the rules defined by the passed protocol version. // @@ -74,3 +78,10 @@ func (dr *DynReject) Decode(r io.Reader, _ uint32) error { func (dr *DynReject) MsgType() MessageType { return MsgDynReject } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (dr *DynReject) SerializedSize() (uint32, error) { + return MessageSerializedSize(dr) +} diff --git a/lnwire/error.go b/lnwire/error.go index 120dfc950..fcc22ef24 100644 --- a/lnwire/error.go +++ b/lnwire/error.go @@ -70,6 +70,10 @@ func NewError() *Error { // interface. var _ Message = (*Error)(nil) +// A compile time check to ensure Error implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Error)(nil) + // Error returns the string representation to Error. // // NOTE: Satisfies the error interface. @@ -113,6 +117,13 @@ func (c *Error) MsgType() MessageType { return MsgError } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *Error) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // isASCII is a helper method that checks whether all bytes in `data` would be // printable ASCII characters if interpreted as a string. func isASCII(data []byte) bool { diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index 86aa0bb40..82d0ff87c 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -44,6 +44,10 @@ type FundingCreated struct { // interface. var _ Message = (*FundingCreated)(nil) +// A compile time check to ensure FundingCreated implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*FundingCreated)(nil) + // Encode serializes the target FundingCreated into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -117,3 +121,10 @@ func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { func (f *FundingCreated) MsgType() MessageType { return MsgFundingCreated } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (f *FundingCreated) SerializedSize() (uint32, error) { + return MessageSerializedSize(f) +} diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index 2dd62e177..182e4bbde 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -36,6 +36,10 @@ type FundingSigned struct { // interface. var _ Message = (*FundingSigned)(nil) +// A compile time check to ensure FundingSigned implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*FundingSigned)(nil) + // Encode serializes the target FundingSigned into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -103,3 +107,10 @@ func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { func (f *FundingSigned) MsgType() MessageType { return MsgFundingSigned } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (f *FundingSigned) SerializedSize() (uint32, error) { + return MessageSerializedSize(f) +} diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index 7b628752a..45ff1f939 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -58,6 +58,10 @@ func NewGossipTimestampRange() *GossipTimestampRange { // lnwire.Message interface. var _ Message = (*GossipTimestampRange)(nil) +// A compile time check to ensure GossipTimestampRange implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*GossipTimestampRange)(nil) + // Decode deserializes a serialized GossipTimestampRange message stored in the // passed io.Reader observing the specified protocol version. // @@ -143,3 +147,10 @@ func (g *GossipTimestampRange) Encode(w *bytes.Buffer, pver uint32) error { func (g *GossipTimestampRange) MsgType() MessageType { return MsgGossipTimestampRange } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (g *GossipTimestampRange) SerializedSize() (uint32, error) { + return MessageSerializedSize(g) +} diff --git a/lnwire/init_message.go b/lnwire/init_message.go index dbbddea2b..b88891b08 100644 --- a/lnwire/init_message.go +++ b/lnwire/init_message.go @@ -43,6 +43,10 @@ func NewInitMessage(gf *RawFeatureVector, f *RawFeatureVector) *Init { // interface. var _ Message = (*Init)(nil) +// A compile time check to ensure Init implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Init)(nil) + // Decode deserializes a serialized Init message stored in the passed // io.Reader observing the specified protocol version. // @@ -78,3 +82,10 @@ func (msg *Init) Encode(w *bytes.Buffer, pver uint32) error { func (msg *Init) MsgType() MessageType { return MsgInit } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (msg *Init) SerializedSize() (uint32, error) { + return MessageSerializedSize(msg) +} diff --git a/lnwire/kickoff_sig.go b/lnwire/kickoff_sig.go index 3e46db453..b9c4c206a 100644 --- a/lnwire/kickoff_sig.go +++ b/lnwire/kickoff_sig.go @@ -27,6 +27,10 @@ type KickoffSig struct { // interface. var _ Message = (*KickoffSig)(nil) +// A compile time check to ensure KickoffSig implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*KickoffSig)(nil) + // Encode serializes the target KickoffSig into the passed bytes.Buffer // observing the specified protocol version. // @@ -54,3 +58,10 @@ func (ks *KickoffSig) Decode(r io.Reader, _ uint32) error { // // This is part of the lnwire.Message interface. func (ks *KickoffSig) MsgType() MessageType { return MsgKickoffSig } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (ks *KickoffSig) SerializedSize() (uint32, error) { + return MessageSerializedSize(ks) +} diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 123689902..7f068e526 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -3,31 +3,20 @@ package lnwire import ( "bytes" crand "crypto/rand" - "encoding/binary" "encoding/hex" - "fmt" - "image/color" - "io" "math" "math/rand" "net" - "reflect" "testing" - "testing/quick" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" - "github.com/btcsuite/btcd/btcutil" - "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn/v2" - "github.com/lightningnetwork/lnd/lnwallet/chainfee" - "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tor" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "pgregory.net/rapid" ) var ( @@ -54,106 +43,6 @@ var ( const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -func randLocalNonce(r *rand.Rand) Musig2Nonce { - var nonce Musig2Nonce - _, _ = io.ReadFull(r, nonce[:]) - - return nonce -} - -func someLocalNonce[T tlv.TlvType]( - r *rand.Rand) tlv.OptionalRecordT[T, Musig2Nonce] { - - return tlv.SomeRecordT(tlv.NewRecordT[T, Musig2Nonce]( - randLocalNonce(r), - )) -} - -func randPartialSig(r *rand.Rand) (*PartialSig, error) { - var sigBytes [32]byte - if _, err := r.Read(sigBytes[:]); err != nil { - return nil, fmt.Errorf("unable to generate sig: %w", err) - } - - var s btcec.ModNScalar - s.SetByteSlice(sigBytes[:]) - - return &PartialSig{ - Sig: s, - }, nil -} - -func somePartialSig(t *testing.T, - r *rand.Rand) tlv.OptionalRecordT[PartialSigType, PartialSig] { - - sig, err := randPartialSig(r) - if err != nil { - t.Fatal(err) - } - - return tlv.SomeRecordT(tlv.NewRecordT[PartialSigType, PartialSig]( - *sig, - )) -} - -func randPartialSigWithNonce(r *rand.Rand) (*PartialSigWithNonce, error) { - var sigBytes [32]byte - if _, err := r.Read(sigBytes[:]); err != nil { - return nil, fmt.Errorf("unable to generate sig: %w", err) - } - - var s btcec.ModNScalar - s.SetByteSlice(sigBytes[:]) - - return &PartialSigWithNonce{ - PartialSig: NewPartialSig(s), - Nonce: randLocalNonce(r), - }, nil -} - -func somePartialSigWithNonce(t *testing.T, - r *rand.Rand) OptPartialSigWithNonceTLV { - - sig, err := randPartialSigWithNonce(r) - if err != nil { - t.Fatal(err) - } - - return tlv.SomeRecordT( - tlv.NewRecordT[PartialSigWithNonceType, PartialSigWithNonce]( - *sig, - ), - ) -} - -func randAlias(r *rand.Rand) NodeAlias { - var a NodeAlias - for i := range a { - a[i] = letterBytes[r.Intn(len(letterBytes))] - } - - return a -} - -func randPubKey() (*btcec.PublicKey, error) { - priv, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } - - return priv.PubKey(), nil -} - -// pubkeyFromHex parses a Bitcoin public key from a hex encoded string. -func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) { - pubKeyBytes, err := hex.DecodeString(keyHex) - if err != nil { - return nil, err - } - - return btcec.ParsePubKey(pubKeyBytes) -} - // generateRandomBytes returns a slice of n random bytes. func generateRandomBytes(n int) ([]byte, error) { b := make([]byte, n) @@ -176,140 +65,23 @@ func randRawKey(t *testing.T) [33]byte { return n } -func randDeliveryAddress(r *rand.Rand) (DeliveryAddress, error) { - // Generate size minimum one. Empty scripts should be tested specifically. - size := r.Intn(deliveryAddressMaxSize) + 1 - da := DeliveryAddress(make([]byte, size)) - - _, err := r.Read(da) - return da, err -} - -func randRawFeatureVector(r *rand.Rand) *RawFeatureVector { - featureVec := NewRawFeatureVector() - for i := 0; i < 10000; i++ { - if r.Int31n(2) == 0 { - featureVec.Set(FeatureBit(i)) - } - } - return featureVec -} - -func randTCP4Addr(r *rand.Rand) (*net.TCPAddr, error) { - var ip [4]byte - if _, err := r.Read(ip[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - addrIP := net.IP(ip[:]) - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil -} - -func randTCP6Addr(r *rand.Rand) (*net.TCPAddr, error) { - var ip [16]byte - if _, err := r.Read(ip[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - addrIP := net.IP(ip[:]) - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil -} - -func randV2OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { - var serviceID [tor.V2DecodedLen]byte - if _, err := r.Read(serviceID[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) - onionService += tor.OnionSuffix - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil -} - -func randV3OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { - var serviceID [tor.V3DecodedLen]byte - if _, err := r.Read(serviceID[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) - onionService += tor.OnionSuffix - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil -} - -func randOpaqueAddr(r *rand.Rand) (*OpaqueAddrs, error) { - payloadLen := r.Int63n(64) + 1 - payload := make([]byte, payloadLen) - - // The first byte is the address type. So set it to one that we - // definitely don't know about. - payload[0] = math.MaxUint8 - - // Generate random bytes for the rest of the payload. - if _, err := r.Read(payload[1:]); err != nil { - return nil, err - } - - return &OpaqueAddrs{Payload: payload}, nil -} - -func randAddrs(r *rand.Rand) ([]net.Addr, error) { - tcp4Addr, err := randTCP4Addr(r) +func randPubKey() (*btcec.PublicKey, error) { + priv, err := btcec.NewPrivateKey() if err != nil { return nil, err } - tcp6Addr, err := randTCP6Addr(r) + return priv.PubKey(), nil +} + +// pubkeyFromHex parses a Bitcoin public key from a hex encoded string. +func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) { + pubKeyBytes, err := hex.DecodeString(keyHex) if err != nil { return nil, err } - v2OnionAddr, err := randV2OnionAddr(r) - if err != nil { - return nil, err - } - - v3OnionAddr, err := randV3OnionAddr(r) - if err != nil { - return nil, err - } - - opaqueAddrs, err := randOpaqueAddr(r) - if err != nil { - return nil, err - } - - return []net.Addr{ - tcp4Addr, tcp6Addr, v2OnionAddr, v3OnionAddr, opaqueAddrs, - }, nil + return btcec.ParsePubKey(pubKeyBytes) } // TestChanUpdateChanFlags ensures that converting the ChanUpdateChanFlags and @@ -418,1558 +190,64 @@ func TestEmptyMessageUnknownType(t *testing.T) { } } -// randCustomRecords generates a random set of custom records for testing. -func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords { - var ( - customRecords = CustomRecords{} - - // We'll generate a random number of records, between 1 and 10. - numRecords = r.Intn(9) + 1 - ) - - // For each record, we'll generate a random key and value. - for i := 0; i < numRecords; i++ { - // Keys must be equal to or greater than - // MinCustomRecordsTlvType. - keyOffset := uint64(r.Intn(100)) - key := MinCustomRecordsTlvType + keyOffset - - // Values are byte slices of any length. - value := make([]byte, r.Intn(10)) - _, err := r.Read(value) - require.NoError(t, err) - - customRecords[key] = value - } - - // Validate the custom records as a sanity check. - err := customRecords.Validate() - require.NoError(t, err) - - return customRecords -} - -// TestLightningWireProtocol uses the testing/quick package to create a series -// of fuzz tests to attempt to break a primary scenario which is implemented as -// property based testing scenario. +// TestLightningWireProtocol uses the rapid property-based testing framework to +// verify that all message types can be serialized and deserialized correctly. func TestLightningWireProtocol(t *testing.T) { t.Parallel() - // mainScenario is the primary test that will programmatically be - // executed for all registered wire messages. The quick-checker within - // testing/quick will attempt to find an input to this function, s.t - // the function returns false, if so then we've found an input that - // violates our model of the system. - mainScenario := func(msg Message) bool { - // Give a new message, we'll serialize the message into a new - // bytes buffer. - var b bytes.Buffer - if _, err := WriteMessage(&b, msg, 0); err != nil { - t.Fatalf("unable to write msg: %v", err) - return false + for msgType := MessageType(0); msgType < MsgEnd; msgType++ { + // If MakeEmptyMessage returns an error, then this isn't yet a + // used message type. + if _, err := MakeEmptyMessage(msgType); err != nil { + continue } - // Next, we'll ensure that the serialized payload (subtracting - // the 2 bytes for the message type) is _below_ the specified - // max payload size for this message. - payloadLen := uint32(b.Len()) - 2 - if payloadLen > MaxMsgBody { - t.Fatalf("msg payload constraint violated: %v > %v", - payloadLen, MaxMsgBody) - return false - } + t.Run(msgType.String(), rapid.MakeCheck(func(t *rapid.T) { + // Create an empty message of the given type. + m, err := MakeEmptyMessage(msgType) - // Finally, we'll deserialize the message from the written - // buffer, and finally assert that the messages are equal. - newMsg, err := ReadMessage(&b, 0) - if err != nil { - t.Fatalf("unable to read msg: %v", err) - return false - } - if !assert.Equalf(t, msg, newMsg, "message mismatch") { - return false - } - - return true - } - - // customTypeGen is a map of functions that are able to randomly - // generate a given type. These functions are needed for types which - // are too complex for the testing/quick package to automatically - // generate. - customTypeGen := map[MessageType]func([]reflect.Value, *rand.Rand){ - MsgStfu: func(v []reflect.Value, r *rand.Rand) { - req := Stfu{} - if _, err := r.Read(req.ChanID[:]); err != nil { - t.Fatalf("unable to generate ChanID: %v", err) + // An error means this isn't a valid message type, so we + // skip it. + if err != nil { + return } - // 1/2 chance of being initiator - req.Initiator = r.Intn(2) == 1 - - // 1/2 chance additional TLV data. - if r.Intn(2) == 0 { - req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} - } - - v[0] = reflect.ValueOf(req) - }, - MsgInit: func(v []reflect.Value, r *rand.Rand) { - req := NewInitMessage( - randRawFeatureVector(r), - randRawFeatureVector(r), + // The message must support the message type interface. + testMsg, ok := m.(TestMessage) + require.True( + t, ok, "message %v doesn't support TestMessage", + msgType, ) - v[0] = reflect.ValueOf(*req) - }, - MsgOpenChannel: func(v []reflect.Value, r *rand.Rand) { - req := OpenChannel{ - FundingAmount: btcutil.Amount(r.Int63()), - PushAmount: MilliSatoshi(r.Int63()), - DustLimit: btcutil.Amount(r.Int63()), - MaxValueInFlight: MilliSatoshi(r.Int63()), - ChannelReserve: btcutil.Amount(r.Int63()), - HtlcMinimum: MilliSatoshi(r.Int31()), - FeePerKiloWeight: uint32(r.Int63()), - CsvDelay: uint16(r.Int31()), - MaxAcceptedHTLCs: uint16(r.Int31()), - ChannelFlags: FundingFlag(uint8(r.Int31())), - } - - if _, err := r.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to generate chain hash: %v", err) - return - } - - if _, err := r.Read(req.PendingChannelID[:]); err != nil { - t.Fatalf("unable to generate pending chan id: %v", err) - return - } - - var err error - req.FundingKey, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.RevocationPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.PaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.DelayedPaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.HtlcPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.FirstCommitmentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - // 1/2 chance empty TLV records. - if r.Intn(2) == 0 { - req.UpfrontShutdownScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery address: %v", err) - return - } - - req.ChannelType = new(ChannelType) - *req.ChannelType = ChannelType(*randRawFeatureVector(r)) - - req.LeaseExpiry = new(LeaseExpiry) - *req.LeaseExpiry = LeaseExpiry(1337) - - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } else { - req.UpfrontShutdownScript = []byte{} - } - - // 1/2 chance additional TLV data. - if r.Intn(2) == 0 { - req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} - } - - v[0] = reflect.ValueOf(req) - }, - MsgAcceptChannel: func(v []reflect.Value, r *rand.Rand) { - req := AcceptChannel{ - DustLimit: btcutil.Amount(r.Int63()), - MaxValueInFlight: MilliSatoshi(r.Int63()), - ChannelReserve: btcutil.Amount(r.Int63()), - MinAcceptDepth: uint32(r.Int31()), - HtlcMinimum: MilliSatoshi(r.Int31()), - CsvDelay: uint16(r.Int31()), - MaxAcceptedHTLCs: uint16(r.Int31()), - } - - if _, err := r.Read(req.PendingChannelID[:]); err != nil { - t.Fatalf("unable to generate pending chan id: %v", err) - return - } - - var err error - req.FundingKey, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.RevocationPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.PaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.DelayedPaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.HtlcPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.FirstCommitmentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - // 1/2 chance empty TLV records. - if r.Intn(2) == 0 { - req.UpfrontShutdownScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery address: %v", err) - return - } - - req.ChannelType = new(ChannelType) - *req.ChannelType = ChannelType(*randRawFeatureVector(r)) - - req.LeaseExpiry = new(LeaseExpiry) - *req.LeaseExpiry = LeaseExpiry(1337) - - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } else { - req.UpfrontShutdownScript = []byte{} - } - - // 1/2 chance additional TLV data. - if r.Intn(2) == 0 { - req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} - } - - v[0] = reflect.ValueOf(req) - }, - MsgFundingCreated: func(v []reflect.Value, r *rand.Rand) { - req := FundingCreated{ - ExtraData: make([]byte, 0), - } - - if _, err := r.Read(req.PendingChannelID[:]); err != nil { - t.Fatalf("unable to generate pending chan id: %v", err) - return - } - - if _, err := r.Read(req.FundingPoint.Hash[:]); err != nil { - t.Fatalf("unable to generate hash: %v", err) - return - } - req.FundingPoint.Index = uint32(r.Int31()) % math.MaxUint16 - - var err error - req.CommitSig, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - // 1/2 chance to attach a partial sig. - if r.Intn(2) == 0 { - req.PartialSig = somePartialSigWithNonce(t, r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgFundingSigned: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - req := FundingSigned{ - ChanID: ChannelID(c), - ExtraData: make([]byte, 0), - } - req.CommitSig, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - // 1/2 chance to attach a partial sig. - if r.Intn(2) == 0 { - req.PartialSig = somePartialSigWithNonce(t, r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelReady: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - require.NoError(t, err) - - pubKey, err := randPubKey() - require.NoError(t, err) - - req := NewChannelReady(c, pubKey) - - if r.Int31()%2 == 0 { - scid := NewShortChanIDFromInt(uint64(r.Int63())) - req.AliasScid = &scid - - //nolint:ll - req.NextLocalNonce = someLocalNonce[NonceRecordTypeT](r) - } - - if r.Int31()%2 == 0 { - nodeNonce := tlv.ZeroRecordT[ - tlv.TlvType0, Musig2Nonce, - ]() - nodeNonce.Val = randLocalNonce(r) - req.AnnouncementNodeNonce = tlv.SomeRecordT( - nodeNonce, - ) - - btcNonce := tlv.ZeroRecordT[ - tlv.TlvType2, Musig2Nonce, - ]() - btcNonce.Val = randLocalNonce(r) - req.AnnouncementBitcoinNonce = tlv.SomeRecordT( - btcNonce, - ) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgShutdown: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - shutdownAddr, err := randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - - req := Shutdown{ - ChannelID: ChannelID(c), - Address: shutdownAddr, - } - - if r.Int31()%2 == 0 { - //nolint:ll - req.ShutdownNonce = someLocalNonce[ShutdownNonceType](r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgClosingSigned: func(v []reflect.Value, r *rand.Rand) { - req := ClosingSigned{ - FeeSatoshis: btcutil.Amount(r.Int63()), - ExtraData: make([]byte, 0), - } - var err error - req.Signature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChannelID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - if r.Int31()%2 == 0 { - req.PartialSig = somePartialSig(t, r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgDynPropose: func(v []reflect.Value, r *rand.Rand) { - var dp DynPropose - rand.Read(dp.ChanID[:]) - - if rand.Uint32()%2 == 0 { - v := btcutil.Amount(rand.Uint32()) - dp.DustLimit = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := MilliSatoshi(rand.Uint32()) - dp.MaxValueInFlight = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := btcutil.Amount(rand.Uint32()) - dp.ChannelReserve = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := uint16(rand.Uint32()) - dp.CsvDelay = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := uint16(rand.Uint32()) - dp.MaxAcceptedHTLCs = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v, _ := btcec.NewPrivateKey() - dp.FundingKey = fn.Some(*v.PubKey()) - } - - if rand.Uint32()%2 == 0 { - v := ChannelType(*NewRawFeatureVector()) - dp.ChannelType = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := chainfee.SatPerKWeight(rand.Uint32()) - dp.KickoffFeerate = fn.Some(v) - } - - v[0] = reflect.ValueOf(dp) - }, - MsgDynReject: func(v []reflect.Value, r *rand.Rand) { - var dr DynReject - rand.Read(dr.ChanID[:]) - - features := NewRawFeatureVector() - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPDustLimitSatoshis)) - } - - if rand.Uint32()%2 == 0 { - features.Set( - FeatureBit(DPMaxHtlcValueInFlightMsat), - ) - } - - if rand.Uint32()%2 == 0 { - features.Set( - FeatureBit(DPChannelReserveSatoshis), - ) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPToSelfDelay)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPMaxAcceptedHtlcs)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPFundingPubkey)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPChannelType)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPKickoffFeerate)) - } - dr.UpdateRejections = *features - - v[0] = reflect.ValueOf(dr) - }, - MsgDynAck: func(v []reflect.Value, r *rand.Rand) { - var da DynAck - - rand.Read(da.ChanID[:]) - if rand.Uint32()%2 == 0 { - var nonce Musig2Nonce - rand.Read(nonce[:]) - da.LocalNonce = fn.Some(nonce) - } - - v[0] = reflect.ValueOf(da) - }, - MsgKickoffSig: func(v []reflect.Value, r *rand.Rand) { - ks := KickoffSig{ - ExtraData: make([]byte, 0), - } - - rand.Read(ks.ChanID[:]) - rand.Read(ks.Signature.bytes[:]) - - v[0] = reflect.ValueOf(ks) - }, - MsgCommitSig: func(v []reflect.Value, r *rand.Rand) { - req := NewCommitSig() - if _, err := r.Read(req.ChanID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - var err error - req.CommitSig, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - // Only create the slice if there will be any signatures - // in it to prevent false positive test failures due to - // an empty slice versus a nil slice. - numSigs := uint16(r.Int31n(500)) - if numSigs > 0 { - req.HtlcSigs = make([]Sig, numSigs) - } - for i := 0; i < int(numSigs); i++ { - req.HtlcSigs[i], err = NewSigFromSignature( - testSig, - ) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - } - - req.CustomRecords = randCustomRecords(t, r) - - // 50/50 chance to attach a partial sig. - if r.Int31()%2 == 0 { - req.PartialSig = somePartialSigWithNonce(t, r) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgRevokeAndAck: func(v []reflect.Value, r *rand.Rand) { - req := NewRevokeAndAck() - if _, err := r.Read(req.ChanID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - if _, err := r.Read(req.Revocation[:]); err != nil { - t.Fatalf("unable to generate bytes: %v", err) - return - } - var err error - req.NextRevocationKey, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - // 50/50 chance to attach a local nonce. - if r.Int31()%2 == 0 { - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgChannelAnnouncement: func(v []reflect.Value, r *rand.Rand) { - var err error - req := ChannelAnnouncement1{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - NodeID1: randRawKey(t), - NodeID2: randRawKey(t), - BitcoinKey1: randRawKey(t), - BitcoinKey2: randRawKey(t), - Features: randRawFeatureVector(r), - ExtraOpaqueData: make([]byte, 0), - } - req.NodeSig1, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - req.NodeSig2, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - req.BitcoinSig1, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - req.BitcoinSig2, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to generate chain hash: %v", err) - return - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgNodeAnnouncement: func(v []reflect.Value, r *rand.Rand) { - var err error - req := NodeAnnouncement{ - NodeID: randRawKey(t), - Features: randRawFeatureVector(r), - Timestamp: uint32(r.Int31()), - Alias: randAlias(r), - RGBColor: color.RGBA{ - R: uint8(r.Int31()), - G: uint8(r.Int31()), - B: uint8(r.Int31()), - }, - ExtraOpaqueData: make([]byte, 0), - } - req.Signature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - req.Addresses, err = randAddrs(r) - if err != nil { - t.Fatalf("unable to generate addresses: %v", err) - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelUpdate: func(v []reflect.Value, r *rand.Rand) { - var err error - - msgFlags := ChanUpdateMsgFlags(r.Int31()) - maxHtlc := MilliSatoshi(r.Int63()) - - // We make the max_htlc field zero if it is not flagged - // as being part of the ChannelUpdate, to pass - // serialization tests, as it will be ignored if the bit - // is not set. - if msgFlags&ChanUpdateRequiredMaxHtlc == 0 { - maxHtlc = 0 - } - - req := ChannelUpdate1{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - Timestamp: uint32(r.Int31()), - MessageFlags: msgFlags, - ChannelFlags: ChanUpdateChanFlags(r.Int31()), - TimeLockDelta: uint16(r.Int31()), - HtlcMinimumMsat: MilliSatoshi(r.Int63()), - HtlcMaximumMsat: maxHtlc, - BaseFee: uint32(r.Int31()), - FeeRate: uint32(r.Int31()), - ExtraOpaqueData: make([]byte, 0), - } - req.Signature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to generate chain hash: %v", err) - return - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgAnnounceSignatures: func(v []reflect.Value, r *rand.Rand) { - var err error - req := AnnounceSignatures1{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), - ExtraOpaqueData: make([]byte, 0), - } - - req.NodeSignature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - req.BitcoinSignature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChannelID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelReestablish: func(v []reflect.Value, r *rand.Rand) { - req := ChannelReestablish{ - NextLocalCommitHeight: uint64(r.Int63()), - RemoteCommitTailHeight: uint64(r.Int63()), - ExtraData: make([]byte, 0), - } - - // With a 50/50 probability, we'll include the - // additional fields so we can test our ability to - // properly parse, and write out the optional fields. - if r.Int()%2 == 0 { - _, err := r.Read(req.LastRemoteCommitSecret[:]) - if err != nil { - t.Fatalf("unable to read commit secret: %v", err) - return - } - - req.LocalUnrevokedCommitPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgGossipTimestampRange: func(v []reflect.Value, r *rand.Rand) { - req := GossipTimestampRange{ - FirstTimestamp: rand.Uint32(), - TimestampRange: rand.Uint32(), - ExtraData: make([]byte, 0), - } - - _, err := rand.Read(req.ChainHash[:]) - require.NoError(t, err) - - // Sometimes add a block range. - if r.Int31()%2 == 0 { - firstBlock := tlv.ZeroRecordT[ - tlv.TlvType2, uint32, - ]() - firstBlock.Val = rand.Uint32() - req.FirstBlockHeight = tlv.SomeRecordT( - firstBlock, - ) - - blockRange := tlv.ZeroRecordT[ - tlv.TlvType4, uint32, - ]() - blockRange.Val = rand.Uint32() - req.BlockRange = tlv.SomeRecordT(blockRange) - } - - v[0] = reflect.ValueOf(req) - }, - MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { - req := QueryShortChanIDs{ - ExtraData: make([]byte, 0), - } - - // With a 50/50 change, we'll either use zlib encoding, - // or regular encoding. - if r.Int31()%2 == 0 { - req.EncodingType = EncodingSortedZlib - } else { - req.EncodingType = EncodingSortedPlain - } - - if _, err := rand.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to read chain hash: %v", err) - return - } - - numChanIDs := rand.Int31n(5000) - for i := int32(0); i < numChanIDs; i++ { - req.ShortChanIDs = append(req.ShortChanIDs, - NewShortChanIDFromInt(uint64(r.Int63()))) - } - - v[0] = reflect.ValueOf(req) - }, - MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) { - req := ReplyChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - ExtraData: make([]byte, 0), - } - - if _, err := rand.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to read chain hash: %v", err) - return - } - - req.Complete = uint8(r.Int31n(2)) - - // With a 50/50 change, we'll either use zlib encoding, - // or regular encoding. - if r.Int31()%2 == 0 { - req.EncodingType = EncodingSortedZlib - } else { - req.EncodingType = EncodingSortedPlain - } - - numChanIDs := rand.Int31n(4000) - for i := int32(0); i < numChanIDs; i++ { - req.ShortChanIDs = append(req.ShortChanIDs, - NewShortChanIDFromInt(uint64(r.Int63()))) - } - - // With a 50/50 chance, add some timestamps. - if r.Int31()%2 == 0 { - for i := int32(0); i < numChanIDs; i++ { - timestamps := ChanUpdateTimestamps{ - Timestamp1: rand.Uint32(), - Timestamp2: rand.Uint32(), - } - req.Timestamps = append( - req.Timestamps, timestamps, - ) - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgQueryChannelRange: func(v []reflect.Value, r *rand.Rand) { - req := QueryChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - ExtraData: make([]byte, 0), - } - - _, err := rand.Read(req.ChainHash[:]) - require.NoError(t, err) - - // With a 50/50 change, we'll set a query option. - if r.Int31()%2 == 0 { - req.QueryOptions = NewTimestampQueryOption() - } - - v[0] = reflect.ValueOf(req) - }, - MsgPing: func(v []reflect.Value, r *rand.Rand) { - // We use a special message generator here to ensure we - // don't generate ping messages that are too large, - // which'll cause the test to fail. - // - // We'll allow the test to generate padding bytes up to - // the max message limit, factoring in the 2 bytes for - // the num pong bytes and 2 bytes for encoding the - // length of the padding bytes. - paddingBytes := make([]byte, rand.Intn(MaxMsgBody-3)) - req := Ping{ - NumPongBytes: uint16(r.Intn(MaxPongBytes + 1)), - PaddingBytes: paddingBytes, - } - - v[0] = reflect.ValueOf(req) - }, - MsgClosingComplete: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", - err) - return - } - - req := ClosingComplete{ - ChannelID: ChannelID(c), - FeeSatoshis: btcutil.Amount(r.Int63()), - LockTime: uint32(r.Int63()), - ClosingSigs: ClosingSigs{}, - } - req.CloserScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - req.CloseeScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - - if r.Intn(2) == 0 { - sig := req.CloserNoClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserNoClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.NoCloserClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.NoCloserClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.CloserAndClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserAndClosee = tlv.SomeRecordT(sig) - } - - v[0] = reflect.ValueOf(req) - }, - MsgClosingSig: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - req := ClosingSig{ - ChannelID: ChannelID(c), - ClosingSigs: ClosingSigs{}, - FeeSatoshis: btcutil.Amount(r.Int63()), - LockTime: uint32(r.Int63()), - } - req.CloserScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - req.CloseeScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - - if r.Intn(2) == 0 { - sig := req.CloserNoClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserNoClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.NoCloserClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.NoCloserClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.CloserAndClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserAndClosee = tlv.SomeRecordT(sig) - } - - v[0] = reflect.ValueOf(req) - }, - MsgUpdateAddHTLC: func(v []reflect.Value, r *rand.Rand) { - req := &UpdateAddHTLC{ - ID: r.Uint64(), - Amount: MilliSatoshi(r.Uint64()), - Expiry: r.Uint32(), - } - - _, err := r.Read(req.ChanID[:]) - require.NoError(t, err) - - _, err = r.Read(req.PaymentHash[:]) - require.NoError(t, err) - - _, err = r.Read(req.OnionBlob[:]) - require.NoError(t, err) - - req.CustomRecords = randCustomRecords(t, r) - - // Generate a blinding point 50% of the time, since not - // all update adds will use route blinding. - if r.Int31()%2 == 0 { - pubkey, err := randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", - err) - - return - } - - req.BlindingPoint = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType0]( - pubkey, - ), - ) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgUpdateFulfillHTLC: func(v []reflect.Value, r *rand.Rand) { - req := &UpdateFulfillHTLC{ - ID: r.Uint64(), - } - - _, err := r.Read(req.ChanID[:]) - require.NoError(t, err) - - _, err = r.Read(req.PaymentPreimage[:]) - require.NoError(t, err) - - req.CustomRecords = randCustomRecords(t, r) - - // Generate some random TLV records 50% of the time. - if r.Int31()%2 == 0 { - req.ExtraData = []byte{ - 0x01, 0x03, 1, 2, 3, - 0x02, 0x03, 4, 5, 6, - } - } - - v[0] = reflect.ValueOf(*req) - }, - MsgAnnounceSignatures2: func(v []reflect.Value, - r *rand.Rand) { - - req := AnnounceSignatures2{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - ExtraOpaqueData: make([]byte, 0), - } - - _, err := r.Read(req.ChannelID[:]) - require.NoError(t, err) - - partialSig, err := randPartialSig(r) - require.NoError(t, err) - - req.PartialSignature = *partialSig - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelAnnouncement2: func(v []reflect.Value, r *rand.Rand) { - req := ChannelAnnouncement2{ - Signature: testSchnorrSig, - ExtraOpaqueData: make([]byte, 0), - } - - req.ShortChannelID.Val = NewShortChanIDFromInt( - uint64(r.Int63()), + // Use the RandTestMessage method to create a randomized + // message. + msg := testMsg.RandTestMessage(t) + + // Serialize the message to a buffer. + var b bytes.Buffer + writtenBytes, err := WriteMessage(&b, msg, 0) + require.NoError(t, err, "unable to write msg") + + // Check that the serialized payload is below the max + // payload size, accounting for the message type size. + payloadLen := uint32(writtenBytes) - 2 + require.LessOrEqual( + t, payloadLen, uint32(MaxMsgBody), + "msg payload constraint violated: %v > %v", + payloadLen, MaxMsgBody, ) - req.Capacity.Val = rand.Uint64() - req.Features.Val = *randRawFeatureVector(r) + // Deserialize the message from the buffer. + newMsg, err := ReadMessage(&b, 0) + require.NoError(t, err, "unable to read msg") - req.NodeID1.Val = randRawKey(t) - req.NodeID2.Val = randRawKey(t) - - // Sometimes set chain hash to bitcoin mainnet genesis - // hash. - req.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash - if r.Int31()%2 == 0 { - _, err := r.Read(req.ChainHash.Val[:]) - require.NoError(t, err) - } - - // Sometimes set the bitcoin keys. - if r.Int31()%2 == 0 { - btcKey1 := tlv.ZeroRecordT[ - tlv.TlvType12, [33]byte, - ]() - btcKey1.Val = randRawKey(t) - req.BitcoinKey1 = tlv.SomeRecordT(btcKey1) - - btcKey2 := tlv.ZeroRecordT[ - tlv.TlvType14, [33]byte, - ]() - btcKey2.Val = randRawKey(t) - req.BitcoinKey2 = tlv.SomeRecordT(btcKey2) - - // Occasionally also set the merkle root hash. - if r.Int31()%2 == 0 { - hash := tlv.ZeroRecordT[ - tlv.TlvType16, [32]byte, - ]() - - _, err := r.Read(hash.Val[:]) - require.NoError(t, err) - - req.MerkleRootHash = tlv.SomeRecordT( - hash, - ) - } - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelUpdate2: func(v []reflect.Value, r *rand.Rand) { - req := ChannelUpdate2{ - Signature: testSchnorrSig, - ExtraOpaqueData: make([]byte, 0), - } - - req.ShortChannelID.Val = NewShortChanIDFromInt( - uint64(r.Int63()), + // Verify the deserialized message matches the original. + require.Equal( + t, msg, newMsg, + "message mismatch for type %s", msgType, ) - req.BlockHeight.Val = r.Uint32() - req.HTLCMaximumMsat.Val = MilliSatoshi(r.Uint64()) - - // Sometimes set chain hash to bitcoin mainnet genesis - // hash. - req.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash - if r.Int31()%2 == 0 { - _, err := r.Read(req.ChainHash.Val[:]) - require.NoError(t, err) - } - - // Sometimes use default htlc min msat. - req.HTLCMinimumMsat.Val = defaultHtlcMinMsat - if r.Int31()%2 == 0 { - req.HTLCMinimumMsat.Val = MilliSatoshi( - r.Uint64(), - ) - } - - // Sometimes set the cltv expiry delta to the default. - req.CLTVExpiryDelta.Val = defaultCltvExpiryDelta - if r.Int31()%2 == 0 { - req.CLTVExpiryDelta.Val = uint16(r.Int31()) - } - - // Sometimes use default fee base. - req.FeeBaseMsat.Val = defaultFeeBaseMsat - if r.Int31()%2 == 0 { - req.FeeBaseMsat.Val = r.Uint32() - } - - // Sometimes use default proportional fee. - req.FeeProportionalMillionths.Val = - defaultFeeProportionalMillionths - if r.Int31()%2 == 0 { - req.FeeProportionalMillionths.Val = r.Uint32() - } - - // Alternate between the two direction possibilities. - if r.Int31()%2 == 0 { - req.SecondPeer = tlv.SomeRecordT( - tlv.ZeroRecordT[tlv.TlvType8, TrueBoolean](), //nolint:ll - ) - } - - // Sometimes set the incoming disabled flag. - if r.Int31()%2 == 0 { - req.DisabledFlags.Val |= - ChanUpdateDisableIncoming - } - - // Sometimes set the outgoing disabled flag. - if r.Int31()%2 == 0 { - req.DisabledFlags.Val |= - ChanUpdateDisableOutgoing - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - - v[0] = reflect.ValueOf(req) - }, + })) } - - // With the above types defined, we'll now generate a slice of - // scenarios to feed into quick.Check. The function scans in input - // space of the target function under test, so we'll need to create a - // series of wrapper functions to force it to iterate over the target - // types, but re-use the mainScenario defined above. - tests := []struct { - msgType MessageType - scenario interface{} - }{ - { - msgType: MsgStfu, - scenario: func(m Stfu) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgInit, - scenario: func(m Init) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgWarning, - scenario: func(m Warning) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgError, - scenario: func(m Error) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgPing, - scenario: func(m Ping) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgPong, - scenario: func(m Pong) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgOpenChannel, - scenario: func(m OpenChannel) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgAcceptChannel, - scenario: func(m AcceptChannel) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgFundingCreated, - scenario: func(m FundingCreated) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgFundingSigned, - scenario: func(m FundingSigned) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelReady, - scenario: func(m ChannelReady) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgShutdown, - scenario: func(m Shutdown) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgClosingSigned, - scenario: func(m ClosingSigned) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgDynPropose, - scenario: func(m DynPropose) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgDynReject, - scenario: func(m DynReject) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgDynAck, - scenario: func(m DynAck) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgKickoffSig, - scenario: func(m KickoffSig) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateAddHTLC, - scenario: func(m UpdateAddHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateFulfillHTLC, - scenario: func(m UpdateFulfillHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateFailHTLC, - scenario: func(m UpdateFailHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgCommitSig, - scenario: func(m CommitSig) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgRevokeAndAck, - scenario: func(m RevokeAndAck) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateFee, - scenario: func(m UpdateFee) bool { - return mainScenario(&m) - }, - }, - { - - msgType: MsgUpdateFailMalformedHTLC, - scenario: func(m UpdateFailMalformedHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelReestablish, - scenario: func(m ChannelReestablish) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelAnnouncement, - scenario: func(m ChannelAnnouncement1) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgNodeAnnouncement, - scenario: func(m NodeAnnouncement) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelUpdate, - scenario: func(m ChannelUpdate1) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgAnnounceSignatures, - scenario: func(m AnnounceSignatures1) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgGossipTimestampRange, - scenario: func(m GossipTimestampRange) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgQueryShortChanIDs, - scenario: func(m QueryShortChanIDs) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgReplyShortChanIDsEnd, - scenario: func(m ReplyShortChanIDsEnd) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgQueryChannelRange, - scenario: func(m QueryChannelRange) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgReplyChannelRange, - scenario: func(m ReplyChannelRange) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgClosingComplete, - scenario: func(m ClosingComplete) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgClosingSig, - scenario: func(m ClosingSig) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgAnnounceSignatures2, - scenario: func(m AnnounceSignatures2) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelAnnouncement2, - scenario: func(m ChannelAnnouncement2) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelUpdate2, - scenario: func(m ChannelUpdate2) bool { - return mainScenario(&m) - }, - }, - } - for _, test := range tests { - t.Run(test.msgType.String(), func(t *testing.T) { - var config *quick.Config - - // If the type defined is within the custom type gen - // map above, then we'll modify the default config to - // use this Value function that knows how to generate - // the proper types. - if valueGen, ok := customTypeGen[test.msgType]; ok { - config = &quick.Config{ - Values: valueGen, - } - } - - t.Logf("Running fuzz tests for msgType=%v", - test.msgType) - - err := quick.Check(test.scenario, config) - if err != nil { - t.Fatalf("fuzz checks for msg=%v failed: %v", - test.msgType, err) - } - }) - } - } func init() { diff --git a/lnwire/message.go b/lnwire/message.go index 68b09692e..ea480075a 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -12,6 +12,10 @@ import ( "io" ) +// MessageTypeSize is the size in bytes of the message type field in the header +// of all messages. +const MessageTypeSize = 2 + // MessageType is the unique 2 byte big-endian integer that indicates the type // of message on the wire. All messages have a very simple header which // consists simply of 2-byte message type. We omit a length field, and checksum @@ -61,6 +65,11 @@ const ( MsgChannelAnnouncement2 = 267 MsgChannelUpdate2 = 271 MsgKickoffSig = 777 + + // MsgEnd defines the end of the official message range of the protocol. + // If a new message is added beyond this message, then this should be + // modified. + MsgEnd = 778 ) // IsChannelUpdate is a filter function that discerns channel update messages @@ -234,6 +243,31 @@ type LinkUpdater interface { TargetChanID() ChannelID } +// SizeableMessage is an interface that extends the base Message interface with +// a method to calculate the serialized size of a message. +type SizeableMessage interface { + Message + + // SerializedSize returns the serialized size of the message in bytes. + // The returned size includes the message type header bytes. + SerializedSize() (uint32, error) +} + +// MessageSerializedSize calculates the serialized size of a message in bytes. +// This is a helper function that can be used by all message types to implement +// the SerializedSize method. +func MessageSerializedSize(msg Message) (uint32, error) { + var buf bytes.Buffer + + // Encode the message to the buffer. + if err := msg.Encode(&buf, 0); err != nil { + return 0, err + } + + // Add the size of the message type. + return uint32(buf.Len()) + MessageTypeSize, nil +} + // makeEmptyMessage creates a new empty message of the proper concrete type // based on the passed message type. func makeEmptyMessage(msgType MessageType) (Message, error) { @@ -337,6 +371,12 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { return msg, nil } +// MakeEmptyMessage creates a new empty message of the proper concrete type +// based on the passed message type. This is exported to be used in tests. +func MakeEmptyMessage(msgType MessageType) (Message, error) { + return makeEmptyMessage(msgType) +} + // WriteMessage writes a lightning Message to a buffer including the necessary // header information and returns the number of bytes written. If any error is // encountered, the buffer passed will be reset to its original state since we diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index 4f1620281..5ba2d7a1d 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -104,6 +104,10 @@ type NodeAnnouncement struct { // lnwire.Message interface. var _ Message = (*NodeAnnouncement)(nil) +// A compile time check to ensure NodeAnnouncement implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*NodeAnnouncement)(nil) + // Decode deserializes a serialized NodeAnnouncement stored in the passed // io.Reader observing the specified protocol version. // @@ -202,3 +206,10 @@ func (a *NodeAnnouncement) DataToSign() ([]byte, error) { return buf.Bytes(), nil } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *NodeAnnouncement) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 9694290f7..1751f748b 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -164,6 +164,10 @@ type OpenChannel struct { // interface. var _ Message = (*OpenChannel)(nil) +// A compile time check to ensure OpenChannel implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*OpenChannel)(nil) + // Encode serializes the target OpenChannel into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -335,3 +339,10 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { func (o *OpenChannel) MsgType() MessageType { return MsgOpenChannel } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (o *OpenChannel) SerializedSize() (uint32, error) { + return MessageSerializedSize(o) +} diff --git a/lnwire/ping.go b/lnwire/ping.go index a21f2fa8b..230187b84 100644 --- a/lnwire/ping.go +++ b/lnwire/ping.go @@ -33,6 +33,10 @@ func NewPing(numBytes uint16) *Ping { // A compile time check to ensure Ping implements the lnwire.Message interface. var _ Message = (*Ping)(nil) +// A compile time check to ensure Ping implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Ping)(nil) + // Decode deserializes a serialized Ping message stored in the passed io.Reader // observing the specified protocol version. // @@ -69,3 +73,10 @@ func (p *Ping) Encode(w *bytes.Buffer, pver uint32) error { func (p *Ping) MsgType() MessageType { return MsgPing } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (p *Ping) SerializedSize() (uint32, error) { + return MessageSerializedSize(p) +} diff --git a/lnwire/pong.go b/lnwire/pong.go index 3ab80d70f..b5fca24c9 100644 --- a/lnwire/pong.go +++ b/lnwire/pong.go @@ -39,6 +39,10 @@ func NewPong(pongBytes []byte) *Pong { // A compile time check to ensure Pong implements the lnwire.Message interface. var _ Message = (*Pong)(nil) +// A compile time check to ensure Pong implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Pong)(nil) + // Decode deserializes a serialized Pong message stored in the passed io.Reader // observing the specified protocol version. // @@ -64,3 +68,10 @@ func (p *Pong) Encode(w *bytes.Buffer, pver uint32) error { func (p *Pong) MsgType() MessageType { return MsgPong } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (p *Pong) SerializedSize() (uint32, error) { + return MessageSerializedSize(p) +} diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index 1e0dcb0fa..c816a0050 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -49,6 +49,10 @@ func NewQueryChannelRange() *QueryChannelRange { // lnwire.Message interface. var _ Message = (*QueryChannelRange)(nil) +// A compile time check to ensure QueryChannelRange implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*QueryChannelRange)(nil) + // Decode deserializes a serialized QueryChannelRange message stored in the // passed io.Reader observing the specified protocol version. // @@ -121,6 +125,14 @@ func (q *QueryChannelRange) MsgType() MessageType { return MsgQueryChannelRange } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (q *QueryChannelRange) SerializedSize() (uint32, error) { + msgCpy := *q + return MessageSerializedSize(&msgCpy) +} + // LastBlockHeight returns the last block height covered by the range of a // QueryChannelRange message. func (q *QueryChannelRange) LastBlockHeight() uint32 { diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index cb1bfa22c..37a73ab7c 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -91,6 +91,10 @@ func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding, // lnwire.Message interface. var _ Message = (*QueryShortChanIDs)(nil) +// A compile time check to ensure QueryShortChanIDs implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*QueryShortChanIDs)(nil) + // Decode deserializes a serialized QueryShortChanIDs message stored in the // passed io.Reader observing the specified protocol version. // @@ -427,3 +431,10 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding, func (q *QueryShortChanIDs) MsgType() MessageType { return MsgQueryShortChanIDs } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (q *QueryShortChanIDs) SerializedSize() (uint32, error) { + return MessageSerializedSize(q) +} diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 591fc2bd6..c3a744ebd 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -70,6 +70,10 @@ func NewReplyChannelRange() *ReplyChannelRange { // lnwire.Message interface. var _ Message = (*ReplyChannelRange)(nil) +// A compile time check to ensure ReplyChannelRange implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ReplyChannelRange)(nil) + // Decode deserializes a serialized ReplyChannelRange message stored in the // passed io.Reader observing the specified protocol version. // @@ -223,3 +227,10 @@ func (c *ReplyChannelRange) LastBlockHeight() uint32 { } return uint32(lastBlockHeight) } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ReplyChannelRange) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go index 53676f719..2e50d840f 100644 --- a/lnwire/reply_short_chan_ids_end.go +++ b/lnwire/reply_short_chan_ids_end.go @@ -39,6 +39,10 @@ func NewReplyShortChanIDsEnd() *ReplyShortChanIDsEnd { // lnwire.Message interface. var _ Message = (*ReplyShortChanIDsEnd)(nil) +// A compile time check to ensure ReplyShortChanIDsEnd implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ReplyShortChanIDsEnd)(nil) + // Decode deserializes a serialized ReplyShortChanIDsEnd message stored in the // passed io.Reader observing the specified protocol version. // @@ -74,3 +78,10 @@ func (c *ReplyShortChanIDsEnd) Encode(w *bytes.Buffer, pver uint32) error { func (c *ReplyShortChanIDsEnd) MsgType() MessageType { return MsgReplyShortChanIDsEnd } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ReplyShortChanIDsEnd) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index 9dca1631a..3c9775c99 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -55,6 +55,10 @@ func NewRevokeAndAck() *RevokeAndAck { // interface. var _ Message = (*RevokeAndAck)(nil) +// A compile time check to ensure RevokeAndAck implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*RevokeAndAck)(nil) + // Decode deserializes a serialized RevokeAndAck message stored in the // passed io.Reader observing the specified protocol version. // @@ -136,3 +140,10 @@ func (c *RevokeAndAck) MsgType() MessageType { func (c *RevokeAndAck) TargetChanID() ChannelID { return c.ChanID } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *RevokeAndAck) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/serialized_size_test.go b/lnwire/serialized_size_test.go new file mode 100644 index 000000000..cce315fa2 --- /dev/null +++ b/lnwire/serialized_size_test.go @@ -0,0 +1,68 @@ +package lnwire + +import ( + "bytes" + "math" + "testing" + + "github.com/stretchr/testify/require" + "pgregory.net/rapid" +) + +// TestSerializedSize uses property-based testing to verify that +// SerializedSize returns the correct value for randomly generated messages. +func TestSerializedSize(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + // Pick a random message type. + msgType := rapid.Custom(func(t *rapid.T) MessageType { + return MessageType( + rapid.IntRange( + 0, int(math.MaxUint16), + ).Draw(t, "msgType"), + ) + }).Draw(t, "msgType") + + // Create an empty message of the given type. + m, err := MakeEmptyMessage(msgType) + + // An error means this isn't a valid message type, so we skip + // it. + if err != nil { + return + } + + testMsg, ok := m.(TestMessage) + require.True( + t, ok, "message type %s does not "+ + "implement TestMessage", msgType, + ) + + // Use the testMsg to make a new random message. + msg := testMsg.RandTestMessage(t) + + // Type assertion to ensure the message implements + // SizeableMessage. + sizeMsg, ok := msg.(SizeableMessage) + require.True( + t, ok, "message type %s does not "+ + "implement SizeableMessage", msgType, + ) + + // Get the size using SerializedSize. + size, err := sizeMsg.SerializedSize() + require.NoError(t, err, "SerializedSize error") + + // Get the size by actually serializing the message. + var buf bytes.Buffer + writtenBytes, err := WriteMessage(&buf, msg, 0) + require.NoError(t, err, "WriteMessage error") + + // The SerializedSize should match the number of bytes written. + require.Equal(t, uint32(writtenBytes), size, + "SerializedSize = %d, actual bytes "+ + "written = %d for message type %s (populated)", + size, writtenBytes, msgType) + }) +} diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index b9899fcfb..28df9a4ca 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -61,6 +61,10 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown { // interface. var _ Message = (*Shutdown)(nil) +// A compile-time check to ensure Shutdown implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Shutdown)(nil) + // Decode deserializes a serialized Shutdown from the passed io.Reader, // observing the specified protocol version. // @@ -133,3 +137,10 @@ func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { func (s *Shutdown) MsgType() MessageType { return MsgShutdown } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (s *Shutdown) SerializedSize() (uint32, error) { + return MessageSerializedSize(s) +} diff --git a/lnwire/stfu.go b/lnwire/stfu.go index a052a517a..8e57739d5 100644 --- a/lnwire/stfu.go +++ b/lnwire/stfu.go @@ -24,6 +24,10 @@ type Stfu struct { // A compile time check to ensure Stfu implements the lnwire.Message interface. var _ Message = (*Stfu)(nil) +// A compile time check to ensure Stfu implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Stfu)(nil) + // Encode serializes the target Stfu into the passed io.Writer. // Serialization will observe the rules defined by the passed protocol version. // @@ -68,6 +72,13 @@ func (s *Stfu) MsgType() MessageType { return MsgStfu } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (s *Stfu) SerializedSize() (uint32, error) { + return MessageSerializedSize(s) +} + // A compile time check to ensure Stfu implements the // lnwire.LinkUpdater interface. var _ LinkUpdater = (*Stfu)(nil) diff --git a/lnwire/test_message.go b/lnwire/test_message.go new file mode 100644 index 000000000..8b3d98400 --- /dev/null +++ b/lnwire/test_message.go @@ -0,0 +1,1669 @@ +package lnwire + +import ( + "bytes" + "fmt" + "image/color" + "math" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/tlv" + "pgregory.net/rapid" +) + +// TestMessage is an interface that extends the base Message interface with a +// method to populate the message with random testing data. +type TestMessage interface { + Message + + // RandTestMessage populates the message with random data suitable for + // testing. It uses the rapid testing framework to generate random + // values. + RandTestMessage(t *rapid.T) Message +} + +// A compile time check to ensure AcceptChannel implements the TestMessage +// interface. +var _ TestMessage = (*AcceptChannel)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *AcceptChannel) RandTestMessage(t *rapid.T) Message { + var pendingChanID [32]byte + pendingChanIDBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "pendingChanID", + ) + copy(pendingChanID[:], pendingChanIDBytes) + + var channelType *ChannelType + includeChannelType := rapid.Bool().Draw(t, "includeChannelType") + includeLeaseExpiry := rapid.Bool().Draw(t, "includeLeaseExpiry") + includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce") + + if includeChannelType { + channelType = RandChannelType(t) + } + + var leaseExpiry *LeaseExpiry + if includeLeaseExpiry { + leaseExpiry = RandLeaseExpiry(t) + } + + var localNonce OptMusig2NonceTLV + if includeLocalNonce { + nonce := RandMusig2Nonce(t) + localNonce = tlv.SomeRecordT( + tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + ) + } + + return &AcceptChannel{ + PendingChannelID: pendingChanID, + DustLimit: btcutil.Amount( + rapid.IntRange(100, 1000).Draw(t, "dustLimit"), + ), + MaxValueInFlight: MilliSatoshi( + rapid.IntRange(10000, 1000000).Draw( + t, "maxValueInFlight", + ), + ), + ChannelReserve: btcutil.Amount( + rapid.IntRange(1000, 10000).Draw(t, "channelReserve"), + ), + HtlcMinimum: MilliSatoshi( + rapid.IntRange(1, 1000).Draw(t, "htlcMinimum"), + ), + MinAcceptDepth: uint32( + rapid.IntRange(1, 10).Draw(t, "minAcceptDepth"), + ), + CsvDelay: uint16( + rapid.IntRange(144, 1000).Draw(t, "csvDelay"), + ), + MaxAcceptedHTLCs: uint16( + rapid.IntRange(10, 500).Draw(t, "maxAcceptedHTLCs"), + ), + FundingKey: RandPubKey(t), + RevocationPoint: RandPubKey(t), + PaymentPoint: RandPubKey(t), + DelayedPaymentPoint: RandPubKey(t), + HtlcPoint: RandPubKey(t), + FirstCommitmentPoint: RandPubKey(t), + UpfrontShutdownScript: RandDeliveryAddress(t), + ChannelType: channelType, + LeaseExpiry: leaseExpiry, + LocalNonce: localNonce, + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure AnnounceSignatures1 implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*AnnounceSignatures1)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *AnnounceSignatures1) RandTestMessage(t *rapid.T) Message { + return &AnnounceSignatures1{ + ChannelID: RandChannelID(t), + ShortChannelID: RandShortChannelID(t), + NodeSignature: RandSignature(t), + BitcoinSignature: RandSignature(t), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure AnnounceSignatures2 implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*AnnounceSignatures2)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *AnnounceSignatures2) RandTestMessage(t *rapid.T) Message { + return &AnnounceSignatures2{ + ChannelID: RandChannelID(t), + ShortChannelID: RandShortChannelID(t), + PartialSignature: *RandPartialSig(t), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure ChannelAnnouncement1 implements the +// TestMessage interface. +var _ TestMessage = (*ChannelAnnouncement1)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *ChannelAnnouncement1) RandTestMessage(t *rapid.T) Message { + // Generate Node IDs and Bitcoin keys (compressed public keys) + node1PubKey := RandPubKey(t) + node2PubKey := RandPubKey(t) + bitcoin1PubKey := RandPubKey(t) + bitcoin2PubKey := RandPubKey(t) + + // Convert to byte arrays + var nodeID1, nodeID2, bitcoinKey1, bitcoinKey2 [33]byte + copy(nodeID1[:], node1PubKey.SerializeCompressed()) + copy(nodeID2[:], node2PubKey.SerializeCompressed()) + copy(bitcoinKey1[:], bitcoin1PubKey.SerializeCompressed()) + copy(bitcoinKey2[:], bitcoin2PubKey.SerializeCompressed()) + + // Ensure nodeID1 is numerically less than nodeID2 + // This is a requirement stated in the field description + if bytes.Compare(nodeID1[:], nodeID2[:]) > 0 { + nodeID1, nodeID2 = nodeID2, nodeID1 + } + + // Generate chain hash + chainHash := RandChainHash(t) + var hash chainhash.Hash + copy(hash[:], chainHash[:]) + + return &ChannelAnnouncement1{ + NodeSig1: RandSignature(t), + NodeSig2: RandSignature(t), + BitcoinSig1: RandSignature(t), + BitcoinSig2: RandSignature(t), + Features: RandFeatureVector(t), + ChainHash: hash, + ShortChannelID: RandShortChannelID(t), + NodeID1: nodeID1, + NodeID2: nodeID2, + BitcoinKey1: bitcoinKey1, + BitcoinKey2: bitcoinKey2, + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ChannelAnnouncement2)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { + features := RandFeatureVector(t) + shortChanID := RandShortChannelID(t) + capacity := uint64(rapid.IntRange(1, 16777215).Draw(t, "capacity")) + + var nodeID1, nodeID2 [33]byte + copy(nodeID1[:], RandPubKey(t).SerializeCompressed()) + copy(nodeID2[:], RandPubKey(t).SerializeCompressed()) + + // Make sure nodeID1 is numerically less than nodeID2 (as per spec). + if bytes.Compare(nodeID1[:], nodeID2[:]) > 0 { + nodeID1, nodeID2 = nodeID2, nodeID1 + } + + chainHash := RandChainHash(t) + var chainHashObj chainhash.Hash + copy(chainHashObj[:], chainHash[:]) + + msg := &ChannelAnnouncement2{ + Signature: RandSignature(t), + ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( + chainHashObj, + ), + Features: tlv.NewRecordT[tlv.TlvType2, RawFeatureVector]( + *features, + ), + ShortChannelID: tlv.NewRecordT[tlv.TlvType4, ShortChannelID]( + shortChanID, + ), + Capacity: tlv.NewPrimitiveRecord[tlv.TlvType6, uint64]( + capacity, + ), + NodeID1: tlv.NewPrimitiveRecord[tlv.TlvType8, [33]byte]( + nodeID1, + ), + NodeID2: tlv.NewPrimitiveRecord[tlv.TlvType10, [33]byte]( + nodeID2, + ), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } + + msg.Signature.ForceSchnorr() + + // Randomly include optional fields + if rapid.Bool().Draw(t, "includeBitcoinKey1") { + var bitcoinKey1 [33]byte + copy(bitcoinKey1[:], RandPubKey(t).SerializeCompressed()) + msg.BitcoinKey1 = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType12, [33]byte]( + bitcoinKey1, + ), + ) + } + + if rapid.Bool().Draw(t, "includeBitcoinKey2") { + var bitcoinKey2 [33]byte + copy(bitcoinKey2[:], RandPubKey(t).SerializeCompressed()) + msg.BitcoinKey2 = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType14, [33]byte]( + bitcoinKey2, + ), + ) + } + + if rapid.Bool().Draw(t, "includeMerkleRootHash") { + hash := RandSHA256Hash(t) + var merkleRootHash [32]byte + copy(merkleRootHash[:], hash[:]) + msg.MerkleRootHash = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType16, [32]byte]( + merkleRootHash, + ), + ) + } + + return msg +} + +// A compile time check to ensure ChannelReady implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*ChannelReady)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ChannelReady) RandTestMessage(t *rapid.T) Message { + msg := &ChannelReady{ + ChanID: RandChannelID(t), + NextPerCommitmentPoint: RandPubKey(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeAliasScid := rapid.Bool().Draw(t, "includeAliasScid") + includeNextLocalNonce := rapid.Bool().Draw(t, "includeNextLocalNonce") + includeAnnouncementNodeNonce := rapid.Bool().Draw( + t, "includeAnnouncementNodeNonce", + ) + includeAnnouncementBitcoinNonce := rapid.Bool().Draw( + t, "includeAnnouncementBitcoinNonce", + ) + + if includeAliasScid { + scid := RandShortChannelID(t) + msg.AliasScid = &scid + } + + if includeNextLocalNonce { + nonce := RandMusig2Nonce(t) + msg.NextLocalNonce = SomeMusig2Nonce(nonce) + } + + if includeAnnouncementNodeNonce { + nonce := RandMusig2Nonce(t) + msg.AnnouncementNodeNonce = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType0, Musig2Nonce](nonce), + ) + } + + if includeAnnouncementBitcoinNonce { + nonce := RandMusig2Nonce(t) + msg.AnnouncementBitcoinNonce = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2, Musig2Nonce](nonce), + ) + } + + return msg +} + +// A compile time check to ensure ChannelReestablish implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ChannelReestablish)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *ChannelReestablish) RandTestMessage(t *rapid.T) Message { + msg := &ChannelReestablish{ + ChanID: RandChannelID(t), + NextLocalCommitHeight: rapid.Uint64().Draw( + t, "nextLocalCommitHeight", + ), + RemoteCommitTailHeight: rapid.Uint64().Draw( + t, "remoteCommitTailHeight", + ), + LastRemoteCommitSecret: RandPaymentPreimage(t), + LocalUnrevokedCommitPoint: RandPubKey(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + // Randomly decide whether to include optional fields + includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce") + includeDynHeight := rapid.Bool().Draw(t, "includeDynHeight") + + if includeLocalNonce { + nonce := RandMusig2Nonce(t) + msg.LocalNonce = SomeMusig2Nonce(nonce) + } + + if includeDynHeight { + height := DynHeight(rapid.Uint64().Draw(t, "dynHeight")) + msg.DynHeight = fn.Some(height) + } + + return msg +} + +// A compile time check to ensure ChannelUpdate1 implements the TestMessage +// interface. +var _ TestMessage = (*ChannelUpdate1)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *ChannelUpdate1) RandTestMessage(t *rapid.T) Message { + // Generate random message flags + // Randomly decide whether to include max HTLC field + includeMaxHtlc := rapid.Bool().Draw(t, "includeMaxHtlc") + var msgFlags ChanUpdateMsgFlags + if includeMaxHtlc { + msgFlags |= ChanUpdateRequiredMaxHtlc + } + + // Generate random channel flags + // Randomly decide direction (node1 or node2) + isNode2 := rapid.Bool().Draw(t, "isNode2") + var chanFlags ChanUpdateChanFlags + if isNode2 { + chanFlags |= ChanUpdateDirection + } + + // Randomly decide if channel is disabled + isDisabled := rapid.Bool().Draw(t, "isDisabled") + if isDisabled { + chanFlags |= ChanUpdateDisabled + } + + // Generate chain hash + chainHash := RandChainHash(t) + var hash chainhash.Hash + copy(hash[:], chainHash[:]) + + // Generate other random fields + maxHtlc := MilliSatoshi(rapid.Uint64().Draw(t, "maxHtlc")) + + // If max HTLC flag is not set, we need to zero the value + if !includeMaxHtlc { + maxHtlc = 0 + } + + return &ChannelUpdate1{ + Signature: RandSignature(t), + ChainHash: hash, + ShortChannelID: RandShortChannelID(t), + Timestamp: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw( + t, "timestamp"), + ), + MessageFlags: msgFlags, + ChannelFlags: chanFlags, + TimeLockDelta: uint16(rapid.IntRange(0, 65535).Draw( + t, "timelockDelta"), + ), + HtlcMinimumMsat: MilliSatoshi(rapid.Uint64().Draw( + t, "htlcMinimum"), + ), + BaseFee: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw( + t, "baseFee"), + ), + FeeRate: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw( + t, "feeRate"), + ), + HtlcMaximumMsat: maxHtlc, + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure ChannelUpdate2 implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ChannelUpdate2)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ChannelUpdate2) RandTestMessage(t *rapid.T) Message { + shortChanID := RandShortChannelID(t) + blockHeight := uint32(rapid.IntRange(0, 1000000).Draw(t, "blockHeight")) + + var disabledFlags ChanUpdateDisableFlags + if rapid.Bool().Draw(t, "disableIncoming") { + disabledFlags |= ChanUpdateDisableIncoming + } + if rapid.Bool().Draw(t, "disableOutgoing") { + disabledFlags |= ChanUpdateDisableOutgoing + } + + cltvExpiryDelta := uint16(rapid.IntRange(10, 200).Draw( + t, "cltvExpiryDelta"), + ) + + htlcMinMsat := MilliSatoshi(rapid.IntRange(1, 10000).Draw( + t, "htlcMinMsat"), + ) + htlcMaxMsat := MilliSatoshi(rapid.IntRange(10000, 100000000).Draw( + t, "htlcMaxMsat"), + ) + feeBaseMsat := uint32(rapid.IntRange(0, 10000).Draw(t, "feeBaseMsat")) + feeProportionalMillionths := uint32(rapid.IntRange(0, 10000).Draw( + t, "feeProportionalMillionths"), + ) + + chainHash := RandChainHash(t) + var chainHashObj chainhash.Hash + copy(chainHashObj[:], chainHash[:]) + + //nolint:ll + msg := &ChannelUpdate2{ + Signature: RandSignature(t), + ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( + chainHashObj, + ), + ShortChannelID: tlv.NewRecordT[tlv.TlvType2, ShortChannelID]( + shortChanID, + ), + BlockHeight: tlv.NewPrimitiveRecord[tlv.TlvType4, uint32]( + blockHeight, + ), + DisabledFlags: tlv.NewPrimitiveRecord[tlv.TlvType6, ChanUpdateDisableFlags]( //nolint:ll + disabledFlags, + ), + CLTVExpiryDelta: tlv.NewPrimitiveRecord[tlv.TlvType10, uint16]( + cltvExpiryDelta, + ), + HTLCMinimumMsat: tlv.NewPrimitiveRecord[tlv.TlvType12, MilliSatoshi]( + htlcMinMsat, + ), + HTLCMaximumMsat: tlv.NewPrimitiveRecord[tlv.TlvType14, MilliSatoshi]( + htlcMaxMsat, + ), + FeeBaseMsat: tlv.NewPrimitiveRecord[tlv.TlvType16, uint32]( + feeBaseMsat, + ), + FeeProportionalMillionths: tlv.NewPrimitiveRecord[tlv.TlvType18, uint32]( + feeProportionalMillionths, + ), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } + + msg.Signature.ForceSchnorr() + + if rapid.Bool().Draw(t, "isSecondPeer") { + msg.SecondPeer = tlv.SomeRecordT( + tlv.RecordT[tlv.TlvType8, TrueBoolean]{}, + ) + } + + return msg +} + +// A compile time check to ensure ClosingComplete implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ClosingComplete)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ClosingComplete) RandTestMessage(t *rapid.T) Message { + msg := &ClosingComplete{ + ChannelID: RandChannelID(t), + FeeSatoshis: btcutil.Amount(rapid.Int64Range(0, 1000000).Draw( + t, "feeSatoshis"), + ), + LockTime: rapid.Uint32Range(0, 0xffffffff).Draw( + t, "lockTime", + ), + CloseeScript: RandDeliveryAddress(t), + CloserScript: RandDeliveryAddress(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeCloserNoClosee := rapid.Bool().Draw(t, "includeCloserNoClosee") + includeNoCloserClosee := rapid.Bool().Draw(t, "includeNoCloserClosee") + includeCloserAndClosee := rapid.Bool().Draw(t, "includeCloserAndClosee") + + // Ensure at least one signature is present. + if !includeCloserNoClosee && !includeNoCloserClosee && + !includeCloserAndClosee { + + // If all are false, enable at least one randomly. + choice := rapid.IntRange(0, 2).Draw(t, "sigChoice") + switch choice { + case 0: + includeCloserNoClosee = true + case 1: + includeNoCloserClosee = true + case 2: + includeCloserAndClosee = true + } + } + + if includeCloserNoClosee { + sig := RandSignature(t) + msg.CloserNoClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType1, Sig](sig), + ) + } + + if includeNoCloserClosee { + sig := RandSignature(t) + msg.NoCloserClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2, Sig](sig), + ) + } + + if includeCloserAndClosee { + sig := RandSignature(t) + msg.CloserAndClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3, Sig](sig), + ) + } + + return msg +} + +// A compile time check to ensure ClosingSig implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*ClosingSig)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ClosingSig) RandTestMessage(t *rapid.T) Message { + msg := &ClosingSig{ + ChannelID: RandChannelID(t), + CloseeScript: RandDeliveryAddress(t), + CloserScript: RandDeliveryAddress(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeCloserNoClosee := rapid.Bool().Draw(t, "includeCloserNoClosee") + includeNoCloserClosee := rapid.Bool().Draw(t, "includeNoCloserClosee") + includeCloserAndClosee := rapid.Bool().Draw(t, "includeCloserAndClosee") + + // Ensure at least one signature is present. + if !includeCloserNoClosee && !includeNoCloserClosee && + !includeCloserAndClosee { + + // If all are false, enable at least one randomly. + choice := rapid.IntRange(0, 2).Draw(t, "sigChoice") + switch choice { + case 0: + includeCloserNoClosee = true + case 1: + includeNoCloserClosee = true + case 2: + includeCloserAndClosee = true + } + } + + if includeCloserNoClosee { + sig := RandSignature(t) + msg.CloserNoClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType1, Sig](sig), + ) + } + + if includeNoCloserClosee { + sig := RandSignature(t) + msg.NoCloserClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2, Sig](sig), + ) + } + + if includeCloserAndClosee { + sig := RandSignature(t) + msg.CloserAndClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3, Sig](sig), + ) + } + + return msg +} + +// A compile time check to ensure ClosingSigned implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ClosingSigned)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ClosingSigned) RandTestMessage(t *rapid.T) Message { + // Generate a random boolean to decide whether to include CommitSig or + // PartialSig Since they're mutually exclusive, when one is populated, + // the other must be blank. + usePartialSig := rapid.Bool().Draw(t, "usePartialSig") + + msg := &ClosingSigned{ + ChannelID: RandChannelID(t), + FeeSatoshis: btcutil.Amount( + rapid.Int64Range(0, 1000000).Draw(t, "feeSatoshis"), + ), + ExtraData: RandExtraOpaqueData(t, nil), + } + + if usePartialSig { + sigBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "sigScalar", + ) + var s btcec.ModNScalar + _ = s.SetByteSlice(sigBytes) + + msg.PartialSig = SomePartialSig(NewPartialSig(s)) + msg.Signature = Sig{} + } else { + msg.Signature = RandSignature(t) + } + + return msg +} + +// A compile time check to ensure CommitSig implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*CommitSig)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *CommitSig) RandTestMessage(t *rapid.T) Message { + cr, _ := RandCustomRecords(t, nil, true) + sig := &CommitSig{ + ChanID: RandChannelID(t), + CommitSig: RandSignature(t), + CustomRecords: cr, + } + + numHtlcSigs := rapid.IntRange(0, 20).Draw(t, "numHtlcSigs") + htlcSigs := make([]Sig, numHtlcSigs) + for i := 0; i < numHtlcSigs; i++ { + htlcSigs[i] = RandSignature(t) + } + + if len(htlcSigs) > 0 { + sig.HtlcSigs = htlcSigs + } + + includePartialSig := rapid.Bool().Draw(t, "includePartialSig") + if includePartialSig { + sigWithNonce := RandPartialSigWithNonce(t) + sig.PartialSig = MaybePartialSigWithNonce(sigWithNonce) + } + + return sig +} + +// A compile time check to ensure Custom implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Custom)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *Custom) RandTestMessage(t *rapid.T) Message { + msgType := MessageType( + rapid.IntRange(int(CustomTypeStart), 65535).Draw( + t, "customMsgType", + ), + ) + + dataLen := rapid.IntRange(0, 1000).Draw(t, "customDataLength") + data := rapid.SliceOfN(rapid.Byte(), dataLen, dataLen).Draw( + t, "customData", + ) + + msg, err := NewCustom(msgType, data) + if err != nil { + panic(fmt.Sprintf("Error creating custom message: %v", err)) + } + + return msg +} + +// A compile time check to ensure DynAck implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*DynAck)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (da *DynAck) RandTestMessage(t *rapid.T) Message { + msg := &DynAck{ + ChanID: RandChannelID(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce") + + if includeLocalNonce { + msg.LocalNonce = fn.Some(RandMusig2Nonce(t)) + } + + return msg +} + +// A compile time check to ensure DynPropose implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*DynPropose)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (dp *DynPropose) RandTestMessage(t *rapid.T) Message { + msg := &DynPropose{ + ChanID: RandChannelID(t), + Initiator: rapid.Bool().Draw(t, "initiator"), + ExtraData: RandExtraOpaqueData(t, nil), + } + + // Randomly decide which optional fields to include + includeDustLimit := rapid.Bool().Draw(t, "includeDustLimit") + includeMaxValueInFlight := rapid.Bool().Draw( + t, "includeMaxValueInFlight", + ) + includeChannelReserve := rapid.Bool().Draw(t, "includeChannelReserve") + includeCsvDelay := rapid.Bool().Draw(t, "includeCsvDelay") + includeMaxAcceptedHTLCs := rapid.Bool().Draw( + t, "includeMaxAcceptedHTLCs", + ) + includeFundingKey := rapid.Bool().Draw(t, "includeFundingKey") + includeChannelType := rapid.Bool().Draw(t, "includeChannelType") + includeKickoffFeerate := rapid.Bool().Draw(t, "includeKickoffFeerate") + + // Generate random values for each included field + if includeDustLimit { + dl := btcutil.Amount(rapid.Uint32().Draw(t, "dustLimit")) + msg.DustLimit = fn.Some(dl) + } + + if includeMaxValueInFlight { + mvif := MilliSatoshi(rapid.Uint64().Draw(t, "maxValueInFlight")) + msg.MaxValueInFlight = fn.Some(mvif) + } + + if includeChannelReserve { + cr := btcutil.Amount(rapid.Uint32().Draw(t, "channelReserve")) + msg.ChannelReserve = fn.Some(cr) + } + + if includeCsvDelay { + cd := rapid.Uint16().Draw(t, "csvDelay") + msg.CsvDelay = fn.Some(cd) + } + + if includeMaxAcceptedHTLCs { + mah := rapid.Uint16().Draw(t, "maxAcceptedHTLCs") + msg.MaxAcceptedHTLCs = fn.Some(mah) + } + + if includeFundingKey { + msg.FundingKey = fn.Some(*RandPubKey(t)) + } + + if includeChannelType { + msg.ChannelType = fn.Some(*RandChannelType(t)) + } + + if includeKickoffFeerate { + kf := chainfee.SatPerKWeight(rapid.Uint32().Draw( + t, "kickoffFeerate"), + ) + msg.KickoffFeerate = fn.Some(kf) + } + + return msg +} + +// A compile time check to ensure DynReject implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*DynReject)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (dr *DynReject) RandTestMessage(t *rapid.T) Message { + featureVec := NewRawFeatureVector() + + numFeatures := rapid.IntRange(0, 8).Draw(t, "numRejections") + for i := 0; i < numFeatures; i++ { + bit := FeatureBit( + rapid.IntRange(0, 31).Draw( + t, fmt.Sprintf("rejectionBit-%d", i), + ), + ) + featureVec.Set(bit) + } + + var extraData ExtraOpaqueData + randData := RandExtraOpaqueData(t, nil) + if len(randData) > 0 { + extraData = randData + } + + return &DynReject{ + ChanID: RandChannelID(t), + UpdateRejections: *featureVec, + ExtraData: extraData, + } +} + +// A compile time check to ensure FundingCreated implements the TestMessage +// interface. +var _ TestMessage = (*FundingCreated)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (f *FundingCreated) RandTestMessage(t *rapid.T) Message { + var pendingChanID [32]byte + pendingChanIDBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "pendingChanID", + ) + copy(pendingChanID[:], pendingChanIDBytes) + + includePartialSig := rapid.Bool().Draw(t, "includePartialSig") + var partialSig OptPartialSigWithNonceTLV + var commitSig Sig + + if includePartialSig { + sigWithNonce := RandPartialSigWithNonce(t) + partialSig = MaybePartialSigWithNonce(sigWithNonce) + + // When using partial sig, CommitSig should be empty/blank. + commitSig = Sig{} + } else { + commitSig = RandSignature(t) + } + + return &FundingCreated{ + PendingChannelID: pendingChanID, + FundingPoint: RandOutPoint(t), + CommitSig: commitSig, + PartialSig: partialSig, + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure FundingSigned implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*FundingSigned)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (f *FundingSigned) RandTestMessage(t *rapid.T) Message { + usePartialSig := rapid.Bool().Draw(t, "usePartialSig") + + msg := &FundingSigned{ + ChanID: RandChannelID(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + if usePartialSig { + sigWithNonce := RandPartialSigWithNonce(t) + msg.PartialSig = MaybePartialSigWithNonce(sigWithNonce) + + msg.CommitSig = Sig{} + } else { + msg.CommitSig = RandSignature(t) + } + + return msg +} + +// A compile time check to ensure GossipTimestampRange implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*GossipTimestampRange)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (g *GossipTimestampRange) RandTestMessage(t *rapid.T) Message { + var chainHash chainhash.Hash + hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash") + copy(chainHash[:], hashBytes) + + msg := &GossipTimestampRange{ + ChainHash: chainHash, + FirstTimestamp: rapid.Uint32().Draw(t, "firstTimestamp"), + TimestampRange: rapid.Uint32().Draw(t, "timestampRange"), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeFirstBlockHeight := rapid.Bool().Draw( + t, "includeFirstBlockHeight", + ) + includeBlockRange := rapid.Bool().Draw(t, "includeBlockRange") + + if includeFirstBlockHeight { + height := rapid.Uint32().Draw(t, "firstBlockHeight") + msg.FirstBlockHeight = tlv.SomeRecordT( + tlv.RecordT[tlv.TlvType2, uint32]{Val: height}, + ) + } + + if includeBlockRange { + blockRange := rapid.Uint32().Draw(t, "blockRange") + msg.BlockRange = tlv.SomeRecordT( + tlv.RecordT[tlv.TlvType4, uint32]{Val: blockRange}, + ) + } + + return msg +} + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (msg *Init) RandTestMessage(t *rapid.T) Message { + global := NewRawFeatureVector() + local := NewRawFeatureVector() + + numGlobalFeatures := rapid.IntRange(0, 20).Draw(t, "numGlobalFeatures") + for i := 0; i < numGlobalFeatures; i++ { + bit := FeatureBit( + rapid.IntRange(0, 100).Draw( + t, fmt.Sprintf("globalFeatureBit%d", i), + ), + ) + global.Set(bit) + } + + numLocalFeatures := rapid.IntRange(0, 20).Draw(t, "numLocalFeatures") + for i := 0; i < numLocalFeatures; i++ { + bit := FeatureBit( + rapid.IntRange(0, 100).Draw( + t, fmt.Sprintf("localFeatureBit%d", i), + ), + ) + local.Set(bit) + } + + return NewInitMessage(global, local) +} + +// A compile time check to ensure KickoffSig implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*KickoffSig)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (ks *KickoffSig) RandTestMessage(t *rapid.T) Message { + return &KickoffSig{ + ChanID: RandChannelID(t), + Signature: RandSignature(t), + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure NodeAnnouncement implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*NodeAnnouncement)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *NodeAnnouncement) RandTestMessage(t *rapid.T) Message { + // Generate random compressed public key for node ID + pubKey := RandPubKey(t) + var nodeID [33]byte + copy(nodeID[:], pubKey.SerializeCompressed()) + + // Generate random RGB color + rgbColor := color.RGBA{ + R: uint8(rapid.IntRange(0, 255).Draw(t, "rgbR")), + G: uint8(rapid.IntRange(0, 255).Draw(t, "rgbG")), + B: uint8(rapid.IntRange(0, 255).Draw(t, "rgbB")), + } + + return &NodeAnnouncement{ + Signature: RandSignature(t), + Features: RandFeatureVector(t), + Timestamp: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw( + t, "timestamp"), + ), + NodeID: nodeID, + RGBColor: rgbColor, + Alias: RandNodeAlias(t), + Addresses: RandNetAddrs(t), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure OpenChannel implements the TestMessage +// interface. +var _ TestMessage = (*OpenChannel)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (o *OpenChannel) RandTestMessage(t *rapid.T) Message { + chainHash := RandChainHash(t) + var hash chainhash.Hash + copy(hash[:], chainHash[:]) + + var pendingChanID [32]byte + pendingChanIDBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "pendingChanID", + ) + copy(pendingChanID[:], pendingChanIDBytes) + + includeChannelType := rapid.Bool().Draw(t, "includeChannelType") + includeLeaseExpiry := rapid.Bool().Draw(t, "includeLeaseExpiry") + includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce") + + var channelFlags FundingFlag + if rapid.Bool().Draw(t, "announceChannel") { + channelFlags |= FFAnnounceChannel + } + + var localNonce OptMusig2NonceTLV + if includeLocalNonce { + nonce := RandMusig2Nonce(t) + localNonce = tlv.SomeRecordT( + tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + ) + } + + var channelType *ChannelType + if includeChannelType { + channelType = RandChannelType(t) + } + + var leaseExpiry *LeaseExpiry + if includeLeaseExpiry { + leaseExpiry = RandLeaseExpiry(t) + } + + return &OpenChannel{ + ChainHash: hash, + PendingChannelID: pendingChanID, + FundingAmount: btcutil.Amount( + rapid.IntRange(5000, 10000000).Draw(t, "fundingAmount"), + ), + PushAmount: MilliSatoshi( + rapid.IntRange(0, 1000000).Draw(t, "pushAmount"), + ), + DustLimit: btcutil.Amount( + rapid.IntRange(100, 1000).Draw(t, "dustLimit"), + ), + MaxValueInFlight: MilliSatoshi( + rapid.IntRange(10000, 1000000).Draw( + t, "maxValueInFlight", + ), + ), + ChannelReserve: btcutil.Amount( + rapid.IntRange(1000, 10000).Draw(t, "channelReserve"), + ), + HtlcMinimum: MilliSatoshi( + rapid.IntRange(1, 1000).Draw(t, "htlcMinimum"), + ), + FeePerKiloWeight: uint32( + rapid.IntRange(250, 10000).Draw(t, "feePerKw"), + ), + CsvDelay: uint16( + rapid.IntRange(144, 1000).Draw(t, "csvDelay"), + ), + MaxAcceptedHTLCs: uint16( + rapid.IntRange(10, 500).Draw(t, "maxAcceptedHTLCs"), + ), + FundingKey: RandPubKey(t), + RevocationPoint: RandPubKey(t), + PaymentPoint: RandPubKey(t), + DelayedPaymentPoint: RandPubKey(t), + HtlcPoint: RandPubKey(t), + FirstCommitmentPoint: RandPubKey(t), + ChannelFlags: channelFlags, + UpfrontShutdownScript: RandDeliveryAddress(t), + ChannelType: channelType, + LeaseExpiry: leaseExpiry, + LocalNonce: localNonce, + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure Ping implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Ping)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (p *Ping) RandTestMessage(t *rapid.T) Message { + numPongBytes := uint16(rapid.IntRange(0, int(MaxPongBytes)).Draw( + t, "numPongBytes"), + ) + + // Generate padding bytes (but keeping within allowed message size) + // MaxMsgBody - 2 (for NumPongBytes) - 2 (for padding length) + maxPaddingLen := MaxMsgBody - 4 + paddingLen := rapid.IntRange(0, maxPaddingLen).Draw( + t, "paddingLen", + ) + padding := make(PingPayload, paddingLen) + + // Fill padding with random bytes + for i := 0; i < paddingLen; i++ { + padding[i] = byte(rapid.IntRange(0, 255).Draw( + t, fmt.Sprintf("paddingByte%d", i)), + ) + } + + return &Ping{ + NumPongBytes: numPongBytes, + PaddingBytes: padding, + } +} + +// A compile time check to ensure Pong implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Pong)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (p *Pong) RandTestMessage(t *rapid.T) Message { + payloadLen := rapid.IntRange(0, 1000).Draw(t, "pongPayloadLength") + payload := rapid.SliceOfN(rapid.Byte(), payloadLen, payloadLen).Draw( + t, "pongPayload", + ) + + return &Pong{ + PongBytes: payload, + } +} + +// A compile time check to ensure QueryChannelRange implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*QueryChannelRange)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (q *QueryChannelRange) RandTestMessage(t *rapid.T) Message { + msg := &QueryChannelRange{ + FirstBlockHeight: uint32(rapid.IntRange(0, 1000000).Draw( + t, "firstBlockHeight"), + ), + NumBlocks: uint32(rapid.IntRange(1, 10000).Draw( + t, "numBlocks"), + ), + ExtraData: RandExtraOpaqueData(t, nil), + } + + // Generate chain hash + chainHash := RandChainHash(t) + var chainHashObj chainhash.Hash + copy(chainHashObj[:], chainHash[:]) + msg.ChainHash = chainHashObj + + // Randomly include QueryOptions + if rapid.Bool().Draw(t, "includeQueryOptions") { + queryOptions := &QueryOptions{} + *queryOptions = QueryOptions(*RandFeatureVector(t)) + msg.QueryOptions = queryOptions + } + + return msg +} + +// A compile time check to ensure QueryShortChanIDs implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*QueryShortChanIDs)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (q *QueryShortChanIDs) RandTestMessage(t *rapid.T) Message { + var chainHash chainhash.Hash + hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash") + copy(chainHash[:], hashBytes) + + encodingType := EncodingSortedPlain + if rapid.Bool().Draw(t, "useZlibEncoding") { + encodingType = EncodingSortedZlib + } + + msg := &QueryShortChanIDs{ + ChainHash: chainHash, + EncodingType: encodingType, + ExtraData: RandExtraOpaqueData(t, nil), + noSort: false, + } + + numIDs := rapid.IntRange(2, 20).Draw(t, "numShortChanIDs") + + // Generate sorted short channel IDs. + shortChanIDs := make([]ShortChannelID, numIDs) + for i := 0; i < numIDs; i++ { + shortChanIDs[i] = RandShortChannelID(t) + + // Ensure they're properly sorted. + if i > 0 && shortChanIDs[i].ToUint64() <= + shortChanIDs[i-1].ToUint64() { + + // Ensure this ID is larger than the previous one. + shortChanIDs[i] = NewShortChanIDFromInt( + shortChanIDs[i-1].ToUint64() + 1, + ) + } + } + + msg.ShortChanIDs = shortChanIDs + + return msg +} + +// A compile time check to ensure ReplyChannelRange implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ReplyChannelRange)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ReplyChannelRange) RandTestMessage(t *rapid.T) Message { + msg := &ReplyChannelRange{ + FirstBlockHeight: uint32(rapid.IntRange(0, 1000000).Draw( + t, "firstBlockHeight"), + ), + NumBlocks: uint32(rapid.IntRange(1, 10000).Draw( + t, "numBlocks"), + ), + Complete: uint8(rapid.IntRange(0, 1).Draw(t, "complete")), + EncodingType: QueryEncoding( + rapid.IntRange(0, 1).Draw(t, "encodingType"), + ), + ExtraData: RandExtraOpaqueData(t, nil), + } + + msg.ChainHash = RandChainHash(t) + + numShortChanIDs := rapid.IntRange(0, 20).Draw(t, "numShortChanIDs") + if numShortChanIDs == 0 { + return msg + } + + scidSet := fn.NewSet[ShortChannelID]() + scids := make([]ShortChannelID, numShortChanIDs) + for i := 0; i < numShortChanIDs; i++ { + scid := RandShortChannelID(t) + for scidSet.Contains(scid) { + scid = RandShortChannelID(t) + } + + scids[i] = scid + + scidSet.Add(scid) + } + + // Make sure there're no duplicates. + msg.ShortChanIDs = scids + + if rapid.Bool().Draw(t, "includeTimestamps") && numShortChanIDs > 0 { + msg.Timestamps = make(Timestamps, numShortChanIDs) + for i := 0; i < numShortChanIDs; i++ { + msg.Timestamps[i] = ChanUpdateTimestamps{ + Timestamp1: uint32(rapid.IntRange(0, math.MaxInt32).Draw(t, fmt.Sprintf("timestamp-1-%d", i))), //nolint:ll + Timestamp2: uint32(rapid.IntRange(0, math.MaxInt32).Draw(t, fmt.Sprintf("timestamp-2-%d", i))), //nolint:ll + } + } + } + + return msg +} + +// A compile time check to ensure ReplyShortChanIDsEnd implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ReplyShortChanIDsEnd)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ReplyShortChanIDsEnd) RandTestMessage(t *rapid.T) Message { + var chainHash chainhash.Hash + hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash") + copy(chainHash[:], hashBytes) + + complete := uint8(rapid.IntRange(0, 1).Draw(t, "complete")) + + return &ReplyShortChanIDsEnd{ + ChainHash: chainHash, + Complete: complete, + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// RandTestMessage returns a RevokeAndAck message populated with random data. +// +// This is part of the TestMessage interface. +func (c *RevokeAndAck) RandTestMessage(t *rapid.T) Message { + msg := NewRevokeAndAck() + + var chanID ChannelID + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "channelID") + copy(chanID[:], bytes) + msg.ChanID = chanID + + revBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "revocation") + copy(msg.Revocation[:], revBytes) + + msg.NextRevocationKey = RandPubKey(t) + + if rapid.Bool().Draw(t, "includeLocalNonce") { + var nonce Musig2Nonce + nonceBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "nonce", + ) + copy(nonce[:], nonceBytes) + + msg.LocalNonce = tlv.SomeRecordT( + tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + ) + } + + return msg +} + +// A compile-time check to ensure Shutdown implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Shutdown)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (s *Shutdown) RandTestMessage(t *rapid.T) Message { + // Generate random delivery address + // First decide the address type (P2PKH, P2SH, P2WPKH, P2WSH, P2TR) + addrType := rapid.IntRange(0, 4).Draw(t, "addrType") + + // Generate random address length based on type + var addrLen int + switch addrType { + // P2PKH + case 0: + addrLen = 25 + // P2SH + case 1: + addrLen = 23 + // P2WPKH + case 2: + addrLen = 22 + // P2WSH + case 3: + addrLen = 34 + // P2TR + case 4: + addrLen = 34 + } + + addr := rapid.SliceOfN(rapid.Byte(), addrLen, addrLen).Draw( + t, "address", + ) + + // Randomly decide whether to include a shutdown nonce + includeNonce := rapid.Bool().Draw(t, "includeNonce") + var shutdownNonce ShutdownNonceTLV + + if includeNonce { + shutdownNonce = SomeShutdownNonce(RandMusig2Nonce(t)) + } + + cr, _ := RandCustomRecords(t, nil, true) + + return &Shutdown{ + ChannelID: RandChannelID(t), + Address: addr, + ShutdownNonce: shutdownNonce, + CustomRecords: cr, + } +} + +// A compile time check to ensure Stfu implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Stfu)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (s *Stfu) RandTestMessage(t *rapid.T) Message { + m := &Stfu{ + ChanID: RandChannelID(t), + Initiator: rapid.Bool().Draw(t, "initiator"), + } + + extraData := RandExtraOpaqueData(t, nil) + if len(extraData) > 0 { + m.ExtraData = extraData + } + + return m +} + +// A compile time check to ensure UpdateAddHTLC implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*UpdateAddHTLC)(nil) + +// RandTestMessage returns an UpdateAddHTLC message populated with random data. +// +// This is part of the TestMessage interface. +func (c *UpdateAddHTLC) RandTestMessage(t *rapid.T) Message { + msg := &UpdateAddHTLC{ + ChanID: RandChannelID(t), + ID: rapid.Uint64().Draw(t, "id"), + Amount: MilliSatoshi(rapid.Uint64().Draw(t, "amount")), + Expiry: rapid.Uint32().Draw(t, "expiry"), + } + + hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "paymentHash") + copy(msg.PaymentHash[:], hashBytes) + + onionBytes := rapid.SliceOfN( + rapid.Byte(), OnionPacketSize, OnionPacketSize, + ).Draw(t, "onionBlob") + copy(msg.OnionBlob[:], onionBytes) + + numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") + if numRecords > 0 { + msg.CustomRecords, _ = RandCustomRecords(t, nil, true) + } + + // 50/50 chance to add a blinding point + if rapid.Bool().Draw(t, "includeBlindingPoint") { + pubKey := RandPubKey(t) + + msg.BlindingPoint = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[BlindingPointTlvType](pubKey), + ) + } + + return msg +} + +// A compile time check to ensure UpdateFailHTLC implements the TestMessage +// interface. +var _ TestMessage = (*UpdateFailHTLC)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *UpdateFailHTLC) RandTestMessage(t *rapid.T) Message { + return &UpdateFailHTLC{ + ChanID: RandChannelID(t), + ID: rapid.Uint64().Draw(t, "id"), + Reason: RandOpaqueReason(t), + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure UpdateFailMalformedHTLC implements the +// TestMessage interface. +var _ TestMessage = (*UpdateFailMalformedHTLC)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *UpdateFailMalformedHTLC) RandTestMessage(t *rapid.T) Message { + return &UpdateFailMalformedHTLC{ + ChanID: RandChannelID(t), + ID: rapid.Uint64().Draw(t, "id"), + ShaOnionBlob: RandSHA256Hash(t), + FailureCode: RandFailCode(t), + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure UpdateFee implements the TestMessage +// interface. +var _ TestMessage = (*UpdateFee)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *UpdateFee) RandTestMessage(t *rapid.T) Message { + return &UpdateFee{ + ChanID: RandChannelID(t), + FeePerKw: uint32(rapid.IntRange(1, 10000).Draw(t, "feePerKw")), + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure UpdateFulfillHTLC implements the TestMessage +// interface. +var _ TestMessage = (*UpdateFulfillHTLC)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *UpdateFulfillHTLC) RandTestMessage(t *rapid.T) Message { + msg := &UpdateFulfillHTLC{ + ChanID: RandChannelID(t), + ID: rapid.Uint64().Draw(t, "id"), + PaymentPreimage: RandPaymentPreimage(t), + } + + cr, ignoreRecords := RandCustomRecords(t, nil, true) + msg.CustomRecords = cr + + randData := RandExtraOpaqueData(t, ignoreRecords) + if len(randData) > 0 { + msg.ExtraData = randData + } + + return msg +} + +// A compile time check to ensure Warning implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Warning)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *Warning) RandTestMessage(t *rapid.T) Message { + msg := &Warning{ + ChanID: RandChannelID(t), + } + + useASCII := rapid.Bool().Draw(t, "useASCII") + if useASCII { + length := rapid.IntRange(1, 100).Draw(t, "warningDataLength") + data := make([]byte, length) + for i := 0; i < length; i++ { + data[i] = byte( + rapid.IntRange(32, 126).Draw( + t, fmt.Sprintf("warningDataByte-%d", i), + ), + ) + } + msg.Data = data + } else { + length := rapid.IntRange(1, 100).Draw(t, "warningDataLength") + msg.Data = rapid.SliceOfN(rapid.Byte(), length, length).Draw( + t, "warningData", + ) + } + + return msg +} + +// A compile time check to ensure Error implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Error)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *Error) RandTestMessage(t *rapid.T) Message { + msg := &Error{ + ChanID: RandChannelID(t), + } + + useASCII := rapid.Bool().Draw(t, "useASCII") + if useASCII { + length := rapid.IntRange(1, 100).Draw(t, "errorDataLength") + data := make([]byte, length) + for i := 0; i < length; i++ { + data[i] = byte( + rapid.IntRange(32, 126).Draw( + t, fmt.Sprintf("errorDataByte-%d", i), + ), + ) + } + msg.Data = data + } else { + // Generate random binary data + length := rapid.IntRange(1, 100).Draw(t, "errorDataLength") + msg.Data = rapid.SliceOfN( + rapid.Byte(), length, length, + ).Draw(t, "errorData") + } + + return msg +} diff --git a/lnwire/test_utils.go b/lnwire/test_utils.go new file mode 100644 index 000000000..1065cbacf --- /dev/null +++ b/lnwire/test_utils.go @@ -0,0 +1,360 @@ +package lnwire + +import ( + "crypto/sha256" + "fmt" + "net" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/stretchr/testify/require" + "pgregory.net/rapid" +) + +// RandChannelUpdate generates a random ChannelUpdate message using rapid's +// generators. +func RandPartialSig(t *rapid.T) *PartialSig { + // Generate random private key bytes + sigBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "privKeyBytes") + + var s btcec.ModNScalar + s.SetByteSlice(sigBytes) + + return &PartialSig{ + Sig: s, + } +} + +// RandPartialSigWithNonce generates a random PartialSigWithNonce using rapid +// generators. +func RandPartialSigWithNonce(t *rapid.T) *PartialSigWithNonce { + sigLen := rapid.IntRange(1, 65).Draw(t, "partialSigLen") + sigBytes := rapid.SliceOfN( + rapid.Byte(), sigLen, sigLen, + ).Draw(t, "partialSig") + + sigScalar := new(btcec.ModNScalar) + sigScalar.SetByteSlice(sigBytes) + + return NewPartialSigWithNonce( + RandMusig2Nonce(t), *sigScalar, + ) +} + +// RandPubKey generates a random public key using rapid's generators. +func RandPubKey(t *rapid.T) *btcec.PublicKey { + privKeyBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "privKeyBytes", + ) + _, pub := btcec.PrivKeyFromBytes(privKeyBytes) + + return pub +} + +// RandChannelID generates a random channel ID. +func RandChannelID(t *rapid.T) ChannelID { + var c ChannelID + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "channelID") + copy(c[:], bytes) + + return c +} + +// RandShortChannelID generates a random short channel ID. +func RandShortChannelID(t *rapid.T) ShortChannelID { + return NewShortChanIDFromInt( + uint64(rapid.IntRange(1, 100000).Draw(t, "shortChanID")), + ) +} + +// RandFeatureVector generates a random feature vector. +func RandFeatureVector(t *rapid.T) *RawFeatureVector { + featureVec := NewRawFeatureVector() + + // Add a random number of random feature bits + numFeatures := rapid.IntRange(0, 20).Draw(t, "numFeatures") + for i := 0; i < numFeatures; i++ { + bit := FeatureBit(rapid.IntRange(0, 100).Draw( + t, fmt.Sprintf("featureBit-%d", i)), + ) + featureVec.Set(bit) + } + + return featureVec +} + +// RandSignature generates a signature for testing. +func RandSignature(t *rapid.T) Sig { + testRScalar := new(btcec.ModNScalar) + testSScalar := new(btcec.ModNScalar) + + // Generate random bytes for R and S + rBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "rBytes") + sBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "sBytes") + _ = testRScalar.SetByteSlice(rBytes) + _ = testSScalar.SetByteSlice(sBytes) + + testSig := ecdsa.NewSignature(testRScalar, testSScalar) + + sig, err := NewSigFromSignature(testSig) + if err != nil { + panic(fmt.Sprintf("unable to create signature: %v", err)) + } + + return sig +} + +// RandPaymentHash generates a random payment hash. +func RandPaymentHash(t *rapid.T) [32]byte { + var hash [32]byte + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "paymentHash") + copy(hash[:], bytes) + + return hash +} + +// RandPaymentPreimage generates a random payment preimage. +func RandPaymentPreimage(t *rapid.T) [32]byte { + var preimage [32]byte + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "preimage") + copy(preimage[:], bytes) + + return preimage +} + +// RandChainHash generates a random chain hash. +func RandChainHash(t *rapid.T) chainhash.Hash { + var hash [32]byte + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash") + copy(hash[:], bytes) + + return hash +} + +// RandNodeAlias generates a random node alias. +func RandNodeAlias(t *rapid.T) NodeAlias { + var alias NodeAlias + aliasLength := rapid.IntRange(0, 32).Draw(t, "aliasLength") + + aliasBytes := rapid.StringN( + 0, aliasLength, aliasLength, + ).Draw(t, "alias") + + copy(alias[:], aliasBytes) + + return alias +} + +// RandNetAddrs generates random network addresses. +func RandNetAddrs(t *rapid.T) []net.Addr { + numAddresses := rapid.IntRange(0, 5).Draw(t, "numAddresses") + if numAddresses == 0 { + return nil + } + + addresses := make([]net.Addr, numAddresses) + for i := 0; i < numAddresses; i++ { + addressType := rapid.IntRange(0, 1).Draw( + t, fmt.Sprintf("addressType-%d", i), + ) + + switch addressType { + // IPv4. + case 0: + ipBytes := rapid.SliceOfN(rapid.Byte(), 4, 4).Draw( + t, fmt.Sprintf("ipv4-%d", i), + ) + port := rapid.IntRange(1, 65535).Draw( + t, fmt.Sprintf("port-%d", i), + ) + addresses[i] = &net.TCPAddr{ + IP: ipBytes, + Port: port, + } + + // IPv6. + case 1: + ipBytes := rapid.SliceOfN(rapid.Byte(), 16, 16).Draw( + t, fmt.Sprintf("ipv6-%d", i), + ) + port := rapid.IntRange(1, 65535).Draw( + t, fmt.Sprintf("port-%d", i), + ) + addresses[i] = &net.TCPAddr{ + IP: ipBytes, + Port: port, + } + } + } + + return addresses +} + +// RandCustomRecords generates random custom TLV records. +func RandCustomRecords(t *rapid.T, + ignoreRecords fn.Set[uint64], + custom bool) (CustomRecords, fn.Set[uint64]) { + + numRecords := rapid.IntRange(0, 5).Draw(t, "numCustomRecords") + customRecords := make(CustomRecords) + + if numRecords == 0 { + return nil, nil + } + + rangeStart := 0 + rangeStop := int(CustomTypeStart) + if custom { + rangeStart = 70_000 + rangeStop = 100_000 + } + + ignoreSet := fn.NewSet[uint64]() + for i := 0; i < numRecords; i++ { + recordType := uint64( + rapid.IntRange(rangeStart, rangeStop). + Filter(func(i int) bool { + return !ignoreRecords.Contains( + uint64(i), + ) + }). + Draw( + t, fmt.Sprintf("recordType-%d", i), + ), + ) + recordLen := rapid.IntRange(4, 64).Draw( + t, fmt.Sprintf("recordLen-%d", i), + ) + record := rapid.SliceOfN( + rapid.Byte(), recordLen, recordLen, + ).Draw(t, fmt.Sprintf("record-%d", i)) + + customRecords[recordType] = record + + ignoreSet.Add(recordType) + } + + return customRecords, ignoreSet +} + +// RandMusig2Nonce generates a random musig2 nonce. +func RandMusig2Nonce(t *rapid.T) Musig2Nonce { + var nonce Musig2Nonce + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "nonce") + copy(nonce[:], bytes) + + return nonce +} + +// RandExtraOpaqueData generates random extra opaque data. +func RandExtraOpaqueData(t *rapid.T, + ignoreRecords fn.Set[uint64]) ExtraOpaqueData { + + // Make some random records. + cRecords, _ := RandCustomRecords(t, ignoreRecords, false) + if cRecords == nil { + return ExtraOpaqueData{} + } + + // Encode those records as opaque data. + recordBytes, err := cRecords.Serialize() + require.NoError(t, err) + + return ExtraOpaqueData(recordBytes) +} + +// RandOpaqueReason generates a random opaque reason for HTLC failures. +func RandOpaqueReason(t *rapid.T) OpaqueReason { + reasonLen := rapid.IntRange(32, 300).Draw(t, "reasonLen") + return rapid.SliceOfN(rapid.Byte(), reasonLen, reasonLen).Draw( + t, "opaqueReason", + ) +} + +// RandFailCode generates a random HTLC failure code. +func RandFailCode(t *rapid.T) FailCode { + // List of known failure codes to choose from Using only the documented + // codes. + validCodes := []FailCode{ + CodeInvalidRealm, + CodeTemporaryNodeFailure, + CodePermanentNodeFailure, + CodeRequiredNodeFeatureMissing, + CodePermanentChannelFailure, + CodeRequiredChannelFeatureMissing, + CodeUnknownNextPeer, + CodeIncorrectOrUnknownPaymentDetails, + CodeIncorrectPaymentAmount, + CodeFinalExpiryTooSoon, + CodeInvalidOnionVersion, + CodeInvalidOnionHmac, + CodeInvalidOnionKey, + CodeTemporaryChannelFailure, + CodeChannelDisabled, + CodeExpiryTooSoon, + CodeMPPTimeout, + CodeInvalidOnionPayload, + CodeFeeInsufficient, + } + + // Choose a random code from the list. + idx := rapid.IntRange(0, len(validCodes)-1).Draw(t, "failCodeIndex") + + return validCodes[idx] +} + +// RandSHA256Hash generates a random SHA256 hash. +func RandSHA256Hash(t *rapid.T) [sha256.Size]byte { + var hash [sha256.Size]byte + bytes := rapid.SliceOfN(rapid.Byte(), sha256.Size, sha256.Size).Draw( + t, "sha256Hash", + ) + copy(hash[:], bytes) + + return hash +} + +// RandDeliveryAddress generates a random delivery address (script). +func RandDeliveryAddress(t *rapid.T) DeliveryAddress { + addrLen := rapid.IntRange(1, 34).Draw(t, "addrLen") + + return rapid.SliceOfN(rapid.Byte(), addrLen, addrLen).Draw( + t, "deliveryAddress", + ) +} + +// RandChannelType generates a random channel type. +func RandChannelType(t *rapid.T) *ChannelType { + vec := RandFeatureVector(t) + chanType := ChannelType(*vec) + + return &chanType +} + +// RandLeaseExpiry generates a random lease expiry. +func RandLeaseExpiry(t *rapid.T) *LeaseExpiry { + exp := LeaseExpiry( + uint32(rapid.IntRange(1000, 1000000).Draw(t, "leaseExpiry")), + ) + + return &exp +} + +// RandOutPoint generates a random transaction outpoint. +func RandOutPoint(t *rapid.T) wire.OutPoint { + // Generate a random transaction ID + var txid chainhash.Hash + txidBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "txid") + copy(txid[:], txidBytes) + + // Generate a random output index + vout := uint32(rapid.IntRange(0, 10).Draw(t, "vout")) + + return wire.OutPoint{ + Hash: txid, + Index: vout, + } +} diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 4873cd84b..e627dbf4e 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -212,3 +212,14 @@ func (c *UpdateAddHTLC) MsgType() MessageType { func (c *UpdateAddHTLC) TargetChanID() ChannelID { return c.ChanID } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateAddHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + +// A compile time check to ensure UpdateAddHTLC implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateAddHTLC)(nil) diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 8cd9c7687..1d26444ba 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -38,6 +38,10 @@ type UpdateFailHTLC struct { // interface. var _ Message = (*UpdateFailHTLC)(nil) +// A compile time check to ensure UpdateFailHTLC implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFailHTLC)(nil) + // Decode deserializes a serialized UpdateFailHTLC message stored in the passed // io.Reader observing the specified protocol version. // @@ -51,8 +55,8 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { ) } -// Encode serializes the target UpdateFailHTLC into the passed io.Writer observing -// the protocol version specified. +// Encode serializes the target UpdateFailHTLC into the passed io.Writer +// observing the protocol version specified. // // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error { @@ -79,6 +83,13 @@ func (c *UpdateFailHTLC) MsgType() MessageType { return MsgUpdateFailHTLC } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFailHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go index f28107a9a..bc9bd9aba 100644 --- a/lnwire/update_fail_malformed_htlc.go +++ b/lnwire/update_fail_malformed_htlc.go @@ -36,6 +36,10 @@ type UpdateFailMalformedHTLC struct { // lnwire.Message interface. var _ Message = (*UpdateFailMalformedHTLC)(nil) +// A compile time check to ensure UpdateFailMalformedHTLC implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFailMalformedHTLC)(nil) + // Decode deserializes a serialized UpdateFailMalformedHTLC message stored in the passed // io.Reader observing the specified protocol version. // @@ -84,6 +88,13 @@ func (c *UpdateFailMalformedHTLC) MsgType() MessageType { return MsgUpdateFailMalformedHTLC } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFailMalformedHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go index a7026044c..f32a7385a 100644 --- a/lnwire/update_fee.go +++ b/lnwire/update_fee.go @@ -36,6 +36,10 @@ func NewUpdateFee(chanID ChannelID, feePerKw uint32) *UpdateFee { // interface. var _ Message = (*UpdateFee)(nil) +// A compile time check to ensure UpdateFee implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFee)(nil) + // Decode deserializes a serialized UpdateFee message stored in the passed // io.Reader observing the specified protocol version. // @@ -72,6 +76,13 @@ func (c *UpdateFee) MsgType() MessageType { return MsgUpdateFee } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFee) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 35aaa2ff5..9ec747406 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -44,12 +44,16 @@ func NewUpdateFulfillHTLC(chanID ChannelID, id uint64, } } -// A compile time check to ensure UpdateFulfillHTLC implements the lnwire.Message -// interface. +// A compile time check to ensure UpdateFulfillHTLC implements the +// lnwire.Message interface. var _ Message = (*UpdateFulfillHTLC)(nil) -// Decode deserializes a serialized UpdateFulfillHTLC message stored in the passed -// io.Reader observing the specified protocol version. +// A compile time check to ensure UpdateFulfillHTLC implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFulfillHTLC)(nil) + +// Decode deserializes a serialized UpdateFulfillHTLC message stored in the +// passed io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { @@ -115,6 +119,13 @@ func (c *UpdateFulfillHTLC) MsgType() MessageType { return MsgUpdateFulfillHTLC } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFulfillHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/warning.go b/lnwire/warning.go index adb595cc0..2a0b3df19 100644 --- a/lnwire/warning.go +++ b/lnwire/warning.go @@ -29,6 +29,10 @@ type Warning struct { // interface. var _ Message = (*Warning)(nil) +// A compile time check to ensure Warning implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Warning)(nil) + // NewWarning creates a new Warning message. func NewWarning() *Warning { return &Warning{} @@ -74,3 +78,10 @@ func (c *Warning) Encode(w *bytes.Buffer, _ uint32) error { func (c *Warning) MsgType() MessageType { return MsgWarning } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *Warning) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +}