mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-31 10:10:21 +02:00
Add check to ensure auth is enabled for every endpoint unless explicitly whitelisted
This commit is contained in:
parent
e361e92230
commit
82b9cb4cc1
@ -52,6 +52,7 @@ from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
from danswer.server.danswer_api.ingestion import get_danswer_api_key
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
@ -353,6 +354,9 @@ def get_application() -> FastAPI:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_router_auth(application)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def download_nltk_data():
|
||||
def download_nltk_data() -> None:
|
||||
resources = {
|
||||
"stopwords": "corpora/stopwords",
|
||||
"wordnet": "corpora/wordnet",
|
||||
|
81
backend/danswer/server/auth_check.py
Normal file
81
backend/danswer/server/auth_check.py
Normal file
@ -0,0 +1,81 @@
|
||||
from typing import cast
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.dependencies.models import Dependant
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.server.danswer_api.ingestion import api_key_dep
|
||||
|
||||
|
||||
PUBLIC_ENDPOINT_SPECS = [
|
||||
# built-in documentation functions
|
||||
("/openapi.json", {"GET", "HEAD"}),
|
||||
("/docs", {"GET", "HEAD"}),
|
||||
("/docs/oauth2-redirect", {"GET", "HEAD"}),
|
||||
("/redoc", {"GET", "HEAD"}),
|
||||
# should always be callable, will just return 401 if not authenticated
|
||||
("/manage/me", {"GET"}),
|
||||
# just returns 200 to validate that the server is up
|
||||
("/health", {"GET"}),
|
||||
# just returns auth type, needs to be accessible before the user is logged
|
||||
# in to determine what flow to give the user
|
||||
("/auth/type", {"GET"}),
|
||||
# just gets the version of Danswer (e.g. 0.3.11)
|
||||
("/version", {"GET"}),
|
||||
# stuff related to basic auth
|
||||
("/auth/register", {"POST"}),
|
||||
("/auth/login", {"POST"}),
|
||||
("/auth/logout", {"POST"}),
|
||||
("/auth/forgot-password", {"POST"}),
|
||||
("/auth/reset-password", {"POST"}),
|
||||
("/auth/request-verify-token", {"POST"}),
|
||||
("/auth/verify", {"POST"}),
|
||||
("/users/me", {"GET"}),
|
||||
("/users/me", {"PATCH"}),
|
||||
("/users/{id}", {"GET"}),
|
||||
("/users/{id}", {"PATCH"}),
|
||||
("/users/{id}", {"DELETE"}),
|
||||
# oauth
|
||||
("/auth/oauth/authorize", {"GET"}),
|
||||
("/auth/oauth/callback", {"GET"}),
|
||||
]
|
||||
|
||||
|
||||
def check_router_auth(application: FastAPI) -> None:
|
||||
"""Ensures that all endpoints on the passed in application either
|
||||
(1) have auth enabled OR
|
||||
(2) are explicitly marked as a public endpoint
|
||||
"""
|
||||
for route in application.routes:
|
||||
# explicitly marked as public
|
||||
if (
|
||||
hasattr(route, "path")
|
||||
and hasattr(route, "methods")
|
||||
and (route.path, route.methods) in PUBLIC_ENDPOINT_SPECS
|
||||
):
|
||||
continue
|
||||
|
||||
# check for auth
|
||||
found_auth = False
|
||||
route_dependant_obj = cast(
|
||||
Dependant | None, route.dependant if hasattr(route, "dependant") else None
|
||||
)
|
||||
if route_dependant_obj:
|
||||
for dependency in route_dependant_obj.dependencies:
|
||||
depends_fn = dependency.cache_key[0]
|
||||
if (
|
||||
depends_fn == current_user
|
||||
or depends_fn == current_admin_user
|
||||
or depends_fn == api_key_dep
|
||||
):
|
||||
found_auth = True
|
||||
break
|
||||
|
||||
if not found_auth:
|
||||
# uncomment to print out all route(s) that are missing auth
|
||||
# print(f"(\"{route.path}\", {set(route.methods)}),")
|
||||
|
||||
raise RuntimeError(
|
||||
f"Did not find current_user or current_admin_user dependency in route - {route}"
|
||||
)
|
@ -1,5 +1,4 @@
|
||||
import secrets
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@ -25,7 +24,6 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.server.danswer_api.models import IngestionDocument
|
||||
from danswer.server.danswer_api.models import IngestionResult
|
||||
from danswer.server.models import ApiKey
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -69,26 +67,6 @@ def api_key_dep(authorization: str = Header(...)) -> str:
|
||||
return token
|
||||
|
||||
|
||||
# Provides a way to recover if the api key is deleted for some reason
|
||||
# Can also just restart the server to regenerate a new one
|
||||
def api_key_dep_if_exist(authorization: str | None = Header(None)) -> str | None:
|
||||
token = authorization.removeprefix("Bearer ").strip() if authorization else None
|
||||
saved_key = get_danswer_api_key(dont_regenerate=True)
|
||||
if not saved_key:
|
||||
return None
|
||||
|
||||
if token != saved_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
@router.post("/regenerate-key")
|
||||
def regenerate_key(_: str | None = Depends(api_key_dep_if_exist)) -> ApiKey:
|
||||
delete_danswer_api_key()
|
||||
return ApiKey(api_key=cast(str, get_danswer_api_key()))
|
||||
|
||||
|
||||
@router.post("/doc-ingestion")
|
||||
def document_ingestion(
|
||||
doc_info: IngestionDocument,
|
||||
|
Loading…
x
Reference in New Issue
Block a user