mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-14 05:40:16 +02:00
129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
from typing import cast
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.dependencies.models import Dependant
|
|
from starlette.routing import BaseRoute
|
|
|
|
from onyx.auth.users import current_admin_user
|
|
from onyx.auth.users import current_chat_accesssible_user
|
|
from onyx.auth.users import current_curator_or_admin_user
|
|
from onyx.auth.users import current_limited_user
|
|
from onyx.auth.users import current_user
|
|
from onyx.auth.users import current_user_with_expired_token
|
|
from onyx.configs.app_configs import APP_API_PREFIX
|
|
from onyx.server.onyx_api.ingestion import api_key_dep
|
|
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
|
|
|
|
|
PUBLIC_ENDPOINT_SPECS = [
|
|
# built-in documentation functions
|
|
("/openapi.json", {"GET", "HEAD"}),
|
|
("/docs", {"GET", "HEAD"}),
|
|
("/docs/oauth2-redirect", {"GET", "HEAD"}),
|
|
("/redoc", {"GET", "HEAD"}),
|
|
# should always be callable, will just return 401 if not authenticated
|
|
("/me", {"GET"}),
|
|
# just returns 200 to validate that the server is up
|
|
("/health", {"GET"}),
|
|
# just returns auth type, needs to be accessible before the user is logged
|
|
# in to determine what flow to give the user
|
|
("/auth/type", {"GET"}),
|
|
# just gets the version of Onyx (e.g. 0.3.11)
|
|
("/version", {"GET"}),
|
|
# stuff related to basic auth
|
|
("/auth/register", {"POST"}),
|
|
("/auth/login", {"POST"}),
|
|
("/auth/logout", {"POST"}),
|
|
("/auth/forgot-password", {"POST"}),
|
|
("/auth/reset-password", {"POST"}),
|
|
("/auth/request-verify-token", {"POST"}),
|
|
("/auth/verify", {"POST"}),
|
|
("/users/me", {"GET"}),
|
|
("/users/me", {"PATCH"}),
|
|
("/users/{id}", {"GET"}),
|
|
("/users/{id}", {"PATCH"}),
|
|
("/users/{id}", {"DELETE"}),
|
|
# oauth
|
|
("/auth/oauth/authorize", {"GET"}),
|
|
("/auth/oauth/callback", {"GET"}),
|
|
# anonymous user on cloud
|
|
("/tenants/anonymous-user", {"POST"}),
|
|
]
|
|
|
|
|
|
def is_route_in_spec_list(
|
|
route: BaseRoute, public_endpoint_specs: list[tuple[str, set[str]]]
|
|
) -> bool:
|
|
if not hasattr(route, "path") or not hasattr(route, "methods"):
|
|
return False
|
|
|
|
# try adding the prefix AND not adding the prefix, since some endpoints
|
|
# are not prefixed (e.g. /openapi.json)
|
|
if (route.path, route.methods) in public_endpoint_specs:
|
|
return True
|
|
|
|
processed_global_prefix = f"/{APP_API_PREFIX.strip('/')}" if APP_API_PREFIX else ""
|
|
if not processed_global_prefix:
|
|
return False
|
|
|
|
for endpoint_spec in public_endpoint_specs:
|
|
base_path, methods = endpoint_spec
|
|
prefixed_path = f"{processed_global_prefix}/{base_path.strip('/')}"
|
|
|
|
if prefixed_path == route.path and route.methods == methods:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def check_router_auth(
|
|
application: FastAPI,
|
|
public_endpoint_specs: list[tuple[str, set[str]]] = PUBLIC_ENDPOINT_SPECS,
|
|
) -> None:
|
|
"""Ensures that all endpoints on the passed in application either
|
|
(1) have auth enabled OR
|
|
(2) are explicitly marked as a public endpoint
|
|
"""
|
|
|
|
control_plane_dep = fetch_ee_implementation_or_noop(
|
|
"onyx.server.tenants.access", "control_plane_dep"
|
|
)
|
|
current_cloud_superuser = fetch_ee_implementation_or_noop(
|
|
"onyx.auth.users", "current_cloud_superuser"
|
|
)
|
|
|
|
for route in application.routes:
|
|
# explicitly marked as public
|
|
if is_route_in_spec_list(route, public_endpoint_specs):
|
|
continue
|
|
|
|
# check for auth
|
|
found_auth = False
|
|
route_dependant_obj = cast(
|
|
Dependant | None, route.dependant if hasattr(route, "dependant") else None
|
|
)
|
|
if route_dependant_obj:
|
|
for dependency in route_dependant_obj.dependencies:
|
|
depends_fn = dependency.cache_key[0]
|
|
if (
|
|
depends_fn == current_limited_user
|
|
or depends_fn == current_user
|
|
or depends_fn == current_admin_user
|
|
or depends_fn == current_curator_or_admin_user
|
|
or depends_fn == api_key_dep
|
|
or depends_fn == current_user_with_expired_token
|
|
or depends_fn == current_chat_accesssible_user
|
|
or depends_fn == control_plane_dep
|
|
or depends_fn == current_cloud_superuser
|
|
):
|
|
found_auth = True
|
|
break
|
|
|
|
if not found_auth:
|
|
# uncomment to print out all route(s) that are missing auth
|
|
# print(f"(\"{route.path}\", {set(route.methods)}),")
|
|
|
|
raise RuntimeError(
|
|
f"Did not find user dependency in private route - {route}"
|
|
)
|