mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 21:09:51 +02:00
Fix API keys for MIT users (#3237)
This commit is contained in:
parent
86d8666481
commit
7573416ca1
@ -49,7 +49,7 @@ from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
@ -80,8 +80,8 @@ from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
@ -609,7 +609,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
|
||||
async def optional_user_(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
async_db_session: AsyncSession,
|
||||
) -> User | None:
|
||||
"""NOTE: `request` and `db_session` are not used here, but are included
|
||||
for the EE version of this function."""
|
||||
@ -618,13 +618,21 @@ async def optional_user_(
|
||||
|
||||
async def optional_user(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
async_db_session: AsyncSession = Depends(get_async_session),
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
) -> User | None:
|
||||
versioned_fetch_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "optional_user_"
|
||||
)
|
||||
return await versioned_fetch_user(request, user, db_session)
|
||||
user = await versioned_fetch_user(request, user, async_db_session)
|
||||
|
||||
# check if an API key is present
|
||||
if user is None:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if hashed_api_key:
|
||||
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
@ -910,8 +918,8 @@ def get_oauth_router(
|
||||
return router
|
||||
|
||||
|
||||
def api_key_dep(
|
||||
request: Request, db_session: Session = Depends(get_session)
|
||||
async def api_key_dep(
|
||||
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
|
||||
) -> User | None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return None
|
||||
@ -921,7 +929,7 @@ def api_key_dep(
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
if hashed_api_key:
|
||||
user = fetch_user_for_api_key(hashed_api_key, db_session)
|
||||
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
@ -2,6 +2,7 @@ import uuid
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -45,14 +46,16 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
|
||||
]
|
||||
|
||||
|
||||
def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None:
|
||||
api_key = db_session.scalar(
|
||||
select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
async def fetch_user_for_api_key(
|
||||
hashed_api_key: str, async_db_session: AsyncSession
|
||||
) -> User | None:
|
||||
"""NOTE: this is async, since it's used during auth
|
||||
(which is necessarily async due to FastAPI Users)"""
|
||||
return await async_db_session.scalar(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
)
|
||||
if api_key is None:
|
||||
return None
|
||||
|
||||
return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore
|
||||
|
||||
|
||||
def get_api_key_fake_email(
|
||||
|
@ -2,15 +2,13 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.api_key import fetch_user_for_api_key
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.saml import get_saml_account
|
||||
@ -28,22 +26,18 @@ def verify_auth_setting() -> None:
|
||||
async def optional_user_(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
async_db_session: AsyncSession,
|
||||
) -> User | None:
|
||||
# Check if the user has a session cookie from SAML
|
||||
if AUTH_TYPE == AuthType.SAML:
|
||||
saved_cookie = extract_hashed_cookie(request)
|
||||
|
||||
if saved_cookie:
|
||||
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
|
||||
saml_account = await get_saml_account(
|
||||
cookie=saved_cookie, async_db_session=async_db_session
|
||||
)
|
||||
user = saml_account.user if saml_account else None
|
||||
|
||||
# check if an API key is present
|
||||
if user is None:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if hashed_api_key:
|
||||
user = fetch_user_for_api_key(hashed_api_key, db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ from uuid import UUID
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@ -44,7 +45,11 @@ def upsert_saml_account(
|
||||
return saml_acc.expires_at
|
||||
|
||||
|
||||
def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
|
||||
async def get_saml_account(
|
||||
cookie: str, async_db_session: AsyncSession
|
||||
) -> SamlAccount | None:
|
||||
"""NOTE: this is async, since it's used during auth
|
||||
(which is necessarily async due to FastAPI Users)"""
|
||||
stmt = (
|
||||
select(SamlAccount)
|
||||
.join(User, User.id == SamlAccount.user_id) # type: ignore
|
||||
@ -56,10 +61,12 @@ def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
|
||||
)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
result = await async_db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None:
|
||||
async def expire_saml_account(
|
||||
saml_account: SamlAccount, async_db_session: AsyncSession
|
||||
) -> None:
|
||||
saml_account.expires_at = func.now()
|
||||
db_session.commit()
|
||||
await async_db_session.commit()
|
||||
|
@ -12,6 +12,7 @@ from fastapi_users import exceptions
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserCreate
|
||||
@ -170,15 +171,19 @@ async def saml_login_callback(
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def saml_logout(
|
||||
async def saml_logout(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
async_db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> None:
|
||||
saved_cookie = extract_hashed_cookie(request)
|
||||
|
||||
if saved_cookie:
|
||||
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
|
||||
saml_account = await get_saml_account(
|
||||
cookie=saved_cookie, async_db_session=async_db_session
|
||||
)
|
||||
if saml_account:
|
||||
expire_saml_account(saml_account, db_session)
|
||||
await expire_saml_account(
|
||||
saml_account=saml_account, async_db_session=async_db_session
|
||||
)
|
||||
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user