mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-22 22:44:21 +02:00
138 lines
4.3 KiB
Python
138 lines
4.3 KiB
Python
from datetime import datetime
|
|
from functools import lru_cache
|
|
|
|
import jwt
|
|
import requests
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from fastapi import Request
|
|
from fastapi import status
|
|
from jwt import decode as jwt_decode
|
|
from jwt import InvalidTokenError
|
|
from jwt import PyJWTError
|
|
from sqlalchemy import func
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from ee.onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
|
|
from ee.onyx.configs.app_configs import SUPER_CLOUD_API_KEY
|
|
from ee.onyx.configs.app_configs import SUPER_USERS
|
|
from ee.onyx.db.saml import get_saml_account
|
|
from ee.onyx.server.seeding import get_seed_config
|
|
from ee.onyx.utils.secrets import extract_hashed_cookie
|
|
from onyx.auth.users import current_admin_user
|
|
from onyx.configs.app_configs import AUTH_TYPE
|
|
from onyx.configs.app_configs import USER_AUTH_SECRET
|
|
from onyx.configs.constants import AuthType
|
|
from onyx.db.models import User
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
@lru_cache()
|
|
def get_public_key() -> str | None:
|
|
if JWT_PUBLIC_KEY_URL is None:
|
|
logger.error("JWT_PUBLIC_KEY_URL is not set")
|
|
return None
|
|
|
|
response = requests.get(JWT_PUBLIC_KEY_URL)
|
|
response.raise_for_status()
|
|
return response.text
|
|
|
|
|
|
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
|
|
try:
|
|
public_key_pem = get_public_key()
|
|
if public_key_pem is None:
|
|
logger.error("Failed to retrieve public key")
|
|
return None
|
|
|
|
payload = jwt_decode(
|
|
token,
|
|
public_key_pem,
|
|
algorithms=["RS256"],
|
|
audience=None,
|
|
)
|
|
email = payload.get("email")
|
|
if email:
|
|
result = await async_db_session.execute(
|
|
select(User).where(func.lower(User.email) == func.lower(email))
|
|
)
|
|
return result.scalars().first()
|
|
except InvalidTokenError:
|
|
logger.error("Invalid JWT token")
|
|
get_public_key.cache_clear()
|
|
except PyJWTError as e:
|
|
logger.error(f"JWT decoding error: {str(e)}")
|
|
get_public_key.cache_clear()
|
|
return None
|
|
|
|
|
|
def verify_auth_setting() -> None:
|
|
# All the Auth flows are valid for EE version
|
|
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
|
|
|
|
|
async def optional_user_(
|
|
request: Request,
|
|
user: User | None,
|
|
async_db_session: AsyncSession,
|
|
) -> User | None:
|
|
# Check if the user has a session cookie from SAML
|
|
if AUTH_TYPE == AuthType.SAML:
|
|
saved_cookie = extract_hashed_cookie(request)
|
|
|
|
if saved_cookie:
|
|
saml_account = await get_saml_account(
|
|
cookie=saved_cookie, async_db_session=async_db_session
|
|
)
|
|
user = saml_account.user if saml_account else None
|
|
|
|
# If user is still None, check for JWT in Authorization header
|
|
if user is None and JWT_PUBLIC_KEY_URL is not None:
|
|
auth_header = request.headers.get("Authorization")
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
token = auth_header[len("Bearer ") :].strip()
|
|
user = await verify_jwt_token(token, async_db_session)
|
|
|
|
return user
|
|
|
|
|
|
def get_default_admin_user_emails_() -> list[str]:
|
|
seed_config = get_seed_config()
|
|
if seed_config and seed_config.admin_user_emails:
|
|
return seed_config.admin_user_emails
|
|
return []
|
|
|
|
|
|
async def current_cloud_superuser(
|
|
request: Request,
|
|
user: User | None = Depends(current_admin_user),
|
|
) -> User | None:
|
|
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
|
|
if api_key != SUPER_CLOUD_API_KEY:
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
|
|
if user and user.email not in SUPER_USERS:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access denied. User must be a cloud superuser to perform this action.",
|
|
)
|
|
return user
|
|
|
|
|
|
def generate_anonymous_user_jwt_token(tenant_id: str) -> str:
|
|
payload = {
|
|
"tenant_id": tenant_id,
|
|
# Token does not expire
|
|
"iat": datetime.utcnow(), # Issued at time
|
|
}
|
|
|
|
return jwt.encode(payload, USER_AUTH_SECRET, algorithm="HS256")
|
|
|
|
|
|
def decode_anonymous_user_jwt_token(token: str) -> dict:
|
|
return jwt.decode(token, USER_AUTH_SECRET, algorithms=["HS256"])
|