From 7573416ca181a6bb134ce300a429b494b6f48f3a Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Sun, 24 Nov 2024 16:55:19 -0800 Subject: [PATCH] Fix API keys for MIT users (#3237) --- backend/danswer/auth/users.py | 24 ++++++++++++++++-------- backend/danswer/db/api_key.py | 17 ++++++++++------- backend/ee/danswer/auth/users.py | 16 +++++----------- backend/ee/danswer/db/saml.py | 15 +++++++++++---- backend/ee/danswer/server/saml.py | 13 +++++++++---- 5 files changed, 51 insertions(+), 34 deletions(-) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 5fd7d0cae..cf3de018f 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -49,7 +49,7 @@ from httpx_oauth.oauth2 import BaseOAuth2 from httpx_oauth.oauth2 import OAuth2Token from pydantic import BaseModel from sqlalchemy import text -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from danswer.auth.api_key import get_hashed_api_key_from_request from danswer.auth.invited_users import get_invited_users @@ -80,8 +80,8 @@ from danswer.db.auth import get_default_admin_user_emails from danswer.db.auth import get_user_count from danswer.db.auth import get_user_db from danswer.db.auth import SQLAlchemyUserAdminDB +from danswer.db.engine import get_async_session from danswer.db.engine import get_async_session_with_tenant -from danswer.db.engine import get_session from danswer.db.engine import get_session_with_tenant from danswer.db.models import AccessToken from danswer.db.models import OAuthAccount @@ -609,7 +609,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional async def optional_user_( request: Request, user: User | None, - db_session: Session, + async_db_session: AsyncSession, ) -> User | None: """NOTE: `request` and `db_session` are not used here, but are included for the EE version of this function.""" @@ -618,13 +618,21 @@ async def optional_user_( async def optional_user( request: Request, - db_session: Session = Depends(get_session), + async_db_session: AsyncSession = Depends(get_async_session), user: User | None = Depends(optional_fastapi_current_user), ) -> User | None: versioned_fetch_user = fetch_versioned_implementation( "danswer.auth.users", "optional_user_" ) - return await versioned_fetch_user(request, user, db_session) + user = await versioned_fetch_user(request, user, async_db_session) + + # check if an API key is present + if user is None: + hashed_api_key = get_hashed_api_key_from_request(request) + if hashed_api_key: + user = await fetch_user_for_api_key(hashed_api_key, async_db_session) + + return user async def double_check_user( @@ -910,8 +918,8 @@ def get_oauth_router( return router -def api_key_dep( - request: Request, db_session: Session = Depends(get_session) +async def api_key_dep( + request: Request, async_db_session: AsyncSession = Depends(get_async_session) ) -> User | None: if AUTH_TYPE == AuthType.DISABLED: return None @@ -921,7 +929,7 @@ def api_key_dep( raise HTTPException(status_code=401, detail="Missing API key") if hashed_api_key: - user = fetch_user_for_api_key(hashed_api_key, db_session) + user = await fetch_user_for_api_key(hashed_api_key, async_db_session) if user is None: raise HTTPException(status_code=401, detail="Invalid API key") diff --git a/backend/danswer/db/api_key.py b/backend/danswer/db/api_key.py index 1a16d73a9..b4a56f3f2 100644 --- a/backend/danswer/db/api_key.py +++ b/backend/danswer/db/api_key.py @@ -2,6 +2,7 @@ import uuid from fastapi_users.password import PasswordHelper from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session @@ -45,14 +46,16 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]: ] -def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None: - api_key = db_session.scalar( - select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key) +async def fetch_user_for_api_key( + hashed_api_key: str, async_db_session: AsyncSession +) -> User | None: + """NOTE: this is async, since it's used during auth + (which is necessarily async due to FastAPI Users)""" + return await async_db_session.scalar( + select(User) + .join(ApiKey, ApiKey.user_id == User.id) + .where(ApiKey.hashed_api_key == hashed_api_key) ) - if api_key is None: - return None - - return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore def get_api_key_fake_email( diff --git a/backend/ee/danswer/auth/users.py b/backend/ee/danswer/auth/users.py index 1db90e649..aab88efa8 100644 --- a/backend/ee/danswer/auth/users.py +++ b/backend/ee/danswer/auth/users.py @@ -2,15 +2,13 @@ from fastapi import Depends from fastapi import HTTPException from fastapi import Request from fastapi import status -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession -from danswer.auth.api_key import get_hashed_api_key_from_request from danswer.auth.users import current_admin_user from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import SUPER_CLOUD_API_KEY from danswer.configs.app_configs import SUPER_USERS from danswer.configs.constants import AuthType -from danswer.db.api_key import fetch_user_for_api_key from danswer.db.models import User from danswer.utils.logger import setup_logger from ee.danswer.db.saml import get_saml_account @@ -28,22 +26,18 @@ def verify_auth_setting() -> None: async def optional_user_( request: Request, user: User | None, - db_session: Session, + async_db_session: AsyncSession, ) -> User | None: # Check if the user has a session cookie from SAML if AUTH_TYPE == AuthType.SAML: saved_cookie = extract_hashed_cookie(request) if saved_cookie: - saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session) + saml_account = await get_saml_account( + cookie=saved_cookie, async_db_session=async_db_session + ) user = saml_account.user if saml_account else None - # check if an API key is present - if user is None: - hashed_api_key = get_hashed_api_key_from_request(request) - if hashed_api_key: - user = fetch_user_for_api_key(hashed_api_key, db_session) - return user diff --git a/backend/ee/danswer/db/saml.py b/backend/ee/danswer/db/saml.py index 6689a7a7e..fff18f8a1 100644 --- a/backend/ee/danswer/db/saml.py +++ b/backend/ee/danswer/db/saml.py @@ -5,6 +5,7 @@ from uuid import UUID from sqlalchemy import and_ from sqlalchemy import func from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS @@ -44,7 +45,11 @@ def upsert_saml_account( return saml_acc.expires_at -def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None: +async def get_saml_account( + cookie: str, async_db_session: AsyncSession +) -> SamlAccount | None: + """NOTE: this is async, since it's used during auth + (which is necessarily async due to FastAPI Users)""" stmt = ( select(SamlAccount) .join(User, User.id == SamlAccount.user_id) # type: ignore @@ -56,10 +61,12 @@ def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None: ) ) - result = db_session.execute(stmt) + result = await async_db_session.execute(stmt) return result.scalar_one_or_none() -def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None: +async def expire_saml_account( + saml_account: SamlAccount, async_db_session: AsyncSession +) -> None: saml_account.expires_at = func.now() - db_session.commit() + await async_db_session.commit() diff --git a/backend/ee/danswer/server/saml.py b/backend/ee/danswer/server/saml.py index 81a7b5d8d..20c786af1 100644 --- a/backend/ee/danswer/server/saml.py +++ b/backend/ee/danswer/server/saml.py @@ -12,6 +12,7 @@ from fastapi_users import exceptions from fastapi_users.password import PasswordHelper from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from danswer.auth.schemas import UserCreate @@ -170,15 +171,19 @@ async def saml_login_callback( @router.post("/logout") -def saml_logout( +async def saml_logout( request: Request, - db_session: Session = Depends(get_session), + async_db_session: AsyncSession = Depends(get_async_session), ) -> None: saved_cookie = extract_hashed_cookie(request) if saved_cookie: - saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session) + saml_account = await get_saml_account( + cookie=saved_cookie, async_db_session=async_db_session + ) if saml_account: - expire_saml_account(saml_account, db_session) + await expire_saml_account( + saml_account=saml_account, async_db_session=async_db_session + ) return