danswer/backend/onyx/db/api_key.py
pablonyx a98dcbc7de
Update tenant logic (#4122)
* k

* k

* k

* quick nit

* nit
2025-02-26 03:53:46 +00:00

182 lines
6.1 KiB
Python

import uuid
from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
from onyx.auth.api_key import build_displayable_api_key
from onyx.auth.api_key import generate_api_key
from onyx.auth.api_key import hash_api_key
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.server.api_key.models import APIKeyArgs
from shared_configs.contextvars import get_current_tenant_id
def get_api_key_email_pattern() -> str:
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
def is_api_key_email_address(email: str) -> bool:
return email.endswith(get_api_key_email_pattern())
def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
api_keys = (
db_session.scalars(select(ApiKey).options(joinedload(ApiKey.user)))
.unique()
.all()
)
return [
ApiKeyDescriptor(
api_key_id=api_key.id,
api_key_role=api_key.user.role,
api_key_display=api_key.api_key_display,
api_key_name=api_key.name,
user_id=api_key.user_id,
)
for api_key in api_keys
]
async def fetch_user_for_api_key(
hashed_api_key: str, async_db_session: AsyncSession
) -> User | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
return await async_db_session.scalar(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
)
def get_api_key_fake_email(
name: str,
unique_id: str,
) -> str:
return f"{DANSWER_API_KEY_PREFIX}{name}@{unique_id}{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}"
def insert_api_key(
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
) -> ApiKeyDescriptor:
std_password_helper = PasswordHelper()
# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = get_current_tenant_id()
api_key = generate_api_key(tenant_id)
api_key_user_id = uuid.uuid4()
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
api_key_user_row = User(
id=api_key_user_id,
email=get_api_key_fake_email(display_name, str(api_key_user_id)),
# a random password for the "user"
hashed_password=std_password_helper.hash(std_password_helper.generate()),
is_active=True,
is_superuser=False,
is_verified=True,
role=api_key_args.role,
)
db_session.add(api_key_user_row)
api_key_row = ApiKey(
name=api_key_args.name,
hashed_api_key=hash_api_key(api_key),
api_key_display=build_displayable_api_key(api_key),
user_id=api_key_user_id,
owner_id=user_id,
)
db_session.add(api_key_row)
db_session.commit()
return ApiKeyDescriptor(
api_key_id=api_key_row.id,
api_key_role=api_key_user_row.role,
api_key_display=api_key_row.api_key_display,
api_key=api_key,
api_key_name=api_key_args.name,
user_id=api_key_user_id,
)
def update_api_key(
db_session: Session, api_key_id: int, api_key_args: APIKeyArgs
) -> ApiKeyDescriptor:
existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id))
if existing_api_key is None:
raise ValueError(f"API key with id {api_key_id} does not exist")
existing_api_key.name = api_key_args.name
api_key_user = db_session.scalar(
select(User).where(User.id == existing_api_key.user_id) # type: ignore
)
if api_key_user is None:
raise RuntimeError("API Key does not have associated user.")
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
api_key_user.role = api_key_args.role
db_session.commit()
return ApiKeyDescriptor(
api_key_id=existing_api_key.id,
api_key_display=existing_api_key.api_key_display,
api_key_name=api_key_args.name,
api_key_role=api_key_user.role,
user_id=existing_api_key.user_id,
)
def regenerate_api_key(db_session: Session, api_key_id: int) -> ApiKeyDescriptor:
"""NOTE: currently, any admin can regenerate any API key."""
existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id))
if existing_api_key is None:
raise ValueError(f"API key with id {api_key_id} does not exist")
api_key_user = db_session.scalar(
select(User).where(User.id == existing_api_key.user_id) # type: ignore
)
if api_key_user is None:
raise RuntimeError("API Key does not have associated user.")
new_api_key = generate_api_key()
existing_api_key.hashed_api_key = hash_api_key(new_api_key)
existing_api_key.api_key_display = build_displayable_api_key(new_api_key)
db_session.commit()
return ApiKeyDescriptor(
api_key_id=existing_api_key.id,
api_key_display=existing_api_key.api_key_display,
api_key=new_api_key,
api_key_name=existing_api_key.name,
api_key_role=api_key_user.role,
user_id=existing_api_key.user_id,
)
def remove_api_key(db_session: Session, api_key_id: int) -> None:
existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id))
if existing_api_key is None:
raise ValueError(f"API key with id {api_key_id} does not exist")
user_associated_with_key = db_session.scalar(
select(User).where(User.id == existing_api_key.user_id) # type: ignore
)
if user_associated_with_key is None:
raise ValueError(
f"User associated with API key with id {api_key_id} does not exist. This should not happen."
)
db_session.delete(existing_api_key)
db_session.delete(user_associated_with_key)
db_session.commit()