use coder/websocket for everything, get rid of gobwas.

supposedly it is faster, and anyway it's better to use it since we're already using it for wasm/js.

(previously named nhooyr/websocket).
This commit is contained in:
fiatjaf 2025-01-03 01:15:12 -03:00
parent b33cfb19fa
commit defc349e57
9 changed files with 101 additions and 243 deletions

View File

@ -150,7 +150,7 @@ To use it, use `-tags=libsecp256k1` whenever you're compiling your program that
Install [wasmbrowsertest](https://github.com/agnivade/wasmbrowsertest), then run tests:
```sh
TEST_RELAY_URL=<relay_url> GOOS=js GOARCH=wasm go test -short ./...
GOOS=js GOARCH=wasm go test -short ./...
```
## Warning: risk of goroutine bloat (if used incorrectly)

View File

@ -1,187 +1,53 @@
//go:build !js
package nostr
import (
"bytes"
"compress/flate"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
"github.com/gobwas/ws/wsutil"
ws "github.com/coder/websocket"
)
type Connection struct {
conn net.Conn
enableCompression bool
controlHandler wsutil.FrameHandlerFunc
flateReader *wsflate.Reader
reader *wsutil.Reader
flateWriter *wsflate.Writer
writer *wsutil.Writer
msgStateR *wsflate.MessageState
msgStateW *wsflate.MessageState
conn *ws.Conn
}
func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) {
dialer := ws.Dialer{
Header: ws.HandshakeHeaderHTTP(requestHeader),
Extensions: []httphead.Option{
wsflate.DefaultParameters.Option(),
},
TLSConfig: tlsConfig,
}
conn, _, hs, err := dialer.Dial(ctx, url)
c, _, err := ws.Dial(ctx, url, getConnectionOptions(requestHeader, tlsConfig))
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
return nil, err
}
enableCompression := false
state := ws.StateClientSide
for _, extension := range hs.Extensions {
if string(extension.Name) == wsflate.ExtensionName {
enableCompression = true
state |= ws.StateExtended
break
}
}
// reader
var flateReader *wsflate.Reader
var msgStateR wsflate.MessageState
if enableCompression {
msgStateR.SetCompressed(true)
flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
})
}
controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide)
reader := &wsutil.Reader{
Source: conn,
State: state,
OnIntermediate: controlHandler,
CheckUTF8: false,
Extensions: []wsutil.RecvExtension{
&msgStateR,
},
}
// writer
var flateWriter *wsflate.Writer
var msgStateW wsflate.MessageState
if enableCompression {
msgStateW.SetCompressed(true)
flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
fw, err := flate.NewWriter(w, 4)
if err != nil {
InfoLogger.Printf("Failed to create flate writer: %v", err)
}
return fw
})
}
writer := wsutil.NewWriter(conn, state, ws.OpText)
writer.SetExtensions(&msgStateW)
return &Connection{
conn: conn,
enableCompression: enableCompression,
controlHandler: controlHandler,
flateReader: flateReader,
reader: reader,
msgStateR: &msgStateR,
flateWriter: flateWriter,
writer: writer,
msgStateW: &msgStateW,
conn: c,
}, nil
}
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
select {
case <-ctx.Done():
return errors.New("context canceled")
default:
}
if c.msgStateW.IsCompressed() && c.enableCompression {
c.flateWriter.Reset(c.writer)
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
if err := c.flateWriter.Close(); err != nil {
return fmt.Errorf("failed to close flate writer: %w", err)
}
} else {
if _, err := io.Copy(c.writer, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
}
if err := c.writer.Flush(); err != nil {
return fmt.Errorf("failed to flush writer: %w", err)
if err := c.conn.Write(ctx, ws.MessageText, data); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
return nil
}
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
for {
select {
case <-ctx.Done():
return errors.New("context canceled")
default:
}
h, err := c.reader.NextFrame()
if err != nil {
c.conn.Close()
return fmt.Errorf("failed to advance frame: %w", err)
}
if h.OpCode.IsControl() {
if err := c.controlHandler(h, c.reader); err != nil {
return fmt.Errorf("failed to handle control frame: %w", err)
}
} else if h.OpCode == ws.OpBinary ||
h.OpCode == ws.OpText {
break
}
if err := c.reader.Discard(); err != nil {
return fmt.Errorf("failed to discard: %w", err)
}
_, reader, err := c.conn.Reader(ctx)
if err != nil {
return fmt.Errorf("failed to get reader: %w", err)
}
if c.msgStateR.IsCompressed() && c.enableCompression {
c.flateReader.Reset(c.reader)
if _, err := io.Copy(buf, c.flateReader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
} else {
if _, err := io.Copy(buf, c.reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
if _, err := io.Copy(buf, reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
return nil
}
func (c *Connection) Close() error {
return c.conn.Close()
return c.conn.Close(ws.StatusNormalClosure, "")
}
func (c *Connection) Ping(ctx context.Context) error {
return wsutil.WriteClientMessage(c.conn, ws.OpPing, nil)
return c.conn.Ping(ctx)
}

View File

@ -1,55 +0,0 @@
//go:build js
package nostr
import (
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
ws "github.com/coder/websocket"
)
type Connection struct {
conn *ws.Conn
}
func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) {
c, _, err := ws.Dial(ctx, url, nil)
if err != nil {
return nil, err
}
return &Connection{
conn: c,
}, nil
}
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
if err := c.conn.Write(ctx, ws.MessageBinary, data); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
return nil
}
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
_, reader, err := c.conn.Reader(ctx)
if err != nil {
return fmt.Errorf("failed to get reader: %w", err)
}
if _, err := io.Copy(buf, reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
return nil
}
func (c *Connection) Close() error {
return c.conn.Close(ws.StatusNormalClosure, "")
}
func (c *Connection) Ping(ctx context.Context) error {
return c.conn.Ping(ctx)
}

34
connection_options.go Normal file
View File

@ -0,0 +1,34 @@
//go:build !js
package nostr
import (
"crypto/tls"
"net/http"
"net/textproto"
ws "github.com/coder/websocket"
)
var defaultConnectionOptions = &ws.DialOptions{
CompressionMode: ws.CompressionContextTakeover,
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
},
}
func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions {
if requestHeader == nil && tlsConfig == nil {
return defaultConnectionOptions
}
return &ws.DialOptions{
HTTPHeader: requestHeader,
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
}
}

15
connection_options_js.go Normal file
View File

@ -0,0 +1,15 @@
package nostr
import (
"crypto/tls"
"net/http"
ws "github.com/coder/websocket"
)
var emptyOptions = ws.DialOptions{}
func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions {
// on javascript we ignore everything because there is nothing else we can do
return &emptyOptions
}

29
pool.go
View File

@ -5,6 +5,7 @@ import (
"fmt"
"log"
"math"
"net/http"
"slices"
"strings"
"sync"
@ -31,7 +32,7 @@ type SimplePool struct {
// custom things not often used
penaltyBoxMu sync.Mutex
penaltyBox map[string][2]float64
userAgent string
relayOptions []RelayOption
}
type DirectedFilters struct {
@ -69,6 +70,17 @@ func NewSimplePool(ctx context.Context, opts ...PoolOption) *SimplePool {
return pool
}
// WithRelayOptions sets options that will be used on every relay instance created by this pool.
func WithRelayOptions(ropts ...RelayOption) withRelayOptionsOpt {
return ropts
}
type withRelayOptionsOpt []RelayOption
func (h withRelayOptionsOpt) ApplyPoolOption(pool *SimplePool) {
pool.relayOptions = h
}
// WithAuthHandler must be a function that signs the auth event when called.
// it will be called whenever any relay in the pool returns a `CLOSED` message
// with the "auth-required:" prefix, only once for each relay
@ -129,20 +141,11 @@ func (h WithAuthorKindQueryMiddleware) ApplyPoolOption(pool *SimplePool) {
pool.queryMiddleware = h
}
// WithUserAgent sets the user-agent header for all relay connections in the pool.
func WithUserAgent(userAgent string) withUserAgentOpt { return withUserAgentOpt(userAgent) }
type withUserAgentOpt string
func (h withUserAgentOpt) ApplyPoolOption(pool *SimplePool) {
pool.userAgent = string(h)
}
var (
_ PoolOption = (WithAuthHandler)(nil)
_ PoolOption = (WithEventMiddleware)(nil)
_ PoolOption = WithPenaltyBox()
_ PoolOption = WithUserAgent("")
_ PoolOption = WithRelayOptions(WithRequestHeader(http.Header{}))
)
func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
@ -169,9 +172,7 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
ctx, cancel := context.WithTimeout(pool.Context, time.Second*15)
defer cancel()
relay = NewRelay(context.Background(), url)
relay.RequestHeader.Set("User-Agent", pool.userAgent)
relay = NewRelay(context.Background(), url, pool.relayOptions...)
if err := relay.Connect(ctx); err != nil {
if pool.penaltyBox != nil {
// putting relay in penalty box

View File

@ -23,7 +23,7 @@ type Relay struct {
closeMutex sync.Mutex
URL string
RequestHeader http.Header // e.g. for origin header
requestHeader http.Header // e.g. for origin header
Connection *Connection
Subscriptions *xsync.MapOf[int64, *Subscription]
@ -60,7 +60,7 @@ func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Relay {
okCallbacks: xsync.NewMapOf[string, func(bool, string)](),
writeQueue: make(chan writeRequest),
subscriptionChannelCloseQueue: make(chan *Subscription),
RequestHeader: make(http.Header, 1),
requestHeader: nil,
}
for _, opt := range opts {
@ -88,6 +88,7 @@ type RelayOption interface {
var (
_ RelayOption = (WithNoticeHandler)(nil)
_ RelayOption = (WithCustomHandler)(nil)
_ RelayOption = (WithRequestHeader)(nil)
)
// WithNoticeHandler just takes notices and is expected to do something with them.
@ -106,6 +107,13 @@ func (ch WithCustomHandler) ApplyRelayOption(r *Relay) {
r.customHandler = ch
}
// WithRequestHeader sets the HTTP request header of the websocket preflight request.
type WithRequestHeader http.Header
func (ch WithRequestHeader) ApplyRelayOption(r *Relay) {
r.requestHeader = http.Header(ch)
}
// String just returns the relay URL.
func (r *Relay) String() string {
return r.URL
@ -146,11 +154,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
defer cancel()
}
if r.RequestHeader.Get("User-Agent") == "" {
r.RequestHeader.Set("User-Agent", "github.com/nbd-wtf/go-nostr")
}
conn, err := NewConnection(ctx, r.URL, r.RequestHeader, tlsConfig)
conn, err := NewConnection(ctx, r.URL, r.requestHeader, tlsConfig)
if err != nil {
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
}

View File

@ -12,40 +12,33 @@ import (
"github.com/stretchr/testify/require"
)
func TestConnectContext(t *testing.T) {
var testRelayURL = func() string {
url := os.Getenv("TEST_RELAY_URL")
if url == "" {
t.Fatal("please set the environment: $TEST_RELAY_URL")
if url != "" {
return url
}
return "wss://nos.lol"
}()
func TestConnectContext(t *testing.T) {
// relay client
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
r, err := RelayConnect(ctx, url)
r, err := RelayConnect(ctx, testRelayURL)
assert.NoError(t, err)
defer r.Close()
}
func TestConnectContextCanceled(t *testing.T) {
url := os.Getenv("TEST_RELAY_URL")
if url == "" {
t.Fatal("please set the environment: $TEST_RELAY_URL")
}
// relay client
ctx, cancel := context.WithCancel(context.Background())
cancel() // make ctx expired
_, err := RelayConnect(ctx, url)
_, err := RelayConnect(ctx, testRelayURL)
assert.ErrorIs(t, err, context.Canceled)
}
func TestPublish(t *testing.T) {
url := os.Getenv("TEST_RELAY_URL")
if url == "" {
t.Fatal("please set the environment: $TEST_RELAY_URL")
}
// test note to be sent over websocket
priv, pub := makeKeyPair(t)
textNote := Event{
@ -59,7 +52,7 @@ func TestPublish(t *testing.T) {
assert.NoError(t, err)
// connect a client and send the text note
rl := mustRelayConnect(t, url)
rl := mustRelayConnect(t, testRelayURL)
err = rl.Publish(context.Background(), textNote)
assert.NoError(t, err)
}

View File

@ -149,8 +149,8 @@ func TestConnectWithOrigin(t *testing.T) {
defer ws.Close()
// relay client
r := NewRelay(context.Background(), NormalizeURL(ws.URL))
r.RequestHeader = http.Header{"origin": {"https://example.com"}}
r := NewRelay(context.Background(), NormalizeURL(ws.URL),
WithRequestHeader(http.Header{"origin": {"https://example.com"}}))
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := r.Connect(ctx)