separate msgState for reader/writer

This commit is contained in:
Yasuhiro Matsumoto 2023-11-06 23:07:07 +09:00 committed by fiatjaf_
parent 4fccda5549
commit d6baa2f74c

View File

@ -24,7 +24,8 @@ type Connection struct {
reader *wsutil.Reader
flateWriter *wsflate.Writer
writer *wsutil.Writer
msgState *wsflate.MessageState
msgStateR *wsflate.MessageState
msgStateW *wsflate.MessageState
}
func NewConnection(ctx context.Context, url string, requestHeader http.Header) (*Connection, error) {
@ -51,9 +52,9 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
// reader
var flateReader *wsflate.Reader
var msgState wsflate.MessageState
var msgStateR wsflate.MessageState
if enableCompression {
msgState.SetCompressed(true)
msgStateR.SetCompressed(true)
flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
@ -67,13 +68,16 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
OnIntermediate: controlHandler,
CheckUTF8: false,
Extensions: []wsutil.RecvExtension{
&msgState,
&msgStateR,
},
}
// writer
var flateWriter *wsflate.Writer
var msgStateW wsflate.MessageState
if enableCompression {
msgStateW.SetCompressed(true)
flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
fw, err := flate.NewWriter(w, 4)
if err != nil {
@ -84,7 +88,7 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
}
writer := wsutil.NewWriter(conn, state, ws.OpText)
writer.SetExtensions(&msgState)
writer.SetExtensions(&msgStateW)
return &Connection{
conn: conn,
@ -92,14 +96,15 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
controlHandler: controlHandler,
flateReader: flateReader,
reader: reader,
msgStateR: &msgStateR,
flateWriter: flateWriter,
msgState: &msgState,
writer: writer,
msgStateW: &msgStateW,
}, nil
}
func (c *Connection) WriteMessage(data []byte) error {
if c.msgState.IsCompressed() && c.enableCompression {
if c.msgStateW.IsCompressed() && c.enableCompression {
c.flateWriter.Reset(c.writer)
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
@ -149,7 +154,7 @@ func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
}
}
if c.msgState.IsCompressed() && c.enableCompression {
if c.msgStateR.IsCompressed() && c.enableCompression {
c.flateReader.Reset(c.reader)
if _, err := io.Copy(buf, c.flateReader); err != nil {
return fmt.Errorf("failed to read message: %w", err)