diff --git a/go.mod b/go.mod index edfe471..609128d 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/lib/pq v1.10.3 github.com/mattn/go-sqlite3 v1.14.6 github.com/nbd-wtf/go-nostr v0.20.0 + github.com/puzpuzpuz/xsync/v2 v2.5.1 github.com/rs/cors v1.7.0 github.com/stretchr/testify v1.8.2 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 diff --git a/go.sum b/go.sum index 69ab76d..b3d369d 100644 --- a/go.sum +++ b/go.sum @@ -96,6 +96,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/puzpuzpuz/xsync v1.5.2 h1:yRAP4wqSOZG+/4pxJ08fPTwrfL0IzE/LKQ/cw509qGY= github.com/puzpuzpuz/xsync v1.5.2/go.mod h1:K98BYhX3k1dQ2M63t1YNVDanbwUPmBCAhNmVrrxfiGg= +github.com/puzpuzpuz/xsync/v2 v2.5.1 h1:mVGYAvzDSu52+zaGyNjC+24Xw2bQi3kTr4QJ6N9pIIU= +github.com/puzpuzpuz/xsync/v2 v2.5.1/go.mod h1:gD2H2krq/w52MfPLE+Uy64TzJDVY7lP2znR9qmR35kU= github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= diff --git a/handlers.go b/handlers.go index 004ba47..7db1c6b 100644 --- a/handlers.go +++ b/handlers.go @@ -35,9 +35,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { rl.Log.Printf("failed to upgrade websocket: %v\n", err) return } - rl.clientsMu.Lock() - defer rl.clientsMu.Unlock() - rl.clients[conn] = struct{}{} + rl.clients.Store(conn, struct{}{}) ticker := time.NewTicker(rl.PingPeriod) // NIP-42 challenge @@ -54,13 +52,11 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { go func() { defer func() { ticker.Stop() - rl.clientsMu.Lock() - if _, ok := rl.clients[conn]; ok { + if _, ok := rl.clients.Load(conn); ok { conn.Close() - delete(rl.clients, conn) + rl.clients.Delete(conn) removeListener(ws) } - rl.clientsMu.Unlock() }() conn.SetReadLimit(rl.MaxMessageSize) diff --git a/listener.go b/listener.go index ca4bd0c..95acc51 100644 --- a/listener.go +++ b/listener.go @@ -1,29 +1,22 @@ package khatru import ( - "sync" - "github.com/nbd-wtf/go-nostr" + "github.com/puzpuzpuz/xsync/v2" ) type Listener struct { filters nostr.Filters } -var ( - listeners = make(map[*WebSocket]map[string]*Listener) - listenersMutex = sync.Mutex{} -) +var listeners = xsync.NewTypedMapOf[*WebSocket, map[string]*Listener](pointerHasher[WebSocket]) func GetListeningFilters() nostr.Filters { - respfilters := make(nostr.Filters, 0, len(listeners)*2) - - listenersMutex.Lock() - defer listenersMutex.Unlock() + respfilters := make(nostr.Filters, 0, listeners.Size()*2) // here we go through all the existing listeners - for _, connlisteners := range listeners { - for _, listener := range connlisteners { + listeners.Range(func(_ *WebSocket, subs map[string]*Listener) bool { + for _, listener := range subs { for _, listenerfilter := range listener.filters { for _, respfilter := range respfilters { // check if this filter specifically is already added to respfilters @@ -40,55 +33,42 @@ func GetListeningFilters() nostr.Filters { continue } } - } + + return true + }) // respfilters will be a slice with all the distinct filter we currently have active return respfilters } func setListener(id string, ws *WebSocket, filters nostr.Filters) { - listenersMutex.Lock() - defer listenersMutex.Unlock() - - subs, ok := listeners[ws] - if !ok { - subs = make(map[string]*Listener) - listeners[ws] = subs - } - + subs, _ := listeners.LoadOrCompute(ws, func() map[string]*Listener { return make(map[string]*Listener) }) subs[id] = &Listener{filters: filters} } // Remove a specific subscription id from listeners for a given ws client func removeListenerId(ws *WebSocket, id string) { - listenersMutex.Lock() - defer listenersMutex.Unlock() - - if subs, ok := listeners[ws]; ok { - delete(listeners[ws], id) + if subs, ok := listeners.Load(ws); ok { + delete(subs, id) if len(subs) == 0 { - delete(listeners, ws) + listeners.Delete(ws) } } } // Remove WebSocket conn from listeners func removeListener(ws *WebSocket) { - listenersMutex.Lock() - defer listenersMutex.Unlock() - delete(listeners, ws) + listeners.Delete(ws) } func notifyListeners(event *nostr.Event) { - listenersMutex.Lock() - defer listenersMutex.Unlock() - - for ws, subs := range listeners { + listeners.Range(func(ws *WebSocket, subs map[string]*Listener) bool { for id, listener := range subs { if !listener.filters.Match(event) { continue } ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &id, Event: *event}) } - } + return true + }) } diff --git a/relay.go b/relay.go index 7326f00..dbf39c0 100644 --- a/relay.go +++ b/relay.go @@ -5,12 +5,12 @@ import ( "log" "net/http" "os" - "sync" "time" "github.com/fasthttp/websocket" "github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr/nip11" + "github.com/puzpuzpuz/xsync/v2" ) func NewRelay() *Relay { @@ -23,7 +23,7 @@ func NewRelay() *Relay { CheckOrigin: func(r *http.Request) bool { return true }, }, - clients: make(map[*websocket.Conn]struct{}), + clients: xsync.NewTypedMapOf[*websocket.Conn, struct{}](pointerHasher[websocket.Conn]), serveMux: &http.ServeMux{}, WriteWait: 10 * time.Second, @@ -61,8 +61,7 @@ type Relay struct { upgrader websocket.Upgrader // keep a connection reference to all connected clients for Server.Shutdown - clientsMu sync.Mutex - clients map[*websocket.Conn]struct{} + clients *xsync.MapOf[*websocket.Conn, struct{}] // in case you call Server.Start Addr string diff --git a/start.go b/start.go index 923aa1a..29b243f 100644 --- a/start.go +++ b/start.go @@ -50,11 +50,10 @@ func (rl *Relay) Start(host string, port int, started ...chan bool) error { func (rl *Relay) Shutdown(ctx context.Context) { rl.httpServer.Shutdown(ctx) - rl.clientsMu.Lock() - defer rl.clientsMu.Unlock() - for conn := range rl.clients { + rl.clients.Range(func(conn *websocket.Conn, _ struct{}) bool { conn.WriteControl(websocket.CloseMessage, nil, time.Now().Add(time.Second)) conn.Close() - delete(rl.clients, conn) - } + rl.clients.Delete(conn) + return true + }) } diff --git a/utils.go b/utils.go index ee616e0..3e693a6 100644 --- a/utils.go +++ b/utils.go @@ -2,7 +2,9 @@ package khatru import ( "context" + "hash/maphash" "regexp" + "unsafe" ) const ( @@ -23,3 +25,5 @@ func GetAuthed(ctx context.Context) string { } return authedPubkey.(string) } + +func pointerHasher[V any](_ maphash.Seed, k *V) uint64 { return uint64(uintptr(unsafe.Pointer(k))) }