diff --git a/brontide/conn.go b/brontide/conn.go index e2f339eb2..4c62bc9ec 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -117,14 +117,13 @@ func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Write(b []byte) (n int, err error) { // If the message doesn't require any chunking, then we can go ahead // with a single write. - if len(b)+macSize <= math.MaxUint16 { + if len(b) <= math.MaxUint16 { return len(b), c.noise.WriteMessage(c.conn, b) } // If we need to split the message into fragments, then we'll write - // chunks which maximize usage of the available payload. To do so, we - // subtract the added overhead of the MAC at the end of the message. - chunkSize := math.MaxUint16 - macSize + // chunks which maximize usage of the available payload. + chunkSize := math.MaxUint16 bytesToWrite := len(b) bytesWritten := 0 diff --git a/brontide/noise.go b/brontide/noise.go index 9fdc794d9..fc925ba81 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -641,12 +641,13 @@ func (b *BrontideMachine) WriteMessage(w io.Writer, p []byte) error { // The total length of each message payload including the MAC size // payload exceed the largest number encodable within a 16-bit unsigned // integer. - if len(p)+macSize > math.MaxUint16 { + if len(p) > math.MaxUint16 { return ErrMaxMessageLengthExceeded } - // The full length of the packet includes the 16 byte MAC. - fullLength := uint16(len(p) + macSize) + // The full length of the packet is only the packet length, and does + // NOT include the MAC. + fullLength := uint16(len(p)) var pktLen [2]byte binary.BigEndian.PutUint16(pktLen[:], fullLength) @@ -684,11 +685,11 @@ func (b *BrontideMachine) ReadMessage(r io.Reader) ([]byte, error) { // Next, using the length read from the packet header, read the // encrypted packet itself. - pktLen := binary.BigEndian.Uint16(pktLenBytes) - ciperText := make([]byte, pktLen) - if _, err := io.ReadFull(r, ciperText[:]); err != nil { + pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize + cipherText := make([]byte, pktLen) + if _, err := io.ReadFull(r, cipherText[:]); err != nil { return nil, err } - return b.recvCipher.Decrypt(nil, nil, ciperText) + return b.recvCipher.Decrypt(nil, nil, cipherText) } diff --git a/brontide/noise_test.go b/brontide/noise_test.go index 5ea6f14cb..69aa90fb7 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -63,6 +63,7 @@ func establishTestConnection() (net.Conn, net.Conn, error) { return localConn, remoteConn, nil } + func TestConnectionCorrectness(t *testing.T) { // Create a test connection, grabbing either side of the connection // into local variables. If the initial crypto handshake fails, then @@ -130,9 +131,9 @@ func TestMaxPayloadLength(t *testing.T) { "should have been rejected") } - // Generate another payload which with the MAC acounted for, should be - // accepted as a valid payload. - payloadToAccept := make([]byte, math.MaxUint16-macSize) + // Generate another payload which should be accepted as a valid + // payload. + payloadToAccept := make([]byte, math.MaxUint16-1) if err := b.WriteMessage(&buf, payloadToAccept); err != nil { t.Fatalf("write for payload was rejected, should have been " + "accepted") @@ -140,7 +141,7 @@ func TestMaxPayloadLength(t *testing.T) { // Generate a final payload which is juuust over the max payload length // when the MAC is accounted for. - payloadToReject = make([]byte, math.MaxUint16-macSize+1) + payloadToReject = make([]byte, math.MaxUint16+1) // This payload should be rejected. err = b.WriteMessage(&buf, payloadToReject) @@ -171,7 +172,7 @@ func TestWriteMessageChunking(t *testing.T) { go func() { bytesWritten, err := localConn.Write(largeMessage) if err != nil { - t.Fatalf("unable to write message") + t.Fatalf("unable to write message: %v", err) } // The entire message should have been written out to the remote @@ -186,7 +187,7 @@ func TestWriteMessageChunking(t *testing.T) { // Attempt to read the entirety of the message generated above. buf := make([]byte, len(largeMessage)) if _, err := io.ReadFull(remoteConn, buf); err != nil { - t.Fatalf("unable to read message") + t.Fatalf("unable to read message: %v", err) } wg.Wait()