From 56a100123b543707cf40af1c3f35cb0cc17626c2 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 18 Mar 2025 18:55:16 -0500 Subject: [PATCH] lnwire: add new SerializedSize method to all wire messages This'll be useful for the bandwidth based rate limiting we'll implement in the next commit. --- lnwire/accept_channel.go | 11 +++++++++ lnwire/announcement_signatures.go | 11 +++++++++ lnwire/announcement_signatures_2.go | 11 +++++++++ lnwire/channel_announcement.go | 11 +++++++++ lnwire/channel_announcement_2.go | 11 +++++++++ lnwire/channel_ready.go | 11 +++++++++ lnwire/channel_reestablish.go | 11 +++++++++ lnwire/channel_update.go | 10 ++++++++ lnwire/channel_update_2.go | 7 ++++++ lnwire/closing_complete.go | 10 ++++++++ lnwire/closing_sig.go | 11 +++++++++ lnwire/closing_signed.go | 10 ++++++++ lnwire/commit_sig.go | 10 ++++++++ lnwire/custom.go | 11 +++++++++ lnwire/dyn_ack.go | 11 +++++++++ lnwire/dyn_propose.go | 11 +++++++++ lnwire/dyn_reject.go | 11 +++++++++ lnwire/error.go | 11 +++++++++ lnwire/funding_created.go | 11 +++++++++ lnwire/funding_signed.go | 11 +++++++++ lnwire/gossip_timestamp_range.go | 10 ++++++++ lnwire/init_message.go | 11 +++++++++ lnwire/message.go | 35 ++++++++++++++++++++++++++++ lnwire/node_announcement.go | 10 ++++++++ lnwire/open_channel.go | 11 +++++++++ lnwire/ping.go | 11 +++++++++ lnwire/pong.go | 10 ++++++++ lnwire/query_channel_range.go | 12 ++++++++++ lnwire/query_short_chan_ids.go | 11 +++++++++ lnwire/reply_channel_range.go | 10 ++++++++ lnwire/reply_short_chan_ids_end.go | 10 ++++++++ lnwire/revoke_and_ack.go | 10 ++++++++ lnwire/shutdown.go | 10 ++++++++ lnwire/stfu.go | 10 ++++++++ lnwire/update_add_htlc.go | 11 +++++++++ lnwire/update_fail_htlc.go | 11 +++++++++ lnwire/update_fail_malformed_htlc.go | 11 +++++++++ lnwire/update_fee.go | 11 +++++++++ lnwire/update_fulfill_htlc.go | 19 +++++++++++---- lnwire/warning.go | 11 +++++++++ 40 files changed, 453 insertions(+), 4 deletions(-) diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index e45bcf9ab..aace5d536 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -128,6 +128,10 @@ type AcceptChannel struct { // interface. var _ Message = (*AcceptChannel)(nil) +// A compile time check to ensure AcceptChannel implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*AcceptChannel)(nil) + // Encode serializes the target AcceptChannel into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -281,3 +285,10 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { func (a *AcceptChannel) MsgType() MessageType { return MsgAcceptChannel } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *AcceptChannel) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/announcement_signatures.go b/lnwire/announcement_signatures.go index a2bc21f39..cf8f68be5 100644 --- a/lnwire/announcement_signatures.go +++ b/lnwire/announcement_signatures.go @@ -47,6 +47,10 @@ type AnnounceSignatures1 struct { // lnwire.Message interface. var _ Message = (*AnnounceSignatures1)(nil) +// A compile time check to ensure AnnounceSignatures1 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*AnnounceSignatures1)(nil) + // A compile time check to ensure AnnounceSignatures1 implements the // lnwire.AnnounceSignatures interface. var _ AnnounceSignatures = (*AnnounceSignatures1)(nil) @@ -97,6 +101,13 @@ func (a *AnnounceSignatures1) MsgType() MessageType { return MsgAnnounceSignatures } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *AnnounceSignatures1) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} + // SCID returns the ShortChannelID of the channel. // // This is part of the lnwire.AnnounceSignatures interface. diff --git a/lnwire/announcement_signatures_2.go b/lnwire/announcement_signatures_2.go index a10447032..6e893dafd 100644 --- a/lnwire/announcement_signatures_2.go +++ b/lnwire/announcement_signatures_2.go @@ -41,6 +41,10 @@ type AnnounceSignatures2 struct { // lnwire.Message interface. var _ Message = (*AnnounceSignatures2)(nil) +// A compile time check to ensure AnnounceSignatures2 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*AnnounceSignatures2)(nil) + // Decode deserializes a serialized AnnounceSignatures2 stored in the passed // io.Reader observing the specified protocol version. // @@ -82,6 +86,13 @@ func (a *AnnounceSignatures2) MsgType() MessageType { return MsgAnnounceSignatures2 } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *AnnounceSignatures2) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} + // SCID returns the ShortChannelID of the channel. // // NOTE: this is part of the AnnounceSignatures interface. diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index ed4c5b97e..0a3989abb 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -62,6 +62,10 @@ type ChannelAnnouncement1 struct { // lnwire.Message interface. var _ Message = (*ChannelAnnouncement1)(nil) +// A compile time check to ensure ChannelAnnouncement1 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelAnnouncement1)(nil) + // Decode deserializes a serialized ChannelAnnouncement stored in the passed // io.Reader observing the specified protocol version. // @@ -143,6 +147,13 @@ func (a *ChannelAnnouncement1) MsgType() MessageType { return MsgChannelAnnouncement } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *ChannelAnnouncement1) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} + // DataToSign is used to retrieve part of the announcement message which should // be signed. func (a *ChannelAnnouncement1) DataToSign() ([]byte, error) { diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 074e7d084..57b3a24b8 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -194,10 +194,21 @@ func (c *ChannelAnnouncement2) MsgType() MessageType { return MsgChannelAnnouncement2 } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ChannelAnnouncement2) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // A compile time check to ensure ChannelAnnouncement2 implements the // lnwire.Message interface. var _ Message = (*ChannelAnnouncement2)(nil) +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelAnnouncement2)(nil) + // Node1KeyBytes returns the bytes representing the public key of node 1 in the // channel. // diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index 912a068bd..f388db1d1 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -63,6 +63,10 @@ func NewChannelReady(cid ChannelID, npcp *btcec.PublicKey) *ChannelReady { // interface. var _ Message = (*ChannelReady)(nil) +// A compile time check to ensure ChannelReady implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelReady)(nil) + // Decode deserializes the serialized ChannelReady message stored in the // passed io.Reader into the target ChannelReady using the deserialization // rules defined by the passed protocol version. @@ -170,3 +174,10 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { func (c *ChannelReady) MsgType() MessageType { return MsgChannelReady } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ChannelReady) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index 577379623..f26a2fc5d 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -99,6 +99,10 @@ type ChannelReestablish struct { // lnwire.Message interface. var _ Message = (*ChannelReestablish)(nil) +// A compile time check to ensure ChannelReestablish implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelReestablish)(nil) + // Encode serializes the target ChannelReestablish into the passed io.Writer // observing the protocol version specified. // @@ -234,3 +238,10 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { func (a *ChannelReestablish) MsgType() MessageType { return MsgChannelReestablish } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *ChannelReestablish) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 33cc3bff0..1c7f0eed2 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -124,6 +124,9 @@ type ChannelUpdate1 struct { // interface. var _ Message = (*ChannelUpdate1)(nil) +// A compile time check to ensure ChannelUpdate1 implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ChannelUpdate1)(nil) + // Decode deserializes a serialized ChannelUpdate stored in the passed // io.Reader observing the specified protocol version. // @@ -367,3 +370,10 @@ func (a *ChannelUpdate1) SetSCID(scid ShortChannelID) { // A compile time assertion to ensure ChannelUpdate1 implements the // ChannelUpdate interface. var _ ChannelUpdate = (*ChannelUpdate1)(nil) + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *ChannelUpdate1) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index b41f2f29f..56f7edf6b 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -241,6 +241,13 @@ func (c *ChannelUpdate2) MsgType() MessageType { return MsgChannelUpdate2 } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ChannelUpdate2) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + func (c *ChannelUpdate2) ExtraData() ExtraOpaqueData { return c.ExtraOpaqueData } diff --git a/lnwire/closing_complete.go b/lnwire/closing_complete.go index 4d390fd6f..14784a3c1 100644 --- a/lnwire/closing_complete.go +++ b/lnwire/closing_complete.go @@ -169,6 +169,16 @@ func (c *ClosingComplete) MsgType() MessageType { return MsgClosingComplete } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ClosingComplete) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // A compile time check to ensure ClosingComplete implements the lnwire.Message // interface. var _ Message = (*ClosingComplete)(nil) + +// A compile time check to ensure ClosingComplete implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ClosingComplete)(nil) diff --git a/lnwire/closing_sig.go b/lnwire/closing_sig.go index 2c73fa720..1eeec2580 100644 --- a/lnwire/closing_sig.go +++ b/lnwire/closing_sig.go @@ -107,6 +107,17 @@ func (c *ClosingSig) MsgType() MessageType { return MsgClosingSig } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ClosingSig) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // A compile time check to ensure ClosingSig implements the lnwire.Message // interface. var _ Message = (*ClosingSig)(nil) + +// A compile time check to ensure ClosingSig implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*ClosingSig)(nil) diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 08b5bb6a7..a82dbf402 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -59,6 +59,9 @@ func NewClosingSigned(cid ChannelID, fs btcutil.Amount, // interface. var _ Message = (*ClosingSigned)(nil) +// A compile time check to ensure ClosingSigned implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ClosingSigned)(nil) + // Decode deserializes a serialized ClosingSigned message stored in the passed // io.Reader observing the specified protocol version. // @@ -130,3 +133,10 @@ func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error { func (c *ClosingSigned) MsgType() MessageType { return MsgClosingSigned } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ClosingSigned) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 3a475e71f..7c5a41ccc 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -64,6 +64,9 @@ func NewCommitSig() *CommitSig { // interface. var _ Message = (*CommitSig)(nil) +// A compile time check to ensure CommitSig implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*CommitSig)(nil) + // Decode deserializes a serialized CommitSig message stored in the // passed io.Reader observing the specified protocol version. // @@ -151,3 +154,10 @@ func (c *CommitSig) MsgType() MessageType { func (c *CommitSig) TargetChanID() ChannelID { return c.ChanID } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *CommitSig) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/custom.go b/lnwire/custom.go index 232a8be52..e8c299297 100644 --- a/lnwire/custom.go +++ b/lnwire/custom.go @@ -73,6 +73,10 @@ type Custom struct { // interface. var _ Message = (*Custom)(nil) +// A compile time check to ensure Custom implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Custom)(nil) + // NewCustom instantiates a new custom message. func NewCustom(msgType MessageType, data []byte) (*Custom, error) { if msgType < CustomTypeStart && !IsCustomOverride(msgType) { @@ -117,3 +121,10 @@ func (c *Custom) Decode(r io.Reader, pver uint32) error { func (c *Custom) MsgType() MessageType { return c.Type } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *Custom) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/dyn_ack.go b/lnwire/dyn_ack.go index d477461e7..1cc57e955 100644 --- a/lnwire/dyn_ack.go +++ b/lnwire/dyn_ack.go @@ -41,6 +41,10 @@ type DynAck struct { // interface. var _ Message = (*DynAck)(nil) +// A compile time check to ensure DynAck implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*DynAck)(nil) + // Encode serializes the target DynAck into the passed io.Writer. Serialization // will observe the rules defined by the passed protocol version. // @@ -136,3 +140,10 @@ func (da *DynAck) Decode(r io.Reader, _ uint32) error { func (da *DynAck) MsgType() MessageType { return MsgDynAck } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (da *DynAck) SerializedSize() (uint32, error) { + return MessageSerializedSize(da) +} diff --git a/lnwire/dyn_propose.go b/lnwire/dyn_propose.go index 394fff6f3..cc19ec394 100644 --- a/lnwire/dyn_propose.go +++ b/lnwire/dyn_propose.go @@ -105,6 +105,10 @@ type DynPropose struct { // interface. var _ Message = (*DynPropose)(nil) +// A compile time check to ensure DynPropose implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*DynPropose)(nil) + // Encode serializes the target DynPropose into the passed io.Writer. // Serialization will observe the rules defined by the passed protocol version. // @@ -317,3 +321,10 @@ func (dp *DynPropose) Decode(r io.Reader, _ uint32) error { func (dp *DynPropose) MsgType() MessageType { return MsgDynPropose } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (dp *DynPropose) SerializedSize() (uint32, error) { + return MessageSerializedSize(dp) +} diff --git a/lnwire/dyn_reject.go b/lnwire/dyn_reject.go index 2c6484424..51cb8cee2 100644 --- a/lnwire/dyn_reject.go +++ b/lnwire/dyn_reject.go @@ -30,6 +30,10 @@ type DynReject struct { // interface. var _ Message = (*DynReject)(nil) +// A compile time check to ensure DynReject implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*DynReject)(nil) + // Encode serializes the target DynReject into the passed io.Writer. // Serialization will observe the rules defined by the passed protocol version. // @@ -74,3 +78,10 @@ func (dr *DynReject) Decode(r io.Reader, _ uint32) error { func (dr *DynReject) MsgType() MessageType { return MsgDynReject } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (dr *DynReject) SerializedSize() (uint32, error) { + return MessageSerializedSize(dr) +} diff --git a/lnwire/error.go b/lnwire/error.go index 120dfc950..fcc22ef24 100644 --- a/lnwire/error.go +++ b/lnwire/error.go @@ -70,6 +70,10 @@ func NewError() *Error { // interface. var _ Message = (*Error)(nil) +// A compile time check to ensure Error implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Error)(nil) + // Error returns the string representation to Error. // // NOTE: Satisfies the error interface. @@ -113,6 +117,13 @@ func (c *Error) MsgType() MessageType { return MsgError } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *Error) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // isASCII is a helper method that checks whether all bytes in `data` would be // printable ASCII characters if interpreted as a string. func isASCII(data []byte) bool { diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index 86aa0bb40..82d0ff87c 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -44,6 +44,10 @@ type FundingCreated struct { // interface. var _ Message = (*FundingCreated)(nil) +// A compile time check to ensure FundingCreated implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*FundingCreated)(nil) + // Encode serializes the target FundingCreated into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -117,3 +121,10 @@ func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { func (f *FundingCreated) MsgType() MessageType { return MsgFundingCreated } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (f *FundingCreated) SerializedSize() (uint32, error) { + return MessageSerializedSize(f) +} diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index 2dd62e177..a7f23310a 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -36,6 +36,10 @@ type FundingSigned struct { // interface. var _ Message = (*FundingSigned)(nil) +// A compile time check to ensure FundingSigned implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*FundingSigned)(nil) + // Encode serializes the target FundingSigned into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -103,3 +107,10 @@ func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { func (f *FundingSigned) MsgType() MessageType { return MsgFundingSigned } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (f *FundingSigned) SerializedSize() (uint32, error) { + return MessageSerializedSize(f) +} diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index 7b628752a..25c2a033b 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -58,6 +58,9 @@ func NewGossipTimestampRange() *GossipTimestampRange { // lnwire.Message interface. var _ Message = (*GossipTimestampRange)(nil) +// A compile time check to ensure GossipTimestampRange implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*GossipTimestampRange)(nil) + // Decode deserializes a serialized GossipTimestampRange message stored in the // passed io.Reader observing the specified protocol version. // @@ -143,3 +146,10 @@ func (g *GossipTimestampRange) Encode(w *bytes.Buffer, pver uint32) error { func (g *GossipTimestampRange) MsgType() MessageType { return MsgGossipTimestampRange } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (g *GossipTimestampRange) SerializedSize() (uint32, error) { + return MessageSerializedSize(g) +} diff --git a/lnwire/init_message.go b/lnwire/init_message.go index dbbddea2b..b88891b08 100644 --- a/lnwire/init_message.go +++ b/lnwire/init_message.go @@ -43,6 +43,10 @@ func NewInitMessage(gf *RawFeatureVector, f *RawFeatureVector) *Init { // interface. var _ Message = (*Init)(nil) +// A compile time check to ensure Init implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Init)(nil) + // Decode deserializes a serialized Init message stored in the passed // io.Reader observing the specified protocol version. // @@ -78,3 +82,10 @@ func (msg *Init) Encode(w *bytes.Buffer, pver uint32) error { func (msg *Init) MsgType() MessageType { return MsgInit } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (msg *Init) SerializedSize() (uint32, error) { + return MessageSerializedSize(msg) +} diff --git a/lnwire/message.go b/lnwire/message.go index 68b09692e..8c310bd78 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -12,6 +12,10 @@ import ( "io" ) +// MessageTypeSize is the size in bytes of the message type field in the header +// of all messages. +const MessageTypeSize = 2 + // MessageType is the unique 2 byte big-endian integer that indicates the type // of message on the wire. All messages have a very simple header which // consists simply of 2-byte message type. We omit a length field, and checksum @@ -234,6 +238,31 @@ type LinkUpdater interface { TargetChanID() ChannelID } +// SizeableMessage is an interface that extends the base Message interface with +// a method to calculate the serialized size of a message. +type SizeableMessage interface { + Message + + // SerializedSize returns the serialized size of the message in bytes. + // The returned size includes the message type header bytes. + SerializedSize() (uint32, error) +} + +// MessageSerializedSize calculates the serialized size of a message in bytes. +// This is a helper function that can be used by all message types to implement +// the SerializedSize method. +func MessageSerializedSize(msg Message) (uint32, error) { + var buf bytes.Buffer + + // Encode the message to the buffer. + if err := msg.Encode(&buf, 0); err != nil { + return 0, err + } + + // Add the size of the message type. + return uint32(buf.Len()) + MessageTypeSize, nil +} + // makeEmptyMessage creates a new empty message of the proper concrete type // based on the passed message type. func makeEmptyMessage(msgType MessageType) (Message, error) { @@ -337,6 +366,12 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { return msg, nil } +// MakeEmptyMessage creates a new empty message of the proper concrete type +// based on the passed message type. This is exported to be used in tests. +func MakeEmptyMessage(msgType MessageType) (Message, error) { + return makeEmptyMessage(msgType) +} + // WriteMessage writes a lightning Message to a buffer including the necessary // header information and returns the number of bytes written. If any error is // encountered, the buffer passed will be reset to its original state since we diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index 4f1620281..ae883e886 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -104,6 +104,9 @@ type NodeAnnouncement struct { // lnwire.Message interface. var _ Message = (*NodeAnnouncement)(nil) +// A compile time check to ensure NodeAnnouncement implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*NodeAnnouncement)(nil) + // Decode deserializes a serialized NodeAnnouncement stored in the passed // io.Reader observing the specified protocol version. // @@ -202,3 +205,10 @@ func (a *NodeAnnouncement) DataToSign() ([]byte, error) { return buf.Bytes(), nil } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (a *NodeAnnouncement) SerializedSize() (uint32, error) { + return MessageSerializedSize(a) +} diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 9694290f7..217ddbea0 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -164,6 +164,10 @@ type OpenChannel struct { // interface. var _ Message = (*OpenChannel)(nil) +// A compile time check to ensure OpenChannel implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*OpenChannel)(nil) + // Encode serializes the target OpenChannel into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. @@ -335,3 +339,10 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { func (o *OpenChannel) MsgType() MessageType { return MsgOpenChannel } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (o *OpenChannel) SerializedSize() (uint32, error) { + return MessageSerializedSize(o) +} diff --git a/lnwire/ping.go b/lnwire/ping.go index a21f2fa8b..230187b84 100644 --- a/lnwire/ping.go +++ b/lnwire/ping.go @@ -33,6 +33,10 @@ func NewPing(numBytes uint16) *Ping { // A compile time check to ensure Ping implements the lnwire.Message interface. var _ Message = (*Ping)(nil) +// A compile time check to ensure Ping implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Ping)(nil) + // Decode deserializes a serialized Ping message stored in the passed io.Reader // observing the specified protocol version. // @@ -69,3 +73,10 @@ func (p *Ping) Encode(w *bytes.Buffer, pver uint32) error { func (p *Ping) MsgType() MessageType { return MsgPing } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (p *Ping) SerializedSize() (uint32, error) { + return MessageSerializedSize(p) +} diff --git a/lnwire/pong.go b/lnwire/pong.go index 3ab80d70f..c33e904e1 100644 --- a/lnwire/pong.go +++ b/lnwire/pong.go @@ -39,6 +39,9 @@ func NewPong(pongBytes []byte) *Pong { // A compile time check to ensure Pong implements the lnwire.Message interface. var _ Message = (*Pong)(nil) +// A compile time check to ensure Pong implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*Pong)(nil) + // Decode deserializes a serialized Pong message stored in the passed io.Reader // observing the specified protocol version. // @@ -64,3 +67,10 @@ func (p *Pong) Encode(w *bytes.Buffer, pver uint32) error { func (p *Pong) MsgType() MessageType { return MsgPong } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (p *Pong) SerializedSize() (uint32, error) { + return MessageSerializedSize(p) +} diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index 1e0dcb0fa..90b144bda 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -49,6 +49,10 @@ func NewQueryChannelRange() *QueryChannelRange { // lnwire.Message interface. var _ Message = (*QueryChannelRange)(nil) +// A compile time check to ensure QueryChannelRange implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*QueryChannelRange)(nil) + // Decode deserializes a serialized QueryChannelRange message stored in the // passed io.Reader observing the specified protocol version. // @@ -121,6 +125,14 @@ func (q *QueryChannelRange) MsgType() MessageType { return MsgQueryChannelRange } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (q *QueryChannelRange) SerializedSize() (uint32, error) { + msgCpy := *q + return MessageSerializedSize(&msgCpy) +} + // LastBlockHeight returns the last block height covered by the range of a // QueryChannelRange message. func (q *QueryChannelRange) LastBlockHeight() uint32 { diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index cb1bfa22c..f12d07abf 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -91,6 +91,10 @@ func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding, // lnwire.Message interface. var _ Message = (*QueryShortChanIDs)(nil) +// A compile time check to ensure QueryShortChanIDs implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*QueryShortChanIDs)(nil) + // Decode deserializes a serialized QueryShortChanIDs message stored in the // passed io.Reader observing the specified protocol version. // @@ -427,3 +431,10 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding, func (q *QueryShortChanIDs) MsgType() MessageType { return MsgQueryShortChanIDs } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (q *QueryShortChanIDs) SerializedSize() (uint32, error) { + return MessageSerializedSize(q) +} diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 591fc2bd6..a3b11c53e 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -70,6 +70,9 @@ func NewReplyChannelRange() *ReplyChannelRange { // lnwire.Message interface. var _ Message = (*ReplyChannelRange)(nil) +// A compile time check to ensure ReplyChannelRange implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ReplyChannelRange)(nil) + // Decode deserializes a serialized ReplyChannelRange message stored in the // passed io.Reader observing the specified protocol version. // @@ -223,3 +226,10 @@ func (c *ReplyChannelRange) LastBlockHeight() uint32 { } return uint32(lastBlockHeight) } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ReplyChannelRange) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go index 53676f719..30660a9cf 100644 --- a/lnwire/reply_short_chan_ids_end.go +++ b/lnwire/reply_short_chan_ids_end.go @@ -39,6 +39,9 @@ func NewReplyShortChanIDsEnd() *ReplyShortChanIDsEnd { // lnwire.Message interface. var _ Message = (*ReplyShortChanIDsEnd)(nil) +// A compile time check to ensure ReplyShortChanIDsEnd implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*ReplyShortChanIDsEnd)(nil) + // Decode deserializes a serialized ReplyShortChanIDsEnd message stored in the // passed io.Reader observing the specified protocol version. // @@ -74,3 +77,10 @@ func (c *ReplyShortChanIDsEnd) Encode(w *bytes.Buffer, pver uint32) error { func (c *ReplyShortChanIDsEnd) MsgType() MessageType { return MsgReplyShortChanIDsEnd } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *ReplyShortChanIDsEnd) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index 9dca1631a..aa70b0714 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -55,6 +55,9 @@ func NewRevokeAndAck() *RevokeAndAck { // interface. var _ Message = (*RevokeAndAck)(nil) +// A compile time check to ensure RevokeAndAck implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*RevokeAndAck)(nil) + // Decode deserializes a serialized RevokeAndAck message stored in the // passed io.Reader observing the specified protocol version. // @@ -136,3 +139,10 @@ func (c *RevokeAndAck) MsgType() MessageType { func (c *RevokeAndAck) TargetChanID() ChannelID { return c.ChanID } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *RevokeAndAck) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index b9899fcfb..9ac6eb131 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -61,6 +61,9 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown { // interface. var _ Message = (*Shutdown)(nil) +// A compile-time check to ensure Shutdown implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*Shutdown)(nil) + // Decode deserializes a serialized Shutdown from the passed io.Reader, // observing the specified protocol version. // @@ -133,3 +136,10 @@ func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { func (s *Shutdown) MsgType() MessageType { return MsgShutdown } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (s *Shutdown) SerializedSize() (uint32, error) { + return MessageSerializedSize(s) +} diff --git a/lnwire/stfu.go b/lnwire/stfu.go index a052a517a..f923c94b8 100644 --- a/lnwire/stfu.go +++ b/lnwire/stfu.go @@ -24,6 +24,9 @@ type Stfu struct { // A compile time check to ensure Stfu implements the lnwire.Message interface. var _ Message = (*Stfu)(nil) +// A compile time check to ensure Stfu implements the lnwire.SizeableMessage interface. +var _ SizeableMessage = (*Stfu)(nil) + // Encode serializes the target Stfu into the passed io.Writer. // Serialization will observe the rules defined by the passed protocol version. // @@ -68,6 +71,13 @@ func (s *Stfu) MsgType() MessageType { return MsgStfu } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (s *Stfu) SerializedSize() (uint32, error) { + return MessageSerializedSize(s) +} + // A compile time check to ensure Stfu implements the // lnwire.LinkUpdater interface. var _ LinkUpdater = (*Stfu)(nil) diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 4873cd84b..7976a13c5 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -110,6 +110,10 @@ func NewUpdateAddHTLC() *UpdateAddHTLC { // interface. var _ Message = (*UpdateAddHTLC)(nil) +// A compile time check to ensure UpdateAddHTLC implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*UpdateAddHTLC)(nil) + // Decode deserializes a serialized UpdateAddHTLC message stored in the passed // io.Reader observing the specified protocol version. // @@ -212,3 +216,10 @@ func (c *UpdateAddHTLC) MsgType() MessageType { func (c *UpdateAddHTLC) TargetChanID() ChannelID { return c.ChanID } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateAddHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 8cd9c7687..397c70084 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -38,6 +38,10 @@ type UpdateFailHTLC struct { // interface. var _ Message = (*UpdateFailHTLC)(nil) +// A compile time check to ensure UpdateFailHTLC implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*UpdateFailHTLC)(nil) + // Decode deserializes a serialized UpdateFailHTLC message stored in the passed // io.Reader observing the specified protocol version. // @@ -79,6 +83,13 @@ func (c *UpdateFailHTLC) MsgType() MessageType { return MsgUpdateFailHTLC } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFailHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go index f28107a9a..bc9bd9aba 100644 --- a/lnwire/update_fail_malformed_htlc.go +++ b/lnwire/update_fail_malformed_htlc.go @@ -36,6 +36,10 @@ type UpdateFailMalformedHTLC struct { // lnwire.Message interface. var _ Message = (*UpdateFailMalformedHTLC)(nil) +// A compile time check to ensure UpdateFailMalformedHTLC implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFailMalformedHTLC)(nil) + // Decode deserializes a serialized UpdateFailMalformedHTLC message stored in the passed // io.Reader observing the specified protocol version. // @@ -84,6 +88,13 @@ func (c *UpdateFailMalformedHTLC) MsgType() MessageType { return MsgUpdateFailMalformedHTLC } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFailMalformedHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go index a7026044c..f32a7385a 100644 --- a/lnwire/update_fee.go +++ b/lnwire/update_fee.go @@ -36,6 +36,10 @@ func NewUpdateFee(chanID ChannelID, feePerKw uint32) *UpdateFee { // interface. var _ Message = (*UpdateFee)(nil) +// A compile time check to ensure UpdateFee implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFee)(nil) + // Decode deserializes a serialized UpdateFee message stored in the passed // io.Reader observing the specified protocol version. // @@ -72,6 +76,13 @@ func (c *UpdateFee) MsgType() MessageType { return MsgUpdateFee } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFee) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 35aaa2ff5..9ec747406 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -44,12 +44,16 @@ func NewUpdateFulfillHTLC(chanID ChannelID, id uint64, } } -// A compile time check to ensure UpdateFulfillHTLC implements the lnwire.Message -// interface. +// A compile time check to ensure UpdateFulfillHTLC implements the +// lnwire.Message interface. var _ Message = (*UpdateFulfillHTLC)(nil) -// Decode deserializes a serialized UpdateFulfillHTLC message stored in the passed -// io.Reader observing the specified protocol version. +// A compile time check to ensure UpdateFulfillHTLC implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*UpdateFulfillHTLC)(nil) + +// Decode deserializes a serialized UpdateFulfillHTLC message stored in the +// passed io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { @@ -115,6 +119,13 @@ func (c *UpdateFulfillHTLC) MsgType() MessageType { return MsgUpdateFulfillHTLC } +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *UpdateFulfillHTLC) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +} + // TargetChanID returns the channel id of the link for which this message is // intended. // diff --git a/lnwire/warning.go b/lnwire/warning.go index adb595cc0..2a0b3df19 100644 --- a/lnwire/warning.go +++ b/lnwire/warning.go @@ -29,6 +29,10 @@ type Warning struct { // interface. var _ Message = (*Warning)(nil) +// A compile time check to ensure Warning implements the lnwire.SizeableMessage +// interface. +var _ SizeableMessage = (*Warning)(nil) + // NewWarning creates a new Warning message. func NewWarning() *Warning { return &Warning{} @@ -74,3 +78,10 @@ func (c *Warning) Encode(w *bytes.Buffer, _ uint32) error { func (c *Warning) MsgType() MessageType { return MsgWarning } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (c *Warning) SerializedSize() (uint32, error) { + return MessageSerializedSize(c) +}