mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-12-08 11:53:16 +01:00
lnwire: modify ReadMessage to no longer return the total bytes read
This commit modifies ReadMessage to no longer return the total bytes read as this value will now be calculated at a higher level. The io.Reader that’s passed to ReadMessage is expected to contain the _entire_ message rather than be a pointer into a stream that contains the message itself.
This commit is contained in:
@@ -111,7 +111,7 @@ func TestLightningWireProtocol(t *testing.T) {
|
|||||||
|
|
||||||
// Finally, we'll deserialize the message from the written
|
// Finally, we'll deserialize the message from the written
|
||||||
// buffer, and finally assert that the messages are equal.
|
// buffer, and finally assert that the messages are equal.
|
||||||
_, newMsg, err := ReadMessage(&b, 0)
|
newMsg, err := ReadMessage(&b, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to read msg: %v", err)
|
t.Fatalf("unable to read msg: %v", err)
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -167,21 +167,14 @@ func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
|
|||||||
return totalBytes, err
|
return totalBytes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadMessage reads, validates, and parses the next bitcoin Message from r for
|
// ReadMessage reads, validates, and parses the next Lightning message from r
|
||||||
// the provided protocol version. It returns the number of bytes read in
|
// for the provided protocol version.
|
||||||
// addition to the parsed Message and raw bytes which comprise the message.
|
func ReadMessage(r io.Reader, pver uint32) (Message, error) {
|
||||||
func ReadMessage(r io.Reader, pver uint32) (int, Message, error) {
|
|
||||||
// TODO(roasbeef): need to explicitly enforce max message payload, or
|
|
||||||
// just allow it to be done by the MaxPayloadLength?
|
|
||||||
totalBytes := 0
|
|
||||||
|
|
||||||
// First, we'll read out the first two bytes of the message so we can
|
// First, we'll read out the first two bytes of the message so we can
|
||||||
// create the proper empty message.
|
// create the proper empty message.
|
||||||
var mType [2]byte
|
var mType [2]byte
|
||||||
n, err := io.ReadFull(r, mType[:])
|
if _, err := io.ReadFull(r, mType[:]); err != nil {
|
||||||
totalBytes += n
|
return nil, err
|
||||||
if err != nil {
|
|
||||||
return totalBytes, nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType := MessageType(binary.BigEndian.Uint16(mType[:]))
|
msgType := MessageType(binary.BigEndian.Uint16(mType[:]))
|
||||||
@@ -190,12 +183,11 @@ func ReadMessage(r io.Reader, pver uint32) (int, Message, error) {
|
|||||||
// empty message type and decode the message into it.
|
// empty message type and decode the message into it.
|
||||||
msg, err := makeEmptyMessage(msgType)
|
msg, err := makeEmptyMessage(msgType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return totalBytes, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := msg.Decode(r, pver); err != nil {
|
if err := msg.Decode(r, pver); err != nil {
|
||||||
return totalBytes, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
totalBytes += int(msg.MaxPayloadLength(pver))
|
|
||||||
|
|
||||||
return totalBytes, msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user