Files
multica/server/internal/realtime/hub_test.go

417 lines
12 KiB
Go

package realtime
import (
"bytes"
"context"
"encoding/json"
"errors"
"log/slog"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"sync"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/websocket"
"github.com/multica-ai/multica/server/internal/auth"
)
const testWorkspaceID = "test-workspace"
const testUserID = "test-user"
// mockMembershipChecker always returns true.
type mockMembershipChecker struct{}
func (m *mockMembershipChecker) IsMember(_ context.Context, _, _ string) bool {
return true
}
func makeTestToken(t *testing.T) string {
t.Helper()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": testUserID,
})
signed, err := token.SignedString(auth.JWTSecret())
if err != nil {
t.Fatalf("failed to sign test JWT: %v", err)
}
return signed
}
func newTestHub(t *testing.T) (*Hub, *httptest.Server) {
t.Helper()
hub := NewHub()
go hub.Run()
mc := &mockMembershipChecker{}
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
HandleWebSocket(hub, mc, nil, nil, w, r)
})
server := httptest.NewServer(mux)
return hub, server
}
func connectWS(t *testing.T, server *httptest.Server) *websocket.Conn {
t.Helper()
token := makeTestToken(t)
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?workspace_id=" + testWorkspaceID
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to connect WebSocket: %v", err)
}
authMsg, _ := json.Marshal(map[string]any{
"type": "auth",
"payload": map[string]string{"token": token},
})
if err := conn.WriteMessage(websocket.TextMessage, authMsg); err != nil {
t.Fatalf("failed to send auth message: %v", err)
}
// Read auth_ack before returning the connection.
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, ack, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read auth_ack: %v", err)
}
if !strings.Contains(string(ack), "auth_ack") {
t.Fatalf("expected auth_ack, got %s", ack)
}
conn.SetReadDeadline(time.Time{})
return conn
}
// totalClients counts all currently registered clients.
func totalClients(hub *Hub) int {
hub.mu.RLock()
defer hub.mu.RUnlock()
return len(hub.clients)
}
func TestHub_ClientRegistration(t *testing.T) {
hub, server := newTestHub(t)
defer server.Close()
conn := connectWS(t, server)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
count := totalClients(hub)
if count != 1 {
t.Fatalf("expected 1 client, got %d", count)
}
}
func TestHub_Broadcast(t *testing.T) {
hub, server := newTestHub(t)
defer server.Close()
conn1 := connectWS(t, server)
defer conn1.Close()
conn2 := connectWS(t, server)
defer conn2.Close()
time.Sleep(50 * time.Millisecond)
msg := []byte(`{"type":"issue:created","data":"test"}`)
hub.Broadcast(msg)
conn1.SetReadDeadline(time.Now().Add(2 * time.Second))
_, received1, err := conn1.ReadMessage()
if err != nil {
t.Fatalf("client 1 read error: %v", err)
}
if string(received1) != string(msg) {
t.Fatalf("client 1: expected %s, got %s", msg, received1)
}
conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
_, received2, err := conn2.ReadMessage()
if err != nil {
t.Fatalf("client 2 read error: %v", err)
}
if string(received2) != string(msg) {
t.Fatalf("client 2: expected %s, got %s", msg, received2)
}
}
func TestHub_ClientDisconnect(t *testing.T) {
hub, server := newTestHub(t)
defer server.Close()
conn := connectWS(t, server)
time.Sleep(50 * time.Millisecond)
countBefore := totalClients(hub)
if countBefore != 1 {
t.Fatalf("expected 1 client before disconnect, got %d", countBefore)
}
conn.Close()
time.Sleep(100 * time.Millisecond)
countAfter := totalClients(hub)
if countAfter != 0 {
t.Fatalf("expected 0 clients after disconnect, got %d", countAfter)
}
}
func TestHub_BroadcastToMultipleClients(t *testing.T) {
hub, server := newTestHub(t)
defer server.Close()
const numClients = 5
conns := make([]*websocket.Conn, numClients)
for i := 0; i < numClients; i++ {
conns[i] = connectWS(t, server)
defer conns[i].Close()
}
time.Sleep(50 * time.Millisecond)
count := totalClients(hub)
if count != numClients {
t.Fatalf("expected %d clients, got %d", numClients, count)
}
msg := []byte(`{"type":"test","count":5}`)
hub.Broadcast(msg)
for i, conn := range conns {
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, received, err := conn.ReadMessage()
if err != nil {
t.Fatalf("client %d read error: %v", i, err)
}
if string(received) != string(msg) {
t.Fatalf("client %d: expected %s, got %s", i, msg, received)
}
}
}
func TestHub_MultipleBroadcasts(t *testing.T) {
hub, server := newTestHub(t)
defer server.Close()
conn := connectWS(t, server)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
messages := []string{
`{"type":"issue:created"}`,
`{"type":"issue:updated"}`,
`{"type":"issue:deleted"}`,
}
for _, msg := range messages {
hub.Broadcast([]byte(msg))
}
for i, expected := range messages {
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, received, err := conn.ReadMessage()
if err != nil {
t.Fatalf("message %d read error: %v", i, err)
}
if string(received) != expected {
t.Fatalf("message %d: expected %s, got %s", i, expected, received)
}
}
}
// TestHandleWebSocket_ClientIdentityFromQuery verifies that client_platform,
// client_version, and client_os query params on the WS upgrade URL are read
// by the handler and surfaced to the access log. Browsers cannot set custom
// headers on WS upgrades, so this query-param channel is the only way to
// preserve the same observability dimensions HTTP clients get via X-Client-*.
func TestHandleWebSocket_ClientIdentityFromQuery(t *testing.T) {
var buf bytes.Buffer
var mu sync.Mutex
handler := slog.NewJSONHandler(&lockedWriter{w: &buf, mu: &mu}, &slog.HandlerOptions{Level: slog.LevelDebug})
prevDefault := slog.Default()
slog.SetDefault(slog.New(handler))
t.Cleanup(func() { slog.SetDefault(prevDefault) })
_, server := newTestHub(t)
defer server.Close()
token := makeTestToken(t)
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") +
"/ws?workspace_id=" + testWorkspaceID +
"&client_platform=desktop&client_version=1.2.3&client_os=macos"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
authMsg, _ := json.Marshal(map[string]any{
"type": "auth",
"payload": map[string]string{"token": token},
})
if err := conn.WriteMessage(websocket.TextMessage, authMsg); err != nil {
t.Fatalf("write auth: %v", err)
}
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if _, _, err := conn.ReadMessage(); err != nil {
t.Fatalf("read auth_ack: %v", err)
}
// Wait briefly for the "websocket connected" log line to be flushed.
deadline := time.Now().Add(2 * time.Second)
var found map[string]any
for time.Now().Before(deadline) {
mu.Lock()
raw := buf.String()
mu.Unlock()
for _, line := range strings.Split(raw, "\n") {
if line == "" {
continue
}
var entry map[string]any
if err := json.Unmarshal([]byte(line), &entry); err != nil {
continue
}
if msg, _ := entry["msg"].(string); msg == "websocket connected" {
found = entry
break
}
}
if found != nil {
break
}
time.Sleep(20 * time.Millisecond)
}
if found == nil {
t.Fatalf("did not observe \"websocket connected\" log entry; buffered logs:\n%s", buf.String())
}
if got, _ := found["client_platform"].(string); got != "desktop" {
t.Errorf("client_platform = %q, want %q", got, "desktop")
}
if got, _ := found["client_version"].(string); got != "1.2.3" {
t.Errorf("client_version = %q, want %q", got, "1.2.3")
}
if got, _ := found["client_os"].(string); got != "macos" {
t.Errorf("client_os = %q, want %q", got, "macos")
}
}
// lockedWriter is a thread-safe writer used to capture concurrent slog output.
type lockedWriter struct {
w *bytes.Buffer
mu *sync.Mutex
}
func (l *lockedWriter) Write(p []byte) (int, error) {
l.mu.Lock()
defer l.mu.Unlock()
return l.w.Write(p)
}
type failingWSWriter struct {
err error
}
func (f failingWSWriter) WriteMessage(int, []byte) error {
return f.err
}
func TestWriteWSAuthFrameLogsWriteErrors(t *testing.T) {
var buf bytes.Buffer
prevDefault := slog.Default()
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn})))
t.Cleanup(func() { slog.SetDefault(prevDefault) })
ok := writeWSAuthFrame(
failingWSWriter{err: errors.New("write blocked")},
[]byte(`{"error":"invalid token"}`),
"auth_error",
"workspace_id", testWorkspaceID,
)
if ok {
t.Fatal("expected writeWSAuthFrame to report failed write")
}
logs := buf.String()
if !strings.Contains(logs, "ws: failed to send auth frame") {
t.Fatalf("expected auth frame write failure log, got:\n%s", logs)
}
if !strings.Contains(logs, "write blocked") {
t.Fatalf("expected write error in log, got:\n%s", logs)
}
if !strings.Contains(logs, "auth_error") {
t.Fatalf("expected frame kind in log, got:\n%s", logs)
}
if !strings.Contains(logs, testWorkspaceID) {
t.Fatalf("expected workspace id in log, got:\n%s", logs)
}
}
func TestCheckOrigin(t *testing.T) {
prev := allowedWSOrigins.Load().([]string)
SetAllowedOrigins([]string{
"http://localhost:3000",
"https://multica.ai",
})
t.Cleanup(func() { SetAllowedOrigins(prev) })
prevProxies := trustedProxies.Load().([]netip.Prefix)
SetTrustedProxies([]netip.Prefix{
netip.MustParsePrefix("127.0.0.1/32"),
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("::1/128"),
})
t.Cleanup(func() { SetTrustedProxies(prevProxies) })
cases := []struct {
name string
host string
origin string
fwdHost string
remoteAddr string
want bool
}{
{"empty origin allowed", "api.multica.ai", "", "", "1.2.3.4:5678", true},
{"same-origin allowed (native client default)", "localhost:8080", "http://localhost:8080", "", "1.2.3.4:5678", true},
{"same-origin allowed (https)", "api.multica.ai", "https://api.multica.ai", "", "1.2.3.4:5678", true},
{"same-origin allowed (case-insensitive host, RFC 7230)", "API.Multica.AI", "https://api.multica.ai", "", "1.2.3.4:5678", true},
{"whitelisted origin allowed (web cross-origin)", "localhost:8080", "http://localhost:3000", "", "1.2.3.4:5678", true},
{"whitelisted origin allowed (prod web)", "api.multica.ai", "https://multica.ai", "", "1.2.3.4:5678", true},
{"unknown origin rejected (CSWSH defense)", "api.multica.ai", "https://evil.com", "", "1.2.3.4:5678", false},
{"different port rejected", "localhost:8080", "http://localhost:9999", "", "1.2.3.4:5678", false},
{"X-Forwarded-Host from trusted proxy matches origin", "internal.proxy", "https://multica.ai", "multica.ai", "127.0.0.1:5678", true},
{"X-Forwarded-Host from trusted proxy case-insensitive", "internal.proxy", "https://Multica.AI", "multica.ai", "10.0.0.1:5678", true},
{"X-Forwarded-Host from untrusted source rejected", "internal.proxy", "https://example.com", "example.com", "1.2.3.4:5678", false},
{"X-Forwarded-Host from trusted proxy but evil origin rejected", "internal.proxy", "https://evil.com", "multica.ai", "127.0.0.1:5678", false},
{"X-Forwarded-Host present but origin matches direct Host", "multica.ai", "https://multica.ai", "other.host", "1.2.3.4:5678", true},
{"X-Forwarded-Host spoofed by attacker rejected", "internal.proxy", "https://evil.com", "evil.com", "1.2.3.4:5678", false},
{"X-Forwarded-Host from trusted CIDR range matches origin", "internal.proxy", "https://multica.ai", "multica.ai", "10.5.6.7:5678", true},
{"X-Forwarded-Host from trusted IPv6 proxy matches origin", "internal.proxy", "https://multica.ai", "multica.ai", "[::1]:5678", true},
{"X-Forwarded-Host comma list uses first (client-facing) value", "internal.proxy", "https://multica.ai", "multica.ai, proxy.internal", "127.0.0.1:5678", true},
{"X-Forwarded-Host comma list ignores trailing values", "internal.proxy", "https://app.multica.ai", "proxy.internal, app.multica.ai", "127.0.0.1:5678", false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Host = tc.host
r.RemoteAddr = tc.remoteAddr
if tc.origin != "" {
r.Header.Set("Origin", tc.origin)
}
if tc.fwdHost != "" {
r.Header.Set("X-Forwarded-Host", tc.fwdHost)
}
if got := checkOrigin(r); got != tc.want {
t.Fatalf("checkOrigin(host=%q, origin=%q, X-Forwarded-Host=%q, remoteAddr=%q) = %v, want %v", tc.host, tc.origin, tc.fwdHost, tc.remoteAddr, got, tc.want)
}
})
}
}