2024-12-13 09:56:10 -08:00

174 lines
6.1 KiB
Python

from datetime import datetime
from datetime import timezone
from typing import Any
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from fastapi import status
from fastapi import UploadFile
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME
from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME
from ee.onyx.server.enterprise_settings.store import load_analytics_script
from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
from ee.onyx.server.enterprise_settings.store import store_settings
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
admin_router = APIRouter(prefix="/admin/enterprise-settings")
basic_router = APIRouter(prefix="/enterprise-settings")
logger = setup_logger()
class RefreshTokenData(BaseModel):
access_token: str
refresh_token: str
session: dict = Field(..., description="Contains session information")
userinfo: dict = Field(..., description="Contains user information")
def __init__(self, **data: Any) -> None:
super().__init__(**data)
if "exp" not in self.session:
raise ValueError("'exp' must be set in the session dictionary")
if "userId" not in self.userinfo or "email" not in self.userinfo:
raise ValueError(
"'userId' and 'email' must be set in the userinfo dictionary"
)
@basic_router.post("/refresh-token")
async def refresh_access_token(
refresh_token: RefreshTokenData,
user: User = Depends(current_user_with_expired_token),
user_manager: UserManager = Depends(get_user_manager),
) -> None:
try:
logger.debug(f"Received response from Meechum auth URL for user {user.id}")
# Extract new tokens
new_access_token = refresh_token.access_token
new_refresh_token = refresh_token.refresh_token
new_expiry = datetime.fromtimestamp(
refresh_token.session["exp"] / 1000, tz=timezone.utc
)
expires_at_timestamp = int(new_expiry.timestamp())
logger.debug(f"Access token has been refreshed for user {user.id}")
await user_manager.oauth_callback(
oauth_name="custom",
access_token=new_access_token,
account_id=refresh_token.userinfo["userId"],
account_email=refresh_token.userinfo["email"],
expires_at=expires_at_timestamp,
refresh_token=new_refresh_token,
associate_by_email=True,
)
logger.info(f"Successfully refreshed tokens for user {user.id}")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
logger.warning(f"Full authentication required for user {user.id}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Full authentication required",
)
logger.error(
f"HTTP error occurred while refreshing token for user {user.id}: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to refresh token",
)
except Exception as e:
logger.error(
f"Unexpected error occurred while refreshing token for user {user.id}: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred",
)
@admin_router.put("")
def put_settings(
settings: EnterpriseSettings, _: User | None = Depends(current_admin_user)
) -> None:
store_settings(settings)
@basic_router.get("")
def fetch_settings() -> EnterpriseSettings:
return load_settings()
@admin_router.put("/logo")
def put_logo(
file: UploadFile,
is_logotype: bool = False,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME
file_io = file_store.read_file(filename, mode="b")
# NOTE: specifying "image/jpeg" here, but it still works for pngs
# TODO: do this properly
return Response(content=file_io.read(), media_type="image/jpeg")
except Exception:
raise HTTPException(
status_code=404,
detail=f"No {'logotype' if is_logotype else 'logo'} file found",
)
@basic_router.get("/logotype")
def fetch_logotype(db_session: Session = Depends(get_session)) -> Response:
return fetch_logo_or_logotype(is_logotype=True, db_session=db_session)
@basic_router.get("/logo")
def fetch_logo(
is_logotype: bool = False, db_session: Session = Depends(get_session)
) -> Response:
return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session)
@admin_router.put("/custom-analytics-script")
def upload_custom_analytics_script(
script_upload: AnalyticsScriptUpload, _: User | None = Depends(current_admin_user)
) -> None:
try:
store_analytics_script(script_upload)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@basic_router.get("/custom-analytics-script")
def fetch_custom_analytics_script() -> str | None:
return load_analytics_script()