mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 20:38:32 +02:00
JWT -> Redis (#3574)
* functional v1 * functional logout * minor clean up * quick clean up * update configuration * ni * nit * finalize * update login page * delete unused import * quick nit * updates * clean up * ni * k * k
This commit is contained in:
@@ -2,15 +2,14 @@ import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
import jwt
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -22,11 +21,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
try:
|
||||
tenant_id = (
|
||||
_get_tenant_id_from_request(request, logger)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
if MULTI_TENANT:
|
||||
tenant_id = await _get_tenant_id_from_request(request, logger)
|
||||
else:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return await call_next(request)
|
||||
|
||||
@@ -35,27 +34,36 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
raise
|
||||
|
||||
|
||||
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
|
||||
# First check for API key
|
||||
async def _get_tenant_id_from_request(
|
||||
request: Request, logger: logging.LoggerAdapter
|
||||
) -> str:
|
||||
"""
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id is not None:
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Check for cookie-based auth
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
# Look up token data in Redis
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
# Since payload.get() can return None, ensure we have a string
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
)
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Since token_data.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
@@ -67,9 +75,6 @@ def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
@@ -19,7 +19,7 @@ from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_jwt_strategy
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
@@ -112,7 +112,7 @@ async def impersonate_user(
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_jwt_strategy().write_token(user_to_impersonate)
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
|
Reference in New Issue
Block a user