From c86e907142689830a707251a568dc99bad703468 Mon Sep 17 00:00:00 2001 From: Marc Tarnutzer Date: Fri, 5 May 2023 22:00:25 +0200 Subject: [PATCH] enable compression by default --- connection.go | 23 ++++++++++++++--------- relay.go | 7 +++---- subscription_test.go | 41 ----------------------------------------- 3 files changed, 17 insertions(+), 54 deletions(-) diff --git a/connection.go b/connection.go index ec45fdd..e858d3b 100644 --- a/connection.go +++ b/connection.go @@ -29,23 +29,28 @@ type Connection struct { mutex sync.Mutex } -func NewConnection(ctx context.Context, url string, requestHeader http.Header, enableCompression bool) (*Connection, error) { +func NewConnection(ctx context.Context, url string, requestHeader http.Header) (*Connection, error) { dialer := ws.Dialer{ Header: ws.HandshakeHeaderHTTP(requestHeader), - } - state := ws.StateClientSide - if enableCompression { - state |= ws.StateExtended - dialer.Extensions = []httphead.Option{ + Extensions: []httphead.Option{ wsflate.DefaultParameters.Option(), - } + }, } - - conn, _, _, err := dialer.Dial(ctx, url) + conn, _, hs, err := dialer.Dial(ctx, url) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } + enableCompression := false + state := ws.StateClientSide + for _, extension := range hs.Extensions { + if string(extension.Name) == wsflate.ExtensionName { + enableCompression = true + state |= ws.StateExtended + break + } + } + // reader var flateReader *wsflate.Reader var msgState wsflate.MessageState diff --git a/relay.go b/relay.go index fcec2db..a4af3d0 100644 --- a/relay.go +++ b/relay.go @@ -38,9 +38,8 @@ type Relay struct { URL string RequestHeader http.Header // e.g. for origin header - Connection *Connection - EnableCompression bool - subscriptions s.MapOf[string, *Subscription] + Connection *Connection + subscriptions s.MapOf[string, *Subscription] Challenges chan string // NIP-42 Challenges Notices chan string @@ -90,7 +89,7 @@ func (r *Relay) Connect(ctx context.Context) error { defer cancel() } - conn, err := NewConnection(ctx, r.URL, r.RequestHeader, r.EnableCompression) + conn, err := NewConnection(ctx, r.URL, r.RequestHeader) if err != nil { cancel() return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) diff --git a/subscription_test.go b/subscription_test.go index 7931487..f004c67 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -43,44 +43,3 @@ end: t.Errorf("expected 2 events, got %d", events) } } - -func TestSubscribeEnableCompression(t *testing.T) { - rl := &Relay{URL: NormalizeURL("wss://relay.damus.io"), EnableCompression: true} - err := rl.Connect(context.Background()) - if err != nil { - t.Fatalf("connection failed: %v", err) - } - defer rl.Close() - - sub, err := rl.Subscribe(context.Background(), Filters{{Kinds: []int{1}, Limit: 2}}) - if err != nil { - t.Errorf("subscription failed: %v", err) - return - } - - timeout := time.After(5 * time.Second) - events := 0 - - for { - select { - case event := <-sub.Events: - if event == nil { - t.Errorf("event is nil: %v", event) - } - events++ - case <-sub.EndOfStoredEvents: - goto end - case <-rl.ConnectionContext.Done(): - t.Errorf("connection closed: %v", rl.ConnectionContext.Err()) - goto end - case <-timeout: - t.Errorf("timeout") - goto end - } - } - -end: - if events != 2 { - t.Errorf("expected 2 events, got %d", events) - } -}