Add API key generation in the UI + allow it to be used across all endpoints

This commit is contained in:
Weves
2024-01-11 22:30:49 -08:00
committed by Chris Weaver
parent 4b44073d9a
commit ae02a5199a
10 changed files with 498 additions and 47 deletions

View File

@@ -0,0 +1,48 @@
import secrets
import uuid
from fastapi import Request
from passlib.hash import sha256_crypt
from pydantic import BaseModel
_API_KEY_HEADER_NAME = "Authorization"
_BEARER_PREFIX = "Bearer "
_API_KEY_PREFIX = "dn_"
_API_KEY_LEN = 192
class ApiKeyDescriptor(BaseModel):
api_key_id: int
api_key_display: str
api_key: str | None = None # only present on initial creation
user_id: uuid.UUID
def generate_api_key() -> str:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
def hash_api_key(api_key: str) -> str:
# NOTE: no salt is needed, as the API key is randomly generated
# and overlaps are impossible
return sha256_crypt.hash(api_key, salt="")
def build_displayable_api_key(api_key: str) -> str:
if api_key.startswith(_API_KEY_PREFIX):
api_key = api_key[len(_API_KEY_PREFIX) :]
return _API_KEY_PREFIX + api_key[:4] + "********" + api_key[-4:]
def get_hashed_api_key_from_request(request: Request) -> str | None:
raw_api_key_header = request.headers.get(_API_KEY_HEADER_NAME)
if raw_api_key_header is None:
return None
if raw_api_key_header.startswith(_BEARER_PREFIX):
raw_api_key_header = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
return hash_api_key(raw_api_key_header)

View File

@@ -1,3 +1,4 @@
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
@@ -6,8 +7,11 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.constants import AuthType
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.auth.api_key import get_hashed_api_key_from_request
from ee.danswer.db.api_key import fetch_user_for_api_key
from ee.danswer.db.saml import get_saml_account
from ee.danswer.utils.secrets import extract_hashed_cookie
@@ -36,6 +40,12 @@ async def double_check_user(
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
user = saml_account.user if saml_account else None
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -43,3 +53,17 @@ async def double_check_user(
)
return user
def api_key_dep(request: Request, db_session: Session = Depends(get_session)) -> User:
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user

View File

@@ -0,0 +1,108 @@
import uuid
from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.db.models import ApiKey
from danswer.db.models import User
from ee.danswer.auth.api_key import ApiKeyDescriptor
from ee.danswer.auth.api_key import build_displayable_api_key
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
_DANSWER_API_KEY = "danswer_api_key"
def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
api_keys = db_session.scalars(select(ApiKey)).all()
return [
ApiKeyDescriptor(
api_key_id=api_key.id,
api_key_display=api_key.api_key_display,
user_id=api_key.user_id,
)
for api_key in api_keys
]
def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None:
api_key = db_session.scalar(
select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)
)
if api_key is None:
return None
return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore
def insert_api_key(db_session: Session, user_id: uuid.UUID | None) -> ApiKeyDescriptor:
std_password_helper = PasswordHelper()
api_key = generate_api_key()
api_key_user_id = uuid.uuid4()
api_key_user_row = User(
id=api_key_user_id,
email=f"{_DANSWER_API_KEY}__{api_key_user_id}",
# a random password for the "user"
hashed_password=std_password_helper.hash(std_password_helper.generate()),
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.BASIC,
)
db_session.add(api_key_user_row)
api_key_row = ApiKey(
hashed_api_key=hash_api_key(api_key),
api_key_display=build_displayable_api_key(api_key),
user_id=api_key_user_id,
owner_id=user_id,
)
db_session.add(api_key_row)
db_session.commit()
return ApiKeyDescriptor(
api_key_id=api_key_row.id,
api_key_display=api_key_row.api_key_display,
api_key=api_key,
user_id=api_key_user_id,
)
def regenerate_api_key(db_session: Session, api_key_id: int) -> ApiKeyDescriptor:
"""NOTE: currently, any admin can regenerate any API key."""
existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id))
if existing_api_key is None:
raise ValueError(f"API key with id {api_key_id} does not exist")
new_api_key = generate_api_key()
existing_api_key.hashed_api_key = hash_api_key(new_api_key)
existing_api_key.api_key_display = build_displayable_api_key(new_api_key)
db_session.commit()
return ApiKeyDescriptor(
api_key_id=existing_api_key.id,
api_key_display=existing_api_key.api_key_display,
api_key=new_api_key,
user_id=existing_api_key.user_id,
)
def remove_api_key(db_session: Session, api_key_id: int) -> None:
existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id))
if existing_api_key is None:
raise ValueError(f"API key with id {api_key_id} does not exist")
user_associated_with_key = db_session.scalar(
select(User).where(User.id == existing_api_key.user_id) # type: ignore
)
if user_associated_with_key is None:
raise ValueError(
f"User associated with API key with id {api_key_id} does not exist. This should not happen."
)
db_session.delete(existing_api_key)
db_session.delete(user_associated_with_key)
db_session.commit()

View File

@@ -17,6 +17,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
from ee.danswer.server.analytics.api import router as analytics_router
from ee.danswer.server.api_key.api import router as api_key_router
from ee.danswer.server.query_history.api import router as query_history_router
from ee.danswer.server.saml import router as saml_router
from ee.danswer.server.user_group.api import router as user_group_router
@@ -59,6 +60,8 @@ def get_ee_application() -> FastAPI:
# analytics endpoints
application.include_router(analytics_router)
application.include_router(query_history_router)
# api key management
application.include_router(api_key_router)
return application

View File

@@ -0,0 +1,48 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
import danswer.db.models as db_models
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from ee.danswer.db.api_key import ApiKeyDescriptor
from ee.danswer.db.api_key import fetch_api_keys
from ee.danswer.db.api_key import insert_api_key
from ee.danswer.db.api_key import regenerate_api_key
from ee.danswer.db.api_key import remove_api_key
router = APIRouter(prefix="/admin/api-key")
@router.get("")
def list_api_keys(
_: db_models.User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[ApiKeyDescriptor]:
return fetch_api_keys(db_session)
@router.post("")
def create_api_key(
user: db_models.User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ApiKeyDescriptor:
return insert_api_key(db_session, user.id if user else None)
@router.patch("/{api_key_id}")
def regenerate_existing_api_key(
api_key_id: int,
_: db_models.User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ApiKeyDescriptor:
return regenerate_api_key(db_session, api_key_id)
@router.delete("/{api_key_id}")
def delete_api_key(
api_key_id: int,
_: db_models.User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
remove_api_key(db_session, api_key_id)