diff --git a/watchtower/wtwire/message.go b/watchtower/wtwire/message.go new file mode 100644 index 000000000..5daaa6de0 --- /dev/null +++ b/watchtower/wtwire/message.go @@ -0,0 +1,179 @@ +package wtwire + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/lnwire" +) + +// MaxMessagePayload is the maximum bytes a message can be regardless of other +// individual limits imposed by messages themselves. +const MaxMessagePayload = 65535 // 65KB + +// MessageType is the unique 2 byte big-endian integer that indicates the type +// of message on the wire. All messages have a very simple header which +// consists simply of 2-byte message type. We omit a length field, and checksum +// as the Watchtower Protocol is intended to be encapsulated within a +// confidential+authenticated cryptographic messaging protocol. +type MessageType uint16 + +// The currently defined message types within this current version of the +// Watchtower protocol. +const ( + // MsgInit identifies an encoded Init message. + MsgInit MessageType = 300 + + // MsgError identifies an encoded Error message. + MsgError = 301 + + // MsgCreateSession identifies an encoded CreateSession message. + MsgCreateSession MessageType = 302 + + // MsgCreateSessionReply identifies an encoded CreateSessionReply message. + MsgCreateSessionReply MessageType = 303 + + // MsgStateUpdate identifies an encoded StateUpdate message. + MsgStateUpdate MessageType = 304 + + // MsgStateUpdateReply identifies an encoded StateUpdateReply message. + MsgStateUpdateReply MessageType = 305 +) + +// String returns a human readable description of the message type. +func (m MessageType) String() string { + switch m { + case MsgInit: + return "Init" + case MsgCreateSession: + return "MsgCreateSession" + case MsgCreateSessionReply: + return "MsgCreateSessionReply" + case MsgStateUpdate: + return "MsgStateUpdate" + case MsgStateUpdateReply: + return "MsgStateUpdateReply" + case MsgError: + return "Error" + default: + return "" + } +} + +// Serializable is an interface which defines a lightning wire serializable +// object. +type Serializable = lnwire.Serializable + +// Message is an interface that defines a lightning wire protocol message. The +// interface is general in order to allow implementing types full control over +// the representation of its data. +type Message interface { + Serializable + + // MsgType returns a MessageType that uniquely identifies the message to + // be encoded. + MsgType() MessageType + + // MaxMessagePayload is the maximum serialized length that a particular + // message type can take. + MaxPayloadLength(uint32) uint32 +} + +// makeEmptyMessage creates a new empty message of the proper concrete type +// based on the passed message type. +func makeEmptyMessage(msgType MessageType) (Message, error) { + var msg Message + + switch msgType { + case MsgInit: + msg = &Init{&lnwire.Init{}} + case MsgCreateSession: + msg = &CreateSession{} + case MsgCreateSessionReply: + msg = &CreateSessionReply{} + case MsgStateUpdate: + msg = &StateUpdate{} + case MsgStateUpdateReply: + msg = &StateUpdateReply{} + case MsgError: + msg = &Error{} + default: + return nil, fmt.Errorf("unknown message type [%d]", msgType) + } + + return msg, nil +} + +// WriteMessage writes a lightning Message to w including the necessary header +// information and returns the number of bytes written. +func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) { + totalBytes := 0 + + // Encode the message payload itself into a temporary buffer. + // TODO(roasbeef): create buffer pool + var bw bytes.Buffer + if err := msg.Encode(&bw, pver); err != nil { + return totalBytes, err + } + payload := bw.Bytes() + lenp := len(payload) + + // Enforce maximum overall message payload. + if lenp > MaxMessagePayload { + return totalBytes, fmt.Errorf("message payload is too large - "+ + "encoded %d bytes, but maximum message payload is %d bytes", + lenp, MaxMessagePayload) + } + + // Enforce maximum message payload on the message type. + mpl := msg.MaxPayloadLength(pver) + if uint32(lenp) > mpl { + return totalBytes, fmt.Errorf("message payload is too large - "+ + "encoded %d bytes, but maximum message payload of "+ + "type %v is %d bytes", lenp, msg.MsgType(), mpl) + } + + // With the initial sanity checks complete, we'll now write out the + // message type itself. + var mType [2]byte + binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType())) + n, err := w.Write(mType[:]) + totalBytes += n + if err != nil { + return totalBytes, err + } + + // With the message type written, we'll now write out the raw payload + // itself. + n, err = w.Write(payload) + totalBytes += n + + return totalBytes, err +} + +// ReadMessage reads, validates, and parses the next Watchtower message from r +// for the provided protocol version. +func ReadMessage(r io.Reader, pver uint32) (Message, error) { + // First, we'll read out the first two bytes of the message so we can + // create the proper empty message. + var mType [2]byte + if _, err := io.ReadFull(r, mType[:]); err != nil { + return nil, err + } + + msgType := MessageType(binary.BigEndian.Uint16(mType[:])) + + // Now that we know the target message type, we can create the proper + // empty message type and decode the message into it. + msg, err := makeEmptyMessage(msgType) + if err != nil { + return nil, err + } + if err := msg.Decode(r, pver); err != nil { + return nil, err + } + + return msg, nil +}