diff --git a/backend/ee/onyx/main.py b/backend/ee/onyx/main.py index 4c44a1aec..e47a193cd 100644 --- a/backend/ee/onyx/main.py +++ b/backend/ee/onyx/main.py @@ -1,3 +1,6 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + from fastapi import FastAPI from httpx_oauth.clients.google import GoogleOAuth2 from httpx_oauth.clients.openid import BASE_SCOPES @@ -44,6 +47,7 @@ from onyx.configs.constants import AuthType from onyx.main import get_application as get_application_base from onyx.main import include_auth_router_with_prefix from onyx.main import include_router_with_global_prefix_prepended +from onyx.main import lifespan as lifespan_base from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import global_version from shared_configs.configs import MULTI_TENANT @@ -51,6 +55,20 @@ from shared_configs.configs import MULTI_TENANT logger = setup_logger() +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Small wrapper around the lifespan of the MIT application. + Basically just calls the base lifespan, and then adds EE-only + steps after.""" + + async with lifespan_base(app): + # seed the Onyx environment with LLMs, Assistants, etc. based on an optional + # environment variable. Used to automate deployment for multiple environments. + seed_db() + + yield + + def get_application() -> FastAPI: # Anything that happens at import time is not guaranteed to be running ee-version # Anything after the server startup will be running ee version @@ -58,7 +76,7 @@ def get_application() -> FastAPI: test_encryption() - application = get_application_base() + application = get_application_base(lifespan_override=lifespan) if MULTI_TENANT: add_tenant_id_middleware(application, logger) @@ -166,10 +184,6 @@ def get_application() -> FastAPI: # Ensure all routes have auth enabled or are explicitly marked as public check_ee_router_auth(application) - # seed the Onyx environment with LLMs, Assistants, etc. based on an optional - # environment variable. Used to automate deployment for multiple environments. - seed_db() - # for debugging discovered routes # for route in application.router.routes: # print(f"Path: {route.path}, Methods: {route.methods}") diff --git a/backend/onyx/main.py b/backend/onyx/main.py index c4db368d8..4eb9c83a6 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -21,6 +21,7 @@ from prometheus_fastapi_instrumentator import Instrumentator from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.starlette import StarletteIntegration from sqlalchemy.orm import Session +from starlette.types import Lifespan from onyx import __version__ from onyx.auth.schemas import UserCreate @@ -275,8 +276,12 @@ def log_http_error(request: Request, exc: Exception) -> JSONResponse: ) -def get_application() -> FastAPI: - application = FastAPI(title="Onyx Backend", version=__version__, lifespan=lifespan) +def get_application(lifespan_override: Lifespan | None = None) -> FastAPI: + application = FastAPI( + title="Onyx Backend", + version=__version__, + lifespan=lifespan_override or lifespan, + ) if SENTRY_DSN: sentry_sdk.init( dsn=SENTRY_DSN,