diff --git a/chanbackup/single.go b/chanbackup/single.go index e51700c88..3007559e0 100644 --- a/chanbackup/single.go +++ b/chanbackup/single.go @@ -265,8 +265,15 @@ func (s *Single) Serialize(w io.Writer) error { return err } + // TODO(yy): remove the type assertion when we finished refactoring db + // into using write buffer. + buf, ok := w.(*bytes.Buffer) + if !ok { + return fmt.Errorf("expect io.Writer to be *bytes.Buffer") + } + return lnwire.WriteElements( - w, + buf, byte(s.Version), uint16(len(singleBytes.Bytes())), singleBytes.Bytes(), diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index b35586bf0..e8a09b758 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -2,6 +2,7 @@ package channeldb import ( "encoding/binary" + "fmt" "sync" "io" @@ -232,7 +233,14 @@ func (p *WaitingProof) Encode(w io.Writer) error { return err } - if err := p.AnnounceSignatures.Encode(w, 0); err != nil { + // TODO(yy): remove the type assertion when we finished refactoring db + // into using write buffer. + buf, ok := w.(*bytes.Buffer) + if !ok { + return fmt.Errorf("expect io.Writer to be *bytes.Buffer") + } + + if err := p.AnnounceSignatures.Encode(buf, 0); err != nil { return err } diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index f88a1a08f..f356e8943 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "fmt" "io" @@ -115,7 +116,7 @@ var _ Message = (*AcceptChannel)(nil) // protocol version. // // This is part of the lnwire.Message interface. -func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { +func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error { // Since the upfront script is encoded as a TLV record, concatenate it // with the ExtraData, and write them as one. tlvRecords, err := packShutdownScript( diff --git a/lnwire/announcement_signatures.go b/lnwire/announcement_signatures.go index 41372db19..2550d1c1a 100644 --- a/lnwire/announcement_signatures.go +++ b/lnwire/announcement_signatures.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" ) @@ -64,7 +65,7 @@ func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (a *AnnounceSignatures) Encode(w io.Writer, pver uint32) error { +func (a *AnnounceSignatures) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, a.ChannelID, a.ShortChannelID, diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index 865b7b52c..9e57d17bc 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -87,7 +87,7 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (a *ChannelAnnouncement) Encode(w io.Writer, pver uint32) error { +func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, a.NodeSig1, a.NodeSig2, diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index ebf2fcc1f..44cc6f3f6 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "github.com/btcsuite/btcd/btcec" @@ -75,7 +76,7 @@ var _ Message = (*ChannelReestablish)(nil) // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { +func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error { err := WriteElements(w, a.ChanID, a.NextLocalCommitHeight, diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 87a7d30a1..e1bac9f96 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -159,7 +159,7 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (a *ChannelUpdate) Encode(w io.Writer, pver uint32) error { +func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error { err := WriteElements(w, a.Signature, a.ChainHash[:], diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 3e4dcdf26..66ccb7bcb 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "github.com/btcsuite/btcutil" @@ -63,7 +64,7 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *ClosingSigned) Encode(w io.Writer, pver uint32) error { +func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements( w, c.ChannelID, c.FeeSatoshis, c.Signature, c.ExtraData, ) diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 3251fe8dc..59a785a0d 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" ) @@ -69,7 +70,7 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *CommitSig) Encode(w io.Writer, pver uint32) error { +func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChanID, c.CommitSig, diff --git a/lnwire/error.go b/lnwire/error.go index 0aa2d4c31..0b95607f2 100644 --- a/lnwire/error.go +++ b/lnwire/error.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "fmt" "io" ) @@ -103,7 +104,7 @@ func (c *Error) Decode(r io.Reader, pver uint32) error { // protocol version specified. // // This is part of the lnwire.Message interface. -func (c *Error) Encode(w io.Writer, pver uint32) error { +func (c *Error) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChanID, c.Data, diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index 22fd20bd8..f2dbec45b 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -16,7 +16,7 @@ import ( type ExtraOpaqueData []byte // Encode attempts to encode the raw extra bytes into the passed io.Writer. -func (e *ExtraOpaqueData) Encode(w io.Writer) error { +func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error { eBytes := []byte((*e)[:]) if err := WriteElements(w, eBytes); err != nil { return err diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index 0f10214ed..6ca8c1b01 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "github.com/btcsuite/btcd/wire" @@ -40,7 +41,7 @@ var _ Message = (*FundingCreated)(nil) // protocol version. // // This is part of the lnwire.Message interface. -func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { +func (f *FundingCreated) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements( w, f.PendingChannelID[:], f.FundingPoint, f.CommitSig, f.ExtraData, diff --git a/lnwire/funding_locked.go b/lnwire/funding_locked.go index 0f8e4ca1e..7c3b2eb3a 100644 --- a/lnwire/funding_locked.go +++ b/lnwire/funding_locked.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "github.com/btcsuite/btcd/btcec" @@ -58,7 +59,7 @@ func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { // protocol version. // // This is part of the lnwire.Message interface. -func (c *FundingLocked) Encode(w io.Writer, pver uint32) error { +func (c *FundingLocked) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChanID, c.NextPerCommitmentPoint, diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index 15c62cf05..d45fbd40e 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -1,6 +1,9 @@ package lnwire -import "io" +import ( + "bytes" + "io" +) // FundingSigned is sent from Bob (the responder) to Alice (the initiator) // after receiving the funding outpoint and her signature for Bob's version of @@ -29,7 +32,7 @@ var _ Message = (*FundingSigned)(nil) // protocol version. // // This is part of the lnwire.Message interface. -func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { +func (f *FundingSigned) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, f.ChanID, f.CommitSig, f.ExtraData) } diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index 6827b9818..e9665d157 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -57,7 +58,7 @@ func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (g *GossipTimestampRange) Encode(w io.Writer, pver uint32) error { +func (g *GossipTimestampRange) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, g.ChainHash[:], g.FirstTimestamp, diff --git a/lnwire/init_message.go b/lnwire/init_message.go index 18fa693b4..be25642e1 100644 --- a/lnwire/init_message.go +++ b/lnwire/init_message.go @@ -1,6 +1,9 @@ package lnwire -import "io" +import ( + "bytes" + "io" +) // Init is the first message reveals the features supported or required by this // node. Nodes wait for receipt of the other's features to simplify error @@ -56,7 +59,7 @@ func (msg *Init) Decode(r io.Reader, pver uint32) error { // the protocol version specified. // // This is part of the lnwire.Message interface. -func (msg *Init) Encode(w io.Writer, pver uint32) error { +func (msg *Init) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, msg.GlobalFeatures, msg.Features, diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index c180cad38..8400c210e 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -75,13 +75,8 @@ func (a addressType) AddrLen() uint16 { } // WriteElement is a one-stop shop to write the big endian representation of -// any element which is to be serialized for the wire protocol. The passed -// io.Writer should be backed by an appropriately sized byte slice, or be able -// to dynamically expand to accommodate additional data. -// -// TODO(roasbeef): this should eventually draw from a buffer pool for -// serialization. -func WriteElement(w io.Writer, element interface{}) error { +// any element which is to be serialized for the wire protocol. +func WriteElement(w *bytes.Buffer, element interface{}) error { switch e := element.(type) { case NodeAlias: if _, err := w.Write(e[:]); err != nil { @@ -437,10 +432,10 @@ func WriteElement(w io.Writer, element interface{}) error { } // WriteElements is writes each element in the elements slice to the passed -// io.Writer using WriteElement. -func WriteElements(w io.Writer, elements ...interface{}) error { +// buffer using WriteElement. +func WriteElements(buf *bytes.Buffer, elements ...interface{}) error { for _, element := range elements { - err := WriteElement(w, element) + err := WriteElement(buf, element) if err != nil { return err } diff --git a/lnwire/message.go b/lnwire/message.go index 466dc50f2..ccd24d10b 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -158,8 +158,8 @@ type Serializable interface { Decode(io.Reader, uint32) error // Encode converts object to the bytes stream and write it into the - // writer. - Encode(io.Writer, uint32) error + // write buffer. + Encode(*bytes.Buffer, uint32) error } // Message is an interface that defines a lightning wire protocol message. The diff --git a/lnwire/message_test.go b/lnwire/message_test.go index 4626eb3e7..55e0c0153 100644 --- a/lnwire/message_test.go +++ b/lnwire/message_test.go @@ -52,7 +52,7 @@ func (m *mockMsg) Decode(r io.Reader, pver uint32) error { return args.Error(0) } -func (m *mockMsg) Encode(w io.Writer, pver uint32) error { +func (m *mockMsg) Encode(w *bytes.Buffer, pver uint32) error { args := m.Called(w, pver) return args.Error(0) } diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index 540137aae..664bda84d 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -124,7 +124,7 @@ func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { // Encode serializes the target NodeAnnouncement into the passed io.Writer // observing the protocol version specified. // -func (a *NodeAnnouncement) Encode(w io.Writer, pver uint32) error { +func (a *NodeAnnouncement) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, a.Signature, a.Features, diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 35555e266..f9c86db61 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -419,7 +419,7 @@ func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailIncorrectDetails) Encode(w io.Writer, pver uint32) error { +func (f *FailIncorrectDetails) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, f.amount, f.height) } @@ -485,7 +485,7 @@ func (f *FailInvalidOnionVersion) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailInvalidOnionVersion) Encode(w io.Writer, pver uint32) error { +func (f *FailInvalidOnionVersion) Encode(w *bytes.Buffer, pver uint32) error { return WriteElement(w, f.OnionSHA256[:]) } @@ -519,7 +519,7 @@ func (f *FailInvalidOnionHmac) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailInvalidOnionHmac) Encode(w io.Writer, pver uint32) error { +func (f *FailInvalidOnionHmac) Encode(w *bytes.Buffer, pver uint32) error { return WriteElement(w, f.OnionSHA256[:]) } @@ -561,7 +561,7 @@ func (f *FailInvalidOnionKey) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailInvalidOnionKey) Encode(w io.Writer, pver uint32) error { +func (f *FailInvalidOnionKey) Encode(w *bytes.Buffer, pver uint32) error { return WriteElement(w, f.OnionSHA256[:]) } @@ -670,7 +670,9 @@ func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailTemporaryChannelFailure) Encode(w io.Writer, pver uint32) error { +func (f *FailTemporaryChannelFailure) Encode(w *bytes.Buffer, + pver uint32) error { + var payload []byte if f.Update != nil { var bw bytes.Buffer @@ -749,7 +751,7 @@ func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailAmountBelowMinimum) Encode(w io.Writer, pver uint32) error { +func (f *FailAmountBelowMinimum) Encode(w *bytes.Buffer, pver uint32) error { if err := WriteElement(w, f.HtlcMsat); err != nil { return err } @@ -817,7 +819,7 @@ func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailFeeInsufficient) Encode(w io.Writer, pver uint32) error { +func (f *FailFeeInsufficient) Encode(w *bytes.Buffer, pver uint32) error { if err := WriteElement(w, f.HtlcMsat); err != nil { return err } @@ -885,7 +887,7 @@ func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error { +func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { if err := WriteElement(w, f.CltvExpiry); err != nil { return err } @@ -942,7 +944,7 @@ func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailExpiryTooSoon) Encode(w io.Writer, pver uint32) error { +func (f *FailExpiryTooSoon) Encode(w *bytes.Buffer, pver uint32) error { return writeOnionErrorChanUpdate(w, &f.Update, pver) } @@ -1006,7 +1008,7 @@ func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailChannelDisabled) Encode(w io.Writer, pver uint32) error { +func (f *FailChannelDisabled) Encode(w *bytes.Buffer, pver uint32) error { if err := WriteElement(w, f.Flags); err != nil { return err } @@ -1056,7 +1058,9 @@ func (f *FailFinalIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailFinalIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error { +func (f *FailFinalIncorrectCltvExpiry) Encode(w *bytes.Buffer, + pver uint32) error { + return WriteElement(w, f.CltvExpiry) } @@ -1102,7 +1106,9 @@ func (f *FailFinalIncorrectHtlcAmount) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *FailFinalIncorrectHtlcAmount) Encode(w io.Writer, pver uint32) error { +func (f *FailFinalIncorrectHtlcAmount) Encode(w *bytes.Buffer, + pver uint32) error { + return WriteElement(w, f.IncomingHTLCAmount) } @@ -1177,7 +1183,7 @@ func (f *InvalidOnionPayload) Decode(r io.Reader, pver uint32) error { // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. -func (f *InvalidOnionPayload) Encode(w io.Writer, pver uint32) error { +func (f *InvalidOnionPayload) Encode(w *bytes.Buffer, pver uint32) error { var buf [8]byte if err := tlv.WriteVarInt(w, f.Type, &buf); err != nil { return err @@ -1263,7 +1269,7 @@ func DecodeFailureMessage(r io.Reader, pver uint32) (FailureMessage, error) { // EncodeFailure encodes, including the necessary onion failure header // information. -func EncodeFailure(w io.Writer, failure FailureMessage, pver uint32) error { +func EncodeFailure(w *bytes.Buffer, failure FailureMessage, pver uint32) error { var failureMessageBuffer bytes.Buffer err := EncodeFailureMessage(&failureMessageBuffer, failure, pver) @@ -1293,7 +1299,9 @@ func EncodeFailure(w io.Writer, failure FailureMessage, pver uint32) error { // EncodeFailureMessage encodes just the failure message without adding a length // and padding the message for the onion protocol. -func EncodeFailureMessage(w io.Writer, failure FailureMessage, pver uint32) error { +func EncodeFailureMessage(w *bytes.Buffer, + failure FailureMessage, pver uint32) error { + // First, we'll write out the error code itself into the failure // buffer. var codeBytes [2]byte @@ -1401,7 +1409,7 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) { // writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error // format. The format is that we first write out the true serialized length of // the channel update, followed by the serialized channel update itself. -func writeOnionErrorChanUpdate(w io.Writer, chanUpdate *ChannelUpdate, +func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate, pver uint32) error { // First, we encode the channel update in a temporary buffer in order diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 8c4c131c6..927b6ee9c 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -275,6 +275,8 @@ func (f *mockFailIncorrectDetailsNoHeight) Decode(r io.Reader, pver uint32) erro return nil } -func (f *mockFailIncorrectDetailsNoHeight) Encode(w io.Writer, pver uint32) error { +func (f *mockFailIncorrectDetailsNoHeight) Encode(w *bytes.Buffer, + pver uint32) error { + return WriteElement(w, f.amount) } diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 9407d0037..d975749d2 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "github.com/btcsuite/btcd/btcec" @@ -150,7 +151,7 @@ var _ Message = (*OpenChannel)(nil) // protocol version. // // This is part of the lnwire.Message interface. -func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { +func (o *OpenChannel) Encode(w *bytes.Buffer, pver uint32) error { // Since the upfront script is encoded as a TLV record, concatenate it // with the ExtraData, and write them as one. tlvRecords, err := packShutdownScript( diff --git a/lnwire/ping.go b/lnwire/ping.go index e7e160d44..9218fe7f6 100644 --- a/lnwire/ping.go +++ b/lnwire/ping.go @@ -1,6 +1,9 @@ package lnwire -import "io" +import ( + "bytes" + "io" +) // PingPayload is a set of opaque bytes used to pad out a ping message. type PingPayload []byte @@ -44,7 +47,7 @@ func (p *Ping) Decode(r io.Reader, pver uint32) error { // protocol version specified. // // This is part of the lnwire.Message interface. -func (p *Ping) Encode(w io.Writer, pver uint32) error { +func (p *Ping) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, p.NumPongBytes, p.PaddingBytes) diff --git a/lnwire/pong.go b/lnwire/pong.go index b0d54523b..6c70c120f 100644 --- a/lnwire/pong.go +++ b/lnwire/pong.go @@ -1,6 +1,9 @@ package lnwire -import "io" +import ( + "bytes" + "io" +) // PongPayload is a set of opaque bytes sent in response to a ping message. type PongPayload []byte @@ -40,7 +43,7 @@ func (p *Pong) Decode(r io.Reader, pver uint32) error { // protocol version specified. // // This is part of the lnwire.Message interface. -func (p *Pong) Encode(w io.Writer, pver uint32) error { +func (p *Pong) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, p.PongBytes, ) diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index 167a4c64e..a6a20e1fd 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "math" @@ -58,7 +59,7 @@ func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (q *QueryChannelRange) Encode(w io.Writer, pver uint32) error { +func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, q.ChainHash[:], q.FirstBlockHeight, diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 92ffaf585..930becdbb 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -291,7 +291,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { +func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error { // First, we'll write out the chain hash. err := WriteElements(w, q.ChainHash[:]) if err != nil { @@ -310,7 +310,7 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { // encodeShortChanIDs encodes the passed short channel ID's into the passed // io.Writer, respecting the specified encoding type. -func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, +func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, shortChanIDs []ShortChannelID, noSort bool) error { // For both of the current encoding types, the channel ID's are to be @@ -360,29 +360,40 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, // TODO(roasbeef): assumes the caller knows the proper chunk size to // pass to avoid bin-packing here case EncodingSortedZlib: - // We'll make a new buffer, then wrap that with a zlib writer - // so we can write directly to the buffer and encode in a - // streaming manner. - var buf bytes.Buffer - zlibWriter := zlib.NewWriter(&buf) - // If we don't have anything at all to write, then we'll write // an empty payload so we don't include things like the zlib // header when the remote party is expecting no actual short // channel IDs. var compressedPayload []byte if len(shortChanIDs) > 0 { + // We'll make a new write buffer to hold the bytes of + // shortChanIDs. + var wb bytes.Buffer + // Next, we'll write out all the channel ID's directly // into the zlib writer, which will do compressing on // the fly. for _, chanID := range shortChanIDs { - err := WriteElements(zlibWriter, chanID) + err := WriteElements(&wb, chanID) if err != nil { - return fmt.Errorf("unable to write short chan "+ - "ID: %v", err) + return fmt.Errorf( + "unable to write short chan "+ + "ID: %v", err, + ) } } + // With shortChanIDs written into wb, we'll create a + // zlib writer and write all the compressed bytes. + var zlibBuffer bytes.Buffer + zlibWriter := zlib.NewWriter(&zlibBuffer) + + if _, err := zlibWriter.Write(wb.Bytes()); err != nil { + return fmt.Errorf( + "unable to write compressed short chan"+ + "ID: %w", err) + } + // Now that we've written all the elements, we'll // ensure the compressed stream is written to the // underlying buffer. @@ -391,7 +402,7 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, "compression: %v", err) } - compressedPayload = buf.Bytes() + compressedPayload = zlibBuffer.Bytes() } // Now that we have all the items compressed, we can compute diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 2e8085c65..00bb10eb2 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "math" @@ -85,7 +86,7 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { +func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error { err := WriteElements(w, c.ChainHash[:], c.FirstBlockHeight, diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go index 2280dee4b..6d909febf 100644 --- a/lnwire/reply_short_chan_ids_end.go +++ b/lnwire/reply_short_chan_ids_end.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -54,7 +55,7 @@ func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *ReplyShortChanIDsEnd) Encode(w io.Writer, pver uint32) error { +func (c *ReplyShortChanIDsEnd) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChainHash[:], c.Complete, diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index 6562547fe..acbfcae6e 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" "github.com/btcsuite/btcd/btcec" @@ -65,7 +66,7 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *RevokeAndAck) Encode(w io.Writer, pver uint32) error { +func (c *RevokeAndAck) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChanID, c.Revocation[:], diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index 4eec13523..e7b26283b 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" ) @@ -46,7 +47,7 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // the protocol version specified. // // This is part of the lnwire.Message interface. -func (s *Shutdown) Encode(w io.Writer, pver uint32) error { +func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, s.ChannelID, s.Address, s.ExtraData) } diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 35fc5fc36..a7756994c 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" ) @@ -88,7 +89,7 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { // the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *UpdateAddHTLC) Encode(w io.Writer, pver uint32) error { +func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChanID, c.ID, diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 8706c1dc8..2344b5683 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" ) @@ -54,7 +55,7 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { // the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *UpdateFailHTLC) Encode(w io.Writer, pver uint32) error { +func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChanID, c.ID, diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go index e7994fccc..120f89541 100644 --- a/lnwire/update_fail_malformed_htlc.go +++ b/lnwire/update_fail_malformed_htlc.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "crypto/sha256" "io" ) @@ -53,7 +54,7 @@ func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { // io.Writer observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *UpdateFailMalformedHTLC) Encode(w io.Writer, pver uint32) error { +func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChanID, c.ID, diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go index 375b5c671..a30634d79 100644 --- a/lnwire/update_fee.go +++ b/lnwire/update_fee.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" ) @@ -51,7 +52,7 @@ func (c *UpdateFee) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *UpdateFee) Encode(w io.Writer, pver uint32) error { +func (c *UpdateFee) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChanID, c.FeePerKw, diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 2b7f2e494..4cd3a9720 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "io" ) @@ -60,7 +61,7 @@ func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *UpdateFulfillHTLC) Encode(w io.Writer, pver uint32) error { +func (c *UpdateFulfillHTLC) Encode(w *bytes.Buffer, pver uint32) error { return WriteElements(w, c.ChanID, c.ID,