Fix API keys for MIT users (#3237)

This commit is contained in:
Chris Weaver 2024-11-24 16:55:19 -08:00 committed by GitHub
parent 86d8666481
commit 7573416ca1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 51 additions and 34 deletions

View File

@ -49,7 +49,7 @@ from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
@ -80,8 +80,8 @@ from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
@ -609,7 +609,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
@ -618,13 +618,21 @@ async def optional_user_(
async def optional_user(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
return await versioned_fetch_user(request, user, db_session)
user = await versioned_fetch_user(request, user, async_db_session)
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
return user
async def double_check_user(
@ -910,8 +918,8 @@ def get_oauth_router(
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
@ -921,7 +929,7 @@ def api_key_dep(
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")

View File

@ -2,6 +2,7 @@ import uuid
from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
@ -45,14 +46,16 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
]
def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None:
api_key = db_session.scalar(
select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)
async def fetch_user_for_api_key(
hashed_api_key: str, async_db_session: AsyncSession
) -> User | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
return await async_db_session.scalar(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
)
if api_key is None:
return None
return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore
def get_api_key_fake_email(

View File

@ -2,15 +2,13 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.db.saml import get_saml_account
@ -28,22 +26,18 @@ def verify_auth_setting() -> None:
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
# Check if the user has a session cookie from SAML
if AUTH_TYPE == AuthType.SAML:
saved_cookie = extract_hashed_cookie(request)
if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
user = saml_account.user if saml_account else None
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
return user

View File

@ -5,6 +5,7 @@ from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@ -44,7 +45,11 @@ def upsert_saml_account(
return saml_acc.expires_at
def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
async def get_saml_account(
cookie: str, async_db_session: AsyncSession
) -> SamlAccount | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
stmt = (
select(SamlAccount)
.join(User, User.id == SamlAccount.user_id) # type: ignore
@ -56,10 +61,12 @@ def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
)
)
result = db_session.execute(stmt)
result = await async_db_session.execute(stmt)
return result.scalar_one_or_none()
def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None:
async def expire_saml_account(
saml_account: SamlAccount, async_db_session: AsyncSession
) -> None:
saml_account.expires_at = func.now()
db_session.commit()
await async_db_session.commit()

View File

@ -12,6 +12,7 @@ from fastapi_users import exceptions
from fastapi_users.password import PasswordHelper
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserCreate
@ -170,15 +171,19 @@ async def saml_login_callback(
@router.post("/logout")
def saml_logout(
async def saml_logout(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
) -> None:
saved_cookie = extract_hashed_cookie(request)
if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
if saml_account:
expire_saml_account(saml_account, db_session)
await expire_saml_account(
saml_account=saml_account, async_db_session=async_db_session
)
return