From 9069341657cd2fcec68f90e2a211092862609074 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Wed, 16 Oct 2024 15:06:47 -0300 Subject: [PATCH] a context that is canceled whenever a websocket is dropped. --- get-started.go | 1 + handlers.go | 15 +++++++++------ websocket.go | 5 +++++ 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/get-started.go b/get-started.go index fcc5503..f25b4e3 100644 --- a/get-started.go +++ b/get-started.go @@ -53,6 +53,7 @@ func (rl *Relay) Shutdown(ctx context.Context) { defer rl.clientsMutex.Unlock() for ws := range rl.clients { ws.conn.WriteControl(websocket.CloseMessage, nil, time.Now().Add(time.Second)) + ws.cancel() ws.conn.Close() } clear(rl.clients) diff --git a/handlers.go b/handlers.go index 90a5a6f..7a953bf 100644 --- a/handlers.go +++ b/handlers.go @@ -58,6 +58,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { Request: r, Challenge: hex.EncodeToString(challenge), } + ws.Context, ws.cancel = context.WithCancel(context.Background()) rl.clientsMutex.Lock() rl.clients[ws] = make([]listenerSpec, 0, 2) @@ -77,7 +78,8 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { ticker.Stop() cancel() - conn.Close() + ws.cancel() + ws.conn.Close() rl.removeClientAndListeners(ws) } @@ -85,10 +87,10 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { go func() { defer kill() - conn.SetReadLimit(rl.MaxMessageSize) - conn.SetReadDeadline(time.Now().Add(rl.PongWait)) - conn.SetPongHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(rl.PongWait)) + ws.conn.SetReadLimit(rl.MaxMessageSize) + ws.conn.SetReadDeadline(time.Now().Add(rl.PongWait)) + ws.conn.SetPongHandler(func(string) error { + ws.conn.SetReadDeadline(time.Now().Add(rl.PongWait)) return nil }) @@ -97,7 +99,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { } for { - typ, message, err := conn.ReadMessage() + typ, message, err := ws.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError( err, @@ -109,6 +111,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { ) { rl.Log.Printf("unexpected close error from %s: %v\n", r.Header.Get("X-Forwarded-For"), err) } + ws.cancel() return } diff --git a/websocket.go b/websocket.go index 21f0922..7962f17 100644 --- a/websocket.go +++ b/websocket.go @@ -1,6 +1,7 @@ package khatru import ( + "context" "net/http" "sync" @@ -15,6 +16,10 @@ type WebSocket struct { // original request Request *http.Request + // this Context will be canceled whenever the connection is closed from the client side or server-side. + Context context.Context + cancel context.CancelFunc + // nip42 Challenge string AuthedPublicKey string