check if messages are compressed on receive

This commit is contained in:
Marc Tarnutzer 2023-05-06 01:48:01 +02:00
parent c86e907142
commit 69b9d82bb1

View File

@ -26,6 +26,7 @@ type Connection struct {
reader *wsutil.Reader
flateWriter *wsflate.Writer
writer *wsutil.Writer
msgState *wsflate.MessageState
mutex sync.Mutex
}
@ -67,6 +68,7 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
Source: conn,
State: state,
OnIntermediate: controlHandler,
CheckUTF8: false,
Extensions: []wsutil.RecvExtension{
&msgState,
},
@ -98,6 +100,7 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
flateReader: flateReader,
reader: reader,
flateWriter: flateWriter,
msgState: &msgState,
writer: writer,
}, nil
}
@ -106,7 +109,7 @@ func (c *Connection) WriteJSON(v any) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.enableCompression {
if c.enableCompression && c.msgState.IsCompressed() {
c.flateWriter.Reset(c.writer)
if err := json.NewEncoder(c.flateWriter).Encode(v); err != nil {
return fmt.Errorf("failed to encode json: %w", err)
@ -141,7 +144,7 @@ func (c *Connection) WriteMessage(data []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.enableCompression {
if c.msgState.IsCompressed() && c.enableCompression {
c.flateWriter.Reset(c.writer)
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
@ -194,7 +197,7 @@ func (c *Connection) ReadMessage(ctx context.Context) ([]byte, error) {
}
buf := new(bytes.Buffer)
if c.enableCompression {
if c.msgState.IsCompressed() && c.enableCompression {
c.flateReader.Reset(c.reader)
if _, err := io.Copy(buf, c.flateReader); err != nil {
return nil, fmt.Errorf("failed to read message: %w", err)