Add mutexes around websockets

We replace the bare websocket.Conn type with a new Connection type which
implements `WriteJSON`, `WriteMessage`, and `Close`.  The Connection
type adds mutexes around writes since gorilla doesn't support concurrent
writes to websockets.

Signed-off-by: Honza Pokorny <honza@pokorny.ca>
This commit is contained in:
Honza Pokorny 2022-01-12 10:54:45 -04:00 committed by fiatjaf
parent ba0507cce7
commit a3df2cb893
3 changed files with 51 additions and 20 deletions

33
connection.go Normal file
View File

@ -0,0 +1,33 @@
package nostr
import (
"github.com/gorilla/websocket"
"sync"
)
type Connection struct {
socket *websocket.Conn
mutex sync.Mutex
}
func NewConnection(socket *websocket.Conn) *Connection {
return &Connection{
socket: socket,
}
}
func (c *Connection) WriteJSON(v interface{}) error {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.socket.WriteJSON(v)
}
func (c *Connection) WriteMessage(messageType int, data []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.socket.WriteMessage(messageType, data)
}
func (c *Connection) Close() error {
return c.socket.Close()
}

View File

@ -28,7 +28,7 @@ type RelayPool struct {
SecretKey *string SecretKey *string
Relays map[string]RelayPoolPolicy Relays map[string]RelayPoolPolicy
websockets map[string]*websocket.Conn websockets map[string]*Connection
subscriptions map[string]*Subscription subscriptions map[string]*Subscription
Notices chan *NoticeMessage Notices chan *NoticeMessage
@ -61,7 +61,7 @@ type NoticeMessage struct {
func NewRelayPool() *RelayPool { func NewRelayPool() *RelayPool {
return &RelayPool{ return &RelayPool{
Relays: make(map[string]RelayPoolPolicy), Relays: make(map[string]RelayPoolPolicy),
websockets: make(map[string]*websocket.Conn), websockets: make(map[string]*Connection),
subscriptions: make(map[string]*Subscription), subscriptions: make(map[string]*Subscription),
Notices: make(chan *NoticeMessage), Notices: make(chan *NoticeMessage),
@ -80,11 +80,13 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error {
return fmt.Errorf("invalid relay URL '%s'", url) return fmt.Errorf("invalid relay URL '%s'", url)
} }
conn, _, err := websocket.DefaultDialer.Dial(NormalizeURL(url), nil) socket, _, err := websocket.DefaultDialer.Dial(NormalizeURL(url), nil)
if err != nil { if err != nil {
return fmt.Errorf("error opening websocket to '%s': %w", nm, err) return fmt.Errorf("error opening websocket to '%s': %w", nm, err)
} }
conn := NewConnection(socket)
r.Relays[nm] = policy r.Relays[nm] = policy
r.websockets[nm] = conn r.websockets[nm] = conn
@ -94,7 +96,7 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error {
go func() { go func() {
for { for {
typ, message, err := conn.ReadMessage() typ, message, err := conn.socket.ReadMessage()
if err != nil { if err != nil {
log.Println("read error: ", err) log.Println("read error: ", err)
return return
@ -183,7 +185,7 @@ func (r *RelayPool) Sub(filters EventFilters) *Subscription {
subscription := Subscription{filters: filters} subscription := Subscription{filters: filters}
subscription.channel = hex.EncodeToString(random) subscription.channel = hex.EncodeToString(random)
subscription.relays = make(map[string]*websocket.Conn) subscription.relays = make(map[string]*Connection)
for relay, policy := range r.Relays { for relay, policy := range r.Relays {
if policy.ShouldRead(filters) { if policy.ShouldRead(filters) {
ws := r.websockets[relay] ws := r.websockets[relay]
@ -225,7 +227,7 @@ func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error)
continue continue
} }
go func(relay string, conn *websocket.Conn) { go func(relay string, conn *Connection) {
err := conn.WriteJSON([]interface{}{"EVENT", evt}) err := conn.WriteJSON([]interface{}{"EVENT", evt})
if err != nil { if err != nil {
log.Printf("error sending event to '%s': %s", relay, err.Error()) log.Printf("error sending event to '%s': %s", relay, err.Error())

View File

@ -1,12 +1,8 @@
package nostr package nostr
import (
"github.com/gorilla/websocket"
)
type Subscription struct { type Subscription struct {
channel string channel string
relays map[string]*websocket.Conn relays map[string]*Connection
filters EventFilters filters EventFilters
Events chan EventMessage Events chan EventMessage
@ -21,8 +17,8 @@ type EventMessage struct {
} }
func (subscription Subscription) Unsub() { func (subscription Subscription) Unsub() {
for _, ws := range subscription.relays { for _, conn := range subscription.relays {
ws.WriteJSON([]interface{}{ conn.WriteJSON([]interface{}{
"CLOSE", "CLOSE",
subscription.channel, subscription.channel,
}) })
@ -37,7 +33,7 @@ func (subscription Subscription) Unsub() {
} }
func (subscription Subscription) Sub() { func (subscription Subscription) Sub() {
for _, ws := range subscription.relays { for _, conn := range subscription.relays {
message := []interface{}{ message := []interface{}{
"REQ", "REQ",
subscription.channel, subscription.channel,
@ -46,7 +42,7 @@ func (subscription Subscription) Sub() {
message = append(message, filter) message = append(message, filter)
} }
ws.WriteJSON(message) conn.WriteJSON(message)
} }
if !subscription.started { if !subscription.started {
@ -66,17 +62,17 @@ func (subscription Subscription) startHandlingUnique() {
} }
func (subscription Subscription) removeRelay(relay string) { func (subscription Subscription) removeRelay(relay string) {
if ws, ok := subscription.relays[relay]; ok { if conn, ok := subscription.relays[relay]; ok {
delete(subscription.relays, relay) delete(subscription.relays, relay)
ws.WriteJSON([]interface{}{ conn.WriteJSON([]interface{}{
"CLOSE", "CLOSE",
subscription.channel, subscription.channel,
}) })
} }
} }
func (subscription Subscription) addRelay(relay string, ws *websocket.Conn) { func (subscription Subscription) addRelay(relay string, conn *Connection) {
subscription.relays[relay] = ws subscription.relays[relay] = conn
message := []interface{}{ message := []interface{}{
"REQ", "REQ",
@ -86,5 +82,5 @@ func (subscription Subscription) addRelay(relay string, ws *websocket.Conn) {
message = append(message, filter) message = append(message, filter)
} }
ws.WriteJSON(message) conn.WriteJSON(message)
} }