enable compression by default

This commit is contained in:
Marc Tarnutzer 2023-05-05 22:00:25 +02:00
parent ee9502bc3e
commit c86e907142
3 changed files with 17 additions and 54 deletions

View File

@ -29,23 +29,28 @@ type Connection struct {
mutex sync.Mutex
}
func NewConnection(ctx context.Context, url string, requestHeader http.Header, enableCompression bool) (*Connection, error) {
func NewConnection(ctx context.Context, url string, requestHeader http.Header) (*Connection, error) {
dialer := ws.Dialer{
Header: ws.HandshakeHeaderHTTP(requestHeader),
}
state := ws.StateClientSide
if enableCompression {
state |= ws.StateExtended
dialer.Extensions = []httphead.Option{
Extensions: []httphead.Option{
wsflate.DefaultParameters.Option(),
}
},
}
conn, _, _, err := dialer.Dial(ctx, url)
conn, _, hs, err := dialer.Dial(ctx, url)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}
enableCompression := false
state := ws.StateClientSide
for _, extension := range hs.Extensions {
if string(extension.Name) == wsflate.ExtensionName {
enableCompression = true
state |= ws.StateExtended
break
}
}
// reader
var flateReader *wsflate.Reader
var msgState wsflate.MessageState

View File

@ -38,9 +38,8 @@ type Relay struct {
URL string
RequestHeader http.Header // e.g. for origin header
Connection *Connection
EnableCompression bool
subscriptions s.MapOf[string, *Subscription]
Connection *Connection
subscriptions s.MapOf[string, *Subscription]
Challenges chan string // NIP-42 Challenges
Notices chan string
@ -90,7 +89,7 @@ func (r *Relay) Connect(ctx context.Context) error {
defer cancel()
}
conn, err := NewConnection(ctx, r.URL, r.RequestHeader, r.EnableCompression)
conn, err := NewConnection(ctx, r.URL, r.RequestHeader)
if err != nil {
cancel()
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)

View File

@ -43,44 +43,3 @@ end:
t.Errorf("expected 2 events, got %d", events)
}
}
func TestSubscribeEnableCompression(t *testing.T) {
rl := &Relay{URL: NormalizeURL("wss://relay.damus.io"), EnableCompression: true}
err := rl.Connect(context.Background())
if err != nil {
t.Fatalf("connection failed: %v", err)
}
defer rl.Close()
sub, err := rl.Subscribe(context.Background(), Filters{{Kinds: []int{1}, Limit: 2}})
if err != nil {
t.Errorf("subscription failed: %v", err)
return
}
timeout := time.After(5 * time.Second)
events := 0
for {
select {
case event := <-sub.Events:
if event == nil {
t.Errorf("event is nil: %v", event)
}
events++
case <-sub.EndOfStoredEvents:
goto end
case <-rl.ConnectionContext.Done():
t.Errorf("connection closed: %v", rl.ConnectionContext.Err())
goto end
case <-timeout:
t.Errorf("timeout")
goto end
}
}
end:
if events != 2 {
t.Errorf("expected 2 events, got %d", events)
}
}