separate msgState for reader/writer

This commit is contained in:
Yasuhiro Matsumoto 2023-11-06 23:07:07 +09:00 committed by fiatjaf_
parent 4fccda5549
commit d6baa2f74c

View File

@ -24,7 +24,8 @@ type Connection struct {
reader *wsutil.Reader reader *wsutil.Reader
flateWriter *wsflate.Writer flateWriter *wsflate.Writer
writer *wsutil.Writer writer *wsutil.Writer
msgState *wsflate.MessageState msgStateR *wsflate.MessageState
msgStateW *wsflate.MessageState
} }
func NewConnection(ctx context.Context, url string, requestHeader http.Header) (*Connection, error) { func NewConnection(ctx context.Context, url string, requestHeader http.Header) (*Connection, error) {
@ -51,9 +52,9 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
// reader // reader
var flateReader *wsflate.Reader var flateReader *wsflate.Reader
var msgState wsflate.MessageState var msgStateR wsflate.MessageState
if enableCompression { if enableCompression {
msgState.SetCompressed(true) msgStateR.SetCompressed(true)
flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor { flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r) return flate.NewReader(r)
@ -67,13 +68,16 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
OnIntermediate: controlHandler, OnIntermediate: controlHandler,
CheckUTF8: false, CheckUTF8: false,
Extensions: []wsutil.RecvExtension{ Extensions: []wsutil.RecvExtension{
&msgState, &msgStateR,
}, },
} }
// writer // writer
var flateWriter *wsflate.Writer var flateWriter *wsflate.Writer
var msgStateW wsflate.MessageState
if enableCompression { if enableCompression {
msgStateW.SetCompressed(true)
flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor { flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
fw, err := flate.NewWriter(w, 4) fw, err := flate.NewWriter(w, 4)
if err != nil { if err != nil {
@ -84,7 +88,7 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
} }
writer := wsutil.NewWriter(conn, state, ws.OpText) writer := wsutil.NewWriter(conn, state, ws.OpText)
writer.SetExtensions(&msgState) writer.SetExtensions(&msgStateW)
return &Connection{ return &Connection{
conn: conn, conn: conn,
@ -92,14 +96,15 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
controlHandler: controlHandler, controlHandler: controlHandler,
flateReader: flateReader, flateReader: flateReader,
reader: reader, reader: reader,
msgStateR: &msgStateR,
flateWriter: flateWriter, flateWriter: flateWriter,
msgState: &msgState,
writer: writer, writer: writer,
msgStateW: &msgStateW,
}, nil }, nil
} }
func (c *Connection) WriteMessage(data []byte) error { func (c *Connection) WriteMessage(data []byte) error {
if c.msgState.IsCompressed() && c.enableCompression { if c.msgStateW.IsCompressed() && c.enableCompression {
c.flateWriter.Reset(c.writer) c.flateWriter.Reset(c.writer)
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil { if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err) return fmt.Errorf("failed to write message: %w", err)
@ -149,7 +154,7 @@ func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
} }
} }
if c.msgState.IsCompressed() && c.enableCompression { if c.msgStateR.IsCompressed() && c.enableCompression {
c.flateReader.Reset(c.reader) c.flateReader.Reset(c.reader)
if _, err := io.Copy(buf, c.flateReader); err != nil { if _, err := io.Copy(buf, c.flateReader); err != nil {
return fmt.Errorf("failed to read message: %w", err) return fmt.Errorf("failed to read message: %w", err)