diff --git a/handlers.go b/handlers.go index 0f9f69b..9b440ee 100644 --- a/handlers.go +++ b/handlers.go @@ -19,6 +19,10 @@ import ( // ServeHTTP implements http.Handler interface. func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if rl.ServiceURL == "" { + rl.ServiceURL = getServiceBaseURL(r) + } + if r.Header.Get("Upgrade") == "websocket" { rl.HandleWebsocket(w, r) } else if r.Header.Get("Accept") == "application/nostr+json" { @@ -29,7 +33,7 @@ func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() + connectionContext := r.Context() conn, err := rl.upgrader.Upgrade(w, r, nil) if err != nil { @@ -50,7 +54,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { Authed: make(chan struct{}), } - ctx = context.WithValue(ctx, WS_KEY, ws) + connectionContext = context.WithValue(connectionContext, WS_KEY, ws) // reader go func() { @@ -71,7 +75,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { }) for _, onconnect := range rl.OnConnect { - onconnect(ctx) + onconnect(connectionContext) } for { @@ -95,7 +99,13 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { } go func(message []byte) { - ctx := context.WithValue(context.Background(), WS_KEY, ws) + ctx := context.WithValue( + context.WithValue( + context.Background(), + AUTH_CONTEXT_KEY, connectionContext.Value(AUTH_CONTEXT_KEY), + ), + WS_KEY, ws, + ) envelope := nostr.ParseMessage(message) if envelope == nil { @@ -174,15 +184,14 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { case *nostr.CloseEnvelope: removeListenerId(ws, string(*env)) case *nostr.AuthEnvelope: - if rl.ServiceURL != "" { - if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, rl.ServiceURL); ok { - ws.AuthedPublicKey = pubkey - close(ws.Authed) - ctx = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey) - ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true}) - } else { - ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"}) - } + wsBaseUrl := strings.Replace(rl.ServiceURL, "http", "ws", 1) + if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok { + ws.AuthedPublicKey = pubkey + close(ws.Authed) + connectionContext = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey) + ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true}) + } else { + ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"}) } } }(message) diff --git a/helpers.go b/helpers.go index dd2d858..163ae17 100644 --- a/helpers.go +++ b/helpers.go @@ -2,7 +2,9 @@ package khatru import ( "hash/maphash" + "net/http" "regexp" + "strconv" "strings" "unsafe" @@ -29,3 +31,25 @@ func isAuthRequired(msg string) bool { idx := strings.IndexByte(msg, ':') return msg[0:idx] == "auth-required" } + +func getServiceBaseURL(r *http.Request) string { + host := r.Header.Get("X-Forwarded-Host") + if host == "" { + host = r.Host + } + proto := r.Header.Get("X-Forwarded-Proto") + if proto == "" { + if host == "localhost" { + proto = "http" + } else if strings.Index(host, ":") != -1 { + // has a port number + proto = "http" + } else if _, err := strconv.Atoi(strings.ReplaceAll(host, ".", "")); err == nil { + // it's a naked IP + proto = "http" + } else { + proto = "https" + } + } + return proto + "://" + host +} diff --git a/relay.go b/relay.go index bfe840c..6f928b0 100644 --- a/relay.go +++ b/relay.go @@ -40,7 +40,7 @@ func NewRelay() *Relay { } type Relay struct { - ServiceURL string // required for nip-42 + ServiceURL string RejectEvent []func(ctx context.Context, event *nostr.Event) (reject bool, msg string) RejectFilter []func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)