mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-18 13:51:46 +01:00
182 lines
6.1 KiB
Python
182 lines
6.1 KiB
Python
import uuid
|
|
|
|
from fastapi_users.password import PasswordHelper
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import joinedload
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.api_key import ApiKeyDescriptor
|
|
from onyx.auth.api_key import build_displayable_api_key
|
|
from onyx.auth.api_key import generate_api_key
|
|
from onyx.auth.api_key import hash_api_key
|
|
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
|
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
|
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
|
from onyx.db.models import ApiKey
|
|
from onyx.db.models import User
|
|
from onyx.server.api_key.models import APIKeyArgs
|
|
from shared_configs.contextvars import get_current_tenant_id
|
|
|
|
|
|
def get_api_key_email_pattern() -> str:
|
|
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
|
|
|
|
|
def is_api_key_email_address(email: str) -> bool:
|
|
return email.endswith(get_api_key_email_pattern())
|
|
|
|
|
|
def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
|
|
api_keys = (
|
|
db_session.scalars(select(ApiKey).options(joinedload(ApiKey.user)))
|
|
.unique()
|
|
.all()
|
|
)
|
|
return [
|
|
ApiKeyDescriptor(
|
|
api_key_id=api_key.id,
|
|
api_key_role=api_key.user.role,
|
|
api_key_display=api_key.api_key_display,
|
|
api_key_name=api_key.name,
|
|
user_id=api_key.user_id,
|
|
)
|
|
for api_key in api_keys
|
|
]
|
|
|
|
|
|
async def fetch_user_for_api_key(
|
|
hashed_api_key: str, async_db_session: AsyncSession
|
|
) -> User | None:
|
|
"""NOTE: this is async, since it's used during auth
|
|
(which is necessarily async due to FastAPI Users)"""
|
|
return await async_db_session.scalar(
|
|
select(User)
|
|
.join(ApiKey, ApiKey.user_id == User.id)
|
|
.where(ApiKey.hashed_api_key == hashed_api_key)
|
|
)
|
|
|
|
|
|
def get_api_key_fake_email(
|
|
name: str,
|
|
unique_id: str,
|
|
) -> str:
|
|
return f"{DANSWER_API_KEY_PREFIX}{name}@{unique_id}{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}"
|
|
|
|
|
|
def insert_api_key(
|
|
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
|
|
) -> ApiKeyDescriptor:
|
|
std_password_helper = PasswordHelper()
|
|
|
|
# Get tenant_id from context var (will be default schema for single tenant)
|
|
tenant_id = get_current_tenant_id()
|
|
|
|
api_key = generate_api_key(tenant_id)
|
|
api_key_user_id = uuid.uuid4()
|
|
|
|
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
|
api_key_user_row = User(
|
|
id=api_key_user_id,
|
|
email=get_api_key_fake_email(display_name, str(api_key_user_id)),
|
|
# a random password for the "user"
|
|
hashed_password=std_password_helper.hash(std_password_helper.generate()),
|
|
is_active=True,
|
|
is_superuser=False,
|
|
is_verified=True,
|
|
role=api_key_args.role,
|
|
)
|
|
db_session.add(api_key_user_row)
|
|
|
|
api_key_row = ApiKey(
|
|
name=api_key_args.name,
|
|
hashed_api_key=hash_api_key(api_key),
|
|
api_key_display=build_displayable_api_key(api_key),
|
|
user_id=api_key_user_id,
|
|
owner_id=user_id,
|
|
)
|
|
db_session.add(api_key_row)
|
|
|
|
db_session.commit()
|
|
return ApiKeyDescriptor(
|
|
api_key_id=api_key_row.id,
|
|
api_key_role=api_key_user_row.role,
|
|
api_key_display=api_key_row.api_key_display,
|
|
api_key=api_key,
|
|
api_key_name=api_key_args.name,
|
|
user_id=api_key_user_id,
|
|
)
|
|
|
|
|
|
def update_api_key(
|
|
db_session: Session, api_key_id: int, api_key_args: APIKeyArgs
|
|
) -> ApiKeyDescriptor:
|
|
existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id))
|
|
if existing_api_key is None:
|
|
raise ValueError(f"API key with id {api_key_id} does not exist")
|
|
|
|
existing_api_key.name = api_key_args.name
|
|
api_key_user = db_session.scalar(
|
|
select(User).where(User.id == existing_api_key.user_id) # type: ignore
|
|
)
|
|
if api_key_user is None:
|
|
raise RuntimeError("API Key does not have associated user.")
|
|
|
|
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
|
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
|
|
api_key_user.role = api_key_args.role
|
|
db_session.commit()
|
|
|
|
return ApiKeyDescriptor(
|
|
api_key_id=existing_api_key.id,
|
|
api_key_display=existing_api_key.api_key_display,
|
|
api_key_name=api_key_args.name,
|
|
api_key_role=api_key_user.role,
|
|
user_id=existing_api_key.user_id,
|
|
)
|
|
|
|
|
|
def regenerate_api_key(db_session: Session, api_key_id: int) -> ApiKeyDescriptor:
|
|
"""NOTE: currently, any admin can regenerate any API key."""
|
|
existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id))
|
|
if existing_api_key is None:
|
|
raise ValueError(f"API key with id {api_key_id} does not exist")
|
|
|
|
api_key_user = db_session.scalar(
|
|
select(User).where(User.id == existing_api_key.user_id) # type: ignore
|
|
)
|
|
if api_key_user is None:
|
|
raise RuntimeError("API Key does not have associated user.")
|
|
|
|
new_api_key = generate_api_key()
|
|
existing_api_key.hashed_api_key = hash_api_key(new_api_key)
|
|
existing_api_key.api_key_display = build_displayable_api_key(new_api_key)
|
|
db_session.commit()
|
|
|
|
return ApiKeyDescriptor(
|
|
api_key_id=existing_api_key.id,
|
|
api_key_display=existing_api_key.api_key_display,
|
|
api_key=new_api_key,
|
|
api_key_name=existing_api_key.name,
|
|
api_key_role=api_key_user.role,
|
|
user_id=existing_api_key.user_id,
|
|
)
|
|
|
|
|
|
def remove_api_key(db_session: Session, api_key_id: int) -> None:
|
|
existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id))
|
|
if existing_api_key is None:
|
|
raise ValueError(f"API key with id {api_key_id} does not exist")
|
|
|
|
user_associated_with_key = db_session.scalar(
|
|
select(User).where(User.id == existing_api_key.user_id) # type: ignore
|
|
)
|
|
if user_associated_with_key is None:
|
|
raise ValueError(
|
|
f"User associated with API key with id {api_key_id} does not exist. This should not happen."
|
|
)
|
|
|
|
db_session.delete(existing_api_key)
|
|
db_session.delete(user_associated_with_key)
|
|
db_session.commit()
|