Auth specific rate limiting (#3463)

* k

* v1

* fully functional

* finalize

* nit

* nit

* nit

* clean up with wrapper + comments

* k

* update

* minor clean
This commit is contained in:
pablonyx
2024-12-29 18:34:23 -05:00
committed by GitHub
parent d14ef431a7
commit 27acd3387a
8 changed files with 149 additions and 24 deletions

View File

@@ -74,6 +74,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
@@ -153,6 +156,20 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs)
def include_auth_router_with_prefix(
application: FastAPI, router: APIRouter, prefix: str, tags: list[str] | None = None
) -> None:
"""Wrapper function to include an 'auth' router with prefix + rate-limiting dependencies."""
final_tags = tags or ["auth"]
include_router_with_global_prefix_prepended(
application,
router,
prefix=prefix,
tags=final_tags,
dependencies=get_auth_rate_limiters(),
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
# Set recursion limit
@@ -194,8 +211,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
setup_multitenant_onyx()
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
# Set up rate limiter
await setup_limiter()
yield
# Close rate limiter
await close_limiter()
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
status_code = getattr(exc, "status_code", 500)
@@ -283,42 +307,37 @@ def get_application() -> FastAPI:
pass
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
tags=["users"],
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
oauth_client,
@@ -330,15 +349,13 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
application.add_exception_handler(