Compare commits

...

1 Commits

Author SHA1 Message Date
Jiang Bohan
8f9cea9226 fix(security): add ws: scheme and dynamic origins to CSP connect-src
The CSP connect-src directive only allowed 'self' and wss:, which
blocks WebSocket connections over ws:// in non-HTTPS environments
(e.g. dev deployments). Also, cross-origin API/WS endpoints were not
covered when frontend and backend are on different origins.

Changes:
- Add ws: alongside wss: in connect-src
- Dynamically inject ALLOWED_ORIGINS into connect-src so cross-origin
  connections are permitted by the policy
- Export BuildCSP / InitCSP for testability and router integration

Closes MUL-667
2026-04-13 14:18:42 +08:00
3 changed files with 79 additions and 11 deletions

View File

@@ -81,8 +81,9 @@ func NewRouter(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus) chi.Route
r.Use(middleware.ContentSecurityPolicy)
origins := allowedOrigins()
// Share allowed origins with WebSocket origin checker.
// Share allowed origins with WebSocket origin checker and CSP.
realtime.SetAllowedOrigins(origins)
middleware.InitCSP(origins)
r.Use(cors.Handler(cors.Options{
AllowedOrigins: origins,

View File

@@ -1,20 +1,66 @@
package middleware
import "net/http"
import (
"net/http"
"strings"
"sync"
)
const cspHeader = "default-src 'self'; " +
"script-src 'self'; " +
"style-src 'self' 'unsafe-inline'; " +
"img-src 'self' https: data:; " +
"connect-src 'self' wss:; " +
"frame-ancestors 'none'; " +
"object-src 'none'; " +
"base-uri 'self'; " +
"form-action 'self'"
var (
cspOnce sync.Once
cspHeader string
)
// BuildCSP constructs the Content-Security-Policy header value.
// allowedOrigins are included in connect-src so that cross-origin API calls
// and WebSocket connections are permitted by the policy.
func BuildCSP(allowedOrigins []string) string {
// Deduplicate and collect origins for connect-src.
seen := make(map[string]bool)
var extra []string
for _, origin := range allowedOrigins {
origin = strings.TrimSpace(origin)
if origin == "" || seen[origin] {
continue
}
seen[origin] = true
extra = append(extra, origin)
}
connectSrc := "'self' ws: wss:"
if len(extra) > 0 {
connectSrc += " " + strings.Join(extra, " ")
}
return "default-src 'self'; " +
"script-src 'self'; " +
"style-src 'self' 'unsafe-inline'; " +
"img-src 'self' https: data:; " +
"connect-src " + connectSrc + "; " +
"frame-ancestors 'none'; " +
"object-src 'none'; " +
"base-uri 'self'; " +
"form-action 'self'"
}
// ContentSecurityPolicy returns middleware that sets the CSP header.
// Call InitCSP before mounting this middleware to include allowed origins.
func ContentSecurityPolicy(next http.Handler) http.Handler {
// Ensure a default header exists even if InitCSP was never called.
cspOnce.Do(func() {
if cspHeader == "" {
cspHeader = BuildCSP(nil)
}
})
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Security-Policy", cspHeader)
next.ServeHTTP(w, r)
})
}
// InitCSP pre-computes the CSP header with the given allowed origins.
// Must be called before the first request is served.
func InitCSP(allowedOrigins []string) {
cspHeader = BuildCSP(allowedOrigins)
}

View File

@@ -8,6 +8,9 @@ import (
)
func TestContentSecurityPolicy(t *testing.T) {
// Reset global state for test isolation.
InitCSP(nil)
handler := ContentSecurityPolicy(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
@@ -23,6 +26,7 @@ func TestContentSecurityPolicy(t *testing.T) {
required := []string{
"script-src 'self'",
"connect-src 'self' ws: wss:",
"object-src 'none'",
"frame-ancestors 'none'",
"base-uri 'self'",
@@ -34,3 +38,20 @@ func TestContentSecurityPolicy(t *testing.T) {
}
}
}
func TestBuildCSP_WithOrigins(t *testing.T) {
csp := BuildCSP([]string{"https://app.example.com", "https://dev.example.com"})
if !strings.Contains(csp, "connect-src 'self' ws: wss: https://app.example.com https://dev.example.com") {
t.Errorf("CSP connect-src should include allowed origins; got: %s", csp)
}
}
func TestBuildCSP_DeduplicatesOrigins(t *testing.T) {
csp := BuildCSP([]string{"https://app.example.com", "https://app.example.com", ""})
count := strings.Count(csp, "https://app.example.com")
if count != 1 {
t.Errorf("expected origin to appear once, appeared %d times; got: %s", count, csp)
}
}