Add support for global API prefix env variable

This commit is contained in:
Weves 2023-12-07 12:32:45 -08:00 committed by Chris Weaver
parent 56785e6065
commit ddf3f99da4
2 changed files with 56 additions and 19 deletions

View File

@ -9,6 +9,10 @@ from danswer.configs.constants import DocumentIndexType
#####
APP_HOST = "0.0.0.0"
APP_PORT = 8080
# API_PREFIX is used to prepend a base path for all API routes
# generally used if using a reverse proxy which doesn't support stripping the `/api`
# prefix from requests directed towards the API server. In these cases, set this to `/api`
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
#####

View File

@ -1,6 +1,10 @@
from typing import Any
from typing import cast
import nltk # type:ignore
import torch
import uvicorn
from fastapi import APIRouter
from fastapi import FastAPI
from fastapi import Request
from fastapi.exceptions import RequestValidationError
@ -16,6 +20,7 @@ from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.chat.personas import load_personas_from_yaml
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
@ -87,47 +92,73 @@ def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
)
def include_router_with_global_prefix_prepended(
application: FastAPI, router: APIRouter, **kwargs: Any
) -> None:
"""Adds the global prefix to all routes in the router."""
processed_global_prefix = f"/{APP_API_PREFIX.strip('/')}" if APP_API_PREFIX else ""
passed_in_prefix = cast(str | None, kwargs.get("prefix"))
if passed_in_prefix:
final_prefix = f"{processed_global_prefix}/{passed_in_prefix.strip('/')}"
else:
final_prefix = f"{processed_global_prefix}"
final_kwargs: dict[str, Any] = {
**kwargs,
"prefix": final_prefix,
}
application.include_router(router, **final_kwargs)
def get_application() -> FastAPI:
application = FastAPI(title="Danswer Backend", version=__version__)
application.include_router(backend_router)
application.include_router(chat_router)
application.include_router(admin_router)
application.include_router(user_router)
application.include_router(connector_router)
application.include_router(credential_router)
application.include_router(cc_pair_router)
application.include_router(document_set_router)
application.include_router(slack_bot_management_router)
application.include_router(persona_router)
application.include_router(state_router)
application.include_router(danswer_api_router)
include_router_with_global_prefix_prepended(application, backend_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, admin_router)
include_router_with_global_prefix_prepended(application, user_router)
include_router_with_global_prefix_prepended(application, connector_router)
include_router_with_global_prefix_prepended(application, credential_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(
application, slack_bot_management_router
)
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, state_router)
include_router_with_global_prefix_prepended(application, danswer_api_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step
pass
elif AUTH_TYPE == AuthType.BASIC:
application.include_router(
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
application.include_router(
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
application.include_router(
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
application.include_router(
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
application.include_router(
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
tags=["users"],
@ -135,7 +166,8 @@ def get_application() -> FastAPI:
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
application.include_router(
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_oauth_router(
oauth_client,
auth_backend,
@ -149,7 +181,8 @@ def get_application() -> FastAPI:
tags=["auth"],
)
# need basic auth router for `logout` endpoint
application.include_router(
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],