mirror of
https://github.com/fiatjaf/khatru.git
synced 2025-03-17 21:32:55 +01:00
guard all websocket writes with mutexes.
This commit is contained in:
parent
ed829ac5f8
commit
ea7d2eeb3e
19
handlers.go
19
handlers.go
@ -40,6 +40,8 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
|
||||
}
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
|
||||
ws := &WebSocket{conn: conn}
|
||||
|
||||
// reader
|
||||
go func() {
|
||||
defer func() {
|
||||
@ -65,7 +67,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
|
||||
}
|
||||
|
||||
if typ == websocket.PingMessage {
|
||||
conn.WriteMessage(websocket.PongMessage, nil)
|
||||
ws.WriteMessage(websocket.PongMessage, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -73,7 +75,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
|
||||
var notice string
|
||||
defer func() {
|
||||
if notice != "" {
|
||||
conn.WriteJSON([]interface{}{"NOTICE", notice})
|
||||
ws.WriteJSON([]interface{}{"NOTICE", notice})
|
||||
}
|
||||
}()
|
||||
|
||||
@ -123,6 +125,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
|
||||
}
|
||||
|
||||
notifyListeners(&evt)
|
||||
break
|
||||
case "REQ":
|
||||
var id string
|
||||
json.Unmarshal(request[1], &id)
|
||||
@ -144,13 +147,13 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
|
||||
events, err := relay.QueryEvents(&filters[i])
|
||||
if err == nil {
|
||||
for _, event := range events {
|
||||
conn.WriteJSON([]interface{}{"EVENT", id, event})
|
||||
ws.WriteJSON([]interface{}{"EVENT", id, event})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setListener(id, conn, filters)
|
||||
|
||||
setListener(id, ws, filters)
|
||||
break
|
||||
case "CLOSE":
|
||||
var id string
|
||||
json.Unmarshal(request[0], &id)
|
||||
@ -159,8 +162,8 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
removeListener(conn, id)
|
||||
|
||||
removeListener(ws, id)
|
||||
break
|
||||
default:
|
||||
notice = "unknown message type " + typ
|
||||
return
|
||||
@ -179,7 +182,7 @@ func handleWebsocket(relay Relay) func(http.ResponseWriter, *http.Request) {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := conn.WriteMessage(websocket.PingMessage, nil)
|
||||
err := ws.WriteMessage(websocket.PingMessage, nil)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("error writing ping, closing websocket")
|
||||
return
|
||||
|
21
listener.go
21
listener.go
@ -4,14 +4,13 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/fiatjaf/go-nostr"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
filters nostr.EventFilters
|
||||
}
|
||||
|
||||
var listeners = make(map[*websocket.Conn]map[string]*Listener)
|
||||
var listeners = make(map[*WebSocket]map[string]*Listener)
|
||||
var listenersMutex = sync.Mutex{}
|
||||
|
||||
func GetListeningFilters() nostr.EventFilters {
|
||||
@ -47,16 +46,16 @@ func GetListeningFilters() nostr.EventFilters {
|
||||
return respfilters
|
||||
}
|
||||
|
||||
func setListener(id string, conn *websocket.Conn, filters nostr.EventFilters) {
|
||||
func setListener(id string, ws *WebSocket, filters nostr.EventFilters) {
|
||||
listenersMutex.Lock()
|
||||
defer func() {
|
||||
listenersMutex.Unlock()
|
||||
}()
|
||||
|
||||
subs, ok := listeners[conn]
|
||||
subs, ok := listeners[ws]
|
||||
if !ok {
|
||||
subs = make(map[string]*Listener)
|
||||
listeners[conn] = subs
|
||||
listeners[ws] = subs
|
||||
}
|
||||
|
||||
subs[id] = &Listener{
|
||||
@ -64,17 +63,17 @@ func setListener(id string, conn *websocket.Conn, filters nostr.EventFilters) {
|
||||
}
|
||||
}
|
||||
|
||||
func removeListener(conn *websocket.Conn, id string) {
|
||||
func removeListener(ws *WebSocket, id string) {
|
||||
listenersMutex.Lock()
|
||||
defer func() {
|
||||
listenersMutex.Unlock()
|
||||
}()
|
||||
|
||||
subs, ok := listeners[conn]
|
||||
subs, ok := listeners[ws]
|
||||
if ok {
|
||||
delete(listeners[conn], id)
|
||||
delete(listeners[ws], id)
|
||||
if len(subs) == 0 {
|
||||
delete(listeners, conn)
|
||||
delete(listeners, ws)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -85,12 +84,12 @@ func notifyListeners(event *nostr.Event) {
|
||||
listenersMutex.Unlock()
|
||||
}()
|
||||
|
||||
for conn, subs := range listeners {
|
||||
for ws, subs := range listeners {
|
||||
for id, listener := range subs {
|
||||
if !listener.filters.Match(event) {
|
||||
continue
|
||||
}
|
||||
conn.WriteJSON([]interface{}{"EVENT", id, event})
|
||||
ws.WriteJSON([]interface{}{"EVENT", id, event})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
24
websocket.go
Normal file
24
websocket.go
Normal file
@ -0,0 +1,24 @@
|
||||
package relayer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type WebSocket struct {
|
||||
conn *websocket.Conn
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func (ws *WebSocket) WriteJSON(any interface{}) error {
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
return ws.conn.WriteJSON(any)
|
||||
}
|
||||
|
||||
func (ws *WebSocket) WriteMessage(t int, b []byte) error {
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
return ws.conn.WriteMessage(t, b)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user