Compare commits

...

3 Commits

Author SHA1 Message Date
yushen
26267d33ac fix(test): update WebSocket integration test for first-message auth
The integration test still passed the token as a URL query param,
causing a timeout since the server now expects first-message auth
for non-cookie clients.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-13 17:09:24 +08:00
yushen
66f78a900d fix(security): add auth_ack and fix test JSON construction
Server sends auth_ack after successful first-message auth so the client
knows auth completed before firing reconnect callbacks. Test now uses
json.Marshal instead of string concatenation for the auth message.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-13 17:06:37 +08:00
Jiang Bohan
9af66746a9 fix(security): use first-message auth for WebSocket instead of URL query param
Token was exposed in URL query parameters (HIGH-4 from security audit),
visible in server/proxy logs, browser history, and referrer headers.

Now non-cookie clients (desktop, CLI) send the token as the first
WebSocket message after the connection opens. Cookie-based auth (web)
continues to work unchanged. Server-side auth priority flipped to
cookie-first.

Closes MUL-580
2026-04-13 16:57:00 +08:00
4 changed files with 166 additions and 69 deletions

View File

@@ -29,31 +29,32 @@ export class WSClient {
connect() {
const url = new URL(this.baseUrl);
// In cookie mode, the browser sends the HttpOnly cookie automatically
// with the WebSocket upgrade request — no token in URL needed.
if (!this.cookieAuth && this.token)
url.searchParams.set("token", this.token);
// Token is never sent as a URL query parameter — it would be logged by
// proxies, CDNs, and browser history. In cookie mode the HttpOnly cookie
// is sent automatically with the upgrade request. In token mode the token
// is delivered as the first WebSocket message after the connection opens.
if (this.workspaceId)
url.searchParams.set("workspace_id", this.workspaceId);
this.ws = new WebSocket(url.toString());
this.ws.onopen = () => {
this.logger.info("connected");
if (this.hasConnectedBefore) {
for (const cb of this.onReconnectCallbacks) {
try {
cb();
} catch {
// ignore reconnect callback errors
}
}
if (!this.cookieAuth && this.token) {
this.ws!.send(
JSON.stringify({ type: "auth", payload: { token: this.token } }),
);
return;
}
this.hasConnectedBefore = true;
this.onAuthenticated();
};
this.ws.onmessage = (event) => {
const msg = JSON.parse(event.data as string) as WSMessage;
if ((msg as any).type === "auth_ack") {
this.onAuthenticated();
return;
}
this.logger.debug("received", msg.type);
const eventHandlers = this.handlers.get(msg.type);
if (eventHandlers) {
@@ -77,6 +78,20 @@ export class WSClient {
};
}
private onAuthenticated() {
this.logger.info("connected");
if (this.hasConnectedBefore) {
for (const cb of this.onReconnectCallbacks) {
try {
cb();
} catch {
// ignore reconnect callback errors
}
}
}
this.hasConnectedBefore = true;
}
disconnect() {
if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer);

View File

@@ -747,14 +747,34 @@ func TestInvalidRequestBodies(t *testing.T) {
// ---- WebSocket integration through full router ----
func TestWebSocketIntegration(t *testing.T) {
// Connect WebSocket client
wsURL := "ws" + strings.TrimPrefix(testServer.URL, "http") + "/ws?token=" + testToken + "&workspace_id=" + testWorkspaceID
// Connect WebSocket client (no token in URL — first-message auth)
wsURL := "ws" + strings.TrimPrefix(testServer.URL, "http") + "/ws?workspace_id=" + testWorkspaceID
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("WebSocket connection failed: %v", err)
}
defer conn.Close()
// First-message auth
authMsg, _ := json.Marshal(map[string]any{
"type": "auth",
"payload": map[string]string{"token": testToken},
})
if err := conn.WriteMessage(websocket.TextMessage, authMsg); err != nil {
t.Fatalf("failed to send auth message: %v", err)
}
// Read auth_ack
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
_, ack, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read auth_ack: %v", err)
}
if !strings.Contains(string(ack), "auth_ack") {
t.Fatalf("expected auth_ack, got %s", ack)
}
conn.SetReadDeadline(time.Time{})
// Allow Hub goroutine to process the register and add client to room
time.Sleep(100 * time.Millisecond)

View File

@@ -2,12 +2,14 @@ package realtime
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/websocket"
@@ -272,7 +274,67 @@ func (h *Hub) Broadcast(message []byte) {
h.broadcast <- message
}
// HandleWebSocket upgrades an HTTP connection to WebSocket with JWT, PAT, or cookie auth.
// authenticateToken validates a JWT or PAT string and returns the user ID.
func authenticateToken(tokenStr string, pr PATResolver, ctx context.Context) (string, string) {
if strings.HasPrefix(tokenStr, "mul_") {
if pr == nil {
return "", `{"error":"invalid token"}`
}
uid, ok := pr.ResolveToken(ctx, tokenStr)
if !ok {
return "", `{"error":"invalid token"}`
}
return uid, ""
}
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return auth.JWTSecret(), nil
})
if err != nil || !token.Valid {
return "", `{"error":"invalid token"}`
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", `{"error":"invalid claims"}`
}
uid, ok := claims["sub"].(string)
if !ok || strings.TrimSpace(uid) == "" {
return "", `{"error":"invalid claims"}`
}
return uid, ""
}
// firstMessageAuth reads the first WebSocket message expecting an auth payload.
// Message format: {"type":"auth","payload":{"token":"..."}}
// Returns the token string or an error description.
func firstMessageAuth(conn *websocket.Conn) (string, string) {
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
defer conn.SetReadDeadline(time.Time{}) // clear deadline for subsequent reads
_, raw, err := conn.ReadMessage()
if err != nil {
return "", `{"error":"auth timeout or read error"}`
}
var msg struct {
Type string `json:"type"`
Payload struct {
Token string `json:"token"`
} `json:"payload"`
}
if err := json.Unmarshal(raw, &msg); err != nil || msg.Type != "auth" || msg.Payload.Token == "" {
return "", `{"error":"expected auth message as first frame"}`
}
return msg.Payload.Token, ""
}
// HandleWebSocket upgrades an HTTP connection to WebSocket with cookie or first-message auth.
func HandleWebSocket(hub *Hub, mc MembershipChecker, pr PATResolver, w http.ResponseWriter, r *http.Request) {
workspaceID := r.URL.Query().Get("workspace_id")
if workspaceID == "" {
@@ -280,71 +342,53 @@ func HandleWebSocket(hub *Hub, mc MembershipChecker, pr PATResolver, w http.Resp
return
}
// Resolve token: query param first, then cookie fallback.
tokenStr := r.URL.Query().Get("token")
if tokenStr == "" {
if cookie, err := r.Cookie(auth.AuthCookieName); err == nil && cookie.Value != "" {
tokenStr = cookie.Value
}
}
if tokenStr == "" {
http.Error(w, `{"error":"authentication required"}`, http.StatusUnauthorized)
return
}
// Try cookie auth first (web clients).
var userID string
if strings.HasPrefix(tokenStr, "mul_") {
// PAT authentication
if pr == nil {
http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized)
if cookie, err := r.Cookie(auth.AuthCookieName); err == nil && cookie.Value != "" {
uid, errMsg := authenticateToken(cookie.Value, pr, r.Context())
if errMsg != "" {
http.Error(w, errMsg, http.StatusUnauthorized)
return
}
uid, ok := pr.ResolveToken(r.Context(), tokenStr)
if !ok {
http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized)
return
}
userID = uid
} else {
// JWT authentication
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return auth.JWTSecret(), nil
})
if err != nil || !token.Valid {
http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, `{"error":"invalid claims"}`, http.StatusUnauthorized)
return
}
uid, ok := claims["sub"].(string)
if !ok || strings.TrimSpace(uid) == "" {
http.Error(w, `{"error":"invalid claims"}`, http.StatusUnauthorized)
if !mc.IsMember(r.Context(), uid, workspaceID) {
http.Error(w, `{"error":"not a member of this workspace"}`, http.StatusForbidden)
return
}
userID = uid
}
// Verify user is a member of the workspace
if !mc.IsMember(r.Context(), userID, workspaceID) {
http.Error(w, `{"error":"not a member of this workspace"}`, http.StatusForbidden)
return
}
// Upgrade the connection. Clients without cookies (desktop) will authenticate
// via the first WebSocket message, so we must upgrade before we have a token.
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
return
}
// First-message auth for non-cookie clients (desktop, CLI).
if userID == "" {
tokenStr, errMsg := firstMessageAuth(conn)
if errMsg != "" {
conn.WriteMessage(websocket.TextMessage, []byte(errMsg))
conn.Close()
return
}
uid, errMsg := authenticateToken(tokenStr, pr, r.Context())
if errMsg != "" {
conn.WriteMessage(websocket.TextMessage, []byte(errMsg))
conn.Close()
return
}
if !mc.IsMember(r.Context(), uid, workspaceID) {
conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"not a member of this workspace"}`))
conn.Close()
return
}
userID = uid
conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"auth_ack"}`))
}
client := &Client{
hub: hub,
conn: conn,

View File

@@ -2,6 +2,7 @@ package realtime
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
@@ -52,11 +53,28 @@ func newTestHub(t *testing.T) (*Hub, *httptest.Server) {
func connectWS(t *testing.T, server *httptest.Server) *websocket.Conn {
t.Helper()
token := makeTestToken(t)
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?token=" + token + "&workspace_id=" + testWorkspaceID
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?workspace_id=" + testWorkspaceID
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to connect WebSocket: %v", err)
}
authMsg, _ := json.Marshal(map[string]any{
"type": "auth",
"payload": map[string]string{"token": token},
})
if err := conn.WriteMessage(websocket.TextMessage, authMsg); err != nil {
t.Fatalf("failed to send auth message: %v", err)
}
// Read auth_ack before returning the connection.
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, ack, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read auth_ack: %v", err)
}
if !strings.Contains(string(ack), "auth_ack") {
t.Fatalf("expected auth_ack, got %s", ack)
}
conn.SetReadDeadline(time.Time{})
return conn
}