diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 0e43e9754..e8afa0838 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -52,6 +52,7 @@ from danswer.llm.factory import get_default_llm from danswer.llm.utils import get_default_llm_version from danswer.search.retrieval.search_runner import download_nltk_data from danswer.search.search_nlp_models import warm_up_encoders +from danswer.server.auth_check import check_router_auth from danswer.server.danswer_api.ingestion import get_danswer_api_key from danswer.server.danswer_api.ingestion import router as danswer_api_router from danswer.server.documents.cc_pair import router as cc_pair_router @@ -353,6 +354,9 @@ def get_application() -> FastAPI: allow_headers=["*"], ) + # Ensure all routes have auth enabled or are explicitly marked as public + check_router_auth(application) + return application diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 092b755d9..411db5b0f 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -32,7 +32,7 @@ from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() -def download_nltk_data(): +def download_nltk_data() -> None: resources = { "stopwords": "corpora/stopwords", "wordnet": "corpora/wordnet", diff --git a/backend/danswer/server/auth_check.py b/backend/danswer/server/auth_check.py new file mode 100644 index 000000000..54efc4956 --- /dev/null +++ b/backend/danswer/server/auth_check.py @@ -0,0 +1,81 @@ +from typing import cast + +from fastapi import FastAPI +from fastapi.dependencies.models import Dependant + +from danswer.auth.users import current_admin_user +from danswer.auth.users import current_user +from danswer.server.danswer_api.ingestion import api_key_dep + + +PUBLIC_ENDPOINT_SPECS = [ + # built-in documentation functions + ("/openapi.json", {"GET", "HEAD"}), + ("/docs", {"GET", "HEAD"}), + ("/docs/oauth2-redirect", {"GET", "HEAD"}), + ("/redoc", {"GET", "HEAD"}), + # should always be callable, will just return 401 if not authenticated + ("/manage/me", {"GET"}), + # just returns 200 to validate that the server is up + ("/health", {"GET"}), + # just returns auth type, needs to be accessible before the user is logged + # in to determine what flow to give the user + ("/auth/type", {"GET"}), + # just gets the version of Danswer (e.g. 0.3.11) + ("/version", {"GET"}), + # stuff related to basic auth + ("/auth/register", {"POST"}), + ("/auth/login", {"POST"}), + ("/auth/logout", {"POST"}), + ("/auth/forgot-password", {"POST"}), + ("/auth/reset-password", {"POST"}), + ("/auth/request-verify-token", {"POST"}), + ("/auth/verify", {"POST"}), + ("/users/me", {"GET"}), + ("/users/me", {"PATCH"}), + ("/users/{id}", {"GET"}), + ("/users/{id}", {"PATCH"}), + ("/users/{id}", {"DELETE"}), + # oauth + ("/auth/oauth/authorize", {"GET"}), + ("/auth/oauth/callback", {"GET"}), +] + + +def check_router_auth(application: FastAPI) -> None: + """Ensures that all endpoints on the passed in application either + (1) have auth enabled OR + (2) are explicitly marked as a public endpoint + """ + for route in application.routes: + # explicitly marked as public + if ( + hasattr(route, "path") + and hasattr(route, "methods") + and (route.path, route.methods) in PUBLIC_ENDPOINT_SPECS + ): + continue + + # check for auth + found_auth = False + route_dependant_obj = cast( + Dependant | None, route.dependant if hasattr(route, "dependant") else None + ) + if route_dependant_obj: + for dependency in route_dependant_obj.dependencies: + depends_fn = dependency.cache_key[0] + if ( + depends_fn == current_user + or depends_fn == current_admin_user + or depends_fn == api_key_dep + ): + found_auth = True + break + + if not found_auth: + # uncomment to print out all route(s) that are missing auth + # print(f"(\"{route.path}\", {set(route.methods)}),") + + raise RuntimeError( + f"Did not find current_user or current_admin_user dependency in route - {route}" + ) diff --git a/backend/danswer/server/danswer_api/ingestion.py b/backend/danswer/server/danswer_api/ingestion.py index 7fce8d1d3..2b0bed6c3 100644 --- a/backend/danswer/server/danswer_api/ingestion.py +++ b/backend/danswer/server/danswer_api/ingestion.py @@ -1,5 +1,4 @@ import secrets -from typing import cast from fastapi import APIRouter from fastapi import Depends @@ -25,7 +24,6 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder from danswer.indexing.indexing_pipeline import build_indexing_pipeline from danswer.server.danswer_api.models import IngestionDocument from danswer.server.danswer_api.models import IngestionResult -from danswer.server.models import ApiKey from danswer.utils.logger import setup_logger logger = setup_logger() @@ -69,26 +67,6 @@ def api_key_dep(authorization: str = Header(...)) -> str: return token -# Provides a way to recover if the api key is deleted for some reason -# Can also just restart the server to regenerate a new one -def api_key_dep_if_exist(authorization: str | None = Header(None)) -> str | None: - token = authorization.removeprefix("Bearer ").strip() if authorization else None - saved_key = get_danswer_api_key(dont_regenerate=True) - if not saved_key: - return None - - if token != saved_key: - raise HTTPException(status_code=401, detail="Invalid API key") - - return token - - -@router.post("/regenerate-key") -def regenerate_key(_: str | None = Depends(api_key_dep_if_exist)) -> ApiKey: - delete_danswer_api_key() - return ApiKey(api_key=cast(str, get_danswer_api_key())) - - @router.post("/doc-ingestion") def document_ingestion( doc_info: IngestionDocument,