diff --git a/backend/ee/danswer/auth/api_key.py b/backend/ee/danswer/auth/api_key.py index 74e391ad1da..9ea827d27dc 100644 --- a/backend/ee/danswer/auth/api_key.py +++ b/backend/ee/danswer/auth/api_key.py @@ -1,5 +1,7 @@ import secrets import uuid +from urllib.parse import quote +from urllib.parse import unquote from fastapi import Request from passlib.hash import sha256_crypt @@ -30,8 +32,35 @@ class ApiKeyDescriptor(BaseModel): user_id: uuid.UUID -def generate_api_key() -> str: - return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN) +def generate_api_key(tenant_id: str | None = None) -> str: + # For backwards compatibility, if no tenant_id, generate old style key + if not tenant_id: + return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN) + + encoded_tenant = quote(tenant_id) # URL encode the tenant ID + return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}" + + +def extract_tenant_from_api_key_header(request: Request) -> str | None: + """Extract tenant ID from request. Returns None if auth is disabled or invalid format.""" + raw_api_key_header = request.headers.get( + _API_KEY_HEADER_ALTERNATIVE_NAME + ) or request.headers.get(_API_KEY_HEADER_NAME) + + if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX): + return None + + api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip() + + if not api_key.startswith(_API_KEY_PREFIX): + return None + + parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1) + if len(parts) != 2: + return None + + tenant_id = parts[0] + return unquote(tenant_id) if tenant_id else None def hash_api_key(api_key: str) -> str: diff --git a/backend/ee/danswer/db/api_key.py b/backend/ee/danswer/db/api_key.py index 8bbdf7eaa9a..4cc0774f129 100644 --- a/backend/ee/danswer/db/api_key.py +++ b/backend/ee/danswer/db/api_key.py @@ -15,6 +15,8 @@ from ee.danswer.auth.api_key import build_displayable_api_key from ee.danswer.auth.api_key import generate_api_key from ee.danswer.auth.api_key import hash_api_key from ee.danswer.server.api_key.models import APIKeyArgs +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR +from shared_configs.configs import MULTI_TENANT def get_api_key_email_pattern() -> str: @@ -64,7 +66,11 @@ def insert_api_key( db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None ) -> ApiKeyDescriptor: std_password_helper = PasswordHelper() - api_key = generate_api_key() + + # Get tenant_id from context var (will be default schema for single tenant) + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + + api_key = generate_api_key(tenant_id if MULTI_TENANT else None) api_key_user_id = uuid.uuid4() display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER diff --git a/backend/ee/danswer/server/middleware/tenant_tracking.py b/backend/ee/danswer/server/middleware/tenant_tracking.py index 8c076e57d5c..f9fe75425e0 100644 --- a/backend/ee/danswer/server/middleware/tenant_tracking.py +++ b/backend/ee/danswer/server/middleware/tenant_tracking.py @@ -10,6 +10,7 @@ from fastapi import Response from danswer.configs.app_configs import USER_AUTH_SECRET from danswer.db.engine import is_valid_schema_name +from ee.danswer.auth.api_key import extract_tenant_from_api_key_header from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA @@ -21,40 +22,54 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: try: - if not MULTI_TENANT: - tenant_id = POSTGRES_DEFAULT_SCHEMA - else: - token = request.cookies.get("fastapiusersauth") + tenant_id = POSTGRES_DEFAULT_SCHEMA - if token: - try: - payload = jwt.decode( - token, - USER_AUTH_SECRET, - audience=["fastapi-users:auth"], - algorithms=["HS256"], - ) - tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA) - if not is_valid_schema_name(tenant_id): - raise HTTPException( - status_code=400, detail="Invalid tenant ID format" - ) - except jwt.InvalidTokenError: - tenant_id = POSTGRES_DEFAULT_SCHEMA - except Exception as e: - logger.error( - f"Unexpected error in set_tenant_id_middleware: {str(e)}" - ) - raise HTTPException( - status_code=500, detail="Internal server error" - ) - else: - tenant_id = POSTGRES_DEFAULT_SCHEMA + if MULTI_TENANT: + tenant_id = _get_tenant_id_from_request(request, logger) CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - response = await call_next(request) - return response + return await call_next(request) except Exception as e: logger.error(f"Error in tenant ID middleware: {str(e)}") raise + + +def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str: + # First check for API key + tenant_id = extract_tenant_from_api_key_header(request) + if tenant_id is not None: + return tenant_id + + # Check for cookie-based auth + token = request.cookies.get("fastapiusersauth") + if not token: + return POSTGRES_DEFAULT_SCHEMA + + try: + payload = jwt.decode( + token, + USER_AUTH_SECRET, + audience=["fastapi-users:auth"], + algorithms=["HS256"], + ) + tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA) + + # Since payload.get() can return None, ensure we have a string + tenant_id = ( + str(tenant_id_from_payload) + if tenant_id_from_payload is not None + else POSTGRES_DEFAULT_SCHEMA + ) + + if not is_valid_schema_name(tenant_id): + raise HTTPException(status_code=400, detail="Invalid tenant ID format") + + return tenant_id + + except jwt.InvalidTokenError: + return POSTGRES_DEFAULT_SCHEMA + + except Exception as e: + logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error")