From 2cf6969dbc76fb770607b0f6e604d8a5c09d5b7b Mon Sep 17 00:00:00 2001
From: yyforyongyu <yy2452@columbia.edu>
Date: Fri, 18 Jun 2021 15:15:44 +0800
Subject: [PATCH] lnwire: refactor Encode to use specific writers - III

This commit refactors the remaining usage of WriteElements. By
replacing the interface types with concrete types for the params used in
the methods, most of the encoding of the messages now takes zero heap
allocations.
---
 lnwire/channel_announcement.go       | 105 ++++++++++++++++++-------
 lnwire/channel_reestablish.go        |  30 +++++---
 lnwire/channel_update.go             | 111 +++++++++++++++++++--------
 lnwire/extra_bytes.go                |   2 +-
 lnwire/lnwire_test.go                |   2 +-
 lnwire/query_short_chan_ids.go       |  24 +++---
 lnwire/reply_channel_range.go        |  30 +++++---
 lnwire/update_add_htlc.go            |  38 ++++++---
 lnwire/update_fail_htlc.go           |  19 +++--
 lnwire/update_fail_malformed_htlc.go |  28 +++++--
 lnwire/update_fee.go                 |  14 ++--
 lnwire/update_fulfill_htlc.go        |  19 +++--
 12 files changed, 289 insertions(+), 133 deletions(-)

diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go
index 9e57d17bc..2b34c0f99 100644
--- a/lnwire/channel_announcement.go
+++ b/lnwire/channel_announcement.go
@@ -88,20 +88,51 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error {
 //
 // This is part of the lnwire.Message interface.
 func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error {
-	return WriteElements(w,
-		a.NodeSig1,
-		a.NodeSig2,
-		a.BitcoinSig1,
-		a.BitcoinSig2,
-		a.Features,
-		a.ChainHash[:],
-		a.ShortChannelID,
-		a.NodeID1,
-		a.NodeID2,
-		a.BitcoinKey1,
-		a.BitcoinKey2,
-		a.ExtraOpaqueData,
-	)
+	if err := WriteSig(w, a.NodeSig1); err != nil {
+		return err
+	}
+
+	if err := WriteSig(w, a.NodeSig2); err != nil {
+		return err
+	}
+
+	if err := WriteSig(w, a.BitcoinSig1); err != nil {
+		return err
+	}
+
+	if err := WriteSig(w, a.BitcoinSig2); err != nil {
+		return err
+	}
+
+	if err := WriteRawFeatureVector(w, a.Features); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, a.ChainHash[:]); err != nil {
+		return err
+	}
+
+	if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, a.NodeID1[:]); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, a.NodeID2[:]); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, a.BitcoinKey1[:]); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, a.BitcoinKey2[:]); err != nil {
+		return err
+	}
+
+	return WriteBytes(w, a.ExtraOpaqueData)
 }
 
 // MsgType returns the integer uniquely identifying this message type on the
