Basic multi tenant api key (#3004)

* basic multi tenant api key

* organization

* nit

* clean
This commit is contained in:
pablodanswer
2024-11-01 12:34:51 -07:00
committed by GitHub
parent 6d543f3d4f
commit 753293cefb
3 changed files with 83 additions and 33 deletions

View File

@@ -1,5 +1,7 @@
import secrets
import uuid
from urllib.parse import quote
from urllib.parse import unquote
from fastapi import Request
from passlib.hash import sha256_crypt
@@ -30,8 +32,35 @@ class ApiKeyDescriptor(BaseModel):
user_id: uuid.UUID
def generate_api_key() -> str:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
def generate_api_key(tenant_id: str | None = None) -> str:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"
def extract_tenant_from_api_key_header(request: Request) -> str | None:
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
raw_api_key_header = request.headers.get(
_API_KEY_HEADER_ALTERNATIVE_NAME
) or request.headers.get(_API_KEY_HEADER_NAME)
if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
return None
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
if not api_key.startswith(_API_KEY_PREFIX):
return None
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
if len(parts) != 2:
return None
tenant_id = parts[0]
return unquote(tenant_id) if tenant_id else None
def hash_api_key(api_key: str) -> str:

View File

@@ -15,6 +15,8 @@ from ee.danswer.auth.api_key import build_displayable_api_key
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
def get_api_key_email_pattern() -> str:
@@ -64,7 +66,11 @@ def insert_api_key(
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
) -> ApiKeyDescriptor:
std_password_helper = PasswordHelper()
api_key = generate_api_key()
# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
api_key_user_id = uuid.uuid4()
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER

View File

@@ -10,6 +10,7 @@ from fastapi import Response
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.db.engine import is_valid_schema_name
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -21,40 +22,54 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
token = request.cookies.get("fastapiusersauth")
tenant_id = POSTGRES_DEFAULT_SCHEMA
if token:
try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
except jwt.InvalidTokenError:
tenant_id = POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(
f"Unexpected error in set_tenant_id_middleware: {str(e)}"
)
raise HTTPException(
status_code=500, detail="Internal server error"
)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
if MULTI_TENANT:
tenant_id = _get_tenant_id_from_request(request, logger)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
response = await call_next(request)
return response
return await call_next(request)
except Exception as e:
logger.error(f"Error in tenant ID middleware: {str(e)}")
raise
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
# First check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
return tenant_id
# Check for cookie-based auth
token = request.cookies.get("fastapiusersauth")
if not token:
return POSTGRES_DEFAULT_SCHEMA
try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
# Since payload.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else POSTGRES_DEFAULT_SCHEMA
)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
return tenant_id
except jwt.InvalidTokenError:
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")