diff --git a/connection.go b/connection.go new file mode 100644 index 0000000..50c81fd --- /dev/null +++ b/connection.go @@ -0,0 +1,33 @@ +package nostr + +import ( + "github.com/gorilla/websocket" + "sync" +) + +type Connection struct { + socket *websocket.Conn + mutex sync.Mutex +} + +func NewConnection(socket *websocket.Conn) *Connection { + return &Connection{ + socket: socket, + } +} + +func (c *Connection) WriteJSON(v interface{}) error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.socket.WriteJSON(v) +} + +func (c *Connection) WriteMessage(messageType int, data []byte) error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.socket.WriteMessage(messageType, data) +} + +func (c *Connection) Close() error { + return c.socket.Close() +} diff --git a/relaypool.go b/relaypool.go index deaa330..9a6d4fe 100644 --- a/relaypool.go +++ b/relaypool.go @@ -28,7 +28,7 @@ type RelayPool struct { SecretKey *string Relays map[string]RelayPoolPolicy - websockets map[string]*websocket.Conn + websockets map[string]*Connection subscriptions map[string]*Subscription Notices chan *NoticeMessage @@ -61,7 +61,7 @@ type NoticeMessage struct { func NewRelayPool() *RelayPool { return &RelayPool{ Relays: make(map[string]RelayPoolPolicy), - websockets: make(map[string]*websocket.Conn), + websockets: make(map[string]*Connection), subscriptions: make(map[string]*Subscription), Notices: make(chan *NoticeMessage), @@ -80,11 +80,13 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error { return fmt.Errorf("invalid relay URL '%s'", url) } - conn, _, err := websocket.DefaultDialer.Dial(NormalizeURL(url), nil) + socket, _, err := websocket.DefaultDialer.Dial(NormalizeURL(url), nil) if err != nil { return fmt.Errorf("error opening websocket to '%s': %w", nm, err) } + conn := NewConnection(socket) + r.Relays[nm] = policy r.websockets[nm] = conn @@ -94,7 +96,7 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error { go func() { for { - typ, message, err := conn.ReadMessage() + typ, message, err := conn.socket.ReadMessage() if err != nil { log.Println("read error: ", err) return @@ -183,7 +185,7 @@ func (r *RelayPool) Sub(filters EventFilters) *Subscription { subscription := Subscription{filters: filters} subscription.channel = hex.EncodeToString(random) - subscription.relays = make(map[string]*websocket.Conn) + subscription.relays = make(map[string]*Connection) for relay, policy := range r.Relays { if policy.ShouldRead(filters) { ws := r.websockets[relay] @@ -225,7 +227,7 @@ func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error) continue } - go func(relay string, conn *websocket.Conn) { + go func(relay string, conn *Connection) { err := conn.WriteJSON([]interface{}{"EVENT", evt}) if err != nil { log.Printf("error sending event to '%s': %s", relay, err.Error()) diff --git a/subscription.go b/subscription.go index 22b8cc5..1dec46d 100644 --- a/subscription.go +++ b/subscription.go @@ -1,12 +1,8 @@ package nostr -import ( - "github.com/gorilla/websocket" -) - type Subscription struct { channel string - relays map[string]*websocket.Conn + relays map[string]*Connection filters EventFilters Events chan EventMessage @@ -21,8 +17,8 @@ type EventMessage struct { } func (subscription Subscription) Unsub() { - for _, ws := range subscription.relays { - ws.WriteJSON([]interface{}{ + for _, conn := range subscription.relays { + conn.WriteJSON([]interface{}{ "CLOSE", subscription.channel, }) @@ -37,7 +33,7 @@ func (subscription Subscription) Unsub() { } func (subscription Subscription) Sub() { - for _, ws := range subscription.relays { + for _, conn := range subscription.relays { message := []interface{}{ "REQ", subscription.channel, @@ -46,7 +42,7 @@ func (subscription Subscription) Sub() { message = append(message, filter) } - ws.WriteJSON(message) + conn.WriteJSON(message) } if !subscription.started { @@ -66,17 +62,17 @@ func (subscription Subscription) startHandlingUnique() { } func (subscription Subscription) removeRelay(relay string) { - if ws, ok := subscription.relays[relay]; ok { + if conn, ok := subscription.relays[relay]; ok { delete(subscription.relays, relay) - ws.WriteJSON([]interface{}{ + conn.WriteJSON([]interface{}{ "CLOSE", subscription.channel, }) } } -func (subscription Subscription) addRelay(relay string, ws *websocket.Conn) { - subscription.relays[relay] = ws +func (subscription Subscription) addRelay(relay string, conn *Connection) { + subscription.relays[relay] = conn message := []interface{}{ "REQ", @@ -86,5 +82,5 @@ func (subscription Subscription) addRelay(relay string, ws *websocket.Conn) { message = append(message, filter) } - ws.WriteJSON(message) + conn.WriteJSON(message) }