diff --git a/go.mod b/go.mod index 10f5790..16c69c5 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,10 @@ go 1.21.0 require ( github.com/fasthttp/websocket v1.5.3 github.com/fiatjaf/eventstore v0.1.0 - github.com/nbd-wtf/go-nostr v0.26.0 + github.com/nbd-wtf/go-nostr v0.27.0 github.com/puzpuzpuz/xsync/v2 v2.5.1 github.com/rs/cors v1.7.0 + github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 ) @@ -44,7 +45,6 @@ require ( github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect - github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect diff --git a/go.sum b/go.sum index c5181a2..90ef5b2 100644 --- a/go.sum +++ b/go.sum @@ -90,8 +90,8 @@ github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJ github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= -github.com/nbd-wtf/go-nostr v0.26.0 h1:Tofbs9i8DD5iEKIhLlWFO7kfWpvmUG16fEyW30MzHVQ= -github.com/nbd-wtf/go-nostr v0.26.0/go.mod h1:bkffJI+x914sPQWum9ZRUn66D7NpDnAoWo1yICvj3/0= +github.com/nbd-wtf/go-nostr v0.27.0 h1:h6JmMMmfNcAORTL2kk/K3+U6Mju6rk/IjcHA/PMeOc8= +github.com/nbd-wtf/go-nostr v0.27.0/go.mod h1:bkffJI+x914sPQWum9ZRUn66D7NpDnAoWo1yICvj3/0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/handlers.go b/handlers.go index 11036a2..5e3bbf9 100644 --- a/handlers.go +++ b/handlers.go @@ -145,9 +145,10 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { if err == nil { ok = true } else { - reason = nostr.NormalizeOKMessage(err.Error(), "blocked") - if isAuthRequired(reason) { + if strings.HasPrefix(reason, "auth-required:") { ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge}) + } else { + reason = nostr.NormalizeOKMessage(err.Error(), "blocked") } } ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: ok, Reason: reason}) @@ -173,9 +174,11 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { err := rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter) if err != nil { // fail everything if any filter is rejected - reason := nostr.NormalizeOKMessage(err.Error(), "blocked") - if isAuthRequired(reason) { + reason := err.Error() + if strings.HasPrefix(reason, "auth-required:") { ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge}) + } else { + reason = nostr.NormalizeOKMessage(reason, "blocked") } ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: reason}) cancelReqCtx(fmt.Errorf("filter rejected")) diff --git a/helpers.go b/helpers.go index 163ae17..5e08da8 100644 --- a/helpers.go +++ b/helpers.go @@ -27,11 +27,6 @@ func isOlder(previous, next *nostr.Event) bool { (previous.CreatedAt == next.CreatedAt && previous.ID > next.ID) } -func isAuthRequired(msg string) bool { - idx := strings.IndexByte(msg, ':') - return msg[0:idx] == "auth-required" -} - func getServiceBaseURL(r *http.Request) string { host := r.Header.Get("X-Forwarded-Host") if host == "" { diff --git a/utils.go b/utils.go index 453511b..b2dbe84 100644 --- a/utils.go +++ b/utils.go @@ -20,10 +20,10 @@ func GetIP(ctx context.Context) string { } func GetOpenSubscriptions(ctx context.Context) []nostr.Filter { - if listeners, ok := listeners.Load(GetConnection(ctx)); ok { + if subs, ok := listeners.Load(GetConnection(ctx)); ok { res := make([]nostr.Filter, 0, listeners.Size()*2) - listeners.Range(func(_ string, listener *Listener) bool { - res = append(res, listener.filters...) + subs.Range(func(_ string, sub *Listener) bool { + res = append(res, sub.filters...) return true }) return res