mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-06 13:09:39 +02:00
Add check to ensure auth is enabled for every endpoint unless explicitly whitelisted
This commit is contained in:
parent
e361e92230
commit
82b9cb4cc1
@ -52,6 +52,7 @@ from danswer.llm.factory import get_default_llm
|
|||||||
from danswer.llm.utils import get_default_llm_version
|
from danswer.llm.utils import get_default_llm_version
|
||||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||||
from danswer.search.search_nlp_models import warm_up_encoders
|
from danswer.search.search_nlp_models import warm_up_encoders
|
||||||
|
from danswer.server.auth_check import check_router_auth
|
||||||
from danswer.server.danswer_api.ingestion import get_danswer_api_key
|
from danswer.server.danswer_api.ingestion import get_danswer_api_key
|
||||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||||
@ -353,6 +354,9 @@ def get_application() -> FastAPI:
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||||
|
check_router_auth(application)
|
||||||
|
|
||||||
return application
|
return application
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def download_nltk_data():
|
def download_nltk_data() -> None:
|
||||||
resources = {
|
resources = {
|
||||||
"stopwords": "corpora/stopwords",
|
"stopwords": "corpora/stopwords",
|
||||||
"wordnet": "corpora/wordnet",
|
"wordnet": "corpora/wordnet",
|
||||||
|
81
backend/danswer/server/auth_check.py
Normal file
81
backend/danswer/server/auth_check.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.dependencies.models import Dependant
|
||||||
|
|
||||||
|
from danswer.auth.users import current_admin_user
|
||||||
|
from danswer.auth.users import current_user
|
||||||
|
from danswer.server.danswer_api.ingestion import api_key_dep
|
||||||
|
|
||||||
|
|
||||||
|
PUBLIC_ENDPOINT_SPECS = [
|
||||||
|
# built-in documentation functions
|
||||||
|
("/openapi.json", {"GET", "HEAD"}),
|
||||||
|
("/docs", {"GET", "HEAD"}),
|
||||||
|
("/docs/oauth2-redirect", {"GET", "HEAD"}),
|
||||||
|
("/redoc", {"GET", "HEAD"}),
|
||||||
|
# should always be callable, will just return 401 if not authenticated
|
||||||
|
("/manage/me", {"GET"}),
|
||||||
|
# just returns 200 to validate that the server is up
|
||||||
|
("/health", {"GET"}),
|
||||||
|
# just returns auth type, needs to be accessible before the user is logged
|
||||||
|
# in to determine what flow to give the user
|
||||||
|
("/auth/type", {"GET"}),
|
||||||
|
# just gets the version of Danswer (e.g. 0.3.11)
|
||||||
|
("/version", {"GET"}),
|
||||||
|
# stuff related to basic auth
|
||||||
|
("/auth/register", {"POST"}),
|
||||||
|
("/auth/login", {"POST"}),
|
||||||
|
("/auth/logout", {"POST"}),
|
||||||
|
("/auth/forgot-password", {"POST"}),
|
||||||
|
("/auth/reset-password", {"POST"}),
|
||||||
|
("/auth/request-verify-token", {"POST"}),
|
||||||
|
("/auth/verify", {"POST"}),
|
||||||
|
("/users/me", {"GET"}),
|
||||||
|
("/users/me", {"PATCH"}),
|
||||||
|
("/users/{id}", {"GET"}),
|
||||||
|
("/users/{id}", {"PATCH"}),
|
||||||
|
("/users/{id}", {"DELETE"}),
|
||||||
|
# oauth
|
||||||
|
("/auth/oauth/authorize", {"GET"}),
|
||||||
|
("/auth/oauth/callback", {"GET"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def check_router_auth(application: FastAPI) -> None:
|
||||||
|
"""Ensures that all endpoints on the passed in application either
|
||||||
|
(1) have auth enabled OR
|
||||||
|
(2) are explicitly marked as a public endpoint
|
||||||
|
"""
|
||||||
|
for route in application.routes:
|
||||||
|
# explicitly marked as public
|
||||||
|
if (
|
||||||
|
hasattr(route, "path")
|
||||||
|
and hasattr(route, "methods")
|
||||||
|
and (route.path, route.methods) in PUBLIC_ENDPOINT_SPECS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# check for auth
|
||||||
|
found_auth = False
|
||||||
|
route_dependant_obj = cast(
|
||||||
|
Dependant | None, route.dependant if hasattr(route, "dependant") else None
|
||||||
|
)
|
||||||
|
if route_dependant_obj:
|
||||||
|
for dependency in route_dependant_obj.dependencies:
|
||||||
|
depends_fn = dependency.cache_key[0]
|
||||||
|
if (
|
||||||
|
depends_fn == current_user
|
||||||
|
or depends_fn == current_admin_user
|
||||||
|
or depends_fn == api_key_dep
|
||||||
|
):
|
||||||
|
found_auth = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found_auth:
|
||||||
|
# uncomment to print out all route(s) that are missing auth
|
||||||
|
# print(f"(\"{route.path}\", {set(route.methods)}),")
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Did not find current_user or current_admin_user dependency in route - {route}"
|
||||||
|
)
|
@ -1,5 +1,4 @@
|
|||||||
import secrets
|
import secrets
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
@ -25,7 +24,6 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder
|
|||||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||||
from danswer.server.danswer_api.models import IngestionDocument
|
from danswer.server.danswer_api.models import IngestionDocument
|
||||||
from danswer.server.danswer_api.models import IngestionResult
|
from danswer.server.danswer_api.models import IngestionResult
|
||||||
from danswer.server.models import ApiKey
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -69,26 +67,6 @@ def api_key_dep(authorization: str = Header(...)) -> str:
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
# Provides a way to recover if the api key is deleted for some reason
|
|
||||||
# Can also just restart the server to regenerate a new one
|
|
||||||
def api_key_dep_if_exist(authorization: str | None = Header(None)) -> str | None:
|
|
||||||
token = authorization.removeprefix("Bearer ").strip() if authorization else None
|
|
||||||
saved_key = get_danswer_api_key(dont_regenerate=True)
|
|
||||||
if not saved_key:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if token != saved_key:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
||||||
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/regenerate-key")
|
|
||||||
def regenerate_key(_: str | None = Depends(api_key_dep_if_exist)) -> ApiKey:
|
|
||||||
delete_danswer_api_key()
|
|
||||||
return ApiKey(api_key=cast(str, get_danswer_api_key()))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/doc-ingestion")
|
@router.post("/doc-ingestion")
|
||||||
def document_ingestion(
|
def document_ingestion(
|
||||||
doc_info: IngestionDocument,
|
doc_info: IngestionDocument,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user