From 56a100123b543707cf40af1c3f35cb0cc17626c2 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 18 Mar 2025 18:55:16 -0500 Subject: [PATCH 1/4] lnwire: add new SerializedSize method to all wire messages This'll be useful for the bandwidth based rate limiting we'll implement in the next commit. --- lnwire/accept_channel.go | 11 +++++++++ lnwire/announcement_signatures.go | 11 +++++++++ lnwire/announcement_signatures_2.go | 11 +++++++++ lnwire/channel_announcement.go | 11 +++++++++ lnwire/channel_announcement_2.go | 11 +++++++++ lnwire/channel_ready.go | 11 +++++++++ lnwire/channel_reestablish.go | 11 +++++++++ lnwire/channel_update.go | 10 ++++++++ lnwire/channel_update_2.go | 7 ++++++ lnwire/closing_complete.go | 10 ++++++++ lnwire/closing_sig.go | 11 +++++++++ lnwire/closing_signed.go | 10 ++++++++ lnwire/commit_sig.go | 10 ++++++++ lnwire/custom.go | 11 +++++++++ lnwire/dyn_ack.go | 11 +++++++++ lnwire/dyn_propose.go | 11 +++++++++ lnwire/dyn_reject.go | 11 +++++++++ lnwire/error.go | 11 +++++++++ lnwire/funding_created.go | 11 +++++++++ lnwire/funding_signed.go | 11 +++++++++ lnwire/gossip_timestamp_range.go | 10 ++++++++ lnwire/init_message.go | 11 +++++++++ lnwire/message.go | 35 ++++++++++++++++++++++++++++ lnwire/node_announcement.go | 10 ++++++++ lnwire/open_channel.go | 11 +++++++++ lnwire/ping.go | 11 +++++++++ lnwire/pong.go | 10 ++++++++ lnwire/query_channel_range.go | 12 ++++++++++ lnwire/query_short_chan_ids.go | 11 +++++++++ lnwire/reply_channel_range.go | 10 ++++++++ lnwire/reply_short_chan_ids_end.go | 10 ++++++++ lnwire/revoke_and_ack.go | 10 ++++++++ lnwire/shutdown.go | 10 ++++++++ lnwire/stfu.go | 10 ++++++++ lnwire/update_add_htlc.go | 11 +++++++++ lnwire/update_fail_htlc.go | 11 +++++++++ lnwire/update_fail_malformed_htlc.go | 11 +++++++++ lnwire/update_fee.go | 11 +++++++++ lnwire/update_fulfill_htlc.go | 19 +++++++++++---- lnwire/warning.go | 11 +++++++++ 40 files changed, 453 insertions(+), 4 deletions(-) diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index e45bcf9ab..aace5d536 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -128,6 +128,10 @@ type AcceptChannel struct { // interface. var _ Message = (*AcceptChannel)(nil) +// A compile time check to ensure AcceptChannel implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*AcceptChannel)(nil) + // Encode serializes the target AcceptChannel into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -281,3 +285,10 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { func (a *AcceptChannel) MsgType() MessageType { return MsgAcceptChannel } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *AcceptChannel) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/announcement_signatures.go b/lnwire/announcement_signatures.go index a2bc21f39..cf8f68be5 100644 --- a/lnwire/announcement_signatures.go +++ b/lnwire/announcement_signatures.go @@ -47,6 +47,10 @@ type AnnounceSignatures1 struct { // lnwire.Message interface. var _ Message = (*AnnounceSignatures1)(nil) +// A compile time check to ensure AnnounceSignatures1 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*AnnounceSignatures1)(nil) + // A compile time check to ensure AnnounceSignatures1 implements the // lnwire.AnnounceSignatures interface. var _ AnnounceSignatures = (*AnnounceSignatures1)(nil) @@ -97,6 +101,13 @@ func (a *AnnounceSignatures1) MsgType() MessageType { return MsgAnnounceSignatures } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *AnnounceSignatures1) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} + // SCID returns the ShortChannelID of the channel. // // This is part of the lnwire.AnnounceSignatures interface. diff --git a/lnwire/announcement_signatures_2.go b/lnwire/announcement_signatures_2.go index a10447032..6e893dafd 100644 --- a/lnwire/announcement_signatures_2.go +++ b/lnwire/announcement_signatures_2.go @@ -41,6 +41,10 @@ type AnnounceSignatures2 struct { // lnwire.Message interface. var _ Message = (*AnnounceSignatures2)(nil) +// A compile time check to ensure AnnounceSignatures2 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*AnnounceSignatures2)(nil) + // Decode deserializes a serialized AnnounceSignatures2 stored in the passed // io.Reader observing the specified protocol version. // @@ -82,6 +86,13 @@ func (a *AnnounceSignatures2) MsgType() MessageType { return MsgAnnounceSignatures2 } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *AnnounceSignatures2) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} + // SCID returns the ShortChannelID of the channel. // // NOTE: this is part of the AnnounceSignatures interface. diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index ed4c5b97e..0a3989abb 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -62,6 +62,10 @@ type ChannelAnnouncement1 struct { // lnwire.Message interface. var _ Message = (*ChannelAnnouncement1)(nil) +// A compile time check to ensure ChannelAnnouncement1 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelAnnouncement1)(nil) + // Decode deserializes a serialized ChannelAnnouncement stored in the passed // io.Reader observing the specified protocol version. // @@ -143,6 +147,13 @@ func (a *ChannelAnnouncement1) MsgType() MessageType { return MsgChannelAnnouncement } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *ChannelAnnouncement1) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} + // DataToSign is used to retrieve part of the announcement message which should // be signed. func (a *ChannelAnnouncement1) DataToSign() ([]byte, error) { diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 074e7d084..57b3a24b8 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -194,10 +194,21 @@ func (c *ChannelAnnouncement2) MsgType() MessageType { return MsgChannelAnnouncement2 } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ChannelAnnouncement2) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // A compile time check to ensure ChannelAnnouncement2 implements the // lnwire.Message interface. var _ Message = (*ChannelAnnouncement2)(nil) +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelAnnouncement2)(nil) + // Node1KeyBytes returns the bytes representing the public key of node 1 in the // channel. // diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index 912a068bd..f388db1d1 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -63,6 +63,10 @@ func NewChannelReady(cid ChannelID, npcp *btcec.PublicKey) *ChannelReady { // interface. var _ Message = (*ChannelReady)(nil) +// A compile time check to ensure ChannelReady implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelReady)(nil) + // Decode deserializes the serialized ChannelReady message stored in the // passed io.Reader into the target ChannelReady using the deserialization // rules defined by the passed protocol version. @@ -170,3 +174,10 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { func (c *ChannelReady) MsgType() MessageType { return MsgChannelReady } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ChannelReady) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index 577379623..f26a2fc5d 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -99,6 +99,10 @@ type ChannelReestablish struct { // lnwire.Message interface. var _ Message = (*ChannelReestablish)(nil) +// A compile time check to ensure ChannelReestablish implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelReestablish)(nil) + // Encode serializes the target ChannelReestablish into the passed io.Writer // observing the protocol version specified. // @@ -234,3 +238,10 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { func (a *ChannelReestablish) MsgType() MessageType { return MsgChannelReestablish } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *ChannelReestablish) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 33cc3bff0..1c7f0eed2 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -124,6 +124,9 @@ type ChannelUpdate1 struct { // interface. var _ Message = (*ChannelUpdate1)(nil) +// A compile time check to ensure ChannelUpdate1 implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelUpdate1)(nil) + // Decode deserializes a serialized ChannelUpdate stored in the passed // io.Reader observing the specified protocol version. // @@ -367,3 +370,10 @@ func (a *ChannelUpdate1) SetSCID(scid ShortChannelID) { // A compile time assertion to ensure ChannelUpdate1 implements the // ChannelUpdate interface. var _ ChannelUpdate = (*ChannelUpdate1)(nil) + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *ChannelUpdate1) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index b41f2f29f..56f7edf6b 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -241,6 +241,13 @@ func (c *ChannelUpdate2) MsgType() MessageType { return MsgChannelUpdate2 } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ChannelUpdate2) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + func (c *ChannelUpdate2) ExtraData() ExtraOpaqueData { return c.ExtraOpaqueData } diff --git a/lnwire/closing_complete.go b/lnwire/closing_complete.go index 4d390fd6f..14784a3c1 100644 --- a/lnwire/closing_complete.go +++ b/lnwire/closing_complete.go @@ -169,6 +169,16 @@ func (c *ClosingComplete) MsgType() MessageType { return MsgClosingComplete } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ClosingComplete) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // A compile time check to ensure ClosingComplete implements the lnwire.Message // interface. var _ Message = (*ClosingComplete)(nil) + +// A compile time check to ensure ClosingComplete implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ClosingComplete)(nil) diff --git a/lnwire/closing_sig.go b/lnwire/closing_sig.go index 2c73fa720..1eeec2580 100644 --- a/lnwire/closing_sig.go +++ b/lnwire/closing_sig.go @@ -107,6 +107,17 @@ func (c *ClosingSig) MsgType() MessageType { return MsgClosingSig } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ClosingSig) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // A compile time check to ensure ClosingSig implements the lnwire.Message // interface. var _ Message = (*ClosingSig)(nil) + +// A compile time check to ensure ClosingSig implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*ClosingSig)(nil) diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 08b5bb6a7..a82dbf402 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -59,6 +59,9 @@ func NewClosingSigned(cid ChannelID, fs btcutil.Amount, // interface. var _ Message = (*ClosingSigned)(nil) +// A compile time check to ensure ClosingSigned implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ClosingSigned)(nil) + // Decode deserializes a serialized ClosingSigned message stored in the passed // io.Reader observing the specified protocol version. // @@ -130,3 +133,10 @@ func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error { func (c *ClosingSigned) MsgType() MessageType { return MsgClosingSigned } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ClosingSigned) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 3a475e71f..7c5a41ccc 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -64,6 +64,9 @@ func NewCommitSig() *CommitSig { // interface. var _ Message = (*CommitSig)(nil) +// A compile time check to ensure CommitSig implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*CommitSig)(nil) + // Decode deserializes a serialized CommitSig message stored in the // passed io.Reader observing the specified protocol version. // @@ -151,3 +154,10 @@ func (c *CommitSig) MsgType() MessageType { func (c *CommitSig) TargetChanID() ChannelID { return c.ChanID } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *CommitSig) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/custom.go b/lnwire/custom.go index 232a8be52..e8c299297 100644 --- a/lnwire/custom.go +++ b/lnwire/custom.go @@ -73,6 +73,10 @@ type Custom struct { // interface. var _ Message = (*Custom)(nil) +// A compile time check to ensure Custom implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Custom)(nil) + // NewCustom instantiates a new custom message. func NewCustom(msgType MessageType, data []byte) (*Custom, error) { if msgType < CustomTypeStart && !IsCustomOverride(msgType) { @@ -117,3 +121,10 @@ func (c *Custom) Decode(r io.Reader, pver uint32) error { func (c *Custom) MsgType() MessageType { return c.Type } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *Custom) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/dyn_ack.go b/lnwire/dyn_ack.go index d477461e7..1cc57e955 100644 --- a/lnwire/dyn_ack.go +++ b/lnwire/dyn_ack.go @@ -41,6 +41,10 @@ type DynAck struct { // interface. var _ Message = (*DynAck)(nil) +// A compile time check to ensure DynAck implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*DynAck)(nil) + // Encode serializes the target DynAck into the passed io.Writer. Serialization // will observe the rules defined by the passed protocol version. // @@ -136,3 +140,10 @@ func (da *DynAck) Decode(r io.Reader, _ uint32) error { func (da *DynAck) MsgType() MessageType { return MsgDynAck } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (da *DynAck) SerializedSize() (uint32, error) { + return MessageSerializedSize(da) +} diff --git a/lnwire/dyn_propose.go b/lnwire/dyn_propose.go index 394fff6f3..cc19ec394 100644 --- a/lnwire/dyn_propose.go +++ b/lnwire/dyn_propose.go @@ -105,6 +105,10 @@ type DynPropose struct { // interface. var _ Message = (*DynPropose)(nil) +// A compile time check to ensure DynPropose implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*DynPropose)(nil) + // Encode serializes the target DynPropose into the passed io.Writer. // Serialization will observe the rules defined by the passed protocol version. // @@ -317,3 +321,10 @@ func (dp *DynPropose) Decode(r io.Reader, _ uint32) error { func (dp *DynPropose) MsgType() MessageType { return MsgDynPropose } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (dp *DynPropose) SerializedSize() (uint32, error) { + return MessageSerializedSize(dp) +} diff --git a/lnwire/dyn_reject.go b/lnwire/dyn_reject.go index 2c6484424..51cb8cee2 100644 --- a/lnwire/dyn_reject.go +++ b/lnwire/dyn_reject.go @@ -30,6 +30,10 @@ type DynReject struct { // interface. var _ Message = (*DynReject)(nil) +// A compile time check to ensure DynReject implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*DynReject)(nil) + // Encode serializes the target DynReject into the passed io.Writer. // Serialization will observe the rules defined by the passed protocol version. // @@ -74,3 +78,10 @@ func (dr *DynReject) Decode(r io.Reader, _ uint32) error { func (dr *DynReject) MsgType() MessageType { return MsgDynReject } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (dr *DynReject) SerializedSize() (uint32, error) { + return MessageSerializedSize(dr) +} diff --git a/lnwire/error.go b/lnwire/error.go index 120dfc950..fcc22ef24 100644 --- a/lnwire/error.go +++ b/lnwire/error.go @@ -70,6 +70,10 @@ func NewError() *Error { // interface. var _ Message = (*Error)(nil) +// A compile time check to ensure Error implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Error)(nil) + // Error returns the string representation to Error. // // NOTE: Satisfies the error interface. @@ -113,6 +117,13 @@ func (c *Error) MsgType() MessageType { return MsgError } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *Error) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // isASCII is a helper method that checks whether all bytes in `data` would be // printable ASCII characters if interpreted as a string. func isASCII(data []byte) bool { diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index 86aa0bb40..82d0ff87c 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -44,6 +44,10 @@ type FundingCreated struct { // interface. var _ Message = (*FundingCreated)(nil) +// A compile time check to ensure FundingCreated implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*FundingCreated)(nil) + // Encode serializes the target FundingCreated into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -117,3 +121,10 @@ func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { func (f *FundingCreated) MsgType() MessageType { return MsgFundingCreated } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (f *FundingCreated) SerializedSize() (uint32, error) { + return MessageSerializedSize(f) +} diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index 2dd62e177..a7f23310a 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -36,6 +36,10 @@ type FundingSigned struct { // interface. var _ Message = (*FundingSigned)(nil) +// A compile time check to ensure FundingSigned implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*FundingSigned)(nil) + // Encode serializes the target FundingSigned into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -103,3 +107,10 @@ func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { func (f *FundingSigned) MsgType() MessageType { return MsgFundingSigned } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (f *FundingSigned) SerializedSize() (uint32, error) { + return MessageSerializedSize(f) +} diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index 7b628752a..25c2a033b 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -58,6 +58,9 @@ func NewGossipTimestampRange() *GossipTimestampRange { // lnwire.Message interface. var _ Message = (*GossipTimestampRange)(nil) +// A compile time check to ensure GossipTimestampRange implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*GossipTimestampRange)(nil) + // Decode deserializes a serialized GossipTimestampRange message stored in the // passed io.Reader observing the specified protocol version. // @@ -143,3 +146,10 @@ func (g *GossipTimestampRange) Encode(w *bytes.Buffer, pver uint32) error { func (g *GossipTimestampRange) MsgType() MessageType { return MsgGossipTimestampRange } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (g *GossipTimestampRange) SerializedSize() (uint32, error) { + return MessageSerializedSize(g) +} diff --git a/lnwire/init_message.go b/lnwire/init_message.go index dbbddea2b..b88891b08 100644 --- a/lnwire/init_message.go +++ b/lnwire/init_message.go @@ -43,6 +43,10 @@ func NewInitMessage(gf *RawFeatureVector, f *RawFeatureVector) *Init { // interface. var _ Message = (*Init)(nil) +// A compile time check to ensure Init implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Init)(nil) + // Decode deserializes a serialized Init message stored in the passed // io.Reader observing the specified protocol version. // @@ -78,3 +82,10 @@ func (msg *Init) Encode(w *bytes.Buffer, pver uint32) error { func (msg *Init) MsgType() MessageType { return MsgInit } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (msg *Init) SerializedSize() (uint32, error) { + return MessageSerializedSize(msg) +} diff --git a/lnwire/message.go b/lnwire/message.go index 68b09692e..8c310bd78 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -12,6 +12,10 @@ import ( "io" ) +// MessageTypeSize is the size in bytes of the message type field in the header +// of all messages. +const MessageTypeSize = 2 + // MessageType is the unique 2 byte big-endian integer that indicates the type // of message on the wire. All messages have a very simple header which // consists simply of 2-byte message type. We omit a length field, and checksum @@ -234,6 +238,31 @@ type LinkUpdater interface { TargetChanID() ChannelID } +// SizeableMessage is an interface that extends the base Message interface with +// a method to calculate the serialized size of a message. +type SizeableMessage interface { + Message + + // SerializedSize returns the serialized size of the message in bytes. + // The returned size includes the message type header bytes. + SerializedSize() (uint32, error) +} + +// MessageSerializedSize calculates the serialized size of a message in bytes. +// This is a helper function that can be used by all message types to implement +// the SerializedSize method. +func MessageSerializedSize(msg Message) (uint32, error) { + var buf bytes.Buffer + + // Encode the message to the buffer. + if err := msg.Encode(&buf, 0); err != nil { + return 0, err + } + + // Add the size of the message type. + return uint32(buf.Len()) + MessageTypeSize, nil +} + // makeEmptyMessage creates a new empty message of the proper concrete type // based on the passed message type. func makeEmptyMessage(msgType MessageType) (Message, error) { @@ -337,6 +366,12 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { return msg, nil } +// MakeEmptyMessage creates a new empty message of the proper concrete type +// based on the passed message type. This is exported to be used in tests. +func MakeEmptyMessage(msgType MessageType) (Message, error) { + return makeEmptyMessage(msgType) +} + // WriteMessage writes a lightning Message to a buffer including the necessary // header information and returns the number of bytes written. If any error is // encountered, the buffer passed will be reset to its original state since we diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index 4f1620281..ae883e886 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -104,6 +104,9 @@ type NodeAnnouncement struct { // lnwire.Message interface. var _ Message = (*NodeAnnouncement)(nil) +// A compile time check to ensure NodeAnnouncement implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*NodeAnnouncement)(nil) + // Decode deserializes a serialized NodeAnnouncement stored in the passed // io.Reader observing the specified protocol version. // @@ -202,3 +205,10 @@ func (a *NodeAnnouncement) DataToSign() ([]byte, error) { return buf.Bytes(), nil } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *NodeAnnouncement) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 9694290f7..217ddbea0 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -164,6 +164,10 @@ type OpenChannel struct { // interface. var _ Message = (*OpenChannel)(nil) +// A compile time check to ensure OpenChannel implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*OpenChannel)(nil) + // Encode serializes the target OpenChannel into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -335,3 +339,10 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { func (o *OpenChannel) MsgType() MessageType { return MsgOpenChannel } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (o *OpenChannel) SerializedSize() (uint32, error) { + return MessageSerializedSize(o) +} diff --git a/lnwire/ping.go b/lnwire/ping.go index a21f2fa8b..230187b84 100644 --- a/lnwire/ping.go +++ b/lnwire/ping.go @@ -33,6 +33,10 @@ func NewPing(numBytes uint16) *Ping { // A compile time check to ensure Ping implements the lnwire.Message interface. var _ Message = (*Ping)(nil) +// A compile time check to ensure Ping implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Ping)(nil) + // Decode deserializes a serialized Ping message stored in the passed io.Reader // observing the specified protocol version. // @@ -69,3 +73,10 @@ func (p *Ping) Encode(w *bytes.Buffer, pver uint32) error { func (p *Ping) MsgType() MessageType { return MsgPing } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (p *Ping) SerializedSize() (uint32, error) { + return MessageSerializedSize(p) +} diff --git a/lnwire/pong.go b/lnwire/pong.go index 3ab80d70f..c33e904e1 100644 --- a/lnwire/pong.go +++ b/lnwire/pong.go @@ -39,6 +39,9 @@ func NewPong(pongBytes []byte) *Pong { // A compile time check to ensure Pong implements the lnwire.Message interface. var _ Message = (*Pong)(nil) +// A compile time check to ensure Pong implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*Pong)(nil) + // Decode deserializes a serialized Pong message stored in the passed io.Reader // observing the specified protocol version. // @@ -64,3 +67,10 @@ func (p *Pong) Encode(w *bytes.Buffer, pver uint32) error { func (p *Pong) MsgType() MessageType { return MsgPong } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (p *Pong) SerializedSize() (uint32, error) { + return MessageSerializedSize(p) +} diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index 1e0dcb0fa..90b144bda 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -49,6 +49,10 @@ func NewQueryChannelRange() *QueryChannelRange { // lnwire.Message interface. var _ Message = (*QueryChannelRange)(nil) +// A compile time check to ensure QueryChannelRange implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*QueryChannelRange)(nil) + // Decode deserializes a serialized QueryChannelRange message stored in the // passed io.Reader observing the specified protocol version. // @@ -121,6 +125,14 @@ func (q *QueryChannelRange) MsgType() MessageType { return MsgQueryChannelRange } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (q *QueryChannelRange) SerializedSize() (uint32, error) { + msgCpy := *q + return MessageSerializedSize(&msgCpy) +} + // LastBlockHeight returns the last block height covered by the range of a // QueryChannelRange message. func (q *QueryChannelRange) LastBlockHeight() uint32 { diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index cb1bfa22c..f12d07abf 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -91,6 +91,10 @@ func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding, // lnwire.Message interface. var _ Message = (*QueryShortChanIDs)(nil) +// A compile time check to ensure QueryShortChanIDs implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*QueryShortChanIDs)(nil) + // Decode deserializes a serialized QueryShortChanIDs message stored in the // passed io.Reader observing the specified protocol version. // @@ -427,3 +431,10 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding, func (q *QueryShortChanIDs) MsgType() MessageType { return MsgQueryShortChanIDs } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (q *QueryShortChanIDs) SerializedSize() (uint32, error) { + return MessageSerializedSize(q) +} diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 591fc2bd6..a3b11c53e 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -70,6 +70,9 @@ func NewReplyChannelRange() *ReplyChannelRange { // lnwire.Message interface. var _ Message = (*ReplyChannelRange)(nil) +// A compile time check to ensure ReplyChannelRange implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ReplyChannelRange)(nil) + // Decode deserializes a serialized ReplyChannelRange message stored in the // passed io.Reader observing the specified protocol version. // @@ -223,3 +226,10 @@ func (c *ReplyChannelRange) LastBlockHeight() uint32 { } return uint32(lastBlockHeight) } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ReplyChannelRange) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go index 53676f719..30660a9cf 100644 --- a/lnwire/reply_short_chan_ids_end.go +++ b/lnwire/reply_short_chan_ids_end.go @@ -39,6 +39,9 @@ func NewReplyShortChanIDsEnd() *ReplyShortChanIDsEnd { // lnwire.Message interface. var _ Message = (*ReplyShortChanIDsEnd)(nil) +// A compile time check to ensure ReplyShortChanIDsEnd implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ReplyShortChanIDsEnd)(nil) + // Decode deserializes a serialized ReplyShortChanIDsEnd message stored in the // passed io.Reader observing the specified protocol version. // @@ -74,3 +77,10 @@ func (c *ReplyShortChanIDsEnd) Encode(w *bytes.Buffer, pver uint32) error { func (c *ReplyShortChanIDsEnd) MsgType() MessageType { return MsgReplyShortChanIDsEnd } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ReplyShortChanIDsEnd) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index 9dca1631a..aa70b0714 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -55,6 +55,9 @@ func NewRevokeAndAck() *RevokeAndAck { // interface. var _ Message = (*RevokeAndAck)(nil) +// A compile time check to ensure RevokeAndAck implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*RevokeAndAck)(nil) + // Decode deserializes a serialized RevokeAndAck message stored in the // passed io.Reader observing the specified protocol version. // @@ -136,3 +139,10 @@ func (c *RevokeAndAck) MsgType() MessageType { func (c *RevokeAndAck) TargetChanID() ChannelID { return c.ChanID } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *RevokeAndAck) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index b9899fcfb..9ac6eb131 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -61,6 +61,9 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown { // interface. var _ Message = (*Shutdown)(nil) +// A compile-time check to ensure Shutdown implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*Shutdown)(nil) + // Decode deserializes a serialized Shutdown from the passed io.Reader, // observing the specified protocol version. // @@ -133,3 +136,10 @@ func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { func (s *Shutdown) MsgType() MessageType { return MsgShutdown } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (s *Shutdown) SerializedSize() (uint32, error) { + return MessageSerializedSize(s) +} diff --git a/lnwire/stfu.go b/lnwire/stfu.go index a052a517a..f923c94b8 100644 --- a/lnwire/stfu.go +++ b/lnwire/stfu.go @@ -24,6 +24,9 @@ type Stfu struct { // A compile time check to ensure Stfu implements the lnwire.Message interface. var _ Message = (*Stfu)(nil) +// A compile time check to ensure Stfu implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*Stfu)(nil) + // Encode serializes the target Stfu into the passed io.Writer. // Serialization will observe the rules defined by the passed protocol version. // @@ -68,6 +71,13 @@ func (s *Stfu) MsgType() MessageType { return MsgStfu } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (s *Stfu) SerializedSize() (uint32, error) { + return MessageSerializedSize(s) +} + // A compile time check to ensure Stfu implements the // lnwire.LinkUpdater interface. var _ LinkUpdater = (*Stfu)(nil) diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 4873cd84b..7976a13c5 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -110,6 +110,10 @@ func NewUpdateAddHTLC() *UpdateAddHTLC { // interface. var _ Message = (*UpdateAddHTLC)(nil) +// A compile time check to ensure UpdateAddHTLC implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*UpdateAddHTLC)(nil) + // Decode deserializes a serialized UpdateAddHTLC message stored in the passed // io.Reader observing the specified protocol version. // @@ -212,3 +216,10 @@ func (c *UpdateAddHTLC) MsgType() MessageType { func (c *UpdateAddHTLC) TargetChanID() ChannelID { return c.ChanID } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateAddHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 8cd9c7687..397c70084 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -38,6 +38,10 @@ type UpdateFailHTLC struct { // interface. var _ Message = (*UpdateFailHTLC)(nil) +// A compile time check to ensure UpdateFailHTLC implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*UpdateFailHTLC)(nil) + // Decode deserializes a serialized UpdateFailHTLC message stored in the passed // io.Reader observing the specified protocol version. // @@ -79,6 +83,13 @@ func (c *UpdateFailHTLC) MsgType() MessageType { return MsgUpdateFailHTLC } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFailHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go index f28107a9a..bc9bd9aba 100644 --- a/lnwire/update_fail_malformed_htlc.go +++ b/lnwire/update_fail_malformed_htlc.go @@ -36,6 +36,10 @@ type UpdateFailMalformedHTLC struct { // lnwire.Message interface. var _ Message = (*UpdateFailMalformedHTLC)(nil) +// A compile time check to ensure UpdateFailMalformedHTLC implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFailMalformedHTLC)(nil) + // Decode deserializes a serialized UpdateFailMalformedHTLC message stored in the passed // io.Reader observing the specified protocol version. // @@ -84,6 +88,13 @@ func (c *UpdateFailMalformedHTLC) MsgType() MessageType { return MsgUpdateFailMalformedHTLC } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFailMalformedHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go index a7026044c..f32a7385a 100644 --- a/lnwire/update_fee.go +++ b/lnwire/update_fee.go @@ -36,6 +36,10 @@ func NewUpdateFee(chanID ChannelID, feePerKw uint32) *UpdateFee { // interface. var _ Message = (*UpdateFee)(nil) +// A compile time check to ensure UpdateFee implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFee)(nil) + // Decode deserializes a serialized UpdateFee message stored in the passed // io.Reader observing the specified protocol version. // @@ -72,6 +76,13 @@ func (c *UpdateFee) MsgType() MessageType { return MsgUpdateFee } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFee) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 35aaa2ff5..9ec747406 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -44,12 +44,16 @@ func NewUpdateFulfillHTLC(chanID ChannelID, id uint64, } } -// A compile time check to ensure UpdateFulfillHTLC implements the lnwire.Message -// interface. +// A compile time check to ensure UpdateFulfillHTLC implements the +// lnwire.Message interface. var _ Message = (*UpdateFulfillHTLC)(nil) -// Decode deserializes a serialized UpdateFulfillHTLC message stored in the passed -// io.Reader observing the specified protocol version. +// A compile time check to ensure UpdateFulfillHTLC implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFulfillHTLC)(nil) + +// Decode deserializes a serialized UpdateFulfillHTLC message stored in the +// passed io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { @@ -115,6 +119,13 @@ func (c *UpdateFulfillHTLC) MsgType() MessageType { return MsgUpdateFulfillHTLC } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFulfillHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/warning.go b/lnwire/warning.go index adb595cc0..2a0b3df19 100644 --- a/lnwire/warning.go +++ b/lnwire/warning.go @@ -29,6 +29,10 @@ type Warning struct { // interface. var _ Message = (*Warning)(nil) +// A compile time check to ensure Warning implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Warning)(nil) + // NewWarning creates a new Warning message. func NewWarning() *Warning { return &Warning{} @@ -74,3 +78,10 @@ func (c *Warning) Encode(w *bytes.Buffer, _ uint32) error { func (c *Warning) MsgType() MessageType { return MsgWarning } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *Warning) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} From eb877db2ffbcd378f279734615ee2b6fbc03b8e4 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 19 Mar 2025 15:10:35 -0500 Subject: [PATCH 2/4] lnwire: add new TestMessage interface for property tests In this commit, we add a new `TestMessage` interface for use in property tests. With this, we'll be able to generate a random instance of a given message, using the rapid byte stream. This can also eventually be useful for fuzzing. --- .gitignore | 2 + lnwire/accept_channel.go | 4 +- lnwire/channel_update.go | 3 +- lnwire/closing_complete.go | 3 +- lnwire/closing_sig.go | 4 +- lnwire/closing_signed.go | 3 +- lnwire/commit_sig.go | 3 +- lnwire/custom.go | 2 +- lnwire/funding_signed.go | 4 +- lnwire/gossip_timestamp_range.go | 3 +- lnwire/kickoff_sig.go | 11 + lnwire/message.go | 5 + lnwire/node_announcement.go | 3 +- lnwire/open_channel.go | 4 +- lnwire/query_channel_range.go | 4 +- lnwire/query_short_chan_ids.go | 4 +- lnwire/reply_channel_range.go | 3 +- lnwire/reply_short_chan_ids_end.go | 3 +- lnwire/revoke_and_ack.go | 3 +- lnwire/shutdown.go | 3 +- lnwire/stfu.go | 3 +- lnwire/test_message.go | 1669 ++++++++++++++++++++++++++++ lnwire/test_utils.go | 360 ++++++ lnwire/update_add_htlc.go | 8 +- lnwire/update_fail_htlc.go | 8 +- 25 files changed, 2090 insertions(+), 32 deletions(-) create mode 100644 lnwire/test_message.go create mode 100644 lnwire/test_utils.go diff --git a/.gitignore b/.gitignore index 5ddd602d3..1e498b0cf 100644 --- a/.gitignore +++ b/.gitignore @@ -80,3 +80,5 @@ coverage.txt # Release build directory (to avoid build.vcs.modified Golang build tag to be # set to true by having untracked files in the working directory). /lnd-*/ + +.aider* diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index aace5d536..afb2f1412 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -128,8 +128,8 @@ type AcceptChannel struct { // interface. var _ Message = (*AcceptChannel)(nil) -// A compile time check to ensure AcceptChannel implements the lnwire.SizeableMessage -// interface. +// A compile time check to ensure AcceptChannel implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*AcceptChannel)(nil) // Encode serializes the target AcceptChannel into the passed io.Writer diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 1c7f0eed2..88f981671 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -124,7 +124,8 @@ type ChannelUpdate1 struct { // interface. var _ Message = (*ChannelUpdate1)(nil) -// A compile time check to ensure ChannelUpdate1 implements the lnwire.SizeableMessage interface. +// A compile time check to ensure ChannelUpdate1 implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*ChannelUpdate1)(nil) // Decode deserializes a serialized ChannelUpdate stored in the passed diff --git a/lnwire/closing_complete.go b/lnwire/closing_complete.go index 14784a3c1..7980ef1ee 100644 --- a/lnwire/closing_complete.go +++ b/lnwire/closing_complete.go @@ -180,5 +180,6 @@ func (c *ClosingComplete) SerializedSize() (uint32, error) { // interface. var _ Message = (*ClosingComplete)(nil) -// A compile time check to ensure ClosingComplete implements the lnwire.SizeableMessage interface. +// A compile time check to ensure ClosingComplete implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*ClosingComplete)(nil) diff --git a/lnwire/closing_sig.go b/lnwire/closing_sig.go index 1eeec2580..94a356066 100644 --- a/lnwire/closing_sig.go +++ b/lnwire/closing_sig.go @@ -118,6 +118,6 @@ func (c *ClosingSig) SerializedSize() (uint32, error) { // interface. var _ Message = (*ClosingSig)(nil) -// A compile time check to ensure ClosingSig implements the lnwire.SizeableMessage -// interface. +// A compile time check to ensure ClosingSig implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*ClosingSig)(nil) diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index a82dbf402..c247cfe0a 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -59,7 +59,8 @@ func NewClosingSigned(cid ChannelID, fs btcutil.Amount, // interface. var _ Message = (*ClosingSigned)(nil) -// A compile time check to ensure ClosingSigned implements the lnwire.SizeableMessage interface. +// A compile time check to ensure ClosingSigned implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*ClosingSigned)(nil) // Decode deserializes a serialized ClosingSigned message stored in the passed diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 7c5a41ccc..600ff81e7 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -64,7 +64,8 @@ func NewCommitSig() *CommitSig { // interface. var _ Message = (*CommitSig)(nil) -// A compile time check to ensure CommitSig implements the lnwire.SizeableMessage interface. +// A compile time check to ensure CommitSig implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*CommitSig)(nil) // Decode deserializes a serialized CommitSig message stored in the diff --git a/lnwire/custom.go b/lnwire/custom.go index e8c299297..740c23a14 100644 --- a/lnwire/custom.go +++ b/lnwire/custom.go @@ -69,7 +69,7 @@ type Custom struct { Data []byte } -// A compile time check to ensure FundingCreated implements the lnwire.Message +// A compile time check to ensure Custom implements the lnwire.Message // interface. var _ Message = (*Custom)(nil) diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index a7f23310a..182e4bbde 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -36,8 +36,8 @@ type FundingSigned struct { // interface. var _ Message = (*FundingSigned)(nil) -// A compile time check to ensure FundingSigned implements the lnwire.SizeableMessage -// interface. +// A compile time check to ensure FundingSigned implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*FundingSigned)(nil) // Encode serializes the target FundingSigned into the passed io.Writer diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index 25c2a033b..45ff1f939 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -58,7 +58,8 @@ func NewGossipTimestampRange() *GossipTimestampRange { // lnwire.Message interface. var _ Message = (*GossipTimestampRange)(nil) -// A compile time check to ensure GossipTimestampRange implements the lnwire.SizeableMessage interface. +// A compile time check to ensure GossipTimestampRange implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*GossipTimestampRange)(nil) // Decode deserializes a serialized GossipTimestampRange message stored in the diff --git a/lnwire/kickoff_sig.go b/lnwire/kickoff_sig.go index 3e46db453..b9c4c206a 100644 --- a/lnwire/kickoff_sig.go +++ b/lnwire/kickoff_sig.go @@ -27,6 +27,10 @@ type KickoffSig struct { // interface. var _ Message = (*KickoffSig)(nil) +// A compile time check to ensure KickoffSig implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*KickoffSig)(nil) + // Encode serializes the target KickoffSig into the passed bytes.Buffer // observing the specified protocol version. // @@ -54,3 +58,10 @@ func (ks *KickoffSig) Decode(r io.Reader, _ uint32) error { // // This is part of the lnwire.Message interface. func (ks *KickoffSig) MsgType() MessageType { return MsgKickoffSig } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (ks *KickoffSig) SerializedSize() (uint32, error) { + return MessageSerializedSize(ks) +} diff --git a/lnwire/message.go b/lnwire/message.go index 8c310bd78..ea480075a 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -65,6 +65,11 @@ const ( MsgChannelAnnouncement2 = 267 MsgChannelUpdate2 = 271 MsgKickoffSig = 777 + + // MsgEnd defines the end of the official message range of the protocol. + // If a new message is added beyond this message, then this should be + // modified. + MsgEnd = 778 ) // IsChannelUpdate is a filter function that discerns channel update messages diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index ae883e886..5ba2d7a1d 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -104,7 +104,8 @@ type NodeAnnouncement struct { // lnwire.Message interface. var _ Message = (*NodeAnnouncement)(nil) -// A compile time check to ensure NodeAnnouncement implements the lnwire.SizeableMessage interface. +// A compile time check to ensure NodeAnnouncement implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*NodeAnnouncement)(nil) // Decode deserializes a serialized NodeAnnouncement stored in the passed diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 217ddbea0..1751f748b 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -164,8 +164,8 @@ type OpenChannel struct { // interface. var _ Message = (*OpenChannel)(nil) -// A compile time check to ensure OpenChannel implements the lnwire.SizeableMessage -// interface. +// A compile time check to ensure OpenChannel implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*OpenChannel)(nil) // Encode serializes the target OpenChannel into the passed io.Writer diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index 90b144bda..c816a0050 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -49,8 +49,8 @@ func NewQueryChannelRange() *QueryChannelRange { // lnwire.Message interface. var _ Message = (*QueryChannelRange)(nil) -// A compile time check to ensure QueryChannelRange implements the lnwire.SizeableMessage -// interface. +// A compile time check to ensure QueryChannelRange implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*QueryChannelRange)(nil) // Decode deserializes a serialized QueryChannelRange message stored in the diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index f12d07abf..37a73ab7c 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -91,8 +91,8 @@ func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding, // lnwire.Message interface. var _ Message = (*QueryShortChanIDs)(nil) -// A compile time check to ensure QueryShortChanIDs implements the lnwire.SizeableMessage -// interface. +// A compile time check to ensure QueryShortChanIDs implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*QueryShortChanIDs)(nil) // Decode deserializes a serialized QueryShortChanIDs message stored in the diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index a3b11c53e..c3a744ebd 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -70,7 +70,8 @@ func NewReplyChannelRange() *ReplyChannelRange { // lnwire.Message interface. var _ Message = (*ReplyChannelRange)(nil) -// A compile time check to ensure ReplyChannelRange implements the lnwire.SizeableMessage interface. +// A compile time check to ensure ReplyChannelRange implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*ReplyChannelRange)(nil) // Decode deserializes a serialized ReplyChannelRange message stored in the diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go index 30660a9cf..2e50d840f 100644 --- a/lnwire/reply_short_chan_ids_end.go +++ b/lnwire/reply_short_chan_ids_end.go @@ -39,7 +39,8 @@ func NewReplyShortChanIDsEnd() *ReplyShortChanIDsEnd { // lnwire.Message interface. var _ Message = (*ReplyShortChanIDsEnd)(nil) -// A compile time check to ensure ReplyShortChanIDsEnd implements the lnwire.SizeableMessage interface. +// A compile time check to ensure ReplyShortChanIDsEnd implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*ReplyShortChanIDsEnd)(nil) // Decode deserializes a serialized ReplyShortChanIDsEnd message stored in the diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index aa70b0714..3c9775c99 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -55,7 +55,8 @@ func NewRevokeAndAck() *RevokeAndAck { // interface. var _ Message = (*RevokeAndAck)(nil) -// A compile time check to ensure RevokeAndAck implements the lnwire.SizeableMessage interface. +// A compile time check to ensure RevokeAndAck implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*RevokeAndAck)(nil) // Decode deserializes a serialized RevokeAndAck message stored in the diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index 9ac6eb131..28df9a4ca 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -61,7 +61,8 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown { // interface. var _ Message = (*Shutdown)(nil) -// A compile-time check to ensure Shutdown implements the lnwire.SizeableMessage interface. +// A compile-time check to ensure Shutdown implements the lnwire.SizeableMessage +// interface. var _ SizeableMessage = (*Shutdown)(nil) // Decode deserializes a serialized Shutdown from the passed io.Reader, diff --git a/lnwire/stfu.go b/lnwire/stfu.go index f923c94b8..8e57739d5 100644 --- a/lnwire/stfu.go +++ b/lnwire/stfu.go @@ -24,7 +24,8 @@ type Stfu struct { // A compile time check to ensure Stfu implements the lnwire.Message interface. var _ Message = (*Stfu)(nil) -// A compile time check to ensure Stfu implements the lnwire.SizeableMessage interface. +// A compile time check to ensure Stfu implements the lnwire.SizeableMessage +// interface. var _ SizeableMessage = (*Stfu)(nil) // Encode serializes the target Stfu into the passed io.Writer. diff --git a/lnwire/test_message.go b/lnwire/test_message.go new file mode 100644 index 000000000..8b3d98400 --- /dev/null +++ b/lnwire/test_message.go @@ -0,0 +1,1669 @@ +package lnwire + +import ( + "bytes" + "fmt" + "image/color" + "math" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/tlv" + "pgregory.net/rapid" +) + +// TestMessage is an interface that extends the base Message interface with a +// method to populate the message with random testing data. +type TestMessage interface { + Message + + // RandTestMessage populates the message with random data suitable for + // testing. It uses the rapid testing framework to generate random + // values. + RandTestMessage(t *rapid.T) Message +} + +// A compile time check to ensure AcceptChannel implements the TestMessage +// interface. +var _ TestMessage = (*AcceptChannel)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *AcceptChannel) RandTestMessage(t *rapid.T) Message { + var pendingChanID [32]byte + pendingChanIDBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "pendingChanID", + ) + copy(pendingChanID[:], pendingChanIDBytes) + + var channelType *ChannelType + includeChannelType := rapid.Bool().Draw(t, "includeChannelType") + includeLeaseExpiry := rapid.Bool().Draw(t, "includeLeaseExpiry") + includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce") + + if includeChannelType { + channelType = RandChannelType(t) + } + + var leaseExpiry *LeaseExpiry + if includeLeaseExpiry { + leaseExpiry = RandLeaseExpiry(t) + } + + var localNonce OptMusig2NonceTLV + if includeLocalNonce { + nonce := RandMusig2Nonce(t) + localNonce = tlv.SomeRecordT( + tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + ) + } + + return &AcceptChannel{ + PendingChannelID: pendingChanID, + DustLimit: btcutil.Amount( + rapid.IntRange(100, 1000).Draw(t, "dustLimit"), + ), + MaxValueInFlight: MilliSatoshi( + rapid.IntRange(10000, 1000000).Draw( + t, "maxValueInFlight", + ), + ), + ChannelReserve: btcutil.Amount( + rapid.IntRange(1000, 10000).Draw(t, "channelReserve"), + ), + HtlcMinimum: MilliSatoshi( + rapid.IntRange(1, 1000).Draw(t, "htlcMinimum"), + ), + MinAcceptDepth: uint32( + rapid.IntRange(1, 10).Draw(t, "minAcceptDepth"), + ), + CsvDelay: uint16( + rapid.IntRange(144, 1000).Draw(t, "csvDelay"), + ), + MaxAcceptedHTLCs: uint16( + rapid.IntRange(10, 500).Draw(t, "maxAcceptedHTLCs"), + ), + FundingKey: RandPubKey(t), + RevocationPoint: RandPubKey(t), + PaymentPoint: RandPubKey(t), + DelayedPaymentPoint: RandPubKey(t), + HtlcPoint: RandPubKey(t), + FirstCommitmentPoint: RandPubKey(t), + UpfrontShutdownScript: RandDeliveryAddress(t), + ChannelType: channelType, + LeaseExpiry: leaseExpiry, + LocalNonce: localNonce, + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure AnnounceSignatures1 implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*AnnounceSignatures1)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *AnnounceSignatures1) RandTestMessage(t *rapid.T) Message { + return &AnnounceSignatures1{ + ChannelID: RandChannelID(t), + ShortChannelID: RandShortChannelID(t), + NodeSignature: RandSignature(t), + BitcoinSignature: RandSignature(t), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure AnnounceSignatures2 implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*AnnounceSignatures2)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *AnnounceSignatures2) RandTestMessage(t *rapid.T) Message { + return &AnnounceSignatures2{ + ChannelID: RandChannelID(t), + ShortChannelID: RandShortChannelID(t), + PartialSignature: *RandPartialSig(t), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure ChannelAnnouncement1 implements the +// TestMessage interface. +var _ TestMessage = (*ChannelAnnouncement1)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *ChannelAnnouncement1) RandTestMessage(t *rapid.T) Message { + // Generate Node IDs and Bitcoin keys (compressed public keys) + node1PubKey := RandPubKey(t) + node2PubKey := RandPubKey(t) + bitcoin1PubKey := RandPubKey(t) + bitcoin2PubKey := RandPubKey(t) + + // Convert to byte arrays + var nodeID1, nodeID2, bitcoinKey1, bitcoinKey2 [33]byte + copy(nodeID1[:], node1PubKey.SerializeCompressed()) + copy(nodeID2[:], node2PubKey.SerializeCompressed()) + copy(bitcoinKey1[:], bitcoin1PubKey.SerializeCompressed()) + copy(bitcoinKey2[:], bitcoin2PubKey.SerializeCompressed()) + + // Ensure nodeID1 is numerically less than nodeID2 + // This is a requirement stated in the field description + if bytes.Compare(nodeID1[:], nodeID2[:]) > 0 { + nodeID1, nodeID2 = nodeID2, nodeID1 + } + + // Generate chain hash + chainHash := RandChainHash(t) + var hash chainhash.Hash + copy(hash[:], chainHash[:]) + + return &ChannelAnnouncement1{ + NodeSig1: RandSignature(t), + NodeSig2: RandSignature(t), + BitcoinSig1: RandSignature(t), + BitcoinSig2: RandSignature(t), + Features: RandFeatureVector(t), + ChainHash: hash, + ShortChannelID: RandShortChannelID(t), + NodeID1: nodeID1, + NodeID2: nodeID2, + BitcoinKey1: bitcoinKey1, + BitcoinKey2: bitcoinKey2, + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ChannelAnnouncement2)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { + features := RandFeatureVector(t) + shortChanID := RandShortChannelID(t) + capacity := uint64(rapid.IntRange(1, 16777215).Draw(t, "capacity")) + + var nodeID1, nodeID2 [33]byte + copy(nodeID1[:], RandPubKey(t).SerializeCompressed()) + copy(nodeID2[:], RandPubKey(t).SerializeCompressed()) + + // Make sure nodeID1 is numerically less than nodeID2 (as per spec). + if bytes.Compare(nodeID1[:], nodeID2[:]) > 0 { + nodeID1, nodeID2 = nodeID2, nodeID1 + } + + chainHash := RandChainHash(t) + var chainHashObj chainhash.Hash + copy(chainHashObj[:], chainHash[:]) + + msg := &ChannelAnnouncement2{ + Signature: RandSignature(t), + ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( + chainHashObj, + ), + Features: tlv.NewRecordT[tlv.TlvType2, RawFeatureVector]( + *features, + ), + ShortChannelID: tlv.NewRecordT[tlv.TlvType4, ShortChannelID]( + shortChanID, + ), + Capacity: tlv.NewPrimitiveRecord[tlv.TlvType6, uint64]( + capacity, + ), + NodeID1: tlv.NewPrimitiveRecord[tlv.TlvType8, [33]byte]( + nodeID1, + ), + NodeID2: tlv.NewPrimitiveRecord[tlv.TlvType10, [33]byte]( + nodeID2, + ), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } + + msg.Signature.ForceSchnorr() + + // Randomly include optional fields + if rapid.Bool().Draw(t, "includeBitcoinKey1") { + var bitcoinKey1 [33]byte + copy(bitcoinKey1[:], RandPubKey(t).SerializeCompressed()) + msg.BitcoinKey1 = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType12, [33]byte]( + bitcoinKey1, + ), + ) + } + + if rapid.Bool().Draw(t, "includeBitcoinKey2") { + var bitcoinKey2 [33]byte + copy(bitcoinKey2[:], RandPubKey(t).SerializeCompressed()) + msg.BitcoinKey2 = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType14, [33]byte]( + bitcoinKey2, + ), + ) + } + + if rapid.Bool().Draw(t, "includeMerkleRootHash") { + hash := RandSHA256Hash(t) + var merkleRootHash [32]byte + copy(merkleRootHash[:], hash[:]) + msg.MerkleRootHash = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType16, [32]byte]( + merkleRootHash, + ), + ) + } + + return msg +} + +// A compile time check to ensure ChannelReady implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*ChannelReady)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ChannelReady) RandTestMessage(t *rapid.T) Message { + msg := &ChannelReady{ + ChanID: RandChannelID(t), + NextPerCommitmentPoint: RandPubKey(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeAliasScid := rapid.Bool().Draw(t, "includeAliasScid") + includeNextLocalNonce := rapid.Bool().Draw(t, "includeNextLocalNonce") + includeAnnouncementNodeNonce := rapid.Bool().Draw( + t, "includeAnnouncementNodeNonce", + ) + includeAnnouncementBitcoinNonce := rapid.Bool().Draw( + t, "includeAnnouncementBitcoinNonce", + ) + + if includeAliasScid { + scid := RandShortChannelID(t) + msg.AliasScid = &scid + } + + if includeNextLocalNonce { + nonce := RandMusig2Nonce(t) + msg.NextLocalNonce = SomeMusig2Nonce(nonce) + } + + if includeAnnouncementNodeNonce { + nonce := RandMusig2Nonce(t) + msg.AnnouncementNodeNonce = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType0, Musig2Nonce](nonce), + ) + } + + if includeAnnouncementBitcoinNonce { + nonce := RandMusig2Nonce(t) + msg.AnnouncementBitcoinNonce = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2, Musig2Nonce](nonce), + ) + } + + return msg +} + +// A compile time check to ensure ChannelReestablish implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ChannelReestablish)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *ChannelReestablish) RandTestMessage(t *rapid.T) Message { + msg := &ChannelReestablish{ + ChanID: RandChannelID(t), + NextLocalCommitHeight: rapid.Uint64().Draw( + t, "nextLocalCommitHeight", + ), + RemoteCommitTailHeight: rapid.Uint64().Draw( + t, "remoteCommitTailHeight", + ), + LastRemoteCommitSecret: RandPaymentPreimage(t), + LocalUnrevokedCommitPoint: RandPubKey(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + // Randomly decide whether to include optional fields + includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce") + includeDynHeight := rapid.Bool().Draw(t, "includeDynHeight") + + if includeLocalNonce { + nonce := RandMusig2Nonce(t) + msg.LocalNonce = SomeMusig2Nonce(nonce) + } + + if includeDynHeight { + height := DynHeight(rapid.Uint64().Draw(t, "dynHeight")) + msg.DynHeight = fn.Some(height) + } + + return msg +} + +// A compile time check to ensure ChannelUpdate1 implements the TestMessage +// interface. +var _ TestMessage = (*ChannelUpdate1)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *ChannelUpdate1) RandTestMessage(t *rapid.T) Message { + // Generate random message flags + // Randomly decide whether to include max HTLC field + includeMaxHtlc := rapid.Bool().Draw(t, "includeMaxHtlc") + var msgFlags ChanUpdateMsgFlags + if includeMaxHtlc { + msgFlags |= ChanUpdateRequiredMaxHtlc + } + + // Generate random channel flags + // Randomly decide direction (node1 or node2) + isNode2 := rapid.Bool().Draw(t, "isNode2") + var chanFlags ChanUpdateChanFlags + if isNode2 { + chanFlags |= ChanUpdateDirection + } + + // Randomly decide if channel is disabled + isDisabled := rapid.Bool().Draw(t, "isDisabled") + if isDisabled { + chanFlags |= ChanUpdateDisabled + } + + // Generate chain hash + chainHash := RandChainHash(t) + var hash chainhash.Hash + copy(hash[:], chainHash[:]) + + // Generate other random fields + maxHtlc := MilliSatoshi(rapid.Uint64().Draw(t, "maxHtlc")) + + // If max HTLC flag is not set, we need to zero the value + if !includeMaxHtlc { + maxHtlc = 0 + } + + return &ChannelUpdate1{ + Signature: RandSignature(t), + ChainHash: hash, + ShortChannelID: RandShortChannelID(t), + Timestamp: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw( + t, "timestamp"), + ), + MessageFlags: msgFlags, + ChannelFlags: chanFlags, + TimeLockDelta: uint16(rapid.IntRange(0, 65535).Draw( + t, "timelockDelta"), + ), + HtlcMinimumMsat: MilliSatoshi(rapid.Uint64().Draw( + t, "htlcMinimum"), + ), + BaseFee: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw( + t, "baseFee"), + ), + FeeRate: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw( + t, "feeRate"), + ), + HtlcMaximumMsat: maxHtlc, + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure ChannelUpdate2 implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ChannelUpdate2)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ChannelUpdate2) RandTestMessage(t *rapid.T) Message { + shortChanID := RandShortChannelID(t) + blockHeight := uint32(rapid.IntRange(0, 1000000).Draw(t, "blockHeight")) + + var disabledFlags ChanUpdateDisableFlags + if rapid.Bool().Draw(t, "disableIncoming") { + disabledFlags |= ChanUpdateDisableIncoming + } + if rapid.Bool().Draw(t, "disableOutgoing") { + disabledFlags |= ChanUpdateDisableOutgoing + } + + cltvExpiryDelta := uint16(rapid.IntRange(10, 200).Draw( + t, "cltvExpiryDelta"), + ) + + htlcMinMsat := MilliSatoshi(rapid.IntRange(1, 10000).Draw( + t, "htlcMinMsat"), + ) + htlcMaxMsat := MilliSatoshi(rapid.IntRange(10000, 100000000).Draw( + t, "htlcMaxMsat"), + ) + feeBaseMsat := uint32(rapid.IntRange(0, 10000).Draw(t, "feeBaseMsat")) + feeProportionalMillionths := uint32(rapid.IntRange(0, 10000).Draw( + t, "feeProportionalMillionths"), + ) + + chainHash := RandChainHash(t) + var chainHashObj chainhash.Hash + copy(chainHashObj[:], chainHash[:]) + + //nolint:ll + msg := &ChannelUpdate2{ + Signature: RandSignature(t), + ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( + chainHashObj, + ), + ShortChannelID: tlv.NewRecordT[tlv.TlvType2, ShortChannelID]( + shortChanID, + ), + BlockHeight: tlv.NewPrimitiveRecord[tlv.TlvType4, uint32]( + blockHeight, + ), + DisabledFlags: tlv.NewPrimitiveRecord[tlv.TlvType6, ChanUpdateDisableFlags]( //nolint:ll + disabledFlags, + ), + CLTVExpiryDelta: tlv.NewPrimitiveRecord[tlv.TlvType10, uint16]( + cltvExpiryDelta, + ), + HTLCMinimumMsat: tlv.NewPrimitiveRecord[tlv.TlvType12, MilliSatoshi]( + htlcMinMsat, + ), + HTLCMaximumMsat: tlv.NewPrimitiveRecord[tlv.TlvType14, MilliSatoshi]( + htlcMaxMsat, + ), + FeeBaseMsat: tlv.NewPrimitiveRecord[tlv.TlvType16, uint32]( + feeBaseMsat, + ), + FeeProportionalMillionths: tlv.NewPrimitiveRecord[tlv.TlvType18, uint32]( + feeProportionalMillionths, + ), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } + + msg.Signature.ForceSchnorr() + + if rapid.Bool().Draw(t, "isSecondPeer") { + msg.SecondPeer = tlv.SomeRecordT( + tlv.RecordT[tlv.TlvType8, TrueBoolean]{}, + ) + } + + return msg +} + +// A compile time check to ensure ClosingComplete implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ClosingComplete)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ClosingComplete) RandTestMessage(t *rapid.T) Message { + msg := &ClosingComplete{ + ChannelID: RandChannelID(t), + FeeSatoshis: btcutil.Amount(rapid.Int64Range(0, 1000000).Draw( + t, "feeSatoshis"), + ), + LockTime: rapid.Uint32Range(0, 0xffffffff).Draw( + t, "lockTime", + ), + CloseeScript: RandDeliveryAddress(t), + CloserScript: RandDeliveryAddress(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeCloserNoClosee := rapid.Bool().Draw(t, "includeCloserNoClosee") + includeNoCloserClosee := rapid.Bool().Draw(t, "includeNoCloserClosee") + includeCloserAndClosee := rapid.Bool().Draw(t, "includeCloserAndClosee") + + // Ensure at least one signature is present. + if !includeCloserNoClosee && !includeNoCloserClosee && + !includeCloserAndClosee { + + // If all are false, enable at least one randomly. + choice := rapid.IntRange(0, 2).Draw(t, "sigChoice") + switch choice { + case 0: + includeCloserNoClosee = true + case 1: + includeNoCloserClosee = true + case 2: + includeCloserAndClosee = true + } + } + + if includeCloserNoClosee { + sig := RandSignature(t) + msg.CloserNoClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType1, Sig](sig), + ) + } + + if includeNoCloserClosee { + sig := RandSignature(t) + msg.NoCloserClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2, Sig](sig), + ) + } + + if includeCloserAndClosee { + sig := RandSignature(t) + msg.CloserAndClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3, Sig](sig), + ) + } + + return msg +} + +// A compile time check to ensure ClosingSig implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*ClosingSig)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ClosingSig) RandTestMessage(t *rapid.T) Message { + msg := &ClosingSig{ + ChannelID: RandChannelID(t), + CloseeScript: RandDeliveryAddress(t), + CloserScript: RandDeliveryAddress(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeCloserNoClosee := rapid.Bool().Draw(t, "includeCloserNoClosee") + includeNoCloserClosee := rapid.Bool().Draw(t, "includeNoCloserClosee") + includeCloserAndClosee := rapid.Bool().Draw(t, "includeCloserAndClosee") + + // Ensure at least one signature is present. + if !includeCloserNoClosee && !includeNoCloserClosee && + !includeCloserAndClosee { + + // If all are false, enable at least one randomly. + choice := rapid.IntRange(0, 2).Draw(t, "sigChoice") + switch choice { + case 0: + includeCloserNoClosee = true + case 1: + includeNoCloserClosee = true + case 2: + includeCloserAndClosee = true + } + } + + if includeCloserNoClosee { + sig := RandSignature(t) + msg.CloserNoClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType1, Sig](sig), + ) + } + + if includeNoCloserClosee { + sig := RandSignature(t) + msg.NoCloserClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2, Sig](sig), + ) + } + + if includeCloserAndClosee { + sig := RandSignature(t) + msg.CloserAndClosee = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3, Sig](sig), + ) + } + + return msg +} + +// A compile time check to ensure ClosingSigned implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ClosingSigned)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ClosingSigned) RandTestMessage(t *rapid.T) Message { + // Generate a random boolean to decide whether to include CommitSig or + // PartialSig Since they're mutually exclusive, when one is populated, + // the other must be blank. + usePartialSig := rapid.Bool().Draw(t, "usePartialSig") + + msg := &ClosingSigned{ + ChannelID: RandChannelID(t), + FeeSatoshis: btcutil.Amount( + rapid.Int64Range(0, 1000000).Draw(t, "feeSatoshis"), + ), + ExtraData: RandExtraOpaqueData(t, nil), + } + + if usePartialSig { + sigBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "sigScalar", + ) + var s btcec.ModNScalar + _ = s.SetByteSlice(sigBytes) + + msg.PartialSig = SomePartialSig(NewPartialSig(s)) + msg.Signature = Sig{} + } else { + msg.Signature = RandSignature(t) + } + + return msg +} + +// A compile time check to ensure CommitSig implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*CommitSig)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *CommitSig) RandTestMessage(t *rapid.T) Message { + cr, _ := RandCustomRecords(t, nil, true) + sig := &CommitSig{ + ChanID: RandChannelID(t), + CommitSig: RandSignature(t), + CustomRecords: cr, + } + + numHtlcSigs := rapid.IntRange(0, 20).Draw(t, "numHtlcSigs") + htlcSigs := make([]Sig, numHtlcSigs) + for i := 0; i < numHtlcSigs; i++ { + htlcSigs[i] = RandSignature(t) + } + + if len(htlcSigs) > 0 { + sig.HtlcSigs = htlcSigs + } + + includePartialSig := rapid.Bool().Draw(t, "includePartialSig") + if includePartialSig { + sigWithNonce := RandPartialSigWithNonce(t) + sig.PartialSig = MaybePartialSigWithNonce(sigWithNonce) + } + + return sig +} + +// A compile time check to ensure Custom implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Custom)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *Custom) RandTestMessage(t *rapid.T) Message { + msgType := MessageType( + rapid.IntRange(int(CustomTypeStart), 65535).Draw( + t, "customMsgType", + ), + ) + + dataLen := rapid.IntRange(0, 1000).Draw(t, "customDataLength") + data := rapid.SliceOfN(rapid.Byte(), dataLen, dataLen).Draw( + t, "customData", + ) + + msg, err := NewCustom(msgType, data) + if err != nil { + panic(fmt.Sprintf("Error creating custom message: %v", err)) + } + + return msg +} + +// A compile time check to ensure DynAck implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*DynAck)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (da *DynAck) RandTestMessage(t *rapid.T) Message { + msg := &DynAck{ + ChanID: RandChannelID(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce") + + if includeLocalNonce { + msg.LocalNonce = fn.Some(RandMusig2Nonce(t)) + } + + return msg +} + +// A compile time check to ensure DynPropose implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*DynPropose)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (dp *DynPropose) RandTestMessage(t *rapid.T) Message { + msg := &DynPropose{ + ChanID: RandChannelID(t), + Initiator: rapid.Bool().Draw(t, "initiator"), + ExtraData: RandExtraOpaqueData(t, nil), + } + + // Randomly decide which optional fields to include + includeDustLimit := rapid.Bool().Draw(t, "includeDustLimit") + includeMaxValueInFlight := rapid.Bool().Draw( + t, "includeMaxValueInFlight", + ) + includeChannelReserve := rapid.Bool().Draw(t, "includeChannelReserve") + includeCsvDelay := rapid.Bool().Draw(t, "includeCsvDelay") + includeMaxAcceptedHTLCs := rapid.Bool().Draw( + t, "includeMaxAcceptedHTLCs", + ) + includeFundingKey := rapid.Bool().Draw(t, "includeFundingKey") + includeChannelType := rapid.Bool().Draw(t, "includeChannelType") + includeKickoffFeerate := rapid.Bool().Draw(t, "includeKickoffFeerate") + + // Generate random values for each included field + if includeDustLimit { + dl := btcutil.Amount(rapid.Uint32().Draw(t, "dustLimit")) + msg.DustLimit = fn.Some(dl) + } + + if includeMaxValueInFlight { + mvif := MilliSatoshi(rapid.Uint64().Draw(t, "maxValueInFlight")) + msg.MaxValueInFlight = fn.Some(mvif) + } + + if includeChannelReserve { + cr := btcutil.Amount(rapid.Uint32().Draw(t, "channelReserve")) + msg.ChannelReserve = fn.Some(cr) + } + + if includeCsvDelay { + cd := rapid.Uint16().Draw(t, "csvDelay") + msg.CsvDelay = fn.Some(cd) + } + + if includeMaxAcceptedHTLCs { + mah := rapid.Uint16().Draw(t, "maxAcceptedHTLCs") + msg.MaxAcceptedHTLCs = fn.Some(mah) + } + + if includeFundingKey { + msg.FundingKey = fn.Some(*RandPubKey(t)) + } + + if includeChannelType { + msg.ChannelType = fn.Some(*RandChannelType(t)) + } + + if includeKickoffFeerate { + kf := chainfee.SatPerKWeight(rapid.Uint32().Draw( + t, "kickoffFeerate"), + ) + msg.KickoffFeerate = fn.Some(kf) + } + + return msg +} + +// A compile time check to ensure DynReject implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*DynReject)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (dr *DynReject) RandTestMessage(t *rapid.T) Message { + featureVec := NewRawFeatureVector() + + numFeatures := rapid.IntRange(0, 8).Draw(t, "numRejections") + for i := 0; i < numFeatures; i++ { + bit := FeatureBit( + rapid.IntRange(0, 31).Draw( + t, fmt.Sprintf("rejectionBit-%d", i), + ), + ) + featureVec.Set(bit) + } + + var extraData ExtraOpaqueData + randData := RandExtraOpaqueData(t, nil) + if len(randData) > 0 { + extraData = randData + } + + return &DynReject{ + ChanID: RandChannelID(t), + UpdateRejections: *featureVec, + ExtraData: extraData, + } +} + +// A compile time check to ensure FundingCreated implements the TestMessage +// interface. +var _ TestMessage = (*FundingCreated)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (f *FundingCreated) RandTestMessage(t *rapid.T) Message { + var pendingChanID [32]byte + pendingChanIDBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "pendingChanID", + ) + copy(pendingChanID[:], pendingChanIDBytes) + + includePartialSig := rapid.Bool().Draw(t, "includePartialSig") + var partialSig OptPartialSigWithNonceTLV + var commitSig Sig + + if includePartialSig { + sigWithNonce := RandPartialSigWithNonce(t) + partialSig = MaybePartialSigWithNonce(sigWithNonce) + + // When using partial sig, CommitSig should be empty/blank. + commitSig = Sig{} + } else { + commitSig = RandSignature(t) + } + + return &FundingCreated{ + PendingChannelID: pendingChanID, + FundingPoint: RandOutPoint(t), + CommitSig: commitSig, + PartialSig: partialSig, + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure FundingSigned implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*FundingSigned)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (f *FundingSigned) RandTestMessage(t *rapid.T) Message { + usePartialSig := rapid.Bool().Draw(t, "usePartialSig") + + msg := &FundingSigned{ + ChanID: RandChannelID(t), + ExtraData: RandExtraOpaqueData(t, nil), + } + + if usePartialSig { + sigWithNonce := RandPartialSigWithNonce(t) + msg.PartialSig = MaybePartialSigWithNonce(sigWithNonce) + + msg.CommitSig = Sig{} + } else { + msg.CommitSig = RandSignature(t) + } + + return msg +} + +// A compile time check to ensure GossipTimestampRange implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*GossipTimestampRange)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (g *GossipTimestampRange) RandTestMessage(t *rapid.T) Message { + var chainHash chainhash.Hash + hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash") + copy(chainHash[:], hashBytes) + + msg := &GossipTimestampRange{ + ChainHash: chainHash, + FirstTimestamp: rapid.Uint32().Draw(t, "firstTimestamp"), + TimestampRange: rapid.Uint32().Draw(t, "timestampRange"), + ExtraData: RandExtraOpaqueData(t, nil), + } + + includeFirstBlockHeight := rapid.Bool().Draw( + t, "includeFirstBlockHeight", + ) + includeBlockRange := rapid.Bool().Draw(t, "includeBlockRange") + + if includeFirstBlockHeight { + height := rapid.Uint32().Draw(t, "firstBlockHeight") + msg.FirstBlockHeight = tlv.SomeRecordT( + tlv.RecordT[tlv.TlvType2, uint32]{Val: height}, + ) + } + + if includeBlockRange { + blockRange := rapid.Uint32().Draw(t, "blockRange") + msg.BlockRange = tlv.SomeRecordT( + tlv.RecordT[tlv.TlvType4, uint32]{Val: blockRange}, + ) + } + + return msg +} + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (msg *Init) RandTestMessage(t *rapid.T) Message { + global := NewRawFeatureVector() + local := NewRawFeatureVector() + + numGlobalFeatures := rapid.IntRange(0, 20).Draw(t, "numGlobalFeatures") + for i := 0; i < numGlobalFeatures; i++ { + bit := FeatureBit( + rapid.IntRange(0, 100).Draw( + t, fmt.Sprintf("globalFeatureBit%d", i), + ), + ) + global.Set(bit) + } + + numLocalFeatures := rapid.IntRange(0, 20).Draw(t, "numLocalFeatures") + for i := 0; i < numLocalFeatures; i++ { + bit := FeatureBit( + rapid.IntRange(0, 100).Draw( + t, fmt.Sprintf("localFeatureBit%d", i), + ), + ) + local.Set(bit) + } + + return NewInitMessage(global, local) +} + +// A compile time check to ensure KickoffSig implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*KickoffSig)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (ks *KickoffSig) RandTestMessage(t *rapid.T) Message { + return &KickoffSig{ + ChanID: RandChannelID(t), + Signature: RandSignature(t), + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure NodeAnnouncement implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*NodeAnnouncement)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (a *NodeAnnouncement) RandTestMessage(t *rapid.T) Message { + // Generate random compressed public key for node ID + pubKey := RandPubKey(t) + var nodeID [33]byte + copy(nodeID[:], pubKey.SerializeCompressed()) + + // Generate random RGB color + rgbColor := color.RGBA{ + R: uint8(rapid.IntRange(0, 255).Draw(t, "rgbR")), + G: uint8(rapid.IntRange(0, 255).Draw(t, "rgbG")), + B: uint8(rapid.IntRange(0, 255).Draw(t, "rgbB")), + } + + return &NodeAnnouncement{ + Signature: RandSignature(t), + Features: RandFeatureVector(t), + Timestamp: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw( + t, "timestamp"), + ), + NodeID: nodeID, + RGBColor: rgbColor, + Alias: RandNodeAlias(t), + Addresses: RandNetAddrs(t), + ExtraOpaqueData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure OpenChannel implements the TestMessage +// interface. +var _ TestMessage = (*OpenChannel)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (o *OpenChannel) RandTestMessage(t *rapid.T) Message { + chainHash := RandChainHash(t) + var hash chainhash.Hash + copy(hash[:], chainHash[:]) + + var pendingChanID [32]byte + pendingChanIDBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "pendingChanID", + ) + copy(pendingChanID[:], pendingChanIDBytes) + + includeChannelType := rapid.Bool().Draw(t, "includeChannelType") + includeLeaseExpiry := rapid.Bool().Draw(t, "includeLeaseExpiry") + includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce") + + var channelFlags FundingFlag + if rapid.Bool().Draw(t, "announceChannel") { + channelFlags |= FFAnnounceChannel + } + + var localNonce OptMusig2NonceTLV + if includeLocalNonce { + nonce := RandMusig2Nonce(t) + localNonce = tlv.SomeRecordT( + tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + ) + } + + var channelType *ChannelType + if includeChannelType { + channelType = RandChannelType(t) + } + + var leaseExpiry *LeaseExpiry + if includeLeaseExpiry { + leaseExpiry = RandLeaseExpiry(t) + } + + return &OpenChannel{ + ChainHash: hash, + PendingChannelID: pendingChanID, + FundingAmount: btcutil.Amount( + rapid.IntRange(5000, 10000000).Draw(t, "fundingAmount"), + ), + PushAmount: MilliSatoshi( + rapid.IntRange(0, 1000000).Draw(t, "pushAmount"), + ), + DustLimit: btcutil.Amount( + rapid.IntRange(100, 1000).Draw(t, "dustLimit"), + ), + MaxValueInFlight: MilliSatoshi( + rapid.IntRange(10000, 1000000).Draw( + t, "maxValueInFlight", + ), + ), + ChannelReserve: btcutil.Amount( + rapid.IntRange(1000, 10000).Draw(t, "channelReserve"), + ), + HtlcMinimum: MilliSatoshi( + rapid.IntRange(1, 1000).Draw(t, "htlcMinimum"), + ), + FeePerKiloWeight: uint32( + rapid.IntRange(250, 10000).Draw(t, "feePerKw"), + ), + CsvDelay: uint16( + rapid.IntRange(144, 1000).Draw(t, "csvDelay"), + ), + MaxAcceptedHTLCs: uint16( + rapid.IntRange(10, 500).Draw(t, "maxAcceptedHTLCs"), + ), + FundingKey: RandPubKey(t), + RevocationPoint: RandPubKey(t), + PaymentPoint: RandPubKey(t), + DelayedPaymentPoint: RandPubKey(t), + HtlcPoint: RandPubKey(t), + FirstCommitmentPoint: RandPubKey(t), + ChannelFlags: channelFlags, + UpfrontShutdownScript: RandDeliveryAddress(t), + ChannelType: channelType, + LeaseExpiry: leaseExpiry, + LocalNonce: localNonce, + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure Ping implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Ping)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (p *Ping) RandTestMessage(t *rapid.T) Message { + numPongBytes := uint16(rapid.IntRange(0, int(MaxPongBytes)).Draw( + t, "numPongBytes"), + ) + + // Generate padding bytes (but keeping within allowed message size) + // MaxMsgBody - 2 (for NumPongBytes) - 2 (for padding length) + maxPaddingLen := MaxMsgBody - 4 + paddingLen := rapid.IntRange(0, maxPaddingLen).Draw( + t, "paddingLen", + ) + padding := make(PingPayload, paddingLen) + + // Fill padding with random bytes + for i := 0; i < paddingLen; i++ { + padding[i] = byte(rapid.IntRange(0, 255).Draw( + t, fmt.Sprintf("paddingByte%d", i)), + ) + } + + return &Ping{ + NumPongBytes: numPongBytes, + PaddingBytes: padding, + } +} + +// A compile time check to ensure Pong implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Pong)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (p *Pong) RandTestMessage(t *rapid.T) Message { + payloadLen := rapid.IntRange(0, 1000).Draw(t, "pongPayloadLength") + payload := rapid.SliceOfN(rapid.Byte(), payloadLen, payloadLen).Draw( + t, "pongPayload", + ) + + return &Pong{ + PongBytes: payload, + } +} + +// A compile time check to ensure QueryChannelRange implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*QueryChannelRange)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (q *QueryChannelRange) RandTestMessage(t *rapid.T) Message { + msg := &QueryChannelRange{ + FirstBlockHeight: uint32(rapid.IntRange(0, 1000000).Draw( + t, "firstBlockHeight"), + ), + NumBlocks: uint32(rapid.IntRange(1, 10000).Draw( + t, "numBlocks"), + ), + ExtraData: RandExtraOpaqueData(t, nil), + } + + // Generate chain hash + chainHash := RandChainHash(t) + var chainHashObj chainhash.Hash + copy(chainHashObj[:], chainHash[:]) + msg.ChainHash = chainHashObj + + // Randomly include QueryOptions + if rapid.Bool().Draw(t, "includeQueryOptions") { + queryOptions := &QueryOptions{} + *queryOptions = QueryOptions(*RandFeatureVector(t)) + msg.QueryOptions = queryOptions + } + + return msg +} + +// A compile time check to ensure QueryShortChanIDs implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*QueryShortChanIDs)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (q *QueryShortChanIDs) RandTestMessage(t *rapid.T) Message { + var chainHash chainhash.Hash + hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash") + copy(chainHash[:], hashBytes) + + encodingType := EncodingSortedPlain + if rapid.Bool().Draw(t, "useZlibEncoding") { + encodingType = EncodingSortedZlib + } + + msg := &QueryShortChanIDs{ + ChainHash: chainHash, + EncodingType: encodingType, + ExtraData: RandExtraOpaqueData(t, nil), + noSort: false, + } + + numIDs := rapid.IntRange(2, 20).Draw(t, "numShortChanIDs") + + // Generate sorted short channel IDs. + shortChanIDs := make([]ShortChannelID, numIDs) + for i := 0; i < numIDs; i++ { + shortChanIDs[i] = RandShortChannelID(t) + + // Ensure they're properly sorted. + if i > 0 && shortChanIDs[i].ToUint64() <= + shortChanIDs[i-1].ToUint64() { + + // Ensure this ID is larger than the previous one. + shortChanIDs[i] = NewShortChanIDFromInt( + shortChanIDs[i-1].ToUint64() + 1, + ) + } + } + + msg.ShortChanIDs = shortChanIDs + + return msg +} + +// A compile time check to ensure ReplyChannelRange implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ReplyChannelRange)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ReplyChannelRange) RandTestMessage(t *rapid.T) Message { + msg := &ReplyChannelRange{ + FirstBlockHeight: uint32(rapid.IntRange(0, 1000000).Draw( + t, "firstBlockHeight"), + ), + NumBlocks: uint32(rapid.IntRange(1, 10000).Draw( + t, "numBlocks"), + ), + Complete: uint8(rapid.IntRange(0, 1).Draw(t, "complete")), + EncodingType: QueryEncoding( + rapid.IntRange(0, 1).Draw(t, "encodingType"), + ), + ExtraData: RandExtraOpaqueData(t, nil), + } + + msg.ChainHash = RandChainHash(t) + + numShortChanIDs := rapid.IntRange(0, 20).Draw(t, "numShortChanIDs") + if numShortChanIDs == 0 { + return msg + } + + scidSet := fn.NewSet[ShortChannelID]() + scids := make([]ShortChannelID, numShortChanIDs) + for i := 0; i < numShortChanIDs; i++ { + scid := RandShortChannelID(t) + for scidSet.Contains(scid) { + scid = RandShortChannelID(t) + } + + scids[i] = scid + + scidSet.Add(scid) + } + + // Make sure there're no duplicates. + msg.ShortChanIDs = scids + + if rapid.Bool().Draw(t, "includeTimestamps") && numShortChanIDs > 0 { + msg.Timestamps = make(Timestamps, numShortChanIDs) + for i := 0; i < numShortChanIDs; i++ { + msg.Timestamps[i] = ChanUpdateTimestamps{ + Timestamp1: uint32(rapid.IntRange(0, math.MaxInt32).Draw(t, fmt.Sprintf("timestamp-1-%d", i))), //nolint:ll + Timestamp2: uint32(rapid.IntRange(0, math.MaxInt32).Draw(t, fmt.Sprintf("timestamp-2-%d", i))), //nolint:ll + } + } + } + + return msg +} + +// A compile time check to ensure ReplyShortChanIDsEnd implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*ReplyShortChanIDsEnd)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *ReplyShortChanIDsEnd) RandTestMessage(t *rapid.T) Message { + var chainHash chainhash.Hash + hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash") + copy(chainHash[:], hashBytes) + + complete := uint8(rapid.IntRange(0, 1).Draw(t, "complete")) + + return &ReplyShortChanIDsEnd{ + ChainHash: chainHash, + Complete: complete, + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// RandTestMessage returns a RevokeAndAck message populated with random data. +// +// This is part of the TestMessage interface. +func (c *RevokeAndAck) RandTestMessage(t *rapid.T) Message { + msg := NewRevokeAndAck() + + var chanID ChannelID + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "channelID") + copy(chanID[:], bytes) + msg.ChanID = chanID + + revBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "revocation") + copy(msg.Revocation[:], revBytes) + + msg.NextRevocationKey = RandPubKey(t) + + if rapid.Bool().Draw(t, "includeLocalNonce") { + var nonce Musig2Nonce + nonceBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "nonce", + ) + copy(nonce[:], nonceBytes) + + msg.LocalNonce = tlv.SomeRecordT( + tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + ) + } + + return msg +} + +// A compile-time check to ensure Shutdown implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Shutdown)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (s *Shutdown) RandTestMessage(t *rapid.T) Message { + // Generate random delivery address + // First decide the address type (P2PKH, P2SH, P2WPKH, P2WSH, P2TR) + addrType := rapid.IntRange(0, 4).Draw(t, "addrType") + + // Generate random address length based on type + var addrLen int + switch addrType { + // P2PKH + case 0: + addrLen = 25 + // P2SH + case 1: + addrLen = 23 + // P2WPKH + case 2: + addrLen = 22 + // P2WSH + case 3: + addrLen = 34 + // P2TR + case 4: + addrLen = 34 + } + + addr := rapid.SliceOfN(rapid.Byte(), addrLen, addrLen).Draw( + t, "address", + ) + + // Randomly decide whether to include a shutdown nonce + includeNonce := rapid.Bool().Draw(t, "includeNonce") + var shutdownNonce ShutdownNonceTLV + + if includeNonce { + shutdownNonce = SomeShutdownNonce(RandMusig2Nonce(t)) + } + + cr, _ := RandCustomRecords(t, nil, true) + + return &Shutdown{ + ChannelID: RandChannelID(t), + Address: addr, + ShutdownNonce: shutdownNonce, + CustomRecords: cr, + } +} + +// A compile time check to ensure Stfu implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Stfu)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (s *Stfu) RandTestMessage(t *rapid.T) Message { + m := &Stfu{ + ChanID: RandChannelID(t), + Initiator: rapid.Bool().Draw(t, "initiator"), + } + + extraData := RandExtraOpaqueData(t, nil) + if len(extraData) > 0 { + m.ExtraData = extraData + } + + return m +} + +// A compile time check to ensure UpdateAddHTLC implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*UpdateAddHTLC)(nil) + +// RandTestMessage returns an UpdateAddHTLC message populated with random data. +// +// This is part of the TestMessage interface. +func (c *UpdateAddHTLC) RandTestMessage(t *rapid.T) Message { + msg := &UpdateAddHTLC{ + ChanID: RandChannelID(t), + ID: rapid.Uint64().Draw(t, "id"), + Amount: MilliSatoshi(rapid.Uint64().Draw(t, "amount")), + Expiry: rapid.Uint32().Draw(t, "expiry"), + } + + hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "paymentHash") + copy(msg.PaymentHash[:], hashBytes) + + onionBytes := rapid.SliceOfN( + rapid.Byte(), OnionPacketSize, OnionPacketSize, + ).Draw(t, "onionBlob") + copy(msg.OnionBlob[:], onionBytes) + + numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") + if numRecords > 0 { + msg.CustomRecords, _ = RandCustomRecords(t, nil, true) + } + + // 50/50 chance to add a blinding point + if rapid.Bool().Draw(t, "includeBlindingPoint") { + pubKey := RandPubKey(t) + + msg.BlindingPoint = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[BlindingPointTlvType](pubKey), + ) + } + + return msg +} + +// A compile time check to ensure UpdateFailHTLC implements the TestMessage +// interface. +var _ TestMessage = (*UpdateFailHTLC)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *UpdateFailHTLC) RandTestMessage(t *rapid.T) Message { + return &UpdateFailHTLC{ + ChanID: RandChannelID(t), + ID: rapid.Uint64().Draw(t, "id"), + Reason: RandOpaqueReason(t), + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure UpdateFailMalformedHTLC implements the +// TestMessage interface. +var _ TestMessage = (*UpdateFailMalformedHTLC)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *UpdateFailMalformedHTLC) RandTestMessage(t *rapid.T) Message { + return &UpdateFailMalformedHTLC{ + ChanID: RandChannelID(t), + ID: rapid.Uint64().Draw(t, "id"), + ShaOnionBlob: RandSHA256Hash(t), + FailureCode: RandFailCode(t), + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure UpdateFee implements the TestMessage +// interface. +var _ TestMessage = (*UpdateFee)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *UpdateFee) RandTestMessage(t *rapid.T) Message { + return &UpdateFee{ + ChanID: RandChannelID(t), + FeePerKw: uint32(rapid.IntRange(1, 10000).Draw(t, "feePerKw")), + ExtraData: RandExtraOpaqueData(t, nil), + } +} + +// A compile time check to ensure UpdateFulfillHTLC implements the TestMessage +// interface. +var _ TestMessage = (*UpdateFulfillHTLC)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *UpdateFulfillHTLC) RandTestMessage(t *rapid.T) Message { + msg := &UpdateFulfillHTLC{ + ChanID: RandChannelID(t), + ID: rapid.Uint64().Draw(t, "id"), + PaymentPreimage: RandPaymentPreimage(t), + } + + cr, ignoreRecords := RandCustomRecords(t, nil, true) + msg.CustomRecords = cr + + randData := RandExtraOpaqueData(t, ignoreRecords) + if len(randData) > 0 { + msg.ExtraData = randData + } + + return msg +} + +// A compile time check to ensure Warning implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Warning)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *Warning) RandTestMessage(t *rapid.T) Message { + msg := &Warning{ + ChanID: RandChannelID(t), + } + + useASCII := rapid.Bool().Draw(t, "useASCII") + if useASCII { + length := rapid.IntRange(1, 100).Draw(t, "warningDataLength") + data := make([]byte, length) + for i := 0; i < length; i++ { + data[i] = byte( + rapid.IntRange(32, 126).Draw( + t, fmt.Sprintf("warningDataByte-%d", i), + ), + ) + } + msg.Data = data + } else { + length := rapid.IntRange(1, 100).Draw(t, "warningDataLength") + msg.Data = rapid.SliceOfN(rapid.Byte(), length, length).Draw( + t, "warningData", + ) + } + + return msg +} + +// A compile time check to ensure Error implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*Error)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (c *Error) RandTestMessage(t *rapid.T) Message { + msg := &Error{ + ChanID: RandChannelID(t), + } + + useASCII := rapid.Bool().Draw(t, "useASCII") + if useASCII { + length := rapid.IntRange(1, 100).Draw(t, "errorDataLength") + data := make([]byte, length) + for i := 0; i < length; i++ { + data[i] = byte( + rapid.IntRange(32, 126).Draw( + t, fmt.Sprintf("errorDataByte-%d", i), + ), + ) + } + msg.Data = data + } else { + // Generate random binary data + length := rapid.IntRange(1, 100).Draw(t, "errorDataLength") + msg.Data = rapid.SliceOfN( + rapid.Byte(), length, length, + ).Draw(t, "errorData") + } + + return msg +} diff --git a/lnwire/test_utils.go b/lnwire/test_utils.go new file mode 100644 index 000000000..1065cbacf --- /dev/null +++ b/lnwire/test_utils.go @@ -0,0 +1,360 @@ +package lnwire + +import ( + "crypto/sha256" + "fmt" + "net" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/stretchr/testify/require" + "pgregory.net/rapid" +) + +// RandChannelUpdate generates a random ChannelUpdate message using rapid's +// generators. +func RandPartialSig(t *rapid.T) *PartialSig { + // Generate random private key bytes + sigBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "privKeyBytes") + + var s btcec.ModNScalar + s.SetByteSlice(sigBytes) + + return &PartialSig{ + Sig: s, + } +} + +// RandPartialSigWithNonce generates a random PartialSigWithNonce using rapid +// generators. +func RandPartialSigWithNonce(t *rapid.T) *PartialSigWithNonce { + sigLen := rapid.IntRange(1, 65).Draw(t, "partialSigLen") + sigBytes := rapid.SliceOfN( + rapid.Byte(), sigLen, sigLen, + ).Draw(t, "partialSig") + + sigScalar := new(btcec.ModNScalar) + sigScalar.SetByteSlice(sigBytes) + + return NewPartialSigWithNonce( + RandMusig2Nonce(t), *sigScalar, + ) +} + +// RandPubKey generates a random public key using rapid's generators. +func RandPubKey(t *rapid.T) *btcec.PublicKey { + privKeyBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw( + t, "privKeyBytes", + ) + _, pub := btcec.PrivKeyFromBytes(privKeyBytes) + + return pub +} + +// RandChannelID generates a random channel ID. +func RandChannelID(t *rapid.T) ChannelID { + var c ChannelID + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "channelID") + copy(c[:], bytes) + + return c +} + +// RandShortChannelID generates a random short channel ID. +func RandShortChannelID(t *rapid.T) ShortChannelID { + return NewShortChanIDFromInt( + uint64(rapid.IntRange(1, 100000).Draw(t, "shortChanID")), + ) +} + +// RandFeatureVector generates a random feature vector. +func RandFeatureVector(t *rapid.T) *RawFeatureVector { + featureVec := NewRawFeatureVector() + + // Add a random number of random feature bits + numFeatures := rapid.IntRange(0, 20).Draw(t, "numFeatures") + for i := 0; i < numFeatures; i++ { + bit := FeatureBit(rapid.IntRange(0, 100).Draw( + t, fmt.Sprintf("featureBit-%d", i)), + ) + featureVec.Set(bit) + } + + return featureVec +} + +// RandSignature generates a signature for testing. +func RandSignature(t *rapid.T) Sig { + testRScalar := new(btcec.ModNScalar) + testSScalar := new(btcec.ModNScalar) + + // Generate random bytes for R and S + rBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "rBytes") + sBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "sBytes") + _ = testRScalar.SetByteSlice(rBytes) + _ = testSScalar.SetByteSlice(sBytes) + + testSig := ecdsa.NewSignature(testRScalar, testSScalar) + + sig, err := NewSigFromSignature(testSig) + if err != nil { + panic(fmt.Sprintf("unable to create signature: %v", err)) + } + + return sig +} + +// RandPaymentHash generates a random payment hash. +func RandPaymentHash(t *rapid.T) [32]byte { + var hash [32]byte + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "paymentHash") + copy(hash[:], bytes) + + return hash +} + +// RandPaymentPreimage generates a random payment preimage. +func RandPaymentPreimage(t *rapid.T) [32]byte { + var preimage [32]byte + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "preimage") + copy(preimage[:], bytes) + + return preimage +} + +// RandChainHash generates a random chain hash. +func RandChainHash(t *rapid.T) chainhash.Hash { + var hash [32]byte + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash") + copy(hash[:], bytes) + + return hash +} + +// RandNodeAlias generates a random node alias. +func RandNodeAlias(t *rapid.T) NodeAlias { + var alias NodeAlias + aliasLength := rapid.IntRange(0, 32).Draw(t, "aliasLength") + + aliasBytes := rapid.StringN( + 0, aliasLength, aliasLength, + ).Draw(t, "alias") + + copy(alias[:], aliasBytes) + + return alias +} + +// RandNetAddrs generates random network addresses. +func RandNetAddrs(t *rapid.T) []net.Addr { + numAddresses := rapid.IntRange(0, 5).Draw(t, "numAddresses") + if numAddresses == 0 { + return nil + } + + addresses := make([]net.Addr, numAddresses) + for i := 0; i < numAddresses; i++ { + addressType := rapid.IntRange(0, 1).Draw( + t, fmt.Sprintf("addressType-%d", i), + ) + + switch addressType { + // IPv4. + case 0: + ipBytes := rapid.SliceOfN(rapid.Byte(), 4, 4).Draw( + t, fmt.Sprintf("ipv4-%d", i), + ) + port := rapid.IntRange(1, 65535).Draw( + t, fmt.Sprintf("port-%d", i), + ) + addresses[i] = &net.TCPAddr{ + IP: ipBytes, + Port: port, + } + + // IPv6. + case 1: + ipBytes := rapid.SliceOfN(rapid.Byte(), 16, 16).Draw( + t, fmt.Sprintf("ipv6-%d", i), + ) + port := rapid.IntRange(1, 65535).Draw( + t, fmt.Sprintf("port-%d", i), + ) + addresses[i] = &net.TCPAddr{ + IP: ipBytes, + Port: port, + } + } + } + + return addresses +} + +// RandCustomRecords generates random custom TLV records. +func RandCustomRecords(t *rapid.T, + ignoreRecords fn.Set[uint64], + custom bool) (CustomRecords, fn.Set[uint64]) { + + numRecords := rapid.IntRange(0, 5).Draw(t, "numCustomRecords") + customRecords := make(CustomRecords) + + if numRecords == 0 { + return nil, nil + } + + rangeStart := 0 + rangeStop := int(CustomTypeStart) + if custom { + rangeStart = 70_000 + rangeStop = 100_000 + } + + ignoreSet := fn.NewSet[uint64]() + for i := 0; i < numRecords; i++ { + recordType := uint64( + rapid.IntRange(rangeStart, rangeStop). + Filter(func(i int) bool { + return !ignoreRecords.Contains( + uint64(i), + ) + }). + Draw( + t, fmt.Sprintf("recordType-%d", i), + ), + ) + recordLen := rapid.IntRange(4, 64).Draw( + t, fmt.Sprintf("recordLen-%d", i), + ) + record := rapid.SliceOfN( + rapid.Byte(), recordLen, recordLen, + ).Draw(t, fmt.Sprintf("record-%d", i)) + + customRecords[recordType] = record + + ignoreSet.Add(recordType) + } + + return customRecords, ignoreSet +} + +// RandMusig2Nonce generates a random musig2 nonce. +func RandMusig2Nonce(t *rapid.T) Musig2Nonce { + var nonce Musig2Nonce + bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "nonce") + copy(nonce[:], bytes) + + return nonce +} + +// RandExtraOpaqueData generates random extra opaque data. +func RandExtraOpaqueData(t *rapid.T, + ignoreRecords fn.Set[uint64]) ExtraOpaqueData { + + // Make some random records. + cRecords, _ := RandCustomRecords(t, ignoreRecords, false) + if cRecords == nil { + return ExtraOpaqueData{} + } + + // Encode those records as opaque data. + recordBytes, err := cRecords.Serialize() + require.NoError(t, err) + + return ExtraOpaqueData(recordBytes) +} + +// RandOpaqueReason generates a random opaque reason for HTLC failures. +func RandOpaqueReason(t *rapid.T) OpaqueReason { + reasonLen := rapid.IntRange(32, 300).Draw(t, "reasonLen") + return rapid.SliceOfN(rapid.Byte(), reasonLen, reasonLen).Draw( + t, "opaqueReason", + ) +} + +// RandFailCode generates a random HTLC failure code. +func RandFailCode(t *rapid.T) FailCode { + // List of known failure codes to choose from Using only the documented + // codes. + validCodes := []FailCode{ + CodeInvalidRealm, + CodeTemporaryNodeFailure, + CodePermanentNodeFailure, + CodeRequiredNodeFeatureMissing, + CodePermanentChannelFailure, + CodeRequiredChannelFeatureMissing, + CodeUnknownNextPeer, + CodeIncorrectOrUnknownPaymentDetails, + CodeIncorrectPaymentAmount, + CodeFinalExpiryTooSoon, + CodeInvalidOnionVersion, + CodeInvalidOnionHmac, + CodeInvalidOnionKey, + CodeTemporaryChannelFailure, + CodeChannelDisabled, + CodeExpiryTooSoon, + CodeMPPTimeout, + CodeInvalidOnionPayload, + CodeFeeInsufficient, + } + + // Choose a random code from the list. + idx := rapid.IntRange(0, len(validCodes)-1).Draw(t, "failCodeIndex") + + return validCodes[idx] +} + +// RandSHA256Hash generates a random SHA256 hash. +func RandSHA256Hash(t *rapid.T) [sha256.Size]byte { + var hash [sha256.Size]byte + bytes := rapid.SliceOfN(rapid.Byte(), sha256.Size, sha256.Size).Draw( + t, "sha256Hash", + ) + copy(hash[:], bytes) + + return hash +} + +// RandDeliveryAddress generates a random delivery address (script). +func RandDeliveryAddress(t *rapid.T) DeliveryAddress { + addrLen := rapid.IntRange(1, 34).Draw(t, "addrLen") + + return rapid.SliceOfN(rapid.Byte(), addrLen, addrLen).Draw( + t, "deliveryAddress", + ) +} + +// RandChannelType generates a random channel type. +func RandChannelType(t *rapid.T) *ChannelType { + vec := RandFeatureVector(t) + chanType := ChannelType(*vec) + + return &chanType +} + +// RandLeaseExpiry generates a random lease expiry. +func RandLeaseExpiry(t *rapid.T) *LeaseExpiry { + exp := LeaseExpiry( + uint32(rapid.IntRange(1000, 1000000).Draw(t, "leaseExpiry")), + ) + + return &exp +} + +// RandOutPoint generates a random transaction outpoint. +func RandOutPoint(t *rapid.T) wire.OutPoint { + // Generate a random transaction ID + var txid chainhash.Hash + txidBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "txid") + copy(txid[:], txidBytes) + + // Generate a random output index + vout := uint32(rapid.IntRange(0, 10).Draw(t, "vout")) + + return wire.OutPoint{ + Hash: txid, + Index: vout, + } +} diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 7976a13c5..e627dbf4e 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -110,10 +110,6 @@ func NewUpdateAddHTLC() *UpdateAddHTLC { // interface. var _ Message = (*UpdateAddHTLC)(nil) -// A compile time check to ensure UpdateAddHTLC implements the lnwire.SizeableMessage -// interface. -var _ SizeableMessage = (*UpdateAddHTLC)(nil) - // Decode deserializes a serialized UpdateAddHTLC message stored in the passed // io.Reader observing the specified protocol version. // @@ -223,3 +219,7 @@ func (c *UpdateAddHTLC) TargetChanID() ChannelID { func (c *UpdateAddHTLC) SerializedSize() (uint32, error) { return MessageSerializedSize(c) } + +// A compile time check to ensure UpdateAddHTLC implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateAddHTLC)(nil) diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 397c70084..1d26444ba 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -38,8 +38,8 @@ type UpdateFailHTLC struct { // interface. var _ Message = (*UpdateFailHTLC)(nil) -// A compile time check to ensure UpdateFailHTLC implements the lnwire.SizeableMessage -// interface. +// A compile time check to ensure UpdateFailHTLC implements the +// lnwire.SizeableMessage interface. var _ SizeableMessage = (*UpdateFailHTLC)(nil) // Decode deserializes a serialized UpdateFailHTLC message stored in the passed @@ -55,8 +55,8 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { ) } -// Encode serializes the target UpdateFailHTLC into the passed io.Writer observing -// the protocol version specified. +// Encode serializes the target UpdateFailHTLC into the passed io.Writer +// observing the protocol version specified. // // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error { From b2f24789dc087ae00ee78e6d9542c14f1db3f98f Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 20 Mar 2025 15:02:05 -0700 Subject: [PATCH 3/4] lnwire: revamp TestLightningWireProtocol using new rapid test gen With what we added in the prior commit, we can significantly shrink the size of this test. We also make it easier to extend in the future, as this will fail if a new message is added, that doesn't have the needed methods, as long as MsgEnd is updated. --- lnwire/lnwire_test.go | 1828 ++--------------------------------------- 1 file changed, 53 insertions(+), 1775 deletions(-) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 123689902..7f068e526 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -3,31 +3,20 @@ package lnwire import ( "bytes" crand "crypto/rand" - "encoding/binary" "encoding/hex" - "fmt" - "image/color" - "io" "math" "math/rand" "net" - "reflect" "testing" - "testing/quick" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" - "github.com/btcsuite/btcd/btcutil" - "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn/v2" - "github.com/lightningnetwork/lnd/lnwallet/chainfee" - "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tor" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "pgregory.net/rapid" ) var ( @@ -54,106 +43,6 @@ var ( const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -func randLocalNonce(r *rand.Rand) Musig2Nonce { - var nonce Musig2Nonce - _, _ = io.ReadFull(r, nonce[:]) - - return nonce -} - -func someLocalNonce[T tlv.TlvType]( - r *rand.Rand) tlv.OptionalRecordT[T, Musig2Nonce] { - - return tlv.SomeRecordT(tlv.NewRecordT[T, Musig2Nonce]( - randLocalNonce(r), - )) -} - -func randPartialSig(r *rand.Rand) (*PartialSig, error) { - var sigBytes [32]byte - if _, err := r.Read(sigBytes[:]); err != nil { - return nil, fmt.Errorf("unable to generate sig: %w", err) - } - - var s btcec.ModNScalar - s.SetByteSlice(sigBytes[:]) - - return &PartialSig{ - Sig: s, - }, nil -} - -func somePartialSig(t *testing.T, - r *rand.Rand) tlv.OptionalRecordT[PartialSigType, PartialSig] { - - sig, err := randPartialSig(r) - if err != nil { - t.Fatal(err) - } - - return tlv.SomeRecordT(tlv.NewRecordT[PartialSigType, PartialSig]( - *sig, - )) -} - -func randPartialSigWithNonce(r *rand.Rand) (*PartialSigWithNonce, error) { - var sigBytes [32]byte - if _, err := r.Read(sigBytes[:]); err != nil { - return nil, fmt.Errorf("unable to generate sig: %w", err) - } - - var s btcec.ModNScalar - s.SetByteSlice(sigBytes[:]) - - return &PartialSigWithNonce{ - PartialSig: NewPartialSig(s), - Nonce: randLocalNonce(r), - }, nil -} - -func somePartialSigWithNonce(t *testing.T, - r *rand.Rand) OptPartialSigWithNonceTLV { - - sig, err := randPartialSigWithNonce(r) - if err != nil { - t.Fatal(err) - } - - return tlv.SomeRecordT( - tlv.NewRecordT[PartialSigWithNonceType, PartialSigWithNonce]( - *sig, - ), - ) -} - -func randAlias(r *rand.Rand) NodeAlias { - var a NodeAlias - for i := range a { - a[i] = letterBytes[r.Intn(len(letterBytes))] - } - - return a -} - -func randPubKey() (*btcec.PublicKey, error) { - priv, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } - - return priv.PubKey(), nil -} - -// pubkeyFromHex parses a Bitcoin public key from a hex encoded string. -func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) { - pubKeyBytes, err := hex.DecodeString(keyHex) - if err != nil { - return nil, err - } - - return btcec.ParsePubKey(pubKeyBytes) -} - // generateRandomBytes returns a slice of n random bytes. func generateRandomBytes(n int) ([]byte, error) { b := make([]byte, n) @@ -176,140 +65,23 @@ func randRawKey(t *testing.T) [33]byte { return n } -func randDeliveryAddress(r *rand.Rand) (DeliveryAddress, error) { - // Generate size minimum one. Empty scripts should be tested specifically. - size := r.Intn(deliveryAddressMaxSize) + 1 - da := DeliveryAddress(make([]byte, size)) - - _, err := r.Read(da) - return da, err -} - -func randRawFeatureVector(r *rand.Rand) *RawFeatureVector { - featureVec := NewRawFeatureVector() - for i := 0; i < 10000; i++ { - if r.Int31n(2) == 0 { - featureVec.Set(FeatureBit(i)) - } - } - return featureVec -} - -func randTCP4Addr(r *rand.Rand) (*net.TCPAddr, error) { - var ip [4]byte - if _, err := r.Read(ip[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - addrIP := net.IP(ip[:]) - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil -} - -func randTCP6Addr(r *rand.Rand) (*net.TCPAddr, error) { - var ip [16]byte - if _, err := r.Read(ip[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - addrIP := net.IP(ip[:]) - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil -} - -func randV2OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { - var serviceID [tor.V2DecodedLen]byte - if _, err := r.Read(serviceID[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) - onionService += tor.OnionSuffix - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil -} - -func randV3OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { - var serviceID [tor.V3DecodedLen]byte - if _, err := r.Read(serviceID[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) - onionService += tor.OnionSuffix - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil -} - -func randOpaqueAddr(r *rand.Rand) (*OpaqueAddrs, error) { - payloadLen := r.Int63n(64) + 1 - payload := make([]byte, payloadLen) - - // The first byte is the address type. So set it to one that we - // definitely don't know about. - payload[0] = math.MaxUint8 - - // Generate random bytes for the rest of the payload. - if _, err := r.Read(payload[1:]); err != nil { - return nil, err - } - - return &OpaqueAddrs{Payload: payload}, nil -} - -func randAddrs(r *rand.Rand) ([]net.Addr, error) { - tcp4Addr, err := randTCP4Addr(r) +func randPubKey() (*btcec.PublicKey, error) { + priv, err := btcec.NewPrivateKey() if err != nil { return nil, err } - tcp6Addr, err := randTCP6Addr(r) + return priv.PubKey(), nil +} + +// pubkeyFromHex parses a Bitcoin public key from a hex encoded string. +func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) { + pubKeyBytes, err := hex.DecodeString(keyHex) if err != nil { return nil, err } - v2OnionAddr, err := randV2OnionAddr(r) - if err != nil { - return nil, err - } - - v3OnionAddr, err := randV3OnionAddr(r) - if err != nil { - return nil, err - } - - opaqueAddrs, err := randOpaqueAddr(r) - if err != nil { - return nil, err - } - - return []net.Addr{ - tcp4Addr, tcp6Addr, v2OnionAddr, v3OnionAddr, opaqueAddrs, - }, nil + return btcec.ParsePubKey(pubKeyBytes) } // TestChanUpdateChanFlags ensures that converting the ChanUpdateChanFlags and @@ -418,1558 +190,64 @@ func TestEmptyMessageUnknownType(t *testing.T) { } } -// randCustomRecords generates a random set of custom records for testing. -func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords { - var ( - customRecords = CustomRecords{} - - // We'll generate a random number of records, between 1 and 10. - numRecords = r.Intn(9) + 1 - ) - - // For each record, we'll generate a random key and value. - for i := 0; i < numRecords; i++ { - // Keys must be equal to or greater than - // MinCustomRecordsTlvType. - keyOffset := uint64(r.Intn(100)) - key := MinCustomRecordsTlvType + keyOffset - - // Values are byte slices of any length. - value := make([]byte, r.Intn(10)) - _, err := r.Read(value) - require.NoError(t, err) - - customRecords[key] = value - } - - // Validate the custom records as a sanity check. - err := customRecords.Validate() - require.NoError(t, err) - - return customRecords -} - -// TestLightningWireProtocol uses the testing/quick package to create a series -// of fuzz tests to attempt to break a primary scenario which is implemented as -// property based testing scenario. +// TestLightningWireProtocol uses the rapid property-based testing framework to +// verify that all message types can be serialized and deserialized correctly. func TestLightningWireProtocol(t *testing.T) { t.Parallel() - // mainScenario is the primary test that will programmatically be - // executed for all registered wire messages. The quick-checker within - // testing/quick will attempt to find an input to this function, s.t - // the function returns false, if so then we've found an input that - // violates our model of the system. - mainScenario := func(msg Message) bool { - // Give a new message, we'll serialize the message into a new - // bytes buffer. - var b bytes.Buffer - if _, err := WriteMessage(&b, msg, 0); err != nil { - t.Fatalf("unable to write msg: %v", err) - return false + for msgType := MessageType(0); msgType < MsgEnd; msgType++ { + // If MakeEmptyMessage returns an error, then this isn't yet a + // used message type. + if _, err := MakeEmptyMessage(msgType); err != nil { + continue } - // Next, we'll ensure that the serialized payload (subtracting - // the 2 bytes for the message type) is _below_ the specified - // max payload size for this message. - payloadLen := uint32(b.Len()) - 2 - if payloadLen > MaxMsgBody { - t.Fatalf("msg payload constraint violated: %v > %v", - payloadLen, MaxMsgBody) - return false - } + t.Run(msgType.String(), rapid.MakeCheck(func(t *rapid.T) { + // Create an empty message of the given type. + m, err := MakeEmptyMessage(msgType) - // Finally, we'll deserialize the message from the written - // buffer, and finally assert that the messages are equal. - newMsg, err := ReadMessage(&b, 0) - if err != nil { - t.Fatalf("unable to read msg: %v", err) - return false - } - if !assert.Equalf(t, msg, newMsg, "message mismatch") { - return false - } - - return true - } - - // customTypeGen is a map of functions that are able to randomly - // generate a given type. These functions are needed for types which - // are too complex for the testing/quick package to automatically - // generate. - customTypeGen := map[MessageType]func([]reflect.Value, *rand.Rand){ - MsgStfu: func(v []reflect.Value, r *rand.Rand) { - req := Stfu{} - if _, err := r.Read(req.ChanID[:]); err != nil { - t.Fatalf("unable to generate ChanID: %v", err) + // An error means this isn't a valid message type, so we + // skip it. + if err != nil { + return } - // 1/2 chance of being initiator - req.Initiator = r.Intn(2) == 1 - - // 1/2 chance additional TLV data. - if r.Intn(2) == 0 { - req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} - } - - v[0] = reflect.ValueOf(req) - }, - MsgInit: func(v []reflect.Value, r *rand.Rand) { - req := NewInitMessage( - randRawFeatureVector(r), - randRawFeatureVector(r), + // The message must support the message type interface. + testMsg, ok := m.(TestMessage) + require.True( + t, ok, "message %v doesn't support TestMessage", + msgType, ) - v[0] = reflect.ValueOf(*req) - }, - MsgOpenChannel: func(v []reflect.Value, r *rand.Rand) { - req := OpenChannel{ - FundingAmount: btcutil.Amount(r.Int63()), - PushAmount: MilliSatoshi(r.Int63()), - DustLimit: btcutil.Amount(r.Int63()), - MaxValueInFlight: MilliSatoshi(r.Int63()), - ChannelReserve: btcutil.Amount(r.Int63()), - HtlcMinimum: MilliSatoshi(r.Int31()), - FeePerKiloWeight: uint32(r.Int63()), - CsvDelay: uint16(r.Int31()), - MaxAcceptedHTLCs: uint16(r.Int31()), - ChannelFlags: FundingFlag(uint8(r.Int31())), - } - - if _, err := r.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to generate chain hash: %v", err) - return - } - - if _, err := r.Read(req.PendingChannelID[:]); err != nil { - t.Fatalf("unable to generate pending chan id: %v", err) - return - } - - var err error - req.FundingKey, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.RevocationPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.PaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.DelayedPaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.HtlcPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.FirstCommitmentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - // 1/2 chance empty TLV records. - if r.Intn(2) == 0 { - req.UpfrontShutdownScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery address: %v", err) - return - } - - req.ChannelType = new(ChannelType) - *req.ChannelType = ChannelType(*randRawFeatureVector(r)) - - req.LeaseExpiry = new(LeaseExpiry) - *req.LeaseExpiry = LeaseExpiry(1337) - - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } else { - req.UpfrontShutdownScript = []byte{} - } - - // 1/2 chance additional TLV data. - if r.Intn(2) == 0 { - req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} - } - - v[0] = reflect.ValueOf(req) - }, - MsgAcceptChannel: func(v []reflect.Value, r *rand.Rand) { - req := AcceptChannel{ - DustLimit: btcutil.Amount(r.Int63()), - MaxValueInFlight: MilliSatoshi(r.Int63()), - ChannelReserve: btcutil.Amount(r.Int63()), - MinAcceptDepth: uint32(r.Int31()), - HtlcMinimum: MilliSatoshi(r.Int31()), - CsvDelay: uint16(r.Int31()), - MaxAcceptedHTLCs: uint16(r.Int31()), - } - - if _, err := r.Read(req.PendingChannelID[:]); err != nil { - t.Fatalf("unable to generate pending chan id: %v", err) - return - } - - var err error - req.FundingKey, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.RevocationPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.PaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.DelayedPaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.HtlcPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.FirstCommitmentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - // 1/2 chance empty TLV records. - if r.Intn(2) == 0 { - req.UpfrontShutdownScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery address: %v", err) - return - } - - req.ChannelType = new(ChannelType) - *req.ChannelType = ChannelType(*randRawFeatureVector(r)) - - req.LeaseExpiry = new(LeaseExpiry) - *req.LeaseExpiry = LeaseExpiry(1337) - - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } else { - req.UpfrontShutdownScript = []byte{} - } - - // 1/2 chance additional TLV data. - if r.Intn(2) == 0 { - req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} - } - - v[0] = reflect.ValueOf(req) - }, - MsgFundingCreated: func(v []reflect.Value, r *rand.Rand) { - req := FundingCreated{ - ExtraData: make([]byte, 0), - } - - if _, err := r.Read(req.PendingChannelID[:]); err != nil { - t.Fatalf("unable to generate pending chan id: %v", err) - return - } - - if _, err := r.Read(req.FundingPoint.Hash[:]); err != nil { - t.Fatalf("unable to generate hash: %v", err) - return - } - req.FundingPoint.Index = uint32(r.Int31()) % math.MaxUint16 - - var err error - req.CommitSig, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - // 1/2 chance to attach a partial sig. - if r.Intn(2) == 0 { - req.PartialSig = somePartialSigWithNonce(t, r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgFundingSigned: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - req := FundingSigned{ - ChanID: ChannelID(c), - ExtraData: make([]byte, 0), - } - req.CommitSig, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - // 1/2 chance to attach a partial sig. - if r.Intn(2) == 0 { - req.PartialSig = somePartialSigWithNonce(t, r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelReady: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - require.NoError(t, err) - - pubKey, err := randPubKey() - require.NoError(t, err) - - req := NewChannelReady(c, pubKey) - - if r.Int31()%2 == 0 { - scid := NewShortChanIDFromInt(uint64(r.Int63())) - req.AliasScid = &scid - - //nolint:ll - req.NextLocalNonce = someLocalNonce[NonceRecordTypeT](r) - } - - if r.Int31()%2 == 0 { - nodeNonce := tlv.ZeroRecordT[ - tlv.TlvType0, Musig2Nonce, - ]() - nodeNonce.Val = randLocalNonce(r) - req.AnnouncementNodeNonce = tlv.SomeRecordT( - nodeNonce, - ) - - btcNonce := tlv.ZeroRecordT[ - tlv.TlvType2, Musig2Nonce, - ]() - btcNonce.Val = randLocalNonce(r) - req.AnnouncementBitcoinNonce = tlv.SomeRecordT( - btcNonce, - ) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgShutdown: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - shutdownAddr, err := randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - - req := Shutdown{ - ChannelID: ChannelID(c), - Address: shutdownAddr, - } - - if r.Int31()%2 == 0 { - //nolint:ll - req.ShutdownNonce = someLocalNonce[ShutdownNonceType](r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgClosingSigned: func(v []reflect.Value, r *rand.Rand) { - req := ClosingSigned{ - FeeSatoshis: btcutil.Amount(r.Int63()), - ExtraData: make([]byte, 0), - } - var err error - req.Signature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChannelID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - if r.Int31()%2 == 0 { - req.PartialSig = somePartialSig(t, r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgDynPropose: func(v []reflect.Value, r *rand.Rand) { - var dp DynPropose - rand.Read(dp.ChanID[:]) - - if rand.Uint32()%2 == 0 { - v := btcutil.Amount(rand.Uint32()) - dp.DustLimit = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := MilliSatoshi(rand.Uint32()) - dp.MaxValueInFlight = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := btcutil.Amount(rand.Uint32()) - dp.ChannelReserve = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := uint16(rand.Uint32()) - dp.CsvDelay = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := uint16(rand.Uint32()) - dp.MaxAcceptedHTLCs = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v, _ := btcec.NewPrivateKey() - dp.FundingKey = fn.Some(*v.PubKey()) - } - - if rand.Uint32()%2 == 0 { - v := ChannelType(*NewRawFeatureVector()) - dp.ChannelType = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := chainfee.SatPerKWeight(rand.Uint32()) - dp.KickoffFeerate = fn.Some(v) - } - - v[0] = reflect.ValueOf(dp) - }, - MsgDynReject: func(v []reflect.Value, r *rand.Rand) { - var dr DynReject - rand.Read(dr.ChanID[:]) - - features := NewRawFeatureVector() - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPDustLimitSatoshis)) - } - - if rand.Uint32()%2 == 0 { - features.Set( - FeatureBit(DPMaxHtlcValueInFlightMsat), - ) - } - - if rand.Uint32()%2 == 0 { - features.Set( - FeatureBit(DPChannelReserveSatoshis), - ) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPToSelfDelay)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPMaxAcceptedHtlcs)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPFundingPubkey)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPChannelType)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPKickoffFeerate)) - } - dr.UpdateRejections = *features - - v[0] = reflect.ValueOf(dr) - }, - MsgDynAck: func(v []reflect.Value, r *rand.Rand) { - var da DynAck - - rand.Read(da.ChanID[:]) - if rand.Uint32()%2 == 0 { - var nonce Musig2Nonce - rand.Read(nonce[:]) - da.LocalNonce = fn.Some(nonce) - } - - v[0] = reflect.ValueOf(da) - }, - MsgKickoffSig: func(v []reflect.Value, r *rand.Rand) { - ks := KickoffSig{ - ExtraData: make([]byte, 0), - } - - rand.Read(ks.ChanID[:]) - rand.Read(ks.Signature.bytes[:]) - - v[0] = reflect.ValueOf(ks) - }, - MsgCommitSig: func(v []reflect.Value, r *rand.Rand) { - req := NewCommitSig() - if _, err := r.Read(req.ChanID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - var err error - req.CommitSig, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - // Only create the slice if there will be any signatures - // in it to prevent false positive test failures due to - // an empty slice versus a nil slice. - numSigs := uint16(r.Int31n(500)) - if numSigs > 0 { - req.HtlcSigs = make([]Sig, numSigs) - } - for i := 0; i < int(numSigs); i++ { - req.HtlcSigs[i], err = NewSigFromSignature( - testSig, - ) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - } - - req.CustomRecords = randCustomRecords(t, r) - - // 50/50 chance to attach a partial sig. - if r.Int31()%2 == 0 { - req.PartialSig = somePartialSigWithNonce(t, r) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgRevokeAndAck: func(v []reflect.Value, r *rand.Rand) { - req := NewRevokeAndAck() - if _, err := r.Read(req.ChanID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - if _, err := r.Read(req.Revocation[:]); err != nil { - t.Fatalf("unable to generate bytes: %v", err) - return - } - var err error - req.NextRevocationKey, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - // 50/50 chance to attach a local nonce. - if r.Int31()%2 == 0 { - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgChannelAnnouncement: func(v []reflect.Value, r *rand.Rand) { - var err error - req := ChannelAnnouncement1{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - NodeID1: randRawKey(t), - NodeID2: randRawKey(t), - BitcoinKey1: randRawKey(t), - BitcoinKey2: randRawKey(t), - Features: randRawFeatureVector(r), - ExtraOpaqueData: make([]byte, 0), - } - req.NodeSig1, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - req.NodeSig2, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - req.BitcoinSig1, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - req.BitcoinSig2, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to generate chain hash: %v", err) - return - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgNodeAnnouncement: func(v []reflect.Value, r *rand.Rand) { - var err error - req := NodeAnnouncement{ - NodeID: randRawKey(t), - Features: randRawFeatureVector(r), - Timestamp: uint32(r.Int31()), - Alias: randAlias(r), - RGBColor: color.RGBA{ - R: uint8(r.Int31()), - G: uint8(r.Int31()), - B: uint8(r.Int31()), - }, - ExtraOpaqueData: make([]byte, 0), - } - req.Signature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - req.Addresses, err = randAddrs(r) - if err != nil { - t.Fatalf("unable to generate addresses: %v", err) - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelUpdate: func(v []reflect.Value, r *rand.Rand) { - var err error - - msgFlags := ChanUpdateMsgFlags(r.Int31()) - maxHtlc := MilliSatoshi(r.Int63()) - - // We make the max_htlc field zero if it is not flagged - // as being part of the ChannelUpdate, to pass - // serialization tests, as it will be ignored if the bit - // is not set. - if msgFlags&ChanUpdateRequiredMaxHtlc == 0 { - maxHtlc = 0 - } - - req := ChannelUpdate1{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - Timestamp: uint32(r.Int31()), - MessageFlags: msgFlags, - ChannelFlags: ChanUpdateChanFlags(r.Int31()), - TimeLockDelta: uint16(r.Int31()), - HtlcMinimumMsat: MilliSatoshi(r.Int63()), - HtlcMaximumMsat: maxHtlc, - BaseFee: uint32(r.Int31()), - FeeRate: uint32(r.Int31()), - ExtraOpaqueData: make([]byte, 0), - } - req.Signature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to generate chain hash: %v", err) - return - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgAnnounceSignatures: func(v []reflect.Value, r *rand.Rand) { - var err error - req := AnnounceSignatures1{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), - ExtraOpaqueData: make([]byte, 0), - } - - req.NodeSignature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - req.BitcoinSignature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChannelID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelReestablish: func(v []reflect.Value, r *rand.Rand) { - req := ChannelReestablish{ - NextLocalCommitHeight: uint64(r.Int63()), - RemoteCommitTailHeight: uint64(r.Int63()), - ExtraData: make([]byte, 0), - } - - // With a 50/50 probability, we'll include the - // additional fields so we can test our ability to - // properly parse, and write out the optional fields. - if r.Int()%2 == 0 { - _, err := r.Read(req.LastRemoteCommitSecret[:]) - if err != nil { - t.Fatalf("unable to read commit secret: %v", err) - return - } - - req.LocalUnrevokedCommitPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgGossipTimestampRange: func(v []reflect.Value, r *rand.Rand) { - req := GossipTimestampRange{ - FirstTimestamp: rand.Uint32(), - TimestampRange: rand.Uint32(), - ExtraData: make([]byte, 0), - } - - _, err := rand.Read(req.ChainHash[:]) - require.NoError(t, err) - - // Sometimes add a block range. - if r.Int31()%2 == 0 { - firstBlock := tlv.ZeroRecordT[ - tlv.TlvType2, uint32, - ]() - firstBlock.Val = rand.Uint32() - req.FirstBlockHeight = tlv.SomeRecordT( - firstBlock, - ) - - blockRange := tlv.ZeroRecordT[ - tlv.TlvType4, uint32, - ]() - blockRange.Val = rand.Uint32() - req.BlockRange = tlv.SomeRecordT(blockRange) - } - - v[0] = reflect.ValueOf(req) - }, - MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { - req := QueryShortChanIDs{ - ExtraData: make([]byte, 0), - } - - // With a 50/50 change, we'll either use zlib encoding, - // or regular encoding. - if r.Int31()%2 == 0 { - req.EncodingType = EncodingSortedZlib - } else { - req.EncodingType = EncodingSortedPlain - } - - if _, err := rand.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to read chain hash: %v", err) - return - } - - numChanIDs := rand.Int31n(5000) - for i := int32(0); i < numChanIDs; i++ { - req.ShortChanIDs = append(req.ShortChanIDs, - NewShortChanIDFromInt(uint64(r.Int63()))) - } - - v[0] = reflect.ValueOf(req) - }, - MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) { - req := ReplyChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - ExtraData: make([]byte, 0), - } - - if _, err := rand.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to read chain hash: %v", err) - return - } - - req.Complete = uint8(r.Int31n(2)) - - // With a 50/50 change, we'll either use zlib encoding, - // or regular encoding. - if r.Int31()%2 == 0 { - req.EncodingType = EncodingSortedZlib - } else { - req.EncodingType = EncodingSortedPlain - } - - numChanIDs := rand.Int31n(4000) - for i := int32(0); i < numChanIDs; i++ { - req.ShortChanIDs = append(req.ShortChanIDs, - NewShortChanIDFromInt(uint64(r.Int63()))) - } - - // With a 50/50 chance, add some timestamps. - if r.Int31()%2 == 0 { - for i := int32(0); i < numChanIDs; i++ { - timestamps := ChanUpdateTimestamps{ - Timestamp1: rand.Uint32(), - Timestamp2: rand.Uint32(), - } - req.Timestamps = append( - req.Timestamps, timestamps, - ) - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgQueryChannelRange: func(v []reflect.Value, r *rand.Rand) { - req := QueryChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - ExtraData: make([]byte, 0), - } - - _, err := rand.Read(req.ChainHash[:]) - require.NoError(t, err) - - // With a 50/50 change, we'll set a query option. - if r.Int31()%2 == 0 { - req.QueryOptions = NewTimestampQueryOption() - } - - v[0] = reflect.ValueOf(req) - }, - MsgPing: func(v []reflect.Value, r *rand.Rand) { - // We use a special message generator here to ensure we - // don't generate ping messages that are too large, - // which'll cause the test to fail. - // - // We'll allow the test to generate padding bytes up to - // the max message limit, factoring in the 2 bytes for - // the num pong bytes and 2 bytes for encoding the - // length of the padding bytes. - paddingBytes := make([]byte, rand.Intn(MaxMsgBody-3)) - req := Ping{ - NumPongBytes: uint16(r.Intn(MaxPongBytes + 1)), - PaddingBytes: paddingBytes, - } - - v[0] = reflect.ValueOf(req) - }, - MsgClosingComplete: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", - err) - return - } - - req := ClosingComplete{ - ChannelID: ChannelID(c), - FeeSatoshis: btcutil.Amount(r.Int63()), - LockTime: uint32(r.Int63()), - ClosingSigs: ClosingSigs{}, - } - req.CloserScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - req.CloseeScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - - if r.Intn(2) == 0 { - sig := req.CloserNoClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserNoClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.NoCloserClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.NoCloserClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.CloserAndClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserAndClosee = tlv.SomeRecordT(sig) - } - - v[0] = reflect.ValueOf(req) - }, - MsgClosingSig: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - req := ClosingSig{ - ChannelID: ChannelID(c), - ClosingSigs: ClosingSigs{}, - FeeSatoshis: btcutil.Amount(r.Int63()), - LockTime: uint32(r.Int63()), - } - req.CloserScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - req.CloseeScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - - if r.Intn(2) == 0 { - sig := req.CloserNoClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserNoClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.NoCloserClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.NoCloserClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.CloserAndClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserAndClosee = tlv.SomeRecordT(sig) - } - - v[0] = reflect.ValueOf(req) - }, - MsgUpdateAddHTLC: func(v []reflect.Value, r *rand.Rand) { - req := &UpdateAddHTLC{ - ID: r.Uint64(), - Amount: MilliSatoshi(r.Uint64()), - Expiry: r.Uint32(), - } - - _, err := r.Read(req.ChanID[:]) - require.NoError(t, err) - - _, err = r.Read(req.PaymentHash[:]) - require.NoError(t, err) - - _, err = r.Read(req.OnionBlob[:]) - require.NoError(t, err) - - req.CustomRecords = randCustomRecords(t, r) - - // Generate a blinding point 50% of the time, since not - // all update adds will use route blinding. - if r.Int31()%2 == 0 { - pubkey, err := randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", - err) - - return - } - - req.BlindingPoint = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType0]( - pubkey, - ), - ) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgUpdateFulfillHTLC: func(v []reflect.Value, r *rand.Rand) { - req := &UpdateFulfillHTLC{ - ID: r.Uint64(), - } - - _, err := r.Read(req.ChanID[:]) - require.NoError(t, err) - - _, err = r.Read(req.PaymentPreimage[:]) - require.NoError(t, err) - - req.CustomRecords = randCustomRecords(t, r) - - // Generate some random TLV records 50% of the time. - if r.Int31()%2 == 0 { - req.ExtraData = []byte{ - 0x01, 0x03, 1, 2, 3, - 0x02, 0x03, 4, 5, 6, - } - } - - v[0] = reflect.ValueOf(*req) - }, - MsgAnnounceSignatures2: func(v []reflect.Value, - r *rand.Rand) { - - req := AnnounceSignatures2{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - ExtraOpaqueData: make([]byte, 0), - } - - _, err := r.Read(req.ChannelID[:]) - require.NoError(t, err) - - partialSig, err := randPartialSig(r) - require.NoError(t, err) - - req.PartialSignature = *partialSig - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelAnnouncement2: func(v []reflect.Value, r *rand.Rand) { - req := ChannelAnnouncement2{ - Signature: testSchnorrSig, - ExtraOpaqueData: make([]byte, 0), - } - - req.ShortChannelID.Val = NewShortChanIDFromInt( - uint64(r.Int63()), + // Use the RandTestMessage method to create a randomized + // message. + msg := testMsg.RandTestMessage(t) + + // Serialize the message to a buffer. + var b bytes.Buffer + writtenBytes, err := WriteMessage(&b, msg, 0) + require.NoError(t, err, "unable to write msg") + + // Check that the serialized payload is below the max + // payload size, accounting for the message type size. + payloadLen := uint32(writtenBytes) - 2 + require.LessOrEqual( + t, payloadLen, uint32(MaxMsgBody), + "msg payload constraint violated: %v > %v", + payloadLen, MaxMsgBody, ) - req.Capacity.Val = rand.Uint64() - req.Features.Val = *randRawFeatureVector(r) + // Deserialize the message from the buffer. + newMsg, err := ReadMessage(&b, 0) + require.NoError(t, err, "unable to read msg") - req.NodeID1.Val = randRawKey(t) - req.NodeID2.Val = randRawKey(t) - - // Sometimes set chain hash to bitcoin mainnet genesis - // hash. - req.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash - if r.Int31()%2 == 0 { - _, err := r.Read(req.ChainHash.Val[:]) - require.NoError(t, err) - } - - // Sometimes set the bitcoin keys. - if r.Int31()%2 == 0 { - btcKey1 := tlv.ZeroRecordT[ - tlv.TlvType12, [33]byte, - ]() - btcKey1.Val = randRawKey(t) - req.BitcoinKey1 = tlv.SomeRecordT(btcKey1) - - btcKey2 := tlv.ZeroRecordT[ - tlv.TlvType14, [33]byte, - ]() - btcKey2.Val = randRawKey(t) - req.BitcoinKey2 = tlv.SomeRecordT(btcKey2) - - // Occasionally also set the merkle root hash. - if r.Int31()%2 == 0 { - hash := tlv.ZeroRecordT[ - tlv.TlvType16, [32]byte, - ]() - - _, err := r.Read(hash.Val[:]) - require.NoError(t, err) - - req.MerkleRootHash = tlv.SomeRecordT( - hash, - ) - } - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelUpdate2: func(v []reflect.Value, r *rand.Rand) { - req := ChannelUpdate2{ - Signature: testSchnorrSig, - ExtraOpaqueData: make([]byte, 0), - } - - req.ShortChannelID.Val = NewShortChanIDFromInt( - uint64(r.Int63()), + // Verify the deserialized message matches the original. + require.Equal( + t, msg, newMsg, + "message mismatch for type %s", msgType, ) - req.BlockHeight.Val = r.Uint32() - req.HTLCMaximumMsat.Val = MilliSatoshi(r.Uint64()) - - // Sometimes set chain hash to bitcoin mainnet genesis - // hash. - req.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash - if r.Int31()%2 == 0 { - _, err := r.Read(req.ChainHash.Val[:]) - require.NoError(t, err) - } - - // Sometimes use default htlc min msat. - req.HTLCMinimumMsat.Val = defaultHtlcMinMsat - if r.Int31()%2 == 0 { - req.HTLCMinimumMsat.Val = MilliSatoshi( - r.Uint64(), - ) - } - - // Sometimes set the cltv expiry delta to the default. - req.CLTVExpiryDelta.Val = defaultCltvExpiryDelta - if r.Int31()%2 == 0 { - req.CLTVExpiryDelta.Val = uint16(r.Int31()) - } - - // Sometimes use default fee base. - req.FeeBaseMsat.Val = defaultFeeBaseMsat - if r.Int31()%2 == 0 { - req.FeeBaseMsat.Val = r.Uint32() - } - - // Sometimes use default proportional fee. - req.FeeProportionalMillionths.Val = - defaultFeeProportionalMillionths - if r.Int31()%2 == 0 { - req.FeeProportionalMillionths.Val = r.Uint32() - } - - // Alternate between the two direction possibilities. - if r.Int31()%2 == 0 { - req.SecondPeer = tlv.SomeRecordT( - tlv.ZeroRecordT[tlv.TlvType8, TrueBoolean](), //nolint:ll - ) - } - - // Sometimes set the incoming disabled flag. - if r.Int31()%2 == 0 { - req.DisabledFlags.Val |= - ChanUpdateDisableIncoming - } - - // Sometimes set the outgoing disabled flag. - if r.Int31()%2 == 0 { - req.DisabledFlags.Val |= - ChanUpdateDisableOutgoing - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - - v[0] = reflect.ValueOf(req) - }, + })) } - - // With the above types defined, we'll now generate a slice of - // scenarios to feed into quick.Check. The function scans in input - // space of the target function under test, so we'll need to create a - // series of wrapper functions to force it to iterate over the target - // types, but re-use the mainScenario defined above. - tests := []struct { - msgType MessageType - scenario interface{} - }{ - { - msgType: MsgStfu, - scenario: func(m Stfu) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgInit, - scenario: func(m Init) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgWarning, - scenario: func(m Warning) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgError, - scenario: func(m Error) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgPing, - scenario: func(m Ping) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgPong, - scenario: func(m Pong) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgOpenChannel, - scenario: func(m OpenChannel) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgAcceptChannel, - scenario: func(m AcceptChannel) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgFundingCreated, - scenario: func(m FundingCreated) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgFundingSigned, - scenario: func(m FundingSigned) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelReady, - scenario: func(m ChannelReady) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgShutdown, - scenario: func(m Shutdown) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgClosingSigned, - scenario: func(m ClosingSigned) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgDynPropose, - scenario: func(m DynPropose) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgDynReject, - scenario: func(m DynReject) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgDynAck, - scenario: func(m DynAck) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgKickoffSig, - scenario: func(m KickoffSig) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateAddHTLC, - scenario: func(m UpdateAddHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateFulfillHTLC, - scenario: func(m UpdateFulfillHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateFailHTLC, - scenario: func(m UpdateFailHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgCommitSig, - scenario: func(m CommitSig) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgRevokeAndAck, - scenario: func(m RevokeAndAck) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateFee, - scenario: func(m UpdateFee) bool { - return mainScenario(&m) - }, - }, - { - - msgType: MsgUpdateFailMalformedHTLC, - scenario: func(m UpdateFailMalformedHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelReestablish, - scenario: func(m ChannelReestablish) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelAnnouncement, - scenario: func(m ChannelAnnouncement1) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgNodeAnnouncement, - scenario: func(m NodeAnnouncement) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelUpdate, - scenario: func(m ChannelUpdate1) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgAnnounceSignatures, - scenario: func(m AnnounceSignatures1) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgGossipTimestampRange, - scenario: func(m GossipTimestampRange) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgQueryShortChanIDs, - scenario: func(m QueryShortChanIDs) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgReplyShortChanIDsEnd, - scenario: func(m ReplyShortChanIDsEnd) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgQueryChannelRange, - scenario: func(m QueryChannelRange) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgReplyChannelRange, - scenario: func(m ReplyChannelRange) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgClosingComplete, - scenario: func(m ClosingComplete) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgClosingSig, - scenario: func(m ClosingSig) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgAnnounceSignatures2, - scenario: func(m AnnounceSignatures2) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelAnnouncement2, - scenario: func(m ChannelAnnouncement2) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelUpdate2, - scenario: func(m ChannelUpdate2) bool { - return mainScenario(&m) - }, - }, - } - for _, test := range tests { - t.Run(test.msgType.String(), func(t *testing.T) { - var config *quick.Config - - // If the type defined is within the custom type gen - // map above, then we'll modify the default config to - // use this Value function that knows how to generate - // the proper types. - if valueGen, ok := customTypeGen[test.msgType]; ok { - config = &quick.Config{ - Values: valueGen, - } - } - - t.Logf("Running fuzz tests for msgType=%v", - test.msgType) - - err := quick.Check(test.scenario, config) - if err != nil { - t.Fatalf("fuzz checks for msg=%v failed: %v", - test.msgType, err) - } - }) - } - } func init() { From 05a6b6838fcd7b0e5733e29c36e6626dafcd3e13 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 19 Mar 2025 15:10:53 -0500 Subject: [PATCH 4/4] lnwire: add new TestSerializedSize method This uses all the interfaces and implementations added in the prior test. --- lnwire/pong.go | 3 +- lnwire/serialized_size_test.go | 68 ++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 lnwire/serialized_size_test.go diff --git a/lnwire/pong.go b/lnwire/pong.go index c33e904e1..b5fca24c9 100644 --- a/lnwire/pong.go +++ b/lnwire/pong.go @@ -39,7 +39,8 @@ func NewPong(pongBytes []byte) *Pong { // A compile time check to ensure Pong implements the lnwire.Message interface. var _ Message = (*Pong)(nil) -// A compile time check to ensure Pong implements the lnwire.SizeableMessage interface. +// A compile time check to ensure Pong implements the lnwire.SizeableMessage +// interface. var _ SizeableMessage = (*Pong)(nil) // Decode deserializes a serialized Pong message stored in the passed io.Reader diff --git a/lnwire/serialized_size_test.go b/lnwire/serialized_size_test.go new file mode 100644 index 000000000..cce315fa2 --- /dev/null +++ b/lnwire/serialized_size_test.go @@ -0,0 +1,68 @@ +package lnwire + +import ( + "bytes" + "math" + "testing" + + "github.com/stretchr/testify/require" + "pgregory.net/rapid" +) + +// TestSerializedSize uses property-based testing to verify that +// SerializedSize returns the correct value for randomly generated messages. +func TestSerializedSize(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + // Pick a random message type. + msgType := rapid.Custom(func(t *rapid.T) MessageType { + return MessageType( + rapid.IntRange( + 0, int(math.MaxUint16), + ).Draw(t, "msgType"), + ) + }).Draw(t, "msgType") + + // Create an empty message of the given type. + m, err := MakeEmptyMessage(msgType) + + // An error means this isn't a valid message type, so we skip + // it. + if err != nil { + return + } + + testMsg, ok := m.(TestMessage) + require.True( + t, ok, "message type %s does not "+ + "implement TestMessage", msgType, + ) + + // Use the testMsg to make a new random message. + msg := testMsg.RandTestMessage(t) + + // Type assertion to ensure the message implements + // SizeableMessage. + sizeMsg, ok := msg.(SizeableMessage) + require.True( + t, ok, "message type %s does not "+ + "implement SizeableMessage", msgType, + ) + + // Get the size using SerializedSize. + size, err := sizeMsg.SerializedSize() + require.NoError(t, err, "SerializedSize error") + + // Get the size by actually serializing the message. + var buf bytes.Buffer + writtenBytes, err := WriteMessage(&buf, msg, 0) + require.NoError(t, err, "WriteMessage error") + + // The SerializedSize should match the number of bytes written. + require.Equal(t, uint32(writtenBytes), size, + "SerializedSize = %d, actual bytes "+ + "written = %d for message type %s (populated)", + size, writtenBytes, msgType) + }) +}