mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-10-03 20:14:49 +02:00
Message interface and stuff.
* Added Message interface (similar to btcd's) * Moved Funding Request to its own file * Refacored Funding Request Code (*MUCH* better) * Various fixes
This commit is contained in:
209
lnwire/message.go
Normal file
209
lnwire/message.go
Normal file
@@ -0,0 +1,209 @@
|
||||
//Code derived from https://github.com/btcsuite/btcd/blob/master/wire/message.go
|
||||
package lnwire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"io"
|
||||
)
|
||||
|
||||
//4-byte network + 4-byte message id + payload-length 4-byte
|
||||
//Maybe add checksum or something?
|
||||
const MessageHeaderSize = 12
|
||||
|
||||
const MaxMessagePayload = 1024 * 1024 * 32 // 32MB
|
||||
|
||||
const (
|
||||
CmdFundingRequest = uint32(20)
|
||||
)
|
||||
|
||||
//Every message has these functions:
|
||||
type Message interface {
|
||||
Decode(io.Reader, uint32) error //(io, protocol version)
|
||||
Encode(io.Writer, uint32) error //(io, protocol version)
|
||||
Command() uint32 //returns ID of the message
|
||||
MaxPayloadLength(uint32) uint32 //(version) maxpayloadsize
|
||||
Validate() error //Validates the data struct
|
||||
String() string
|
||||
}
|
||||
|
||||
func makeEmptyMessage(command uint32) (Message, error) {
|
||||
var msg Message
|
||||
|
||||
switch command {
|
||||
case CmdFundingRequest:
|
||||
msg = &FundingRequest{}
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled command [%x]", command)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
type messageHeader struct {
|
||||
//NOTE(j): We don't need to worry about the magic overlapping with
|
||||
//bitcoin since this is inside encrypted comms anyway, but maybe we
|
||||
//should use the XOR (^wire.TestNet3) just in case???
|
||||
magic wire.BitcoinNet
|
||||
command uint32
|
||||
length uint32
|
||||
//Do we need a checksum here?
|
||||
}
|
||||
|
||||
func readMessageHeader(r io.Reader) (int, *messageHeader, error) {
|
||||
var headerBytes [MessageHeaderSize]byte
|
||||
n, err := io.ReadFull(r, headerBytes[:])
|
||||
if err != nil {
|
||||
return n, nil, err
|
||||
}
|
||||
hr := bytes.NewReader(headerBytes[:])
|
||||
|
||||
hdr := messageHeader{}
|
||||
|
||||
err = readElements(hr, false,
|
||||
&hdr.magic,
|
||||
&hdr.command,
|
||||
&hdr.length)
|
||||
if err != nil {
|
||||
return n, nil, err
|
||||
}
|
||||
|
||||
return n, &hdr, nil
|
||||
}
|
||||
|
||||
// discardInput reads n bytes from reader r in chunks and discards the read
|
||||
// bytes. This is used to skip payloads when various errors occur and helps
|
||||
// prevent rogue nodes from causing massive memory allocation through forging
|
||||
// header length.
|
||||
func discardInput(r io.Reader, n uint32) {
|
||||
maxSize := uint32(10 * 1024) //10k at a time
|
||||
numReads := n / maxSize
|
||||
bytesRemaining := n % maxSize
|
||||
if n > 0 {
|
||||
buf := make([]byte, maxSize)
|
||||
for i := uint32(0); i < numReads; i++ {
|
||||
io.ReadFull(r, buf)
|
||||
}
|
||||
}
|
||||
if bytesRemaining > 0 {
|
||||
buf := make([]byte, bytesRemaining)
|
||||
io.ReadFull(r, buf)
|
||||
}
|
||||
}
|
||||
|
||||
func WriteMessage(w io.Writer, msg Message, pver uint32, btcnet wire.BitcoinNet) (int, error) {
|
||||
totalBytes := 0
|
||||
|
||||
cmd := msg.Command()
|
||||
|
||||
//Encode the message payload
|
||||
var bw bytes.Buffer
|
||||
err := msg.Encode(&bw, pver)
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
payload := bw.Bytes()
|
||||
lenp := len(payload)
|
||||
|
||||
//Enforce maximum overall message payload
|
||||
if lenp > MaxMessagePayload {
|
||||
return totalBytes, fmt.Errorf("message payload is too large - encoded %d bytes, but maximum message payload is %d bytes", lenp, MaxMessagePayload)
|
||||
}
|
||||
|
||||
//Enforce maximum message payload on the message type
|
||||
mpl := msg.MaxPayloadLength(pver)
|
||||
if uint32(lenp) > mpl {
|
||||
return totalBytes, fmt.Errorf("message payload is too large - encoded %d bytes, but maximum message payload of type %x is %d bytes", lenp, cmd, mpl)
|
||||
}
|
||||
|
||||
//Create header for the message
|
||||
hdr := messageHeader{}
|
||||
hdr.magic = btcnet
|
||||
hdr.command = cmd
|
||||
hdr.length = uint32(lenp)
|
||||
//Checksum goes here if needed... and also add to writeElements
|
||||
|
||||
// Encode the header for the message. This is done to a buffer
|
||||
// rather than directly to the writer since writeElements doesn't
|
||||
// return the number of bytes written.
|
||||
hw := bytes.NewBuffer(make([]byte, 0, MessageHeaderSize))
|
||||
writeElements(hw, false, hdr.magic, hdr.command, hdr.length)
|
||||
|
||||
//Write header
|
||||
n, err := w.Write(hw.Bytes())
|
||||
totalBytes += n
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
|
||||
//Write payload
|
||||
n, err = w.Write(payload)
|
||||
totalBytes += n
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
|
||||
return totalBytes, nil
|
||||
}
|
||||
|
||||
func ReadMessage(r io.Reader, pver uint32, btcnet wire.BitcoinNet) (int, Message, []byte, error) {
|
||||
totalBytes := 0
|
||||
n, hdr, err := readMessageHeader(r)
|
||||
totalBytes += n
|
||||
if err != nil {
|
||||
return totalBytes, nil, nil, err
|
||||
}
|
||||
|
||||
//Enforce maximum message payload
|
||||
if hdr.length > MaxMessagePayload {
|
||||
return totalBytes, nil, nil, fmt.Errorf("message payload is too large - header indicates %d bytes, but max message payload is %d bytes.", hdr.length, MaxMessagePayload)
|
||||
}
|
||||
|
||||
//Check for messages in the wrong bitcoin network
|
||||
if hdr.magic != btcnet {
|
||||
discardInput(r, hdr.length)
|
||||
return totalBytes, nil, nil, fmt.Errorf("message from other network [%v]", hdr.magic)
|
||||
}
|
||||
|
||||
//Create struct of appropriate message type based on the command
|
||||
command := hdr.command
|
||||
msg, err := makeEmptyMessage(command)
|
||||
if err != nil {
|
||||
discardInput(r, hdr.length)
|
||||
return totalBytes, nil, nil, fmt.Errorf("ReadMessage %s", err.Error())
|
||||
}
|
||||
|
||||
//Check for maximum length based on the message type
|
||||
mpl := msg.MaxPayloadLength(pver)
|
||||
if hdr.length > mpl {
|
||||
discardInput(r, hdr.length)
|
||||
return totalBytes, nil, nil, fmt.Errorf("payload exceeds max length. indicates %v bytes, but max of message type %v is %v.", hdr.length, command, mpl)
|
||||
}
|
||||
|
||||
//Read payload
|
||||
payload := make([]byte, hdr.length)
|
||||
n, err = io.ReadFull(r, payload)
|
||||
totalBytes += n
|
||||
if err != nil {
|
||||
return totalBytes, nil, nil, err
|
||||
}
|
||||
|
||||
//If we want to use checksums, test it here
|
||||
|
||||
//Unmarshal message
|
||||
pr := bytes.NewBuffer(payload)
|
||||
err = msg.Decode(pr, pver)
|
||||
if err != nil {
|
||||
return totalBytes, nil, nil, err
|
||||
}
|
||||
|
||||
//Validate the data
|
||||
msg.Validate()
|
||||
if err != nil {
|
||||
return totalBytes, nil, nil, err
|
||||
}
|
||||
|
||||
//We're good!
|
||||
return totalBytes, msg, payload, nil
|
||||
}
|
Reference in New Issue
Block a user