danswer/backend/onyx/server/auth_check.py
pablonyx d9e9c6973d
Multitenant anonymous (#3595)
* anonymous users for multi tenant setting

* nit

* k
2025-01-07 02:57:20 +00:00

129 lines
4.6 KiB
Python

from typing import cast
from fastapi import FastAPI
from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.auth.users import current_user_with_expired_token
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.server.onyx_api.ingestion import api_key_dep
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
PUBLIC_ENDPOINT_SPECS = [
# built-in documentation functions
("/openapi.json", {"GET", "HEAD"}),
("/docs", {"GET", "HEAD"}),
("/docs/oauth2-redirect", {"GET", "HEAD"}),
("/redoc", {"GET", "HEAD"}),
# should always be callable, will just return 401 if not authenticated
("/me", {"GET"}),
# just returns 200 to validate that the server is up
("/health", {"GET"}),
# just returns auth type, needs to be accessible before the user is logged
# in to determine what flow to give the user
("/auth/type", {"GET"}),
# just gets the version of Onyx (e.g. 0.3.11)
("/version", {"GET"}),
# stuff related to basic auth
("/auth/register", {"POST"}),
("/auth/login", {"POST"}),
("/auth/logout", {"POST"}),
("/auth/forgot-password", {"POST"}),
("/auth/reset-password", {"POST"}),
("/auth/request-verify-token", {"POST"}),
("/auth/verify", {"POST"}),
("/users/me", {"GET"}),
("/users/me", {"PATCH"}),
("/users/{id}", {"GET"}),
("/users/{id}", {"PATCH"}),
("/users/{id}", {"DELETE"}),
# oauth
("/auth/oauth/authorize", {"GET"}),
("/auth/oauth/callback", {"GET"}),
# anonymous user on cloud
("/tenants/anonymous-user", {"POST"}),
]
def is_route_in_spec_list(
route: BaseRoute, public_endpoint_specs: list[tuple[str, set[str]]]
) -> bool:
if not hasattr(route, "path") or not hasattr(route, "methods"):
return False
# try adding the prefix AND not adding the prefix, since some endpoints
# are not prefixed (e.g. /openapi.json)
if (route.path, route.methods) in public_endpoint_specs:
return True
processed_global_prefix = f"/{APP_API_PREFIX.strip('/')}" if APP_API_PREFIX else ""
if not processed_global_prefix:
return False
for endpoint_spec in public_endpoint_specs:
base_path, methods = endpoint_spec
prefixed_path = f"{processed_global_prefix}/{base_path.strip('/')}"
if prefixed_path == route.path and route.methods == methods:
return True
return False
def check_router_auth(
application: FastAPI,
public_endpoint_specs: list[tuple[str, set[str]]] = PUBLIC_ENDPOINT_SPECS,
) -> None:
"""Ensures that all endpoints on the passed in application either
(1) have auth enabled OR
(2) are explicitly marked as a public endpoint
"""
control_plane_dep = fetch_ee_implementation_or_noop(
"onyx.server.tenants.access", "control_plane_dep"
)
current_cloud_superuser = fetch_ee_implementation_or_noop(
"onyx.auth.users", "current_cloud_superuser"
)
for route in application.routes:
# explicitly marked as public
if is_route_in_spec_list(route, public_endpoint_specs):
continue
# check for auth
found_auth = False
route_dependant_obj = cast(
Dependant | None, route.dependant if hasattr(route, "dependant") else None
)
if route_dependant_obj:
for dependency in route_dependant_obj.dependencies:
depends_fn = dependency.cache_key[0]
if (
depends_fn == current_limited_user
or depends_fn == current_user
or depends_fn == current_admin_user
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accesssible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):
found_auth = True
break
if not found_auth:
# uncomment to print out all route(s) that are missing auth
# print(f"(\"{route.path}\", {set(route.methods)}),")
raise RuntimeError(
f"Did not find user dependency in private route - {route}"
)