diff --git a/pool.go b/pool.go index a5e710b..729fdbe 100644 --- a/pool.go +++ b/pool.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "time" "github.com/puzpuzpuz/xsync" ) @@ -40,7 +41,9 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) { } else { var err error // we use this ctx here so when the pool dies everything dies - relay, err = RelayConnect(pool.Context, nm) + ctx, cancel := context.WithTimeout(pool.Context, time.Second*15) + defer cancel() + relay, err = RelayConnect(ctx, nm) if err != nil { return nil, fmt.Errorf("failed to connect: %w", err) } diff --git a/relay.go b/relay.go index 11c3edd..5143f33 100644 --- a/relay.go +++ b/relay.go @@ -56,11 +56,13 @@ type Relay struct { // NewRelay returns a new relay. The relay connection will be closed when the context is canceled. func NewRelay(ctx context.Context, url string) *Relay { + ctx, cancel := context.WithCancel(ctx) return &Relay{ - URL: NormalizeURL(url), - connectionContext: ctx, - Subscriptions: xsync.NewMapOf[*Subscription](), - okCallbacks: xsync.NewMapOf[func(bool, string)](), + URL: NormalizeURL(url), + connectionContext: ctx, + connectionContextCancel: cancel, + Subscriptions: xsync.NewMapOf[*Subscription](), + okCallbacks: xsync.NewMapOf[func(bool, string)](), } } @@ -89,10 +91,8 @@ func (r *Relay) Context() context.Context { return r.connectionContext } // pass a custom context to the underlying relay connection, use NewRelay() and // then Relay.Connect(). func (r *Relay) Connect(ctx context.Context) error { - if r.connectionContext == nil { - connectionContext, cancel := context.WithCancel(context.Background()) - r.connectionContext = connectionContext - r.connectionContextCancel = cancel + if r.connectionContext == nil || r.Subscriptions == nil { + return fmt.Errorf("relay must be initialized with a call to NewRelay()") } if r.URL == "" { @@ -134,6 +134,7 @@ func (r *Relay) Connect(ctx context.Context) error { err := conn.Ping() if err != nil { InfoLogger.Printf("{%s} error writing ping: %v; closing websocket", r.URL, err) + r.Close() return } } @@ -407,12 +408,6 @@ func (r *Relay) PrepareSubscription(ctx context.Context) *Subscription { ctx, cancel := context.WithCancel(ctx) - go func() { - // ensure the subscription dies if the relay connection dies - <-r.connectionContext.Done() - cancel() - }() - return &Subscription{ Relay: r, Context: ctx, diff --git a/relay_test.go b/relay_test.go index 1436811..8a1ef62 100644 --- a/relay_test.go +++ b/relay_test.go @@ -162,7 +162,8 @@ func TestConnectWithOrigin(t *testing.T) { defer ws.Close() // relay client - r := &Relay{URL: NormalizeURL(ws.URL), RequestHeader: http.Header{"origin": {"https://example.com"}}} + r := NewRelay(context.Background(), NormalizeURL(ws.URL)) + r.RequestHeader = http.Header{"origin": {"https://example.com"}} ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() err := r.Connect(ctx)