guard all websocket writes with mutexes.

This commit is contained in:
fiatjaf 2022-01-11 16:00:19 -03:00
parent ed829ac5f8
commit ea7d2eeb3e
3 changed files with 45 additions and 19 deletions

View File

@ -40,6 +40,8 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
}
ticker := time.NewTicker(pingPeriod)
ws := &WebSocket{conn: conn}
// reader
go func() {
defer func() {
@ -65,7 +67,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
}
if typ == websocket.PingMessage {
conn.WriteMessage(websocket.PongMessage, nil)
ws.WriteMessage(websocket.PongMessage, nil)
continue
}
@ -73,7 +75,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
var notice string
defer func() {
if notice != "" {
conn.WriteJSON([]interface{}{"NOTICE", notice})
ws.WriteJSON([]interface{}{"NOTICE", notice})
}
}()
@ -123,6 +125,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
}
notifyListeners(&evt)
break
case "REQ":
var id string
json.Unmarshal(request[1], &id)
@ -144,13 +147,13 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
events, err := relay.QueryEvents(&filters[i])
if err == nil {
for _, event := range events {
conn.WriteJSON([]interface{}{"EVENT", id, event})
ws.WriteJSON([]interface{}{"EVENT", id, event})
}
}
}
setListener(id, conn, filters)
setListener(id, ws, filters)
break
case "CLOSE":
var id string
json.Unmarshal(request[0], &id)
@ -159,8 +162,8 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
return
}
removeListener(conn, id)
removeListener(ws, id)
break
default:
notice = "unknown message type " + typ
return
@ -179,7 +182,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
for {
select {
case <-ticker.C:
err := conn.WriteMessage(websocket.PingMessage, nil)
err := ws.WriteMessage(websocket.PingMessage, nil)
if err != nil {
log.Warn().Err(err).Msg("error writing ping, closing websocket")
return

View File

@ -4,14 +4,13 @@ import (
"sync"
"github.com/fiatjaf/go-nostr"
"github.com/gorilla/websocket"
)
type Listener struct {
filters nostr.EventFilters
}
var listeners = make(map[*websocket.Conn]map[string]*Listener)
var listeners = make(map[*WebSocket]map[string]*Listener)
var listenersMutex = sync.Mutex{}
func GetListeningFilters() nostr.EventFilters {
@ -47,16 +46,16 @@ func GetListeningFilters() nostr.EventFilters {
return respfilters
}
func setListener(id string, conn *websocket.Conn, filters nostr.EventFilters) {
func setListener(id string, ws *WebSocket, filters nostr.EventFilters) {
listenersMutex.Lock()
defer func() {
listenersMutex.Unlock()
}()
subs, ok := listeners[conn]
subs, ok := listeners[ws]
if !ok {
subs = make(map[string]*Listener)
listeners[conn] = subs
listeners[ws] = subs
}
subs[id] = &Listener{
@ -64,17 +63,17 @@ func setListener(id string, conn *websocket.Conn, filters nostr.EventFilters) {
}
}
func removeListener(conn *websocket.Conn, id string) {
func removeListener(ws *WebSocket, id string) {
listenersMutex.Lock()
defer func() {
listenersMutex.Unlock()
}()
subs, ok := listeners[conn]
subs, ok := listeners[ws]
if ok {
delete(listeners[conn], id)
delete(listeners[ws], id)
if len(subs) == 0 {
delete(listeners, conn)
delete(listeners, ws)
}
}
}
@ -85,12 +84,12 @@ func notifyListeners(event *nostr.Event) {
listenersMutex.Unlock()
}()
for conn, subs := range listeners {
for ws, subs := range listeners {
for id, listener := range subs {
if !listener.filters.Match(event) {
continue
}
conn.WriteJSON([]interface{}{"EVENT", id, event})
ws.WriteJSON([]interface{}{"EVENT", id, event})
}
}
}

24
websocket.go Normal file
View File

@ -0,0 +1,24 @@
package relayer
import (
"sync"
"github.com/gorilla/websocket"
)
type WebSocket struct {
conn *websocket.Conn
mutex sync.Mutex
}
func (ws *WebSocket) WriteJSON(any interface{}) error {
ws.mutex.Lock()
defer ws.mutex.Unlock()
return ws.conn.WriteJSON(any)
}
func (ws *WebSocket) WriteMessage(t int, b []byte) error {
ws.mutex.Lock()
defer ws.mutex.Unlock()
return ws.conn.WriteMessage(t, b)
}