Compare commits

...

2 Commits

Author SHA1 Message Date
Naiyuan Qing
9c4857695e fix(realtime): use case-insensitive Host comparison for same-origin
HTTP host is case-insensitive (RFC 7230 §2.7.3), and gorilla/websocket's
default checkSameOrigin uses equalASCIIFold(u.Host, r.Host). The plain
== comparison would reject legitimate same-origin requests with a
case-mismatched Host header (e.g. Host: LOCALHOST:8080 vs
Origin: http://localhost:8080).

Switch to strings.EqualFold and cover the case with a regression test.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: multica-agent <github@multica.ai>
2026-05-11 13:37:52 +08:00
Naiyuan Qing
826a951572 fix(realtime): allow same-origin WebSocket clients (mobile/CLI)
The previous CheckOrigin implementation (PR #2318) bypassed the Origin
check whenever the request URL carried `client_platform=mobile` and no
browser session cookie. That contract requires every native client to
remember to add a query parameter — and in practice mobile clients hit
ws://localhost:8080/ws with no extra params, so the Origin filled by
the WebSocket library (the server's own host) gets rejected.

Replace the platform-specific bypass with same-origin acceptance: if
Origin's host equals the request Host, allow the upgrade. This is
gorilla/websocket's default CheckOrigin behavior, restored alongside
the existing cross-origin allowlist (for browser web/desktop clients).

Native clients are now zero-config. CSRF defense is unaffected:
SameSite=Strict cookies, the multica_csrf token, workspace membership
check, and the allowlist itself remain in place. Browser CSWSH attacks
fail both same-origin (browser forces Origin = page origin, not the
server's Host) and allowlist checks.

Refs: https://pkg.go.dev/github.com/gorilla/websocket
      https://cheatsheetseries.owasp.org/cheatsheets/WebSocket_Security_Cheat_Sheet.html

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: multica-agent <github@multica.ai>
2026-05-11 13:32:26 +08:00
2 changed files with 48 additions and 47 deletions

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"log/slog"
"net/http"
"net/url"
"os"
"strings"
"sync"
@@ -81,13 +82,15 @@ func checkOrigin(r *http.Request) bool {
if origin == "" {
return true
}
// Native mobile clients authenticate with an explicit first-frame token.
// Origin is a browser CSRF control, so only skip it for mobile requests
// that are not carrying the browser session cookie.
if r.URL.Query().Get("client_platform") == "mobile" {
if _, err := r.Cookie(auth.AuthCookieName); err == http.ErrNoCookie {
return true
}
// Same-origin: native clients (mobile, CLI) have no real page host, so
// their WebSocket library fills Origin with the connection target —
// which equals the server's own Host. They authenticate via bearer
// token, not auto-attached cookies, so CSRF (the attack the explicit
// allowlist below defends against) does not apply. This matches the
// gorilla/websocket default CheckOrigin behavior; the allowlist exists
// in addition to support cross-origin browser clients (web/desktop).
if u, err := url.Parse(origin); err == nil && strings.EqualFold(u.Host, r.Host) {
return true
}
origins := allowedWSOrigins.Load().([]string)
for _, allowed := range origins {

View File

@@ -81,46 +81,6 @@ func connectWS(t *testing.T, server *httptest.Server) *websocket.Conn {
return conn
}
func TestCheckOrigin_AllowsMobileClientWithoutCookie(t *testing.T) {
prevOrigins := allowedWSOrigins.Load().([]string)
SetAllowedOrigins([]string{"https://app.example.com"})
t.Cleanup(func() { SetAllowedOrigins(prevOrigins) })
req := httptest.NewRequest(http.MethodGet, "/ws?client_platform=mobile", nil)
req.Header.Set("Origin", "https://not-allowed.example.com")
if !checkOrigin(req) {
t.Fatal("expected mobile request without browser auth cookie to bypass Origin whitelist")
}
}
func TestCheckOrigin_RejectsDisallowedOriginWithoutMobileClient(t *testing.T) {
prevOrigins := allowedWSOrigins.Load().([]string)
SetAllowedOrigins([]string{"https://app.example.com"})
t.Cleanup(func() { SetAllowedOrigins(prevOrigins) })
req := httptest.NewRequest(http.MethodGet, "/ws", nil)
req.Header.Set("Origin", "https://not-allowed.example.com")
if checkOrigin(req) {
t.Fatal("expected disallowed Origin without mobile client platform to be rejected")
}
}
func TestCheckOrigin_RejectsMobileClientWithBrowserCookie(t *testing.T) {
prevOrigins := allowedWSOrigins.Load().([]string)
SetAllowedOrigins([]string{"https://app.example.com"})
t.Cleanup(func() { SetAllowedOrigins(prevOrigins) })
req := httptest.NewRequest(http.MethodGet, "/ws?client_platform=mobile", nil)
req.Header.Set("Origin", "https://not-allowed.example.com")
req.AddCookie(&http.Cookie{Name: auth.AuthCookieName, Value: "browser-session"})
if checkOrigin(req) {
t.Fatal("expected disallowed mobile Origin with browser auth cookie to be rejected")
}
}
// totalClients counts all currently registered clients.
func totalClients(hub *Hub) int {
hub.mu.RLock()
@@ -351,3 +311,41 @@ func (l *lockedWriter) Write(p []byte) (int, error) {
defer l.mu.Unlock()
return l.w.Write(p)
}
func TestCheckOrigin(t *testing.T) {
prev := allowedWSOrigins.Load().([]string)
SetAllowedOrigins([]string{
"http://localhost:3000",
"https://multica.ai",
})
t.Cleanup(func() { SetAllowedOrigins(prev) })
cases := []struct {
name string
host string
origin string
want bool
}{
{"empty origin allowed", "api.multica.ai", "", true},
{"same-origin allowed (native client default)", "localhost:8080", "http://localhost:8080", true},
{"same-origin allowed (https)", "api.multica.ai", "https://api.multica.ai", true},
{"same-origin allowed (case-insensitive host, RFC 7230)", "API.Multica.AI", "https://api.multica.ai", true},
{"whitelisted origin allowed (web cross-origin)", "localhost:8080", "http://localhost:3000", true},
{"whitelisted origin allowed (prod web)", "api.multica.ai", "https://multica.ai", true},
{"unknown origin rejected (CSWSH defense)", "api.multica.ai", "https://evil.com", false},
{"different port rejected", "localhost:8080", "http://localhost:9999", false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Host = tc.host
if tc.origin != "" {
r.Header.Set("Origin", tc.origin)
}
if got := checkOrigin(r); got != tc.want {
t.Fatalf("checkOrigin(host=%q, origin=%q) = %v, want %v", tc.host, tc.origin, got, tc.want)
}
})
}
}