lnwire: refactor Encode to use specific writers - III

This commit refactors the remaining usage of WriteElements. By
replacing the interface types with concrete types for the params used in
the methods, most of the encoding of the messages now takes zero heap
allocations.
This commit is contained in:
yyforyongyu
2021-06-18 15:15:44 +08:00
parent c1ad9cc60f
commit 2cf6969dbc
12 changed files with 289 additions and 133 deletions

View File

@@ -88,20 +88,51 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error { func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteSig(w, a.NodeSig1); err != nil {
a.NodeSig1, return err
a.NodeSig2, }
a.BitcoinSig1,
a.BitcoinSig2, if err := WriteSig(w, a.NodeSig2); err != nil {
a.Features, return err
a.ChainHash[:], }
a.ShortChannelID,
a.NodeID1, if err := WriteSig(w, a.BitcoinSig1); err != nil {
a.NodeID2, return err
a.BitcoinKey1, }
a.BitcoinKey2,
a.ExtraOpaqueData, if err := WriteSig(w, a.BitcoinSig2); err != nil {
) return err
}
if err := WriteRawFeatureVector(w, a.Features); err != nil {
return err
}
if err := WriteBytes(w, a.ChainHash[:]); err != nil {
return err
}
if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
return err
}
if err := WriteBytes(w, a.NodeID1[:]); err != nil {
return err
}
if err := WriteBytes(w, a.NodeID2[:]); err != nil {
return err
}
if err := WriteBytes(w, a.BitcoinKey1[:]); err != nil {
return err
}
if err := WriteBytes(w, a.BitcoinKey2[:]); err != nil {
return err
}
return WriteBytes(w, a.ExtraOpaqueData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the
@@ -116,20 +147,40 @@ func (a *ChannelAnnouncement) MsgType() MessageType {
// be signed. // be signed.
func (a *ChannelAnnouncement) DataToSign() ([]byte, error) { func (a *ChannelAnnouncement) DataToSign() ([]byte, error) {
// We should not include the signatures itself. // We should not include the signatures itself.
var w bytes.Buffer b := make([]byte, 0, MaxMsgBody)
err := WriteElements(&w, buf := bytes.NewBuffer(b)
a.Features,
a.ChainHash[:], if err := WriteRawFeatureVector(buf, a.Features); err != nil {
a.ShortChannelID,
a.NodeID1,
a.NodeID2,
a.BitcoinKey1,
a.BitcoinKey2,
a.ExtraOpaqueData,
)
if err != nil {
return nil, err return nil, err
} }
return w.Bytes(), nil if err := WriteBytes(buf, a.ChainHash[:]); err != nil {
return nil, err
}
if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.NodeID1[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.NodeID2[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.BitcoinKey1[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.BitcoinKey2[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
return nil, err
}
return buf.Bytes(), nil
} }

View File

@@ -77,12 +77,15 @@ var _ Message = (*ChannelReestablish)(nil)
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error { func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error {
err := WriteElements(w, if err := WriteChannelID(w, a.ChanID); err != nil {
a.ChanID, return err
a.NextLocalCommitHeight, }
a.RemoteCommitTailHeight,
) if err := WriteUint64(w, a.NextLocalCommitHeight); err != nil {
if err != nil { return err
}
if err := WriteUint64(w, a.RemoteCommitTailHeight); err != nil {
return err return err
} }
@@ -94,15 +97,18 @@ func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error {
// //
// NOTE: This is here primarily for the quickcheck tests, in // NOTE: This is here primarily for the quickcheck tests, in
// practice, we'll always populate this field. // practice, we'll always populate this field.
return WriteElements(w, a.ExtraData) return WriteBytes(w, a.ExtraData)
} }
// Otherwise, we'll write out the remaining elements. // Otherwise, we'll write out the remaining elements.
return WriteElements(w, if err := WriteBytes(w, a.LastRemoteCommitSecret[:]); err != nil {
a.LastRemoteCommitSecret[:], return err
a.LocalUnrevokedCommitPoint, }
a.ExtraData,
) if err := WritePublicKey(w, a.LocalUnrevokedCommitPoint); err != nil {
return err
}
return WriteBytes(w, a.ExtraData)
} }
// Decode deserializes a serialized ChannelReestablish stored in the passed // Decode deserializes a serialized ChannelReestablish stored in the passed

View File

@@ -160,32 +160,57 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error { func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error {
err := WriteElements(w, if err := WriteSig(w, a.Signature); err != nil {
a.Signature, return err
a.ChainHash[:], }
a.ShortChannelID,
a.Timestamp, if err := WriteBytes(w, a.ChainHash[:]); err != nil {
a.MessageFlags, return err
a.ChannelFlags, }
a.TimeLockDelta,
a.HtlcMinimumMsat, if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
a.BaseFee, return err
a.FeeRate, }
)
if err != nil { if err := WriteUint32(w, a.Timestamp); err != nil {
return err
}
if err := WriteChanUpdateMsgFlags(w, a.MessageFlags); err != nil {
return err
}
if err := WriteChanUpdateChanFlags(w, a.ChannelFlags); err != nil {
return err
}
if err := WriteUint16(w, a.TimeLockDelta); err != nil {
return err
}
if err := WriteMilliSatoshi(w, a.HtlcMinimumMsat); err != nil {
return err
}
if err := WriteUint32(w, a.BaseFee); err != nil {
return err
}
if err := WriteUint32(w, a.FeeRate); err != nil {
return err return err
} }
// Now append optional fields if they are set. Currently, the only // Now append optional fields if they are set. Currently, the only
// optional field is max HTLC. // optional field is max HTLC.
if a.MessageFlags.HasMaxHtlc() { if a.MessageFlags.HasMaxHtlc() {
if err := WriteElements(w, a.HtlcMaximumMsat); err != nil { err := WriteMilliSatoshi(w, a.HtlcMaximumMsat)
if err != nil {
return err return err
} }
} }
// Finally, append any extra opaque data. // Finally, append any extra opaque data.
return a.ExtraOpaqueData.Encode(w) return WriteBytes(w, a.ExtraOpaqueData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the
@@ -199,36 +224,58 @@ func (a *ChannelUpdate) MsgType() MessageType {
// DataToSign is used to retrieve part of the announcement message which should // DataToSign is used to retrieve part of the announcement message which should
// be signed. // be signed.
func (a *ChannelUpdate) DataToSign() ([]byte, error) { func (a *ChannelUpdate) DataToSign() ([]byte, error) {
// We should not include the signatures itself. // We should not include the signatures itself.
var w bytes.Buffer b := make([]byte, 0, MaxMsgBody)
err := WriteElements(&w, buf := bytes.NewBuffer(b)
a.ChainHash[:], if err := WriteBytes(buf, a.ChainHash[:]); err != nil {
a.ShortChannelID, return nil, err
a.Timestamp, }
a.MessageFlags,
a.ChannelFlags, if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil {
a.TimeLockDelta, return nil, err
a.HtlcMinimumMsat, }
a.BaseFee,
a.FeeRate, if err := WriteUint32(buf, a.Timestamp); err != nil {
) return nil, err
if err != nil { }
if err := WriteChanUpdateMsgFlags(buf, a.MessageFlags); err != nil {
return nil, err
}
if err := WriteChanUpdateChanFlags(buf, a.ChannelFlags); err != nil {
return nil, err
}
if err := WriteUint16(buf, a.TimeLockDelta); err != nil {
return nil, err
}
if err := WriteMilliSatoshi(buf, a.HtlcMinimumMsat); err != nil {
return nil, err
}
if err := WriteUint32(buf, a.BaseFee); err != nil {
return nil, err
}
if err := WriteUint32(buf, a.FeeRate); err != nil {
return nil, err return nil, err
} }
// Now append optional fields if they are set. Currently, the only // Now append optional fields if they are set. Currently, the only
// optional field is max HTLC. // optional field is max HTLC.
if a.MessageFlags.HasMaxHtlc() { if a.MessageFlags.HasMaxHtlc() {
if err := WriteElements(&w, a.HtlcMaximumMsat); err != nil { err := WriteMilliSatoshi(buf, a.HtlcMaximumMsat)
if err != nil {
return nil, err return nil, err
} }
} }
// Finally, append any extra opaque data. // Finally, append any extra opaque data.
if err := a.ExtraOpaqueData.Encode(&w); err != nil { if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
return nil, err return nil, err
} }
return w.Bytes(), nil return buf.Bytes(), nil
} }

View File

@@ -18,7 +18,7 @@ type ExtraOpaqueData []byte
// Encode attempts to encode the raw extra bytes into the passed io.Writer. // Encode attempts to encode the raw extra bytes into the passed io.Writer.
func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error { func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error {
eBytes := []byte((*e)[:]) eBytes := []byte((*e)[:])
if err := WriteElements(w, eBytes); err != nil { if err := WriteBytes(w, eBytes); err != nil {
return err return err
} }

View File

@@ -232,7 +232,7 @@ func TestMaxOutPointIndex(t *testing.T) {
} }
var b bytes.Buffer var b bytes.Buffer
if err := WriteElement(&b, op); err == nil { if err := WriteOutPoint(&b, op); err == nil {
t.Fatalf("write of outPoint should fail, index exceeds 16-bits") t.Fatalf("write of outPoint should fail, index exceeds 16-bits")
} }
} }

View File

@@ -293,19 +293,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error { func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
// First, we'll write out the chain hash. // First, we'll write out the chain hash.
err := WriteElements(w, q.ChainHash[:]) if err := WriteBytes(w, q.ChainHash[:]); err != nil {
if err != nil {
return err return err
} }
// Base on our encoding type, we'll write out the set of short channel // Base on our encoding type, we'll write out the set of short channel
// ID's. // ID's.
err = encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
if err != nil { if err != nil {
return err return err
} }
return q.ExtraData.Encode(w) return WriteBytes(w, q.ExtraData)
} }
// encodeShortChanIDs encodes the passed short channel ID's into the passed // encodeShortChanIDs encodes the passed short channel ID's into the passed
@@ -332,20 +331,21 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
// body. We add 1 as the response will have the encoding type // body. We add 1 as the response will have the encoding type
// prepended to it. // prepended to it.
numBytesBody := uint16(len(shortChanIDs)*8) + 1 numBytesBody := uint16(len(shortChanIDs)*8) + 1
if err := WriteElements(w, numBytesBody); err != nil { if err := WriteUint16(w, numBytesBody); err != nil {
return err return err
} }
// We'll then write out the encoding that that follows the // We'll then write out the encoding that that follows the
// actual encoded short channel ID's. // actual encoded short channel ID's.
if err := WriteElements(w, encodingType); err != nil { err := WriteShortChanIDEncoding(w, encodingType)
if err != nil {
return err return err
} }
// Now that we know they're sorted, we can write out each short // Now that we know they're sorted, we can write out each short
// channel ID to the buffer. // channel ID to the buffer.
for _, chanID := range shortChanIDs { for _, chanID := range shortChanIDs {
if err := WriteElements(w, chanID); err != nil { if err := WriteShortChannelID(w, chanID); err != nil {
return fmt.Errorf("unable to write short chan "+ return fmt.Errorf("unable to write short chan "+
"ID: %v", err) "ID: %v", err)
} }
@@ -374,7 +374,7 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
// into the zlib writer, which will do compressing on // into the zlib writer, which will do compressing on
// the fly. // the fly.
for _, chanID := range shortChanIDs { for _, chanID := range shortChanIDs {
err := WriteElements(&wb, chanID) err := WriteShortChannelID(&wb, chanID)
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf(
"unable to write short chan "+ "unable to write short chan "+
@@ -418,15 +418,15 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
// Finally, we can write out the number of bytes, the // Finally, we can write out the number of bytes, the
// compression type, and finally the buffer itself. // compression type, and finally the buffer itself.
if err := WriteElements(w, uint16(numBytesBody)); err != nil { if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
return err return err
} }
if err := WriteElements(w, encodingType); err != nil { err := WriteShortChanIDEncoding(w, encodingType)
if err != nil {
return err return err
} }
_, err := w.Write(compressedPayload) return WriteBytes(w, compressedPayload)
return err
default: default:
// If we're trying to encode with an encoding type that we // If we're trying to encode with an encoding type that we

View File

@@ -87,22 +87,28 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error { func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
err := WriteElements(w, if err := WriteBytes(w, c.ChainHash[:]); err != nil {
c.ChainHash[:], return err
c.FirstBlockHeight, }
c.NumBlocks,
c.Complete, if err := WriteUint32(w, c.FirstBlockHeight); err != nil {
) return err
}
if err := WriteUint32(w, c.NumBlocks); err != nil {
return err
}
if err := WriteUint8(w, c.Complete); err != nil {
return err
}
err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
if err != nil { if err != nil {
return err return err
} }
err = encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) return WriteBytes(w, c.ExtraData)
if err != nil {
return err
}
return c.ExtraData.Encode(w)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@@ -85,20 +85,36 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
) )
} }
// Encode serializes the target UpdateAddHTLC into the passed io.Writer observing // Encode serializes the target UpdateAddHTLC into the passed io.Writer
// the protocol version specified. // observing the protocol version specified.
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error { func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteChannelID(w, c.ChanID); err != nil {
c.ChanID, return err
c.ID, }
c.Amount,
c.PaymentHash[:], if err := WriteUint64(w, c.ID); err != nil {
c.Expiry, return err
c.OnionBlob[:], }
c.ExtraData,
) if err := WriteMilliSatoshi(w, c.Amount); err != nil {
return err
}
if err := WriteBytes(w, c.PaymentHash[:]); err != nil {
return err
}
if err := WriteUint32(w, c.Expiry); err != nil {
return err
}
if err := WriteBytes(w, c.OnionBlob[:]); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@@ -56,12 +56,19 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error { func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteChannelID(w, c.ChanID); err != nil {
c.ChanID, return err
c.ID, }
c.Reason,
c.ExtraData, if err := WriteUint64(w, c.ID); err != nil {
) return err
}
if err := WriteOpaqueReason(w, c.Reason); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@@ -54,14 +54,26 @@ func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error {
// io.Writer observing the protocol version specified. // io.Writer observing the protocol version specified.
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer, pver uint32) error { func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer,
return WriteElements(w, pver uint32) error {
c.ChanID,
c.ID, if err := WriteChannelID(w, c.ChanID); err != nil {
c.ShaOnionBlob[:], return err
c.FailureCode, }
c.ExtraData,
) if err := WriteUint64(w, c.ID); err != nil {
return err
}
if err := WriteBytes(w, c.ShaOnionBlob[:]); err != nil {
return err
}
if err := WriteFailCode(w, c.FailureCode); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@@ -53,11 +53,15 @@ func (c *UpdateFee) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *UpdateFee) Encode(w *bytes.Buffer, pver uint32) error { func (c *UpdateFee) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteChannelID(w, c.ChanID); err != nil {
c.ChanID, return err
c.FeePerKw, }
c.ExtraData,
) if err := WriteUint32(w, c.FeePerKw); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@@ -62,12 +62,19 @@ func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) Encode(w *bytes.Buffer, pver uint32) error { func (c *UpdateFulfillHTLC) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteChannelID(w, c.ChanID); err != nil {
c.ChanID, return err
c.ID, }
c.PaymentPreimage[:],
c.ExtraData, if err := WriteUint64(w, c.ID); err != nil {
) return err
}
if err := WriteBytes(w, c.PaymentPreimage[:]); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the