diff --git a/connection.go b/connection.go index e858d3b..a54fdd7 100644 --- a/connection.go +++ b/connection.go @@ -26,6 +26,7 @@ type Connection struct { reader *wsutil.Reader flateWriter *wsflate.Writer writer *wsutil.Writer + msgState *wsflate.MessageState mutex sync.Mutex } @@ -67,6 +68,7 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) ( Source: conn, State: state, OnIntermediate: controlHandler, + CheckUTF8: false, Extensions: []wsutil.RecvExtension{ &msgState, }, @@ -98,6 +100,7 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) ( flateReader: flateReader, reader: reader, flateWriter: flateWriter, + msgState: &msgState, writer: writer, }, nil } @@ -106,7 +109,7 @@ func (c *Connection) WriteJSON(v any) error { c.mutex.Lock() defer c.mutex.Unlock() - if c.enableCompression { + if c.enableCompression && c.msgState.IsCompressed() { c.flateWriter.Reset(c.writer) if err := json.NewEncoder(c.flateWriter).Encode(v); err != nil { return fmt.Errorf("failed to encode json: %w", err) @@ -141,7 +144,7 @@ func (c *Connection) WriteMessage(data []byte) error { c.mutex.Lock() defer c.mutex.Unlock() - if c.enableCompression { + if c.msgState.IsCompressed() && c.enableCompression { c.flateWriter.Reset(c.writer) if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil { return fmt.Errorf("failed to write message: %w", err) @@ -194,7 +197,7 @@ func (c *Connection) ReadMessage(ctx context.Context) ([]byte, error) { } buf := new(bytes.Buffer) - if c.enableCompression { + if c.msgState.IsCompressed() && c.enableCompression { c.flateReader.Reset(c.reader) if _, err := io.Copy(buf, c.flateReader); err != nil { return nil, fmt.Errorf("failed to read message: %w", err)