JWT -> Redis (#3574)

* functional v1

* functional logout

* minor clean up

* quick clean up

* update configuration

* ni

* nit

* finalize

* update login page

* delete unused import

* quick nit

* updates

* clean up

* ni

* k

* k
This commit is contained in:
pablonyx 2025-01-04 11:35:43 -08:00 committed by GitHub
parent 67d2c86250
commit ffec19645b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 163 additions and 100 deletions

View File

@ -2,15 +2,14 @@ import logging
from collections.abc import Awaitable
from collections.abc import Callable
import jwt
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@ -22,11 +21,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
tenant_id = (
_get_tenant_id_from_request(request, logger)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
if MULTI_TENANT:
tenant_id = await _get_tenant_id_from_request(request, logger)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return await call_next(request)
@ -35,27 +34,36 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
raise
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
# First check for API key
async def _get_tenant_id_from_request(
request: Request, logger: logging.LoggerAdapter
) -> str:
"""
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
if tenant_id:
return tenant_id
# Check for cookie-based auth
token = request.cookies.get("fastapiusersauth")
if not token:
return POSTGRES_DEFAULT_SCHEMA
try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
# Since payload.get() can return None, ensure we have a string
if not token_data:
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
# Since token_data.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
@ -67,9 +75,6 @@ def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter)
return tenant_id
except jwt.InvalidTokenError:
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@ -19,7 +19,7 @@ from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_jwt_strategy
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.auth import get_user_count
@ -112,7 +112,7 @@ async def impersonate_user(
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_jwt_strategy().write_token(user_to_impersonate)
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(

View File

@ -1,3 +1,5 @@
import json
import secrets
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
@ -29,10 +31,8 @@ from fastapi_users import schemas
from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import JWTStrategy
from fastapi_users.authentication import RedisStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import decode_jwt
from fastapi_users.jwt import generate_jwt
@ -59,6 +59,8 @@ from onyx.auth.schemas import UserUpdate
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
@ -73,7 +75,6 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.api_key import fetch_user_for_api_key
from onyx.db.auth import get_access_token_db
from onyx.db.auth import get_default_admin_user_emails
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
@ -81,10 +82,10 @@ from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
@ -581,49 +582,70 @@ cookie_transport = CookieTransport(
)
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
"""
A custom strategy that fetches the actual async Redis connection inside each method.
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
"""
def __init__(
self,
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
):
self.lifetime_seconds = lifetime_seconds
self.key_prefix = key_prefix
async def write_token(self, user: User) -> str:
redis = await get_async_redis_connection()
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user.email,
)
)(email=user.email)
data = {
token_data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data
async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
token = secrets.token_urlsafe()
await redis.set(
f"{self.key_prefix}{token}",
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID]
) -> Optional[User]:
redis = await get_async_redis_connection()
token_data_str = await redis.get(f"{self.key_prefix}{token}")
if not token_data_str:
return None
def get_jwt_strategy() -> TenantAwareJWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
)
try:
token_data = json.loads(token_data_str)
user_id = token_data["sub"]
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
return None
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
async def destroy_token(self, token: str, user: User) -> None:
"""Properly delete the token from async redis."""
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
auth_backend = AuthenticationBackend(
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
) # type: ignore
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):

View File

@ -54,6 +54,10 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 3600
)
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
@ -188,9 +192,11 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
if _rate_limit_window_seconds_str is not None:
@ -570,7 +576,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
#####
# API Key Configs
#####

View File

@ -1,4 +1,5 @@
import contextlib
import json
import os
import re
import ssl
@ -14,7 +15,6 @@ from typing import ContextManager
import asyncpg # type: ignore
import boto3
import jwt
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
@ -40,9 +40,9 @@ from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@ -322,31 +322,33 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
def get_current_tenant_id(request: Request) -> str:
async def get_current_tenant_id(request: Request) -> str:
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
token = request.cookies.get("fastapiusersauth")
if not token:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
return current_value
try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
logger.debug(
f"Token data not found or expired in Redis, defaulting to {current_value}"
)
return current_value
tenant_id = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
return CURRENT_TENANT_ID_CONTEXTVAR.get()
except json.JSONDecodeError:
logger.error("Error decoding token data from Redis")
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@ -1,14 +1,17 @@
import asyncio
import functools
import json
import threading
from collections.abc import Callable
from typing import Any
from typing import Optional
import redis
from fastapi import Request
from redis import asyncio as aioredis
from redis.client import Redis
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import REDIS_DB_NUMBER
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from onyx.configs.app_configs import REDIS_HOST
@ -228,3 +231,31 @@ async def get_async_redis_connection() -> aioredis.Redis:
# Return the established connection (or pool) for all future operations
return _async_redis_connection
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
token = request.cookies.get("fastapiusersauth")
if not token:
logger.debug("No auth token cookie found")
return None
try:
redis = await get_async_redis_connection()
redis_key = REDIS_AUTH_KEY_PREFIX + token
token_data_str = await redis.get(redis_key)
if not token_data_str:
logger.debug(f"Token key {redis_key} not found or expired in Redis")
return None
return json.loads(token_data_str)
except json.JSONDecodeError:
logger.error("Error decoding token data from Redis")
return None
except Exception as e:
logger.error(
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
)
raise ValueError(
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
)

View File

@ -11,6 +11,7 @@ import React, { useContext, useState, useEffect } from "react";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { Modal } from "@/components/Modal";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
export function Checkbox({
label,
@ -218,14 +219,19 @@ export function SettingsForm() {
handleToggleSettingsField("auto_scroll", e.target.checked)
}
/>
<Checkbox
label="Anonymous Users"
sublabel="If set, users will not be required to sign in to use Danswer."
checked={settings.anonymous_user_enabled}
onChange={(e) =>
handleToggleSettingsField("anonymous_user_enabled", e.target.checked)
}
/>
{!NEXT_PUBLIC_CLOUD_ENABLED && (
<Checkbox
label="Anonymous Users"
sublabel="If set, users will not be required to sign in to use Onyx."
checked={settings.anonymous_user_enabled}
onChange={(e) =>
handleToggleSettingsField(
"anonymous_user_enabled",
e.target.checked
)
}
/>
)}
{showConfirmModal && (
<Modal
width="max-w-3xl w-full"

View File

@ -104,18 +104,10 @@ const Page = async (props: {
<span className="px-4 text-gray-500">or</span>
<div className="flex-grow border-t border-gray-300"></div>
</div>
<EmailPasswordForm shouldVerify={true} nextUrl={nextUrl} />
<div className="flex mt-4 justify-between">
<Link
href={`/auth/signup${
searchParams?.next ? `?next=${searchParams.next}` : ""
}`}
className="text-link font-medium"
>
Create an account
</Link>
{NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && (
<Link
href="/auth/forgot-password"

View File

@ -16,7 +16,7 @@ export const getCurrentUser = async (): Promise<User | null> => {
};
export const logout = async (): Promise<Response> => {
const response = await fetch("/auth/logout", {
const response = await fetch("/api/auth/logout", {
method: "POST",
credentials: "include",
});