from datetime import datetime from datetime import timezone from typing import Any import httpx from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Response from fastapi import status from fastapi import UploadFile from pydantic import BaseModel from pydantic import Field from sqlalchemy.orm import Session from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload from ee.onyx.server.enterprise_settings.models import EnterpriseSettings from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME from ee.onyx.server.enterprise_settings.store import load_analytics_script from ee.onyx.server.enterprise_settings.store import load_settings from ee.onyx.server.enterprise_settings.store import store_analytics_script from ee.onyx.server.enterprise_settings.store import store_settings from ee.onyx.server.enterprise_settings.store import upload_logo from onyx.auth.users import current_admin_user from onyx.auth.users import current_user_with_expired_token from onyx.auth.users import get_user_manager from onyx.auth.users import UserManager from onyx.db.engine import get_session from onyx.db.models import User from onyx.file_store.file_store import get_default_file_store from onyx.utils.logger import setup_logger admin_router = APIRouter(prefix="/admin/enterprise-settings") basic_router = APIRouter(prefix="/enterprise-settings") logger = setup_logger() class RefreshTokenData(BaseModel): access_token: str refresh_token: str session: dict = Field(..., description="Contains session information") userinfo: dict = Field(..., description="Contains user information") def __init__(self, **data: Any) -> None: super().__init__(**data) if "exp" not in self.session: raise ValueError("'exp' must be set in the session dictionary") if "userId" not in self.userinfo or "email" not in self.userinfo: raise ValueError( "'userId' and 'email' must be set in the userinfo dictionary" ) @basic_router.post("/refresh-token") async def refresh_access_token( refresh_token: RefreshTokenData, user: User = Depends(current_user_with_expired_token), user_manager: UserManager = Depends(get_user_manager), ) -> None: try: logger.debug(f"Received response from Meechum auth URL for user {user.id}") # Extract new tokens new_access_token = refresh_token.access_token new_refresh_token = refresh_token.refresh_token new_expiry = datetime.fromtimestamp( refresh_token.session["exp"] / 1000, tz=timezone.utc ) expires_at_timestamp = int(new_expiry.timestamp()) logger.debug(f"Access token has been refreshed for user {user.id}") await user_manager.oauth_callback( oauth_name="custom", access_token=new_access_token, account_id=refresh_token.userinfo["userId"], account_email=refresh_token.userinfo["email"], expires_at=expires_at_timestamp, refresh_token=new_refresh_token, associate_by_email=True, ) logger.info(f"Successfully refreshed tokens for user {user.id}") except httpx.HTTPStatusError as e: if e.response.status_code == 401: logger.warning(f"Full authentication required for user {user.id}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Full authentication required", ) logger.error( f"HTTP error occurred while refreshing token for user {user.id}: {str(e)}" ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to refresh token", ) except Exception as e: logger.error( f"Unexpected error occurred while refreshing token for user {user.id}: {str(e)}" ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred", ) @admin_router.put("") def put_settings( settings: EnterpriseSettings, _: User | None = Depends(current_admin_user) ) -> None: store_settings(settings) @basic_router.get("") def fetch_settings() -> EnterpriseSettings: return load_settings() @admin_router.put("/logo") def put_logo( file: UploadFile, is_logotype: bool = False, db_session: Session = Depends(get_session), _: User | None = Depends(current_admin_user), ) -> None: upload_logo(file=file, db_session=db_session, is_logotype=is_logotype) def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response: try: file_store = get_default_file_store(db_session) filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME file_io = file_store.read_file(filename, mode="b") # NOTE: specifying "image/jpeg" here, but it still works for pngs # TODO: do this properly return Response(content=file_io.read(), media_type="image/jpeg") except Exception: raise HTTPException( status_code=404, detail=f"No {'logotype' if is_logotype else 'logo'} file found", ) @basic_router.get("/logotype") def fetch_logotype(db_session: Session = Depends(get_session)) -> Response: return fetch_logo_or_logotype(is_logotype=True, db_session=db_session) @basic_router.get("/logo") def fetch_logo( is_logotype: bool = False, db_session: Session = Depends(get_session) ) -> Response: return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session) @admin_router.put("/custom-analytics-script") def upload_custom_analytics_script( script_upload: AnalyticsScriptUpload, _: User | None = Depends(current_admin_user) ) -> None: try: store_analytics_script(script_upload) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @basic_router.get("/custom-analytics-script") def fetch_custom_analytics_script() -> str | None: return load_analytics_script()