@@ -116,20 +147,40 @@ func (a *ChannelAnnouncement) MsgType() MessageType {
 // be signed.
 func (a *ChannelAnnouncement) DataToSign() ([]byte, error) {
 	// We should not include the signatures itself.
-	var w bytes.Buffer
-	err := WriteElements(&w,
-		a.Features,
-		a.ChainHash[:],
-		a.ShortChannelID,
-		a.NodeID1,
-		a.NodeID2,
-		a.BitcoinKey1,
-		a.BitcoinKey2,
-		a.ExtraOpaqueData,
-	)
-	if err != nil {
+	b := make([]byte, 0, MaxMsgBody)
+	buf := bytes.NewBuffer(b)
+
+	if err := WriteRawFeatureVector(buf, a.Features); err != nil {
 		return nil, err
 	}
 
-	return w.Bytes(), nil
+	if err := WriteBytes(buf, a.ChainHash[:]); err != nil {
+		return nil, err
+	}
+
+	if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil {
+		return nil, err
+	}
+
+	if err := WriteBytes(buf, a.NodeID1[:]); err != nil {
+		return nil, err
+	}
+
+	if err := WriteBytes(buf, a.NodeID2[:]); err != nil {
+		return nil, err
+	}
+
+	if err := WriteBytes(buf, a.BitcoinKey1[:]); err != nil {
+		return nil, err
+	}
+
+	if err := WriteBytes(buf, a.BitcoinKey2[:]); err != nil {
+		return nil, err
+	}
+
+	if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
+		return nil, err
+	}
+
+	return buf.Bytes(), nil
 }
diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go
index 44cc6f3f6..0de16a365 100644
--- a/lnwire/channel_reestablish.go
+++ b/lnwire/channel_reestablish.go
@@ -77,12 +77,15 @@ var _ Message = (*ChannelReestablish)(nil)
 //
 // This is part of the lnwire.Message interface.
 func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error {
-	err := WriteElements(w,
-		a.ChanID,
-		a.NextLocalCommitHeight,
-		a.RemoteCommitTailHeight,
-	)
-	if err != nil {
+	if err := WriteChannelID(w, a.ChanID); err != nil {
+		return err
+	}
+
+	if err := WriteUint64(w, a.NextLocalCommitHeight); err != nil {
+		return err
+	}
+
+	if err := WriteUint64(w, a.RemoteCommitTailHeight); err != nil {
 		return err
 	}
 
@@ -94,15 +97,18 @@ func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error {
 		//
 		// NOTE: This is here primarily for the quickcheck tests, in
 		// practice, we'll always populate this field.
-		return WriteElements(w, a.ExtraData)
+		return WriteBytes(w, a.ExtraData)
 	}
 
 	// Otherwise, we'll write out the remaining elements.
-	return WriteElements(w,
-		a.LastRemoteCommitSecret[:],
-		a.LocalUnrevokedCommitPoint,
-		a.ExtraData,
-	)
+	if err := WriteBytes(w, a.LastRemoteCommitSecret[:]); err != nil {
+		return err
+	}
+
+	if err := WritePublicKey(w, a.LocalUnrevokedCommitPoint); err != nil {
+		return err
+	}
+	return WriteBytes(w, a.ExtraData)
 }
 
 // Decode deserializes a serialized ChannelReestablish stored in the passed
diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go
index e1bac9f96..7881f972f 100644
--- a/lnwire/channel_update.go
+++ b/lnwire/channel_update.go
@@ -160,32 +160,57 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error {
 //
 // This is part of the lnwire.Message interface.
 func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error {
-	err := WriteElements(w,
-		a.Signature,
-		a.ChainHash[:],
-		a.ShortChannelID,
-		a.Timestamp,
-		a.MessageFlags,
-		a.ChannelFlags,
-		a.TimeLockDelta,
-		a.HtlcMinimumMsat,
-		a.BaseFee,
-		a.FeeRate,
-	)
-	if err != nil {
+	if err := WriteSig(w, a.Signature); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, a.ChainHash[:]); err != nil {
+		return err
+	}
+
+	if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
+		return err
+	}
+
+	if err := WriteUint32(w, a.Timestamp); err != nil {
+		return err
+	}
+
+	if err := WriteChanUpdateMsgFlags(w, a.MessageFlags); err != nil {
+		return err
+	}
+
+	if err := WriteChanUpdateChanFlags(w, a.ChannelFlags); err != nil {
+		return err
+	}
+
+	if err := WriteUint16(w, a.TimeLockDelta); err != nil {
+		return err
+	}
+
+	if err := WriteMilliSatoshi(w, a.HtlcMinimumMsat); err != nil {
+		return err
+	}
+
+	if err := WriteUint32(w, a.BaseFee); err != nil {
+		return err
+	}
+
+	if err := WriteUint32(w, a.FeeRate); err != nil {
 		return err
 	}
 
 	// Now append optional fields if they are set. Currently, the only
 	// optional field is max HTLC.
 	if a.MessageFlags.HasMaxHtlc() {
-		if err := WriteElements(w, a.HtlcMaximumMsat); err != nil {
+		err := WriteMilliSatoshi(w, a.HtlcMaximumMsat)
+		if err != nil {
 			return err
 		}
 	}
 
 	// Finally, append any extra opaque data.
-	return a.ExtraOpaqueData.Encode(w)
+	return WriteBytes(w, a.ExtraOpaqueData)
 }
 
 // MsgType returns the integer uniquely identifying this message type on the
@@ -199,36 +224,58 @@ func (a *ChannelUpdate) MsgType() MessageType {
 // DataToSign is used to retrieve part of the announcement message which should
 // be signed.
 func (a *ChannelUpdate) DataToSign() ([]byte, error) {
-
 	// We should not include the signatures itself.
-	var w bytes.Buffer
-	err := WriteElements(&w,
-		a.ChainHash[:],
-		a.ShortChannelID,
-		a.Timestamp,
-		a.MessageFlags,
-		a.ChannelFlags,
-		a.TimeLockDelta,
-		a.HtlcMinimumMsat,
-		a.BaseFee,
-		a.FeeRate,
-	)
-	if err != nil {
+	b := make([]byte, 0, MaxMsgBody)
+	buf := bytes.NewBuffer(b)
+	if err := WriteBytes(buf, a.ChainHash[:]); err != nil {
+		return nil, err
+	}
+
+	if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil {
+		return nil, err
+	}
+
+	if err := WriteUint32(buf, a.Timestamp); err != nil {
+		return nil, err
+	}
+
+	if err := WriteChanUpdateMsgFlags(buf, a.MessageFlags); err != nil {
+		return nil, err
+	}
+
+	if err := WriteChanUpdateChanFlags(buf, a.ChannelFlags); err != nil {
+		return nil, err
+	}
+
+	if err := WriteUint16(buf, a.TimeLockDelta); err != nil {
+		return nil, err
+	}
+
+	if err := WriteMilliSatoshi(buf, a.HtlcMinimumMsat); err != nil {
+		return nil, err
+	}
+
+	if err := WriteUint32(buf, a.BaseFee); err != nil {
+		return nil, err
+	}
+
+	if err := WriteUint32(buf, a.FeeRate); err != nil {
 		return nil, err
 	}
 
 	// Now append optional fields if they are set. Currently, the only
 	// optional field is max HTLC.
 	if a.MessageFlags.HasMaxHtlc() {
-		if err := WriteElements(&w, a.HtlcMaximumMsat); err != nil {
+		err := WriteMilliSatoshi(buf, a.HtlcMaximumMsat)
+		if err != nil {
 			return nil, err
 		}
 	}
 
 	// Finally, append any extra opaque data.
-	if err := a.ExtraOpaqueData.Encode(&w); err != nil {
+	if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
 		return nil, err
 	}
 
-	return w.Bytes(), nil
+	return buf.Bytes(), nil
 }
diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go
index f2dbec45b..70554f4f5 100644
--- a/lnwire/extra_bytes.go
+++ b/lnwire/extra_bytes.go
@@ -18,7 +18,7 @@ type ExtraOpaqueData []byte
 // Encode attempts to encode the raw extra bytes into the passed io.Writer.
 func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error {
 	eBytes := []byte((*e)[:])
-	if err := WriteElements(w, eBytes); err != nil {
+	if err := WriteBytes(w, eBytes); err != nil {
 		return err
 	}
 
diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go
index 14082bfe9..4475b3382 100644
--- a/lnwire/lnwire_test.go
+++ b/lnwire/lnwire_test.go
@@ -232,7 +232,7 @@ func TestMaxOutPointIndex(t *testing.T) {
 	}
 
 	var b bytes.Buffer
-	if err := WriteElement(&b, op); err == nil {
+	if err := WriteOutPoint(&b, op); err == nil {
 		t.Fatalf("write of outPoint should fail, index exceeds 16-bits")
 	}
 }
diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go
index 930becdbb..323a936db 100644
--- a/lnwire/query_short_chan_ids.go
+++ b/lnwire/query_short_chan_ids.go
@@ -293,19 +293,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
 // This is part of the lnwire.Message interface.
 func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
 	// First, we'll write out the chain hash.
-	err := WriteElements(w, q.ChainHash[:])
-	if err != nil {
+	if err := WriteBytes(w, q.ChainHash[:]); err != nil {
 		return err
 	}
 
 	// Base on our encoding type, we'll write out the set of short channel
 	// ID's.
-	err = encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
+	err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
 	if err != nil {
 		return err
 	}
 
-	return q.ExtraData.Encode(w)
+	return WriteBytes(w, q.ExtraData)
 }
 
 // encodeShortChanIDs encodes the passed short channel ID's into the passed
@@ -332,20 +331,21 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
 		// body. We add 1 as the response will have the encoding type
 		// prepended to it.
 		numBytesBody := uint16(len(shortChanIDs)*8) + 1
-		if err := WriteElements(w, numBytesBody); err != nil {
+		if err := WriteUint16(w, numBytesBody); err != nil {
 			return err
 		}
 
 		// We'll then write out the encoding that that follows the
 		// actual encoded short channel ID's.
-		if err := WriteElements(w, encodingType); err != nil {
+		err := WriteShortChanIDEncoding(w, encodingType)
+		if err != nil {
 			return err
 		}
 
 		// Now that we know they're sorted, we can write out each short
 		// channel ID to the buffer.
 		for _, chanID := range shortChanIDs {
-			if err := WriteElements(w, chanID); err != nil {
+			if err := WriteShortChannelID(w, chanID); err != nil {
 				return fmt.Errorf("unable to write short chan "+
 					"ID: %v", err)
 			}
@@ -374,7 +374,7 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
 			// into the zlib writer, which will do compressing on
 			// the fly.
 			for _, chanID := range shortChanIDs {
-				err := WriteElements(&wb, chanID)
+				err := WriteShortChannelID(&wb, chanID)
 				if err != nil {
 					return fmt.Errorf(
 						"unable to write short chan "+
@@ -418,15 +418,15 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
 
 		// Finally, we can write out the number of bytes, the
 		// compression type, and finally the buffer itself.
-		if err := WriteElements(w, uint16(numBytesBody)); err != nil {
+		if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
 			return err
 		}
-		if err := WriteElements(w, encodingType); err != nil {
+		err := WriteShortChanIDEncoding(w, encodingType)
+		if err != nil {
 			return err
 		}
 
-		_, err := w.Write(compressedPayload)
-		return err
+		return WriteBytes(w, compressedPayload)
 
 	default:
 		// If we're trying to encode with an encoding type that we
diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go
index 00bb10eb2..9dc0fca9c 100644
--- a/lnwire/reply_channel_range.go
+++ b/lnwire/reply_channel_range.go
@@ -87,22 +87,28 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error {
 //
 // This is part of the lnwire.Message interface.
 func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
-	err := WriteElements(w,
-		c.ChainHash[:],
-		c.FirstBlockHeight,
-		c.NumBlocks,
-		c.Complete,
-	)
+	if err := WriteBytes(w, c.ChainHash[:]); err != nil {
+		return err
+	}
+
+	if err := WriteUint32(w, c.FirstBlockHeight); err != nil {
+		return err
+	}
+
+	if err := WriteUint32(w, c.NumBlocks); err != nil {
+		return err
+	}
+
+	if err := WriteUint8(w, c.Complete); err != nil {
+		return err
+	}
+
+	err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
 	if err != nil {
 		return err
 	}
 
-	err = encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
-	if err != nil {
-		return err
-	}
-
-	return c.ExtraData.Encode(w)
+	return WriteBytes(w, c.ExtraData)
 }
 
 // MsgType returns the integer uniquely identifying this message type on the
diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go
index a7756994c..666a54942 100644
--- a/lnwire/update_add_htlc.go
+++ b/lnwire/update_add_htlc.go
@@ -85,20 +85,36 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
 	)
 }
 
-// Encode serializes the target UpdateAddHTLC into the passed io.Writer observing
-// the protocol version specified.
+// Encode serializes the target UpdateAddHTLC into the passed io.Writer
+// observing the protocol version specified.
 //
 // This is part of the lnwire.Message interface.
 func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error {
-	return WriteElements(w,
-		c.ChanID,
-		c.ID,
-		c.Amount,
-		c.PaymentHash[:],
-		c.Expiry,
-		c.OnionBlob[:],
-		c.ExtraData,
-	)
+	if err := WriteChannelID(w, c.ChanID); err != nil {
+		return err
+	}
+
+	if err := WriteUint64(w, c.ID); err != nil {
+		return err
+	}
+
+	if err := WriteMilliSatoshi(w, c.Amount); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, c.PaymentHash[:]); err != nil {
+		return err
+	}
+
+	if err := WriteUint32(w, c.Expiry); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, c.OnionBlob[:]); err != nil {
+		return err
+	}
+
+	return WriteBytes(w, c.ExtraData)
 }
 
 // MsgType returns the integer uniquely identifying this message type on the
diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go
index 2344b5683..61f02bac9 100644
--- a/lnwire/update_fail_htlc.go
+++ b/lnwire/update_fail_htlc.go
@@ -56,12 +56,19 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error {
 //
 // This is part of the lnwire.Message interface.
 func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error {
-	return WriteElements(w,
-		c.ChanID,
-		c.ID,
-		c.Reason,
-		c.ExtraData,
-	)
+	if err := WriteChannelID(w, c.ChanID); err != nil {
+		return err
+	}
+
+	if err := WriteUint64(w, c.ID); err != nil {
+		return err
+	}
+
+	if err := WriteOpaqueReason(w, c.Reason); err != nil {
+		return err
+	}
+
+	return WriteBytes(w, c.ExtraData)
 }
 
 // MsgType returns the integer uniquely identifying this message type on the
diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go
index 120f89541..f28107a9a 100644
--- a/lnwire/update_fail_malformed_htlc.go
+++ b/lnwire/update_fail_malformed_htlc.go
@@ -54,14 +54,26 @@ func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error {
 // io.Writer observing the protocol version specified.
 //
 // This is part of the lnwire.Message interface.
-func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer, pver uint32) error {
-	return WriteElements(w,
-		c.ChanID,
-		c.ID,
-		c.ShaOnionBlob[:],
-		c.FailureCode,
-		c.ExtraData,
-	)
+func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer,
+	pver uint32) error {
+
+	if err := WriteChannelID(w, c.ChanID); err != nil {
+		return err
+	}
+
+	if err := WriteUint64(w, c.ID); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, c.ShaOnionBlob[:]); err != nil {
+		return err
+	}
+
+	if err := WriteFailCode(w, c.FailureCode); err != nil {
+		return err
+	}
+
+	return WriteBytes(w, c.ExtraData)
 }
 
 // MsgType returns the integer uniquely identifying this message type on the
diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go
index a30634d79..a7026044c 100644
--- a/lnwire/update_fee.go
+++ b/lnwire/update_fee.go
@@ -53,11 +53,15 @@ func (c *UpdateFee) Decode(r io.Reader, pver uint32) error {
 //
 // This is part of the lnwire.Message interface.
 func (c *UpdateFee) Encode(w *bytes.Buffer, pver uint32) error {
-	return WriteElements(w,
-		c.ChanID,
-		c.FeePerKw,
-		c.ExtraData,
-	)
+	if err := WriteChannelID(w, c.ChanID); err != nil {
+		return err
+	}
+
+	if err := WriteUint32(w, c.FeePerKw); err != nil {
+		return err
+	}
+
+	return WriteBytes(w, c.ExtraData)
 }
 
 // MsgType returns the integer uniquely identifying this message type on the
diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go
index 4cd3a9720..275a37c87 100644
--- a/lnwire/update_fulfill_htlc.go
+++ b/lnwire/update_fulfill_htlc.go
@@ -62,12 +62,19 @@ func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error {
 //
 // This is part of the lnwire.Message interface.
 func (c *UpdateFulfillHTLC) Encode(w *bytes.Buffer, pver uint32) error {
-	return WriteElements(w,
-		c.ChanID,
-		c.ID,
-		c.PaymentPreimage[:],
-		c.ExtraData,
-	)
+	if err := WriteChannelID(w, c.ChanID); err != nil {
+		return err
+	}
+
+	if err := WriteUint64(w, c.ID); err != nil {
+		return err
+	}
+
+	if err := WriteBytes(w, c.PaymentPreimage[:]); err != nil {
+		return err
+	}
+
+	return WriteBytes(w, c.ExtraData)
 }
 
 // MsgType returns the integer uniquely identifying this message type on the