diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index 7f18d594298..417c5afd8ea 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -19,6 +19,7 @@ from danswer.utils.variable_functionality import global_version from ee.danswer.configs.app_configs import OPENID_CONFIG_URL from ee.danswer.server.analytics.api import router as analytics_router from ee.danswer.server.api_key.api import router as api_key_router +from ee.danswer.server.auth_check import check_ee_router_auth from ee.danswer.server.enterprise_settings.api import ( admin_router as enterprise_settings_admin_router, ) @@ -85,6 +86,10 @@ def get_ee_application() -> FastAPI: application, enterprise_settings_admin_router ) include_router_with_global_prefix_prepended(application, enterprise_settings_router) + + # Ensure all routes have auth enabled or are explicitly marked as public + check_ee_router_auth(application) + return application diff --git a/backend/ee/danswer/server/auth_check.py b/backend/ee/danswer/server/auth_check.py new file mode 100644 index 00000000000..d0ba3ffe46c --- /dev/null +++ b/backend/ee/danswer/server/auth_check.py @@ -0,0 +1,28 @@ +from fastapi import FastAPI + +from danswer.server.auth_check import check_router_auth +from danswer.server.auth_check import PUBLIC_ENDPOINT_SPECS + + +EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [ + # needs to be accessible prior to user login + ("/enterprise-settings", {"GET"}), + ("/enterprise-settings/logo", {"GET"}), + ("/enterprise-settings/custom-analytics-script", {"GET"}), + # oidc + ("/auth/oidc/authorize", {"GET"}), + ("/auth/oidc/callback", {"GET"}), + # saml + ("/auth/saml/authorize", {"GET"}), + ("/auth/saml/callback", {"POST"}), + ("/auth/saml/logout", {"POST"}), +] + + +def check_ee_router_auth( + application: FastAPI, + public_endpoint_specs: list[tuple[str, set[str]]] = EE_PUBLIC_ENDPOINT_SPECS, +) -> None: + # similar to the open source version of this function, but checking for the EE-only + # endpoints as well + check_router_auth(application, public_endpoint_specs)