From ea7d2eeb3e6969e0416c017f5749cf36dfb69817 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Tue, 11 Jan 2022 16:00:19 -0300 Subject: [PATCH] guard all websocket writes with mutexes. --- handlers.go | 19 +++++++++++-------- listener.go | 21 ++++++++++----------- websocket.go | 24 ++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 19 deletions(-) create mode 100644 websocket.go diff --git a/handlers.go b/handlers.go index b462589..be1f98b 100644 --- a/handlers.go +++ b/handlers.go @@ -40,6 +40,8 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) { } ticker := time.NewTicker(pingPeriod) + ws := &WebSocket{conn: conn} + // reader go func() { defer func() { @@ -65,7 +67,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) { } if typ == websocket.PingMessage { - conn.WriteMessage(websocket.PongMessage, nil) + ws.WriteMessage(websocket.PongMessage, nil) continue } @@ -73,7 +75,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) { var notice string defer func() { if notice != "" { - conn.WriteJSON([]interface{}{"NOTICE", notice}) + ws.WriteJSON([]interface{}{"NOTICE", notice}) } }() @@ -123,6 +125,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) { } notifyListeners(&evt) + break case "REQ": var id string json.Unmarshal(request[1], &id) @@ -144,13 +147,13 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) { events, err := relay.QueryEvents(&filters[i]) if err == nil { for _, event := range events { - conn.WriteJSON([]interface{}{"EVENT", id, event}) + ws.WriteJSON([]interface{}{"EVENT", id, event}) } } } - setListener(id, conn, filters) - + setListener(id, ws, filters) + break case "CLOSE": var id string json.Unmarshal(request[0], &id) @@ -159,8 +162,8 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) { return } - removeListener(conn, id) - + removeListener(ws, id) + break default: notice = "unknown message type " + typ return @@ -179,7 +182,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) { for { select { case <-ticker.C: - err := conn.WriteMessage(websocket.PingMessage, nil) + err := ws.WriteMessage(websocket.PingMessage, nil) if err != nil { log.Warn().Err(err).Msg("error writing ping, closing websocket") return diff --git a/listener.go b/listener.go index 1055025..3f6b645 100644 --- a/listener.go +++ b/listener.go @@ -4,14 +4,13 @@ import ( "sync" "github.com/fiatjaf/go-nostr" - "github.com/gorilla/websocket" ) type Listener struct { filters nostr.EventFilters } -var listeners = make(map[*websocket.Conn]map[string]*Listener) +var listeners = make(map[*WebSocket]map[string]*Listener) var listenersMutex = sync.Mutex{} func GetListeningFilters() nostr.EventFilters { @@ -47,16 +46,16 @@ func GetListeningFilters() nostr.EventFilters { return respfilters } -func setListener(id string, conn *websocket.Conn, filters nostr.EventFilters) { +func setListener(id string, ws *WebSocket, filters nostr.EventFilters) { listenersMutex.Lock() defer func() { listenersMutex.Unlock() }() - subs, ok := listeners[conn] + subs, ok := listeners[ws] if !ok { subs = make(map[string]*Listener) - listeners[conn] = subs + listeners[ws] = subs } subs[id] = &Listener{ @@ -64,17 +63,17 @@ func setListener(id string, conn *websocket.Conn, filters nostr.EventFilters) { } } -func removeListener(conn *websocket.Conn, id string) { +func removeListener(ws *WebSocket, id string) { listenersMutex.Lock() defer func() { listenersMutex.Unlock() }() - subs, ok := listeners[conn] + subs, ok := listeners[ws] if ok { - delete(listeners[conn], id) + delete(listeners[ws], id) if len(subs) == 0 { - delete(listeners, conn) + delete(listeners, ws) } } } @@ -85,12 +84,12 @@ func notifyListeners(event *nostr.Event) { listenersMutex.Unlock() }() - for conn, subs := range listeners { + for ws, subs := range listeners { for id, listener := range subs { if !listener.filters.Match(event) { continue } - conn.WriteJSON([]interface{}{"EVENT", id, event}) + ws.WriteJSON([]interface{}{"EVENT", id, event}) } } } diff --git a/websocket.go b/websocket.go new file mode 100644 index 0000000..c2f8165 --- /dev/null +++ b/websocket.go @@ -0,0 +1,24 @@ +package relayer + +import ( + "sync" + + "github.com/gorilla/websocket" +) + +type WebSocket struct { + conn *websocket.Conn + mutex sync.Mutex +} + +func (ws *WebSocket) WriteJSON(any interface{}) error { + ws.mutex.Lock() + defer ws.mutex.Unlock() + return ws.conn.WriteJSON(any) +} + +func (ws *WebSocket) WriteMessage(t int, b []byte) error { + ws.mutex.Lock() + defer ws.mutex.Unlock() + return ws.conn.WriteMessage(t, b) +}