mirror of
https://github.com/nbd-wtf/go-nostr.git
synced 2025-06-20 13:50:55 +02:00
use coder/websocket for everything, get rid of gobwas.
supposedly it is faster, and anyway it's better to use it since we're already using it for wasm/js. (previously named nhooyr/websocket).
This commit is contained in:
parent
b33cfb19fa
commit
defc349e57
@ -150,7 +150,7 @@ To use it, use `-tags=libsecp256k1` whenever you're compiling your program that
|
||||
Install [wasmbrowsertest](https://github.com/agnivade/wasmbrowsertest), then run tests:
|
||||
|
||||
```sh
|
||||
TEST_RELAY_URL=<relay_url> GOOS=js GOARCH=wasm go test -short ./...
|
||||
GOOS=js GOARCH=wasm go test -short ./...
|
||||
```
|
||||
|
||||
## Warning: risk of goroutine bloat (if used incorrectly)
|
||||
|
162
connection.go
162
connection.go
@ -1,187 +1,53 @@
|
||||
//go:build !js
|
||||
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/gobwas/httphead"
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsflate"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
ws "github.com/coder/websocket"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
conn net.Conn
|
||||
enableCompression bool
|
||||
controlHandler wsutil.FrameHandlerFunc
|
||||
flateReader *wsflate.Reader
|
||||
reader *wsutil.Reader
|
||||
flateWriter *wsflate.Writer
|
||||
writer *wsutil.Writer
|
||||
msgStateR *wsflate.MessageState
|
||||
msgStateW *wsflate.MessageState
|
||||
conn *ws.Conn
|
||||
}
|
||||
|
||||
func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) {
|
||||
dialer := ws.Dialer{
|
||||
Header: ws.HandshakeHeaderHTTP(requestHeader),
|
||||
Extensions: []httphead.Option{
|
||||
wsflate.DefaultParameters.Option(),
|
||||
},
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
conn, _, hs, err := dialer.Dial(ctx, url)
|
||||
c, _, err := ws.Dial(ctx, url, getConnectionOptions(requestHeader, tlsConfig))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
enableCompression := false
|
||||
state := ws.StateClientSide
|
||||
for _, extension := range hs.Extensions {
|
||||
if string(extension.Name) == wsflate.ExtensionName {
|
||||
enableCompression = true
|
||||
state |= ws.StateExtended
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// reader
|
||||
var flateReader *wsflate.Reader
|
||||
var msgStateR wsflate.MessageState
|
||||
if enableCompression {
|
||||
msgStateR.SetCompressed(true)
|
||||
|
||||
flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
|
||||
return flate.NewReader(r)
|
||||
})
|
||||
}
|
||||
|
||||
controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide)
|
||||
reader := &wsutil.Reader{
|
||||
Source: conn,
|
||||
State: state,
|
||||
OnIntermediate: controlHandler,
|
||||
CheckUTF8: false,
|
||||
Extensions: []wsutil.RecvExtension{
|
||||
&msgStateR,
|
||||
},
|
||||
}
|
||||
|
||||
// writer
|
||||
var flateWriter *wsflate.Writer
|
||||
var msgStateW wsflate.MessageState
|
||||
if enableCompression {
|
||||
msgStateW.SetCompressed(true)
|
||||
|
||||
flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
|
||||
fw, err := flate.NewWriter(w, 4)
|
||||
if err != nil {
|
||||
InfoLogger.Printf("Failed to create flate writer: %v", err)
|
||||
}
|
||||
return fw
|
||||
})
|
||||
}
|
||||
|
||||
writer := wsutil.NewWriter(conn, state, ws.OpText)
|
||||
writer.SetExtensions(&msgStateW)
|
||||
|
||||
return &Connection{
|
||||
conn: conn,
|
||||
enableCompression: enableCompression,
|
||||
controlHandler: controlHandler,
|
||||
flateReader: flateReader,
|
||||
reader: reader,
|
||||
msgStateR: &msgStateR,
|
||||
flateWriter: flateWriter,
|
||||
writer: writer,
|
||||
msgStateW: &msgStateW,
|
||||
conn: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.New("context canceled")
|
||||
default:
|
||||
}
|
||||
|
||||
if c.msgStateW.IsCompressed() && c.enableCompression {
|
||||
c.flateWriter.Reset(c.writer)
|
||||
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
|
||||
if err := c.flateWriter.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close flate writer: %w", err)
|
||||
}
|
||||
} else {
|
||||
if _, err := io.Copy(c.writer, bytes.NewReader(data)); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.writer.Flush(); err != nil {
|
||||
return fmt.Errorf("failed to flush writer: %w", err)
|
||||
if err := c.conn.Write(ctx, ws.MessageText, data); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.New("context canceled")
|
||||
default:
|
||||
}
|
||||
|
||||
h, err := c.reader.NextFrame()
|
||||
if err != nil {
|
||||
c.conn.Close()
|
||||
return fmt.Errorf("failed to advance frame: %w", err)
|
||||
}
|
||||
|
||||
if h.OpCode.IsControl() {
|
||||
if err := c.controlHandler(h, c.reader); err != nil {
|
||||
return fmt.Errorf("failed to handle control frame: %w", err)
|
||||
}
|
||||
} else if h.OpCode == ws.OpBinary ||
|
||||
h.OpCode == ws.OpText {
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.reader.Discard(); err != nil {
|
||||
return fmt.Errorf("failed to discard: %w", err)
|
||||
}
|
||||
_, reader, err := c.conn.Reader(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get reader: %w", err)
|
||||
}
|
||||
|
||||
if c.msgStateR.IsCompressed() && c.enableCompression {
|
||||
c.flateReader.Reset(c.reader)
|
||||
if _, err := io.Copy(buf, c.flateReader); err != nil {
|
||||
return fmt.Errorf("failed to read message: %w", err)
|
||||
}
|
||||
} else {
|
||||
if _, err := io.Copy(buf, c.reader); err != nil {
|
||||
return fmt.Errorf("failed to read message: %w", err)
|
||||
}
|
||||
if _, err := io.Copy(buf, reader); err != nil {
|
||||
return fmt.Errorf("failed to read message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) Close() error {
|
||||
return c.conn.Close()
|
||||
return c.conn.Close(ws.StatusNormalClosure, "")
|
||||
}
|
||||
|
||||
func (c *Connection) Ping(ctx context.Context) error {
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpPing, nil)
|
||||
return c.conn.Ping(ctx)
|
||||
}
|
||||
|
@ -1,55 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
conn *ws.Conn
|
||||
}
|
||||
|
||||
func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) {
|
||||
c, _, err := ws.Dial(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Connection{
|
||||
conn: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
|
||||
if err := c.conn.Write(ctx, ws.MessageBinary, data); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
|
||||
_, reader, err := c.conn.Reader(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get reader: %w", err)
|
||||
}
|
||||
if _, err := io.Copy(buf, reader); err != nil {
|
||||
return fmt.Errorf("failed to read message: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) Close() error {
|
||||
return c.conn.Close(ws.StatusNormalClosure, "")
|
||||
}
|
||||
|
||||
func (c *Connection) Ping(ctx context.Context) error {
|
||||
return c.conn.Ping(ctx)
|
||||
}
|
34
connection_options.go
Normal file
34
connection_options.go
Normal file
@ -0,0 +1,34 @@
|
||||
//go:build !js
|
||||
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
)
|
||||
|
||||
var defaultConnectionOptions = &ws.DialOptions{
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPHeader: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
|
||||
},
|
||||
}
|
||||
|
||||
func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions {
|
||||
if requestHeader == nil && tlsConfig == nil {
|
||||
return defaultConnectionOptions
|
||||
}
|
||||
|
||||
return &ws.DialOptions{
|
||||
HTTPHeader: requestHeader,
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
15
connection_options_js.go
Normal file
15
connection_options_js.go
Normal file
@ -0,0 +1,15 @@
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
)
|
||||
|
||||
var emptyOptions = ws.DialOptions{}
|
||||
|
||||
func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions {
|
||||
// on javascript we ignore everything because there is nothing else we can do
|
||||
return &emptyOptions
|
||||
}
|
29
pool.go
29
pool.go
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -31,7 +32,7 @@ type SimplePool struct {
|
||||
// custom things not often used
|
||||
penaltyBoxMu sync.Mutex
|
||||
penaltyBox map[string][2]float64
|
||||
userAgent string
|
||||
relayOptions []RelayOption
|
||||
}
|
||||
|
||||
type DirectedFilters struct {
|
||||
@ -69,6 +70,17 @@ func NewSimplePool(ctx context.Context, opts ...PoolOption) *SimplePool {
|
||||
return pool
|
||||
}
|
||||
|
||||
// WithRelayOptions sets options that will be used on every relay instance created by this pool.
|
||||
func WithRelayOptions(ropts ...RelayOption) withRelayOptionsOpt {
|
||||
return ropts
|
||||
}
|
||||
|
||||
type withRelayOptionsOpt []RelayOption
|
||||
|
||||
func (h withRelayOptionsOpt) ApplyPoolOption(pool *SimplePool) {
|
||||
pool.relayOptions = h
|
||||
}
|
||||
|
||||
// WithAuthHandler must be a function that signs the auth event when called.
|
||||
// it will be called whenever any relay in the pool returns a `CLOSED` message
|
||||
// with the "auth-required:" prefix, only once for each relay
|
||||
@ -129,20 +141,11 @@ func (h WithAuthorKindQueryMiddleware) ApplyPoolOption(pool *SimplePool) {
|
||||
pool.queryMiddleware = h
|
||||
}
|
||||
|
||||
// WithUserAgent sets the user-agent header for all relay connections in the pool.
|
||||
func WithUserAgent(userAgent string) withUserAgentOpt { return withUserAgentOpt(userAgent) }
|
||||
|
||||
type withUserAgentOpt string
|
||||
|
||||
func (h withUserAgentOpt) ApplyPoolOption(pool *SimplePool) {
|
||||
pool.userAgent = string(h)
|
||||
}
|
||||
|
||||
var (
|
||||
_ PoolOption = (WithAuthHandler)(nil)
|
||||
_ PoolOption = (WithEventMiddleware)(nil)
|
||||
_ PoolOption = WithPenaltyBox()
|
||||
_ PoolOption = WithUserAgent("")
|
||||
_ PoolOption = WithRelayOptions(WithRequestHeader(http.Header{}))
|
||||
)
|
||||
|
||||
func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
|
||||
@ -169,9 +172,7 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
|
||||
ctx, cancel := context.WithTimeout(pool.Context, time.Second*15)
|
||||
defer cancel()
|
||||
|
||||
relay = NewRelay(context.Background(), url)
|
||||
relay.RequestHeader.Set("User-Agent", pool.userAgent)
|
||||
|
||||
relay = NewRelay(context.Background(), url, pool.relayOptions...)
|
||||
if err := relay.Connect(ctx); err != nil {
|
||||
if pool.penaltyBox != nil {
|
||||
// putting relay in penalty box
|
||||
|
18
relay.go
18
relay.go
@ -23,7 +23,7 @@ type Relay struct {
|
||||
closeMutex sync.Mutex
|
||||
|
||||
URL string
|
||||
RequestHeader http.Header // e.g. for origin header
|
||||
requestHeader http.Header // e.g. for origin header
|
||||
|
||||
Connection *Connection
|
||||
Subscriptions *xsync.MapOf[int64, *Subscription]
|
||||
@ -60,7 +60,7 @@ func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Relay {
|
||||
okCallbacks: xsync.NewMapOf[string, func(bool, string)](),
|
||||
writeQueue: make(chan writeRequest),
|
||||
subscriptionChannelCloseQueue: make(chan *Subscription),
|
||||
RequestHeader: make(http.Header, 1),
|
||||
requestHeader: nil,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@ -88,6 +88,7 @@ type RelayOption interface {
|
||||
var (
|
||||
_ RelayOption = (WithNoticeHandler)(nil)
|
||||
_ RelayOption = (WithCustomHandler)(nil)
|
||||
_ RelayOption = (WithRequestHeader)(nil)
|
||||
)
|
||||
|
||||
// WithNoticeHandler just takes notices and is expected to do something with them.
|
||||
@ -106,6 +107,13 @@ func (ch WithCustomHandler) ApplyRelayOption(r *Relay) {
|
||||
r.customHandler = ch
|
||||
}
|
||||
|
||||
// WithRequestHeader sets the HTTP request header of the websocket preflight request.
|
||||
type WithRequestHeader http.Header
|
||||
|
||||
func (ch WithRequestHeader) ApplyRelayOption(r *Relay) {
|
||||
r.requestHeader = http.Header(ch)
|
||||
}
|
||||
|
||||
// String just returns the relay URL.
|
||||
func (r *Relay) String() string {
|
||||
return r.URL
|
||||
@ -146,11 +154,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if r.RequestHeader.Get("User-Agent") == "" {
|
||||
r.RequestHeader.Set("User-Agent", "github.com/nbd-wtf/go-nostr")
|
||||
}
|
||||
|
||||
conn, err := NewConnection(ctx, r.URL, r.RequestHeader, tlsConfig)
|
||||
conn, err := NewConnection(ctx, r.URL, r.requestHeader, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
|
||||
}
|
||||
|
@ -12,40 +12,33 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnectContext(t *testing.T) {
|
||||
var testRelayURL = func() string {
|
||||
url := os.Getenv("TEST_RELAY_URL")
|
||||
if url == "" {
|
||||
t.Fatal("please set the environment: $TEST_RELAY_URL")
|
||||
if url != "" {
|
||||
return url
|
||||
}
|
||||
return "wss://nos.lol"
|
||||
}()
|
||||
|
||||
func TestConnectContext(t *testing.T) {
|
||||
// relay client
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
r, err := RelayConnect(ctx, url)
|
||||
r, err := RelayConnect(ctx, testRelayURL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer r.Close()
|
||||
}
|
||||
|
||||
func TestConnectContextCanceled(t *testing.T) {
|
||||
url := os.Getenv("TEST_RELAY_URL")
|
||||
if url == "" {
|
||||
t.Fatal("please set the environment: $TEST_RELAY_URL")
|
||||
}
|
||||
|
||||
// relay client
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // make ctx expired
|
||||
_, err := RelayConnect(ctx, url)
|
||||
_, err := RelayConnect(ctx, testRelayURL)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
url := os.Getenv("TEST_RELAY_URL")
|
||||
if url == "" {
|
||||
t.Fatal("please set the environment: $TEST_RELAY_URL")
|
||||
}
|
||||
|
||||
// test note to be sent over websocket
|
||||
priv, pub := makeKeyPair(t)
|
||||
textNote := Event{
|
||||
@ -59,7 +52,7 @@ func TestPublish(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// connect a client and send the text note
|
||||
rl := mustRelayConnect(t, url)
|
||||
rl := mustRelayConnect(t, testRelayURL)
|
||||
err = rl.Publish(context.Background(), textNote)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
@ -149,8 +149,8 @@ func TestConnectWithOrigin(t *testing.T) {
|
||||
defer ws.Close()
|
||||
|
||||
// relay client
|
||||
r := NewRelay(context.Background(), NormalizeURL(ws.URL))
|
||||
r.RequestHeader = http.Header{"origin": {"https://example.com"}}
|
||||
r := NewRelay(context.Background(), NormalizeURL(ws.URL),
|
||||
WithRequestHeader(http.Header{"origin": {"https://example.com"}}))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
err := r.Connect(ctx)
|
||||
|
Loading…
x
Reference in New Issue
Block a user