lnwire: refactor Encode to use specific writers - II

This commit takes another 10 message types and refactors their Encode
method to use specific writers. The following commit will refactor the
rest.
This commit is contained in:
yyforyongyu
2021-06-18 15:11:43 +08:00
parent 563ff7266a
commit c1ad9cc60f
10 changed files with 213 additions and 87 deletions

View File

@ -124,17 +124,37 @@ func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error {
// Encode serializes the target NodeAnnouncement into the passed io.Writer // Encode serializes the target NodeAnnouncement into the passed io.Writer
// observing the protocol version specified. // observing the protocol version specified.
// //
// This is part of the lnwire.Message interface.
func (a *NodeAnnouncement) Encode(w *bytes.Buffer, pver uint32) error { func (a *NodeAnnouncement) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteSig(w, a.Signature); err != nil {
a.Signature, return err
a.Features, }
a.Timestamp,
a.NodeID, if err := WriteRawFeatureVector(w, a.Features); err != nil {
a.RGBColor, return err
a.Alias, }
a.Addresses,
a.ExtraOpaqueData, if err := WriteUint32(w, a.Timestamp); err != nil {
) return err
}
if err := WriteBytes(w, a.NodeID[:]); err != nil {
return err
}
if err := WriteColorRGBA(w, a.RGBColor); err != nil {
return err
}
if err := WriteNodeAlias(w, a.Alias); err != nil {
return err
}
if err := WriteNetAddrs(w, a.Addresses); err != nil {
return err
}
return WriteBytes(w, a.ExtraOpaqueData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the
@ -149,19 +169,36 @@ func (a *NodeAnnouncement) MsgType() MessageType {
func (a *NodeAnnouncement) DataToSign() ([]byte, error) { func (a *NodeAnnouncement) DataToSign() ([]byte, error) {
// We should not include the signatures itself. // We should not include the signatures itself.
var w bytes.Buffer buffer := make([]byte, 0, MaxMsgBody)
err := WriteElements(&w, buf := bytes.NewBuffer(buffer)
a.Features,
a.Timestamp, if err := WriteRawFeatureVector(buf, a.Features); err != nil {
a.NodeID,
a.RGBColor,
a.Alias[:],
a.Addresses,
a.ExtraOpaqueData,
)
if err != nil {
return nil, err return nil, err
} }
return w.Bytes(), nil if err := WriteUint32(buf, a.Timestamp); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.NodeID[:]); err != nil {
return nil, err
}
if err := WriteColorRGBA(buf, a.RGBColor); err != nil {
return nil, err
}
if err := WriteNodeAlias(buf, a.Alias); err != nil {
return nil, err
}
if err := WriteNetAddrs(buf, a.Addresses); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
return nil, err
}
return buf.Bytes(), nil
} }

View File

@ -420,7 +420,11 @@ func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error {
// //
// NOTE: Part of the Serializable interface. // NOTE: Part of the Serializable interface.
func (f *FailIncorrectDetails) Encode(w *bytes.Buffer, pver uint32) error { func (f *FailIncorrectDetails) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, f.amount, f.height) if err := WriteMilliSatoshi(w, f.amount); err != nil {
return err
}
return WriteUint32(w, f.height)
} }
// FailFinalExpiryTooSoon is returned if the cltv_expiry is too low, the final // FailFinalExpiryTooSoon is returned if the cltv_expiry is too low, the final
@ -486,7 +490,7 @@ func (f *FailInvalidOnionVersion) Decode(r io.Reader, pver uint32) error {
// //
// NOTE: Part of the Serializable interface. // NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionVersion) Encode(w *bytes.Buffer, pver uint32) error { func (f *FailInvalidOnionVersion) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElement(w, f.OnionSHA256[:]) return WriteBytes(w, f.OnionSHA256[:])
} }
// FailInvalidOnionHmac is return if the onion HMAC is incorrect. // FailInvalidOnionHmac is return if the onion HMAC is incorrect.
@ -520,7 +524,7 @@ func (f *FailInvalidOnionHmac) Decode(r io.Reader, pver uint32) error {
// //
// NOTE: Part of the Serializable interface. // NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionHmac) Encode(w *bytes.Buffer, pver uint32) error { func (f *FailInvalidOnionHmac) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElement(w, f.OnionSHA256[:]) return WriteBytes(w, f.OnionSHA256[:])
} }
// Returns a human readable string describing the target FailureMessage. // Returns a human readable string describing the target FailureMessage.
@ -562,7 +566,7 @@ func (f *FailInvalidOnionKey) Decode(r io.Reader, pver uint32) error {
// //
// NOTE: Part of the Serializable interface. // NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionKey) Encode(w *bytes.Buffer, pver uint32) error { func (f *FailInvalidOnionKey) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElement(w, f.OnionSHA256[:]) return WriteBytes(w, f.OnionSHA256[:])
} }
// Returns a human readable string describing the target FailureMessage. // Returns a human readable string describing the target FailureMessage.
@ -682,7 +686,7 @@ func (f *FailTemporaryChannelFailure) Encode(w *bytes.Buffer,
payload = bw.Bytes() payload = bw.Bytes()
} }
if err := WriteElement(w, uint16(len(payload))); err != nil { if err := WriteUint16(w, uint16(len(payload))); err != nil {
return err return err
} }
@ -752,7 +756,7 @@ func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error {
// //
// NOTE: Part of the Serializable interface. // NOTE: Part of the Serializable interface.
func (f *FailAmountBelowMinimum) Encode(w *bytes.Buffer, pver uint32) error { func (f *FailAmountBelowMinimum) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteElement(w, f.HtlcMsat); err != nil { if err := WriteMilliSatoshi(w, f.HtlcMsat); err != nil {
return err return err
} }
@ -820,7 +824,7 @@ func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error {
// //
// NOTE: Part of the Serializable interface. // NOTE: Part of the Serializable interface.
func (f *FailFeeInsufficient) Encode(w *bytes.Buffer, pver uint32) error { func (f *FailFeeInsufficient) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteElement(w, f.HtlcMsat); err != nil { if err := WriteMilliSatoshi(w, f.HtlcMsat); err != nil {
return err return err
} }
@ -888,7 +892,7 @@ func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error {
// //
// NOTE: Part of the Serializable interface. // NOTE: Part of the Serializable interface.
func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteElement(w, f.CltvExpiry); err != nil { if err := WriteUint32(w, f.CltvExpiry); err != nil {
return err return err
} }
@ -1009,7 +1013,7 @@ func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error {
// //
// NOTE: Part of the Serializable interface. // NOTE: Part of the Serializable interface.
func (f *FailChannelDisabled) Encode(w *bytes.Buffer, pver uint32) error { func (f *FailChannelDisabled) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteElement(w, f.Flags); err != nil { if err := WriteUint16(w, f.Flags); err != nil {
return err return err
} }
@ -1061,7 +1065,7 @@ func (f *FailFinalIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error {
func (f *FailFinalIncorrectCltvExpiry) Encode(w *bytes.Buffer, func (f *FailFinalIncorrectCltvExpiry) Encode(w *bytes.Buffer,
pver uint32) error { pver uint32) error {
return WriteElement(w, f.CltvExpiry) return WriteUint32(w, f.CltvExpiry)
} }
// FailFinalIncorrectHtlcAmount is returned if the amt_to_forward is higher // FailFinalIncorrectHtlcAmount is returned if the amt_to_forward is higher
@ -1109,7 +1113,7 @@ func (f *FailFinalIncorrectHtlcAmount) Decode(r io.Reader, pver uint32) error {
func (f *FailFinalIncorrectHtlcAmount) Encode(w *bytes.Buffer, func (f *FailFinalIncorrectHtlcAmount) Encode(w *bytes.Buffer,
pver uint32) error { pver uint32) error {
return WriteElement(w, f.IncomingHTLCAmount) return WriteMilliSatoshi(w, f.IncomingHTLCAmount)
} }
// FailExpiryTooFar is returned if the CLTV expiry in the HTLC is too far in the // FailExpiryTooFar is returned if the CLTV expiry in the HTLC is too far in the
@ -1189,7 +1193,7 @@ func (f *InvalidOnionPayload) Encode(w *bytes.Buffer, pver uint32) error {
return err return err
} }
return WriteElements(w, f.Offset) return WriteUint16(w, f.Offset)
} }
// FailMPPTimeout is returned if the complete amount for a multi part payment // FailMPPTimeout is returned if the complete amount for a multi part payment
@ -1289,12 +1293,18 @@ func EncodeFailure(w *bytes.Buffer, failure FailureMessage, pver uint32) error {
// messages are fixed size. // messages are fixed size.
pad := make([]byte, FailureMessageLength-len(failureMessage)) pad := make([]byte, FailureMessageLength-len(failureMessage))
return WriteElements(w, if err := WriteUint16(w, uint16(len(failureMessage))); err != nil {
uint16(len(failureMessage)), return err
failureMessage, }
uint16(len(pad)),
pad, if err := WriteBytes(w, failureMessage); err != nil {
) return err
}
if err := WriteUint16(w, uint16(len(pad))); err != nil {
return err
}
return WriteBytes(w, pad)
} }
// EncodeFailureMessage encodes just the failure message without adding a length // EncodeFailureMessage encodes just the failure message without adding a length
@ -1422,7 +1432,7 @@ func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate,
// Now that we know the size, we can write the length out in the main // Now that we know the size, we can write the length out in the main
// writer. // writer.
updateLen := b.Len() updateLen := b.Len()
if err := WriteElement(w, uint16(updateLen)); err != nil { if err := WriteUint16(w, uint16(updateLen)); err != nil {
return err return err
} }

View File

@ -278,5 +278,5 @@ func (f *mockFailIncorrectDetailsNoHeight) Decode(r io.Reader, pver uint32) erro
func (f *mockFailIncorrectDetailsNoHeight) Encode(w *bytes.Buffer, func (f *mockFailIncorrectDetailsNoHeight) Encode(w *bytes.Buffer,
pver uint32) error { pver uint32) error {
return WriteElement(w, f.amount) return WriteUint64(w, f.amount)
} }

View File

@ -161,27 +161,80 @@ func (o *OpenChannel) Encode(w *bytes.Buffer, pver uint32) error {
return err return err
} }
return WriteElements(w, if err := WriteBytes(w, o.ChainHash[:]); err != nil {
o.ChainHash[:], return err
o.PendingChannelID[:], }
o.FundingAmount,
o.PushAmount, if err := WriteBytes(w, o.PendingChannelID[:]); err != nil {
o.DustLimit, return err
o.MaxValueInFlight, }
o.ChannelReserve,
o.HtlcMinimum, if err := WriteSatoshi(w, o.FundingAmount); err != nil {
o.FeePerKiloWeight, return err
o.CsvDelay, }
o.MaxAcceptedHTLCs,
o.FundingKey, if err := WriteMilliSatoshi(w, o.PushAmount); err != nil {
o.RevocationPoint, return err
o.PaymentPoint, }
o.DelayedPaymentPoint,
o.HtlcPoint, if err := WriteSatoshi(w, o.DustLimit); err != nil {
o.FirstCommitmentPoint, return err
o.ChannelFlags, }
tlvRecords,
) if err := WriteMilliSatoshi(w, o.MaxValueInFlight); err != nil {
return err
}
if err := WriteSatoshi(w, o.ChannelReserve); err != nil {
return err
}
if err := WriteMilliSatoshi(w, o.HtlcMinimum); err != nil {
return err
}
if err := WriteUint32(w, o.FeePerKiloWeight); err != nil {
return err
}
if err := WriteUint16(w, o.CsvDelay); err != nil {
return err
}
if err := WriteUint16(w, o.MaxAcceptedHTLCs); err != nil {
return err
}
if err := WritePublicKey(w, o.FundingKey); err != nil {
return err
}
if err := WritePublicKey(w, o.RevocationPoint); err != nil {
return err
}
if err := WritePublicKey(w, o.PaymentPoint); err != nil {
return err
}
if err := WritePublicKey(w, o.DelayedPaymentPoint); err != nil {
return err
}
if err := WritePublicKey(w, o.HtlcPoint); err != nil {
return err
}
if err := WritePublicKey(w, o.FirstCommitmentPoint); err != nil {
return err
}
if err := WriteFundingFlag(w, o.ChannelFlags); err != nil {
return err
}
return WriteBytes(w, tlvRecords)
} }
// Decode deserializes the serialized OpenChannel stored in the passed // Decode deserializes the serialized OpenChannel stored in the passed

View File

@ -48,9 +48,11 @@ func (p *Ping) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (p *Ping) Encode(w *bytes.Buffer, pver uint32) error { func (p *Ping) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteUint16(w, p.NumPongBytes); err != nil {
p.NumPongBytes, return err
p.PaddingBytes) }
return WritePingPayload(w, p.PaddingBytes)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@ -44,9 +44,7 @@ func (p *Pong) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (p *Pong) Encode(w *bytes.Buffer, pver uint32) error { func (p *Pong) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, return WritePongPayload(w, p.PongBytes)
p.PongBytes,
)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@ -60,12 +60,19 @@ func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error { func (q *QueryChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteBytes(w, q.ChainHash[:]); err != nil {
q.ChainHash[:], return err
q.FirstBlockHeight, }
q.NumBlocks,
q.ExtraData, if err := WriteUint32(w, q.FirstBlockHeight); err != nil {
) return err
}
if err := WriteUint32(w, q.NumBlocks); err != nil {
return err
}
return WriteBytes(w, q.ExtraData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@ -56,11 +56,15 @@ func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *ReplyShortChanIDsEnd) Encode(w *bytes.Buffer, pver uint32) error { func (c *ReplyShortChanIDsEnd) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteBytes(w, c.ChainHash[:]); err != nil {
c.ChainHash[:], return err
c.Complete, }
c.ExtraData,
) if err := WriteUint8(w, c.Complete); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@ -67,12 +67,19 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *RevokeAndAck) Encode(w *bytes.Buffer, pver uint32) error { func (c *RevokeAndAck) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, if err := WriteChannelID(w, c.ChanID); err != nil {
c.ChanID, return err
c.Revocation[:], }
c.NextRevocationKey,
c.ExtraData, if err := WriteBytes(w, c.Revocation[:]); err != nil {
) return err
}
if err := WritePublicKey(w, c.NextRevocationKey); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

View File

@ -48,7 +48,15 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error {
return WriteElements(w, s.ChannelID, s.Address, s.ExtraData) if err := WriteChannelID(w, s.ChannelID); err != nil {
return err
}
if err := WriteDeliveryAddress(w, s.Address); err != nil {
return err
}
return WriteBytes(w, s.ExtraData)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the