Add check to ensure auth is enabled for every endpoint unless explicitly whitelisted

This commit is contained in:
Weves 2024-04-18 21:57:01 -07:00 committed by Chris Weaver
parent e361e92230
commit 82b9cb4cc1
4 changed files with 86 additions and 23 deletions

View File

@ -52,6 +52,7 @@ from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_default_llm_version
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import get_danswer_api_key
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@ -353,6 +354,9 @@ def get_application() -> FastAPI:
allow_headers=["*"],
)
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
return application

View File

@ -32,7 +32,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
def download_nltk_data():
def download_nltk_data() -> None:
resources = {
"stopwords": "corpora/stopwords",
"wordnet": "corpora/wordnet",

View File

@ -0,0 +1,81 @@
from typing import cast
from fastapi import FastAPI
from fastapi.dependencies.models import Dependant
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.server.danswer_api.ingestion import api_key_dep
PUBLIC_ENDPOINT_SPECS = [
# built-in documentation functions
("/openapi.json", {"GET", "HEAD"}),
("/docs", {"GET", "HEAD"}),
("/docs/oauth2-redirect", {"GET", "HEAD"}),
("/redoc", {"GET", "HEAD"}),
# should always be callable, will just return 401 if not authenticated
("/manage/me", {"GET"}),
# just returns 200 to validate that the server is up
("/health", {"GET"}),
# just returns auth type, needs to be accessible before the user is logged
# in to determine what flow to give the user
("/auth/type", {"GET"}),
# just gets the version of Danswer (e.g. 0.3.11)
("/version", {"GET"}),
# stuff related to basic auth
("/auth/register", {"POST"}),
("/auth/login", {"POST"}),
("/auth/logout", {"POST"}),
("/auth/forgot-password", {"POST"}),
("/auth/reset-password", {"POST"}),
("/auth/request-verify-token", {"POST"}),
("/auth/verify", {"POST"}),
("/users/me", {"GET"}),
("/users/me", {"PATCH"}),
("/users/{id}", {"GET"}),
("/users/{id}", {"PATCH"}),
("/users/{id}", {"DELETE"}),
# oauth
("/auth/oauth/authorize", {"GET"}),
("/auth/oauth/callback", {"GET"}),
]
def check_router_auth(application: FastAPI) -> None:
"""Ensures that all endpoints on the passed in application either
(1) have auth enabled OR
(2) are explicitly marked as a public endpoint
"""
for route in application.routes:
# explicitly marked as public
if (
hasattr(route, "path")
and hasattr(route, "methods")
and (route.path, route.methods) in PUBLIC_ENDPOINT_SPECS
):
continue
# check for auth
found_auth = False
route_dependant_obj = cast(
Dependant | None, route.dependant if hasattr(route, "dependant") else None
)
if route_dependant_obj:
for dependency in route_dependant_obj.dependencies:
depends_fn = dependency.cache_key[0]
if (
depends_fn == current_user
or depends_fn == current_admin_user
or depends_fn == api_key_dep
):
found_auth = True
break
if not found_auth:
# uncomment to print out all route(s) that are missing auth
# print(f"(\"{route.path}\", {set(route.methods)}),")
raise RuntimeError(
f"Did not find current_user or current_admin_user dependency in route - {route}"
)

View File

@ -1,5 +1,4 @@
import secrets
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
@ -25,7 +24,6 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.server.danswer_api.models import IngestionDocument
from danswer.server.danswer_api.models import IngestionResult
from danswer.server.models import ApiKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -69,26 +67,6 @@ def api_key_dep(authorization: str = Header(...)) -> str:
return token
# Provides a way to recover if the api key is deleted for some reason
# Can also just restart the server to regenerate a new one
def api_key_dep_if_exist(authorization: str | None = Header(None)) -> str | None:
token = authorization.removeprefix("Bearer ").strip() if authorization else None
saved_key = get_danswer_api_key(dont_regenerate=True)
if not saved_key:
return None
if token != saved_key:
raise HTTPException(status_code=401, detail="Invalid API key")
return token
@router.post("/regenerate-key")
def regenerate_key(_: str | None = Depends(api_key_dep_if_exist)) -> ApiKey:
delete_danswer_api_key()
return ApiKey(api_key=cast(str, get_danswer_api_key()))
@router.post("/doc-ingestion")
def document_ingestion(
doc_info: IngestionDocument,