use coder/websocket for everything, get rid of gobwas.

supposedly it is faster, and anyway it's better to use it since we're already using it for wasm/js.

(previously named nhooyr/websocket).
This commit is contained in:
fiatjaf 2025-01-03 01:15:12 -03:00
parent b33cfb19fa
commit defc349e57
9 changed files with 101 additions and 243 deletions

View File

@ -150,7 +150,7 @@ To use it, use `-tags=libsecp256k1` whenever you're compiling your program that
Install [wasmbrowsertest](https://github.com/agnivade/wasmbrowsertest), then run tests: Install [wasmbrowsertest](https://github.com/agnivade/wasmbrowsertest), then run tests:
```sh ```sh
TEST_RELAY_URL=<relay_url> GOOS=js GOARCH=wasm go test -short ./... GOOS=js GOARCH=wasm go test -short ./...
``` ```
## Warning: risk of goroutine bloat (if used incorrectly) ## Warning: risk of goroutine bloat (if used incorrectly)

View File

@ -1,187 +1,53 @@
//go:build !js
package nostr package nostr
import ( import (
"bytes"
"compress/flate"
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"github.com/gobwas/httphead" ws "github.com/coder/websocket"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
"github.com/gobwas/ws/wsutil"
) )
type Connection struct { type Connection struct {
conn net.Conn conn *ws.Conn
enableCompression bool
controlHandler wsutil.FrameHandlerFunc
flateReader *wsflate.Reader
reader *wsutil.Reader
flateWriter *wsflate.Writer
writer *wsutil.Writer
msgStateR *wsflate.MessageState
msgStateW *wsflate.MessageState
} }
func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) { func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) {
dialer := ws.Dialer{ c, _, err := ws.Dial(ctx, url, getConnectionOptions(requestHeader, tlsConfig))
Header: ws.HandshakeHeaderHTTP(requestHeader),
Extensions: []httphead.Option{
wsflate.DefaultParameters.Option(),
},
TLSConfig: tlsConfig,
}
conn, _, hs, err := dialer.Dial(ctx, url)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err) return nil, err
} }
enableCompression := false
state := ws.StateClientSide
for _, extension := range hs.Extensions {
if string(extension.Name) == wsflate.ExtensionName {
enableCompression = true
state |= ws.StateExtended
break
}
}
// reader
var flateReader *wsflate.Reader
var msgStateR wsflate.MessageState
if enableCompression {
msgStateR.SetCompressed(true)
flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
})
}
controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide)
reader := &wsutil.Reader{
Source: conn,
State: state,
OnIntermediate: controlHandler,
CheckUTF8: false,
Extensions: []wsutil.RecvExtension{
&msgStateR,
},
}
// writer
var flateWriter *wsflate.Writer
var msgStateW wsflate.MessageState
if enableCompression {
msgStateW.SetCompressed(true)
flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
fw, err := flate.NewWriter(w, 4)
if err != nil {
InfoLogger.Printf("Failed to create flate writer: %v", err)
}
return fw
})
}
writer := wsutil.NewWriter(conn, state, ws.OpText)
writer.SetExtensions(&msgStateW)
return &Connection{ return &Connection{
conn: conn, conn: c,
enableCompression: enableCompression,
controlHandler: controlHandler,
flateReader: flateReader,
reader: reader,
msgStateR: &msgStateR,
flateWriter: flateWriter,
writer: writer,
msgStateW: &msgStateW,
}, nil }, nil
} }
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error { func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
select { if err := c.conn.Write(ctx, ws.MessageText, data); err != nil {
case <-ctx.Done(): return fmt.Errorf("failed to write message: %w", err)
return errors.New("context canceled")
default:
}
if c.msgStateW.IsCompressed() && c.enableCompression {
c.flateWriter.Reset(c.writer)
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
if err := c.flateWriter.Close(); err != nil {
return fmt.Errorf("failed to close flate writer: %w", err)
}
} else {
if _, err := io.Copy(c.writer, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
}
if err := c.writer.Flush(); err != nil {
return fmt.Errorf("failed to flush writer: %w", err)
} }
return nil return nil
} }
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error { func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
for { _, reader, err := c.conn.Reader(ctx)
select { if err != nil {
case <-ctx.Done(): return fmt.Errorf("failed to get reader: %w", err)
return errors.New("context canceled")
default:
}
h, err := c.reader.NextFrame()
if err != nil {
c.conn.Close()
return fmt.Errorf("failed to advance frame: %w", err)
}
if h.OpCode.IsControl() {
if err := c.controlHandler(h, c.reader); err != nil {
return fmt.Errorf("failed to handle control frame: %w", err)
}
} else if h.OpCode == ws.OpBinary ||
h.OpCode == ws.OpText {
break
}
if err := c.reader.Discard(); err != nil {
return fmt.Errorf("failed to discard: %w", err)
}
} }
if _, err := io.Copy(buf, reader); err != nil {
if c.msgStateR.IsCompressed() && c.enableCompression { return fmt.Errorf("failed to read message: %w", err)
c.flateReader.Reset(c.reader)
if _, err := io.Copy(buf, c.flateReader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
} else {
if _, err := io.Copy(buf, c.reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
} }
return nil return nil
} }
func (c *Connection) Close() error { func (c *Connection) Close() error {
return c.conn.Close() return c.conn.Close(ws.StatusNormalClosure, "")
} }
func (c *Connection) Ping(ctx context.Context) error { func (c *Connection) Ping(ctx context.Context) error {
return wsutil.WriteClientMessage(c.conn, ws.OpPing, nil) return c.conn.Ping(ctx)
} }

View File

@ -1,55 +0,0 @@
//go:build js
package nostr
import (
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
ws "github.com/coder/websocket"
)
type Connection struct {
conn *ws.Conn
}
func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) {
c, _, err := ws.Dial(ctx, url, nil)
if err != nil {
return nil, err
}
return &Connection{
conn: c,
}, nil
}
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
if err := c.conn.Write(ctx, ws.MessageBinary, data); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
return nil
}
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
_, reader, err := c.conn.Reader(ctx)
if err != nil {
return fmt.Errorf("failed to get reader: %w", err)
}
if _, err := io.Copy(buf, reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
return nil
}
func (c *Connection) Close() error {
return c.conn.Close(ws.StatusNormalClosure, "")
}
func (c *Connection) Ping(ctx context.Context) error {
return c.conn.Ping(ctx)
}

34
connection_options.go Normal file
View File

@ -0,0 +1,34 @@
//go:build !js
package nostr
import (
"crypto/tls"
"net/http"
"net/textproto"
ws "github.com/coder/websocket"
)
var defaultConnectionOptions = &ws.DialOptions{
CompressionMode: ws.CompressionContextTakeover,
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
},
}
func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions {
if requestHeader == nil && tlsConfig == nil {
return defaultConnectionOptions
}
return &ws.DialOptions{
HTTPHeader: requestHeader,
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
}
}

15
connection_options_js.go Normal file
View File

@ -0,0 +1,15 @@
package nostr
import (
"crypto/tls"
"net/http"
ws "github.com/coder/websocket"
)
var emptyOptions = ws.DialOptions{}
func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions {
// on javascript we ignore everything because there is nothing else we can do
return &emptyOptions
}

29
pool.go
View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"log" "log"
"math" "math"
"net/http"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@ -31,7 +32,7 @@ type SimplePool struct {
// custom things not often used // custom things not often used
penaltyBoxMu sync.Mutex penaltyBoxMu sync.Mutex
penaltyBox map[string][2]float64 penaltyBox map[string][2]float64
userAgent string relayOptions []RelayOption
} }
type DirectedFilters struct { type DirectedFilters struct {
@ -69,6 +70,17 @@ func NewSimplePool(ctx context.Context, opts ...PoolOption) *SimplePool {
return pool return pool
} }
// WithRelayOptions sets options that will be used on every relay instance created by this pool.
func WithRelayOptions(ropts ...RelayOption) withRelayOptionsOpt {
return ropts
}
type withRelayOptionsOpt []RelayOption
func (h withRelayOptionsOpt) ApplyPoolOption(pool *SimplePool) {
pool.relayOptions = h
}
// WithAuthHandler must be a function that signs the auth event when called. // WithAuthHandler must be a function that signs the auth event when called.
// it will be called whenever any relay in the pool returns a `CLOSED` message // it will be called whenever any relay in the pool returns a `CLOSED` message
// with the "auth-required:" prefix, only once for each relay // with the "auth-required:" prefix, only once for each relay
@ -129,20 +141,11 @@ func (h WithAuthorKindQueryMiddleware) ApplyPoolOption(pool *SimplePool) {
pool.queryMiddleware = h pool.queryMiddleware = h
} }
// WithUserAgent sets the user-agent header for all relay connections in the pool.
func WithUserAgent(userAgent string) withUserAgentOpt { return withUserAgentOpt(userAgent) }
type withUserAgentOpt string
func (h withUserAgentOpt) ApplyPoolOption(pool *SimplePool) {
pool.userAgent = string(h)
}
var ( var (
_ PoolOption = (WithAuthHandler)(nil) _ PoolOption = (WithAuthHandler)(nil)
_ PoolOption = (WithEventMiddleware)(nil) _ PoolOption = (WithEventMiddleware)(nil)
_ PoolOption = WithPenaltyBox() _ PoolOption = WithPenaltyBox()
_ PoolOption = WithUserAgent("") _ PoolOption = WithRelayOptions(WithRequestHeader(http.Header{}))
) )
func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) { func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
@ -169,9 +172,7 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
ctx, cancel := context.WithTimeout(pool.Context, time.Second*15) ctx, cancel := context.WithTimeout(pool.Context, time.Second*15)
defer cancel() defer cancel()
relay = NewRelay(context.Background(), url) relay = NewRelay(context.Background(), url, pool.relayOptions...)
relay.RequestHeader.Set("User-Agent", pool.userAgent)
if err := relay.Connect(ctx); err != nil { if err := relay.Connect(ctx); err != nil {
if pool.penaltyBox != nil { if pool.penaltyBox != nil {
// putting relay in penalty box // putting relay in penalty box

View File

@ -23,7 +23,7 @@ type Relay struct {
closeMutex sync.Mutex closeMutex sync.Mutex
URL string URL string
RequestHeader http.Header // e.g. for origin header requestHeader http.Header // e.g. for origin header
Connection *Connection Connection *Connection
Subscriptions *xsync.MapOf[int64, *Subscription] Subscriptions *xsync.MapOf[int64, *Subscription]
@ -60,7 +60,7 @@ func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Relay {
okCallbacks: xsync.NewMapOf[string, func(bool, string)](), okCallbacks: xsync.NewMapOf[string, func(bool, string)](),
writeQueue: make(chan writeRequest), writeQueue: make(chan writeRequest),
subscriptionChannelCloseQueue: make(chan *Subscription), subscriptionChannelCloseQueue: make(chan *Subscription),
RequestHeader: make(http.Header, 1), requestHeader: nil,
} }
for _, opt := range opts { for _, opt := range opts {
@ -88,6 +88,7 @@ type RelayOption interface {
var ( var (
_ RelayOption = (WithNoticeHandler)(nil) _ RelayOption = (WithNoticeHandler)(nil)
_ RelayOption = (WithCustomHandler)(nil) _ RelayOption = (WithCustomHandler)(nil)
_ RelayOption = (WithRequestHeader)(nil)
) )
// WithNoticeHandler just takes notices and is expected to do something with them. // WithNoticeHandler just takes notices and is expected to do something with them.
@ -106,6 +107,13 @@ func (ch WithCustomHandler) ApplyRelayOption(r *Relay) {
r.customHandler = ch r.customHandler = ch
} }
// WithRequestHeader sets the HTTP request header of the websocket preflight request.
type WithRequestHeader http.Header
func (ch WithRequestHeader) ApplyRelayOption(r *Relay) {
r.requestHeader = http.Header(ch)
}
// String just returns the relay URL. // String just returns the relay URL.
func (r *Relay) String() string { func (r *Relay) String() string {
return r.URL return r.URL
@ -146,11 +154,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
defer cancel() defer cancel()
} }
if r.RequestHeader.Get("User-Agent") == "" { conn, err := NewConnection(ctx, r.URL, r.requestHeader, tlsConfig)
r.RequestHeader.Set("User-Agent", "github.com/nbd-wtf/go-nostr")
}
conn, err := NewConnection(ctx, r.URL, r.RequestHeader, tlsConfig)
if err != nil { if err != nil {
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
} }

View File

@ -12,40 +12,33 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestConnectContext(t *testing.T) { var testRelayURL = func() string {
url := os.Getenv("TEST_RELAY_URL") url := os.Getenv("TEST_RELAY_URL")
if url == "" { if url != "" {
t.Fatal("please set the environment: $TEST_RELAY_URL") return url
} }
return "wss://nos.lol"
}()
func TestConnectContext(t *testing.T) {
// relay client // relay client
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer cancel()
r, err := RelayConnect(ctx, url) r, err := RelayConnect(ctx, testRelayURL)
assert.NoError(t, err) assert.NoError(t, err)
defer r.Close() defer r.Close()
} }
func TestConnectContextCanceled(t *testing.T) { func TestConnectContextCanceled(t *testing.T) {
url := os.Getenv("TEST_RELAY_URL")
if url == "" {
t.Fatal("please set the environment: $TEST_RELAY_URL")
}
// relay client // relay client
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() // make ctx expired cancel() // make ctx expired
_, err := RelayConnect(ctx, url) _, err := RelayConnect(ctx, testRelayURL)
assert.ErrorIs(t, err, context.Canceled) assert.ErrorIs(t, err, context.Canceled)
} }
func TestPublish(t *testing.T) { func TestPublish(t *testing.T) {
url := os.Getenv("TEST_RELAY_URL")
if url == "" {
t.Fatal("please set the environment: $TEST_RELAY_URL")
}
// test note to be sent over websocket // test note to be sent over websocket
priv, pub := makeKeyPair(t) priv, pub := makeKeyPair(t)
textNote := Event{ textNote := Event{
@ -59,7 +52,7 @@ func TestPublish(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// connect a client and send the text note // connect a client and send the text note
rl := mustRelayConnect(t, url) rl := mustRelayConnect(t, testRelayURL)
err = rl.Publish(context.Background(), textNote) err = rl.Publish(context.Background(), textNote)
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -149,8 +149,8 @@ func TestConnectWithOrigin(t *testing.T) {
defer ws.Close() defer ws.Close()
// relay client // relay client
r := NewRelay(context.Background(), NormalizeURL(ws.URL)) r := NewRelay(context.Background(), NormalizeURL(ws.URL),
r.RequestHeader = http.Header{"origin": {"https://example.com"}} WithRequestHeader(http.Header{"origin": {"https://example.com"}}))
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer cancel()
err := r.Connect(ctx) err := r.Connect(ctx)