Files
multica/server/internal/middleware/auth_test.go

414 lines
13 KiB
Go

package middleware
import (
"context"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/multica-ai/multica/server/internal/auth"
"github.com/redis/go-redis/v9"
)
// newRedisTestClient connects to REDIS_TEST_URL, flushes, and skips when
// unset — same gating pattern the rest of the suite uses for Redis-backed
// tests, so `go test ./...` works on a stock laptop without a Redis.
func newRedisTestClient(t *testing.T) *redis.Client {
t.Helper()
url := os.Getenv("REDIS_TEST_URL")
if url == "" {
t.Skip("REDIS_TEST_URL not set")
}
opts, err := redis.ParseURL(url)
if err != nil {
t.Fatalf("parse REDIS_TEST_URL: %v", err)
}
rdb := redis.NewClient(opts)
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
t.Skipf("REDIS_TEST_URL unreachable: %v", err)
}
if err := rdb.FlushDB(ctx).Err(); err != nil {
t.Fatalf("flushdb: %v", err)
}
t.Cleanup(func() {
rdb.FlushDB(context.Background())
rdb.Close()
})
return rdb
}
func generateToken(claims jwt.MapClaims, secret []byte) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
s, _ := token.SignedString(secret)
return s
}
func validClaims() jwt.MapClaims {
return jwt.MapClaims{
"sub": "test-user-id",
"email": "test@multica.ai",
"exp": time.Now().Add(time.Hour).Unix(),
}
}
// authMiddleware returns the Auth middleware with nil queries (JWT-only tests).
func authMiddleware(next http.Handler) http.Handler {
return Auth(nil, nil, nil)(next)
}
func TestAuth_MissingHeader(t *testing.T) {
handler := authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))
req := httptest.NewRequest("GET", "/api/me", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
if body := w.Body.String(); body != `{"error":"missing authorization"}`+"\n" {
t.Fatalf("unexpected body: %s", body)
}
}
func TestAuth_NoBearerPrefix(t *testing.T) {
handler := authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Token some-token")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
// Non-Bearer Authorization header with no cookie falls through to "missing authorization".
if body := w.Body.String(); body != `{"error":"missing authorization"}`+"\n" {
t.Fatalf("unexpected body: %s", body)
}
}
func TestAuth_InvalidToken(t *testing.T) {
handler := authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestAuth_ExpiredToken(t *testing.T) {
handler := authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))
claims := validClaims()
claims["exp"] = time.Now().Add(-time.Hour).Unix()
token := generateToken(claims, auth.JWTSecret())
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestAuth_WrongSecret(t *testing.T) {
handler := authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))
token := generateToken(validClaims(), []byte("wrong-secret"))
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestAuth_WrongSigningMethod(t *testing.T) {
handler := authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))
// Use "none" signing method
token := jwt.NewWithClaims(jwt.SigningMethodNone, validClaims())
s, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer "+s)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestAuth_ValidToken(t *testing.T) {
var gotUserID, gotEmail string
handler := authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUserID = r.Header.Get("X-User-ID")
gotEmail = r.Header.Get("X-User-Email")
w.WriteHeader(http.StatusOK)
}))
token := generateToken(validClaims(), auth.JWTSecret())
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if gotUserID != "test-user-id" {
t.Fatalf("expected X-User-ID 'test-user-id', got '%s'", gotUserID)
}
if gotEmail != "test@multica.ai" {
t.Fatalf("expected X-User-Email 'test@multica.ai', got '%s'", gotEmail)
}
}
func TestAuth_MissingClaims(t *testing.T) {
handler := authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))
// Token with no sub or email claims, only exp
claims := jwt.MapClaims{
"exp": time.Now().Add(time.Hour).Unix(),
}
token := generateToken(claims, auth.JWTSecret())
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestAuth_InvalidPAT(t *testing.T) {
handler := authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next handler should not be called")
}))
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer mul_invalid_token_here")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
// TestAuth_StripsClientSuppliedActorSource enforces the invariant that
// X-Actor-Source is a server-only header: any value the client sends
// must be discarded before downstream code sees it. Without this
// guarantee a client carrying a normal mul_ PAT could supply a forged
// `X-Actor-Source: task_token` (or any other value) to fool a handler
// into treating the request differently — exactly the kind of trust
// boundary MUL-2600 introduces.
func TestAuth_StripsClientSuppliedActorSource(t *testing.T) {
var gotActorSource string
mw := Auth(nil, nil, nil)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotActorSource = r.Header.Get("X-Actor-Source")
w.WriteHeader(http.StatusOK)
}))
token := generateToken(validClaims(), auth.JWTSecret())
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer "+token)
// Client tries to forge the actor-source header. The middleware must
// discard it before the JWT branch runs (which doesn't set it again
// for human sessions).
req.Header.Set("X-Actor-Source", "task_token")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if gotActorSource != "" {
t.Fatalf("X-Actor-Source must be cleared on non-task-token paths, got %q", gotActorSource)
}
}
// TestAuth_PATCacheHit pins the optimization: when the PAT cache already
// holds an entry for this token, the middleware MUST NOT call into queries
// — it short-circuits before the DB lookup and the last_used_at update.
//
// We exploit that by passing nil queries: a cache miss would dereference
// the nil and panic; a cache hit must not. Reaching the next handler with
// the cached user_id therefore proves the short-circuit fired.
func TestAuth_PATCacheHit(t *testing.T) {
rdb := newRedisTestClient(t)
cache := auth.NewPATCache(rdb)
if cache == nil {
t.Fatal("expected non-nil cache")
}
const rawToken = "mul_cache_hit_test_token"
hash := auth.HashToken(rawToken)
cache.Set(context.Background(), hash, "cached-user-id", auth.AuthCacheTTL)
var gotUserID string
mw := Auth(nil, cache, nil) // nil queries — only safe on cache hit
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUserID = r.Header.Get("X-User-ID")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer "+rawToken)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 on cache hit, got %d", w.Code)
}
if gotUserID != "cached-user-id" {
t.Fatalf("expected cached X-User-ID, got %q", gotUserID)
}
}
// TestAuth_MCN_NoVerifierConfigured pins the same fail-closed branch
// as the daemon side: with no MULTICA_CLOUD_FLEET_URL configured, an
// mcn_ bearer token must be rejected with 401 at the prefix branch.
// We don't fall through — an mcn_ string can't be a valid mul_ PAT or
// JWT, so any fall-through would be wasted work.
func TestAuth_MCN_NoVerifierConfigured(t *testing.T) {
mw := Auth(nil, nil, nil)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next must not be called when verifier is unconfigured")
}))
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer mcn_anything")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
// TestAuth_MCN_ValidTokenSetsUserID is the happy-path mirror of the
// daemon test: a successful Fleet verify must surface owner_id as
// X-User-ID for downstream user-scoped handlers.
func TestAuth_MCN_ValidTokenSetsUserID(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"valid": true,
"owner_id": "01972f7e-7e8d-77ef-a13d-1b0ce3e9c001"
}`))
}))
defer srv.Close()
verifier := auth.NewCloudPATVerifier(auth.CloudPATVerifierConfig{FleetBaseURL: srv.URL})
var gotUser, gotActorSource string
mw := Auth(nil, nil, verifier)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUser = r.Header.Get("X-User-ID")
gotActorSource = r.Header.Get("X-Actor-Source")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer mcn_x")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if gotUser != "01972f7e-7e8d-77ef-a13d-1b0ce3e9c001" {
t.Errorf("expected owner_id propagated as X-User-ID, got %q", gotUser)
}
// Pinned per the cloud-billing review: a successful mcn_ verify
// MUST stamp X-Actor-Source so account-level guards (e.g.
// handler.RequireHumanActor on /api/cloud-billing/*) can tell a
// machine credential apart from a human PAT/JWT. Dropping this
// stamp would silently let an mcn_ holder reach billing.
if gotActorSource != "cloud_pat" {
t.Errorf("expected X-Actor-Source=cloud_pat, got %q", gotActorSource)
}
}
// TestAuth_MCN_InvalidReturns401 confirms that a Fleet valid:false maps
// to 401 — the token is genuinely bad, retrying won't help.
func TestAuth_MCN_InvalidReturns401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"valid":false,"reason":"token_revoked"}`))
}))
defer srv.Close()
verifier := auth.NewCloudPATVerifier(auth.CloudPATVerifierConfig{FleetBaseURL: srv.URL})
mw := Auth(nil, nil, verifier)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next must not be called when token is invalid")
}))
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer mcn_revoked")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
// TestAuth_MCN_FleetUnreachableReturns503 — the Unavailable branch must
// emit 503 (transient), distinguishing it from a 401 (token is bad).
func TestAuth_MCN_FleetUnreachableReturns503(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer srv.Close()
verifier := auth.NewCloudPATVerifier(auth.CloudPATVerifierConfig{FleetBaseURL: srv.URL})
mw := Auth(nil, nil, verifier)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next must not be called when fleet is unavailable")
}))
req := httptest.NewRequest("GET", "/api/me", nil)
req.Header.Set("Authorization", "Bearer mcn_x")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("expected 503, got %d", w.Code)
}
}