mirror of
https://github.com/nbd-wtf/go-nostr.git
synced 2025-06-29 10:09:35 +02:00
Add mutexes around websockets
We replace the bare websocket.Conn type with a new Connection type which implements `WriteJSON`, `WriteMessage`, and `Close`. The Connection type adds mutexes around writes since gorilla doesn't support concurrent writes to websockets. Signed-off-by: Honza Pokorny <honza@pokorny.ca>
This commit is contained in:
33
connection.go
Normal file
33
connection.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package nostr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Connection struct {
|
||||||
|
socket *websocket.Conn
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnection(socket *websocket.Conn) *Connection {
|
||||||
|
return &Connection{
|
||||||
|
socket: socket,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) WriteJSON(v interface{}) error {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
return c.socket.WriteJSON(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) WriteMessage(messageType int, data []byte) error {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
return c.socket.WriteMessage(messageType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Close() error {
|
||||||
|
return c.socket.Close()
|
||||||
|
}
|
14
relaypool.go
14
relaypool.go
@ -28,7 +28,7 @@ type RelayPool struct {
|
|||||||
SecretKey *string
|
SecretKey *string
|
||||||
|
|
||||||
Relays map[string]RelayPoolPolicy
|
Relays map[string]RelayPoolPolicy
|
||||||
websockets map[string]*websocket.Conn
|
websockets map[string]*Connection
|
||||||
subscriptions map[string]*Subscription
|
subscriptions map[string]*Subscription
|
||||||
|
|
||||||
Notices chan *NoticeMessage
|
Notices chan *NoticeMessage
|
||||||
@ -61,7 +61,7 @@ type NoticeMessage struct {
|
|||||||
func NewRelayPool() *RelayPool {
|
func NewRelayPool() *RelayPool {
|
||||||
return &RelayPool{
|
return &RelayPool{
|
||||||
Relays: make(map[string]RelayPoolPolicy),
|
Relays: make(map[string]RelayPoolPolicy),
|
||||||
websockets: make(map[string]*websocket.Conn),
|
websockets: make(map[string]*Connection),
|
||||||
subscriptions: make(map[string]*Subscription),
|
subscriptions: make(map[string]*Subscription),
|
||||||
|
|
||||||
Notices: make(chan *NoticeMessage),
|
Notices: make(chan *NoticeMessage),
|
||||||
@ -80,11 +80,13 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error {
|
|||||||
return fmt.Errorf("invalid relay URL '%s'", url)
|
return fmt.Errorf("invalid relay URL '%s'", url)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, _, err := websocket.DefaultDialer.Dial(NormalizeURL(url), nil)
|
socket, _, err := websocket.DefaultDialer.Dial(NormalizeURL(url), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error opening websocket to '%s': %w", nm, err)
|
return fmt.Errorf("error opening websocket to '%s': %w", nm, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn := NewConnection(socket)
|
||||||
|
|
||||||
r.Relays[nm] = policy
|
r.Relays[nm] = policy
|
||||||
r.websockets[nm] = conn
|
r.websockets[nm] = conn
|
||||||
|
|
||||||
@ -94,7 +96,7 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
typ, message, err := conn.ReadMessage()
|
typ, message, err := conn.socket.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("read error: ", err)
|
log.Println("read error: ", err)
|
||||||
return
|
return
|
||||||
@ -183,7 +185,7 @@ func (r *RelayPool) Sub(filters EventFilters) *Subscription {
|
|||||||
|
|
||||||
subscription := Subscription{filters: filters}
|
subscription := Subscription{filters: filters}
|
||||||
subscription.channel = hex.EncodeToString(random)
|
subscription.channel = hex.EncodeToString(random)
|
||||||
subscription.relays = make(map[string]*websocket.Conn)
|
subscription.relays = make(map[string]*Connection)
|
||||||
for relay, policy := range r.Relays {
|
for relay, policy := range r.Relays {
|
||||||
if policy.ShouldRead(filters) {
|
if policy.ShouldRead(filters) {
|
||||||
ws := r.websockets[relay]
|
ws := r.websockets[relay]
|
||||||
@ -225,7 +227,7 @@ func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
go func(relay string, conn *websocket.Conn) {
|
go func(relay string, conn *Connection) {
|
||||||
err := conn.WriteJSON([]interface{}{"EVENT", evt})
|
err := conn.WriteJSON([]interface{}{"EVENT", evt})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error sending event to '%s': %s", relay, err.Error())
|
log.Printf("error sending event to '%s': %s", relay, err.Error())
|
||||||
|
@ -1,12 +1,8 @@
|
|||||||
package nostr
|
package nostr
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Subscription struct {
|
type Subscription struct {
|
||||||
channel string
|
channel string
|
||||||
relays map[string]*websocket.Conn
|
relays map[string]*Connection
|
||||||
|
|
||||||
filters EventFilters
|
filters EventFilters
|
||||||
Events chan EventMessage
|
Events chan EventMessage
|
||||||
@ -21,8 +17,8 @@ type EventMessage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (subscription Subscription) Unsub() {
|
func (subscription Subscription) Unsub() {
|
||||||
for _, ws := range subscription.relays {
|
for _, conn := range subscription.relays {
|
||||||
ws.WriteJSON([]interface{}{
|
conn.WriteJSON([]interface{}{
|
||||||
"CLOSE",
|
"CLOSE",
|
||||||
subscription.channel,
|
subscription.channel,
|
||||||
})
|
})
|
||||||
@ -37,7 +33,7 @@ func (subscription Subscription) Unsub() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (subscription Subscription) Sub() {
|
func (subscription Subscription) Sub() {
|
||||||
for _, ws := range subscription.relays {
|
for _, conn := range subscription.relays {
|
||||||
message := []interface{}{
|
message := []interface{}{
|
||||||
"REQ",
|
"REQ",
|
||||||
subscription.channel,
|
subscription.channel,
|
||||||
@ -46,7 +42,7 @@ func (subscription Subscription) Sub() {
|
|||||||
message = append(message, filter)
|
message = append(message, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
ws.WriteJSON(message)
|
conn.WriteJSON(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !subscription.started {
|
if !subscription.started {
|
||||||
@ -66,17 +62,17 @@ func (subscription Subscription) startHandlingUnique() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (subscription Subscription) removeRelay(relay string) {
|
func (subscription Subscription) removeRelay(relay string) {
|
||||||
if ws, ok := subscription.relays[relay]; ok {
|
if conn, ok := subscription.relays[relay]; ok {
|
||||||
delete(subscription.relays, relay)
|
delete(subscription.relays, relay)
|
||||||
ws.WriteJSON([]interface{}{
|
conn.WriteJSON([]interface{}{
|
||||||
"CLOSE",
|
"CLOSE",
|
||||||
subscription.channel,
|
subscription.channel,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (subscription Subscription) addRelay(relay string, ws *websocket.Conn) {
|
func (subscription Subscription) addRelay(relay string, conn *Connection) {
|
||||||
subscription.relays[relay] = ws
|
subscription.relays[relay] = conn
|
||||||
|
|
||||||
message := []interface{}{
|
message := []interface{}{
|
||||||
"REQ",
|
"REQ",
|
||||||
@ -86,5 +82,5 @@ func (subscription Subscription) addRelay(relay string, ws *websocket.Conn) {
|
|||||||
message = append(message, filter)
|
message = append(message, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
ws.WriteJSON(message)
|
conn.WriteJSON(message)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user