mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-07 03:06:01 +02:00
lnwire: refactor WriteMessage to use bytes.Buffer
This commit changes the method WriteMessage to use bytes.Buffer to save heap allocations. A unit test is added to check the method is implemented as expected.
This commit is contained in:
@@ -52,6 +52,27 @@ const (
|
|||||||
MsgGossipTimestampRange = 265
|
MsgGossipTimestampRange = 265
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrorEncodeMessage is used when failed to encode the message payload.
|
||||||
|
func ErrorEncodeMessage(err error) error {
|
||||||
|
return fmt.Errorf("failed to encode message to buffer, got %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorWriteMessageType is used when failed to write the message type.
|
||||||
|
func ErrorWriteMessageType(err error) error {
|
||||||
|
return fmt.Errorf("failed to write message type, got %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorPayloadTooLarge is used when the payload size exceeds the
|
||||||
|
// MaxMsgBody.
|
||||||
|
func ErrorPayloadTooLarge(size int) error {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"message payload is too large - encoded %d bytes, "+
|
||||||
|
"but maximum message payload is %d bytes",
|
||||||
|
size, MaxMsgBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// String return the string representation of message type.
|
// String return the string representation of message type.
|
||||||
func (t MessageType) String() string {
|
func (t MessageType) String() string {
|
||||||
switch t {
|
switch t {
|
||||||
@@ -218,44 +239,49 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteMessage writes a lightning Message to w including the necessary header
|
// WriteMessage writes a lightning Message to a buffer including the necessary
|
||||||
// information and returns the number of bytes written.
|
// header information and returns the number of bytes written. If any error is
|
||||||
func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
|
// encountered, the buffer passed will be reset to its original state since we
|
||||||
totalBytes := 0
|
// don't want any broken bytes left. In other words, no bytes will be written
|
||||||
|
// if there's an error. Either all or none of the message bytes will be written
|
||||||
|
// to the buffer.
|
||||||
|
//
|
||||||
|
// NOTE: this method is not concurrent safe.
|
||||||
|
func WriteMessage(buf *bytes.Buffer, msg Message, pver uint32) (int, error) {
|
||||||
|
// Record the size of the bytes already written in buffer.
|
||||||
|
oldByteSize := buf.Len()
|
||||||
|
|
||||||
// Encode the message payload itself into a temporary buffer.
|
// cleanBrokenBytes is a helper closure that helps reset the buffer to
|
||||||
// TODO(roasbeef): create buffer pool
|
// its original state. It truncates all the bytes written in current
|
||||||
var bw bytes.Buffer
|
// scope.
|
||||||
if err := msg.Encode(&bw, pver); err != nil {
|
var cleanBrokenBytes = func(b *bytes.Buffer) int {
|
||||||
return totalBytes, err
|
b.Truncate(oldByteSize)
|
||||||
}
|
return 0
|
||||||
payload := bw.Bytes()
|
|
||||||
lenp := len(payload)
|
|
||||||
|
|
||||||
// Enforce maximum message payload, which means the body cannot be
|
|
||||||
// greater than MaxMsgBody.
|
|
||||||
if lenp > MaxMsgBody {
|
|
||||||
return totalBytes, fmt.Errorf("message payload is too large - "+
|
|
||||||
"encoded %d bytes, but maximum message body is %d bytes",
|
|
||||||
lenp, MaxMsgBody)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// With the initial sanity checks complete, we'll now write out the
|
// Write the message type.
|
||||||
// message type itself.
|
|
||||||
var mType [2]byte
|
var mType [2]byte
|
||||||
binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType()))
|
binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType()))
|
||||||
n, err := w.Write(mType[:])
|
msgTypeBytes, err := buf.Write(mType[:])
|
||||||
totalBytes += n
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return totalBytes, err
|
return cleanBrokenBytes(buf), ErrorWriteMessageType(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// With the message type written, we'll now write out the raw payload
|
// Use the write buffer to encode our message.
|
||||||
// itself.
|
if err := msg.Encode(buf, pver); err != nil {
|
||||||
n, err = w.Write(payload)
|
return cleanBrokenBytes(buf), ErrorEncodeMessage(err)
|
||||||
totalBytes += n
|
}
|
||||||
|
|
||||||
return totalBytes, err
|
// Enforce maximum overall message payload. The write buffer now has
|
||||||
|
// the size of len(originalBytes) + len(payload) + len(type). We want
|
||||||
|
// to enforce the payload here, so we subtract it by the length of the
|
||||||
|
// type and old bytes.
|
||||||
|
lenp := buf.Len() - oldByteSize - msgTypeBytes
|
||||||
|
if lenp > MaxMsgBody {
|
||||||
|
return cleanBrokenBytes(buf), ErrorPayloadTooLarge(lenp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Len() - oldByteSize, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadMessage reads, validates, and parses the next Lightning message from r
|
// ReadMessage reads, validates, and parses the next Lightning message from r
|
||||||
|
@@ -3,6 +3,7 @@ package lnwire_test
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"image/color"
|
"image/color"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
@@ -16,6 +17,7 @@ import (
|
|||||||
"github.com/btcsuite/btcutil"
|
"github.com/btcsuite/btcutil"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/tor"
|
"github.com/lightningnetwork/lnd/tor"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,6 +43,148 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mockMsg struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMsg) Decode(r io.Reader, pver uint32) error {
|
||||||
|
args := m.Called(r, pver)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMsg) Encode(w io.Writer, pver uint32) error {
|
||||||
|
args := m.Called(w, pver)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMsg) MsgType() lnwire.MessageType {
|
||||||
|
args := m.Called()
|
||||||
|
return lnwire.MessageType(args.Int(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// A compile time check to ensure mockMsg implements the lnwire.Message
|
||||||
|
// interface.
|
||||||
|
var _ lnwire.Message = (*mockMsg)(nil)
|
||||||
|
|
||||||
|
// TestWriteMessage tests the function lnwire.WriteMessage.
|
||||||
|
func TestWriteMessage(t *testing.T) {
|
||||||
|
var (
|
||||||
|
buf = new(bytes.Buffer)
|
||||||
|
|
||||||
|
// encodeNormalSize specifies a message size that is normal.
|
||||||
|
encodeNormalSize = 1000
|
||||||
|
|
||||||
|
// encodeOversize specifies a message size that's too big.
|
||||||
|
encodeOversize = lnwire.MaxMsgBody + 1
|
||||||
|
|
||||||
|
// errDummy is returned by the msg.Encode when specified.
|
||||||
|
errDummy = errors.New("test error")
|
||||||
|
|
||||||
|
// oneByte is a dummy byte used to fill up the buffer.
|
||||||
|
oneByte = [1]byte{}
|
||||||
|
)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
|
||||||
|
// encodeSize controls how many bytes are written to the buffer
|
||||||
|
// by the method msg.Encode(buf, pver).
|
||||||
|
encodeSize int
|
||||||
|
|
||||||
|
// encodeErr determines the return value of the method
|
||||||
|
// msg.Encode(buf, pver).
|
||||||
|
encodeErr error
|
||||||
|
|
||||||
|
errorExpected error
|
||||||
|
}{
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "successful write",
|
||||||
|
encodeSize: encodeNormalSize,
|
||||||
|
encodeErr: nil,
|
||||||
|
errorExpected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failed to encode payload",
|
||||||
|
encodeSize: encodeNormalSize,
|
||||||
|
encodeErr: errDummy,
|
||||||
|
errorExpected: lnwire.ErrorEncodeMessage(errDummy),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exceeds MaxMsgBody",
|
||||||
|
encodeSize: encodeOversize,
|
||||||
|
encodeErr: nil,
|
||||||
|
errorExpected: lnwire.ErrorPayloadTooLarge(
|
||||||
|
encodeOversize,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
tc := test
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Start the test by creating a mock message and patch
|
||||||
|
// the relevant methods.
|
||||||
|
msg := &mockMsg{}
|
||||||
|
|
||||||
|
// Use message type Ping here since all types are
|
||||||
|
// encoded using 2 bytes, it won't affect anything
|
||||||
|
// here.
|
||||||
|
msg.On("MsgType").Return(lnwire.MsgPing)
|
||||||
|
|
||||||
|
// Encode will return the specified error (could be
|
||||||
|
// nil) and has the side effect of filling up the
|
||||||
|
// buffer by repeating the oneByte encodeSize times.
|
||||||
|
msg.On("Encode", mock.Anything, mock.Anything).Return(
|
||||||
|
tc.encodeErr,
|
||||||
|
).Run(func(_ mock.Arguments) {
|
||||||
|
for i := 0; i < tc.encodeSize; i++ {
|
||||||
|
_, err := buf.Write(oneByte[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Record the initial state of the buffer and write the
|
||||||
|
// message.
|
||||||
|
oldBytesSize := buf.Len()
|
||||||
|
bytesWritten, err := lnwire.WriteMessage(
|
||||||
|
buf, msg, 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check that the returned error is expected.
|
||||||
|
require.Equal(
|
||||||
|
t, tc.errorExpected, err, "unexpected err",
|
||||||
|
)
|
||||||
|
|
||||||
|
// If there's an error, no bytes should be written to
|
||||||
|
// the buf.
|
||||||
|
if tc.errorExpected != nil {
|
||||||
|
require.Equal(
|
||||||
|
t, 0, bytesWritten,
|
||||||
|
"bytes written should be 0",
|
||||||
|
)
|
||||||
|
|
||||||
|
// We also check that the old buf was not
|
||||||
|
// affected.
|
||||||
|
require.Equal(
|
||||||
|
t, oldBytesSize, buf.Len(),
|
||||||
|
"original buffer should not change",
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
expected := buf.Len() - oldBytesSize
|
||||||
|
require.Equal(
|
||||||
|
t, expected, bytesWritten,
|
||||||
|
"bytes written not matched",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, check the mocked methods are called as
|
||||||
|
// expected.
|
||||||
|
msg.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// BenchmarkWriteMessage benchmarks the performance of lnwire.WriteMessage. It
|
// BenchmarkWriteMessage benchmarks the performance of lnwire.WriteMessage. It
|
||||||
// generates a test message for each of the lnwire.Message, calls the
|
// generates a test message for each of the lnwire.Message, calls the
|
||||||
// WriteMessage method and benchmark it.
|
// WriteMessage method and benchmark it.
|
||||||
|
Reference in New Issue
Block a user