mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
Basic multi tenant api key (#3004)
* basic multi tenant api key * organization * nit * clean
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import secrets
|
||||
import uuid
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import unquote
|
||||
|
||||
from fastapi import Request
|
||||
from passlib.hash import sha256_crypt
|
||||
@@ -30,8 +32,35 @@ class ApiKeyDescriptor(BaseModel):
|
||||
user_id: uuid.UUID
|
||||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"
|
||||
|
||||
|
||||
def extract_tenant_from_api_key_header(request: Request) -> str | None:
|
||||
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
|
||||
raw_api_key_header = request.headers.get(
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME
|
||||
) or request.headers.get(_API_KEY_HEADER_NAME)
|
||||
|
||||
if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
|
||||
return None
|
||||
|
||||
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
|
||||
|
||||
if not api_key.startswith(_API_KEY_PREFIX):
|
||||
return None
|
||||
|
||||
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
tenant_id = parts[0]
|
||||
return unquote(tenant_id) if tenant_id else None
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
|
@@ -15,6 +15,8 @@ from ee.danswer.auth.api_key import build_displayable_api_key
|
||||
from ee.danswer.auth.api_key import generate_api_key
|
||||
from ee.danswer.auth.api_key import hash_api_key
|
||||
from ee.danswer.server.api_key.models import APIKeyArgs
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
@@ -64,7 +66,11 @@ def insert_api_key(
|
||||
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
|
||||
) -> ApiKeyDescriptor:
|
||||
std_password_helper = PasswordHelper()
|
||||
api_key = generate_api_key()
|
||||
|
||||
# Get tenant_id from context var (will be default schema for single tenant)
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
|
||||
api_key_user_id = uuid.uuid4()
|
||||
|
||||
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
|
@@ -10,6 +10,7 @@ from fastapi import Response
|
||||
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.db.engine import is_valid_schema_name
|
||||
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -21,40 +22,54 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
try:
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
else:
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if token:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid tenant ID format"
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error in set_tenant_id_middleware: {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error"
|
||||
)
|
||||
else:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
if MULTI_TENANT:
|
||||
tenant_id = _get_tenant_id_from_request(request, logger)
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
return await call_next(request)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tenant ID middleware: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
|
||||
# First check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id is not None:
|
||||
return tenant_id
|
||||
|
||||
# Check for cookie-based auth
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Since payload.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
return tenant_id
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
Reference in New Issue
Block a user