mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
Auth specific rate limiting (#3463)
* k * v1 * fully functional * finalize * nit * nit * nit * clean up with wrapper + comments * k * update * minor clean
This commit is contained in:
parent
d14ef431a7
commit
27acd3387a
@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@ -62,7 +63,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@ -74,19 +75,17 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
@ -97,19 +96,21 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
include_router_with_global_prefix_prepended(application, saml_router)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
saml_router,
|
||||
prefix="/auth/saml",
|
||||
)
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
|
@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
the files in the existing huggingface cache that don't exist in the temp
|
||||
huggingface cache.
|
||||
"""
|
||||
|
||||
for item in source.iterdir():
|
||||
target_path = dest / item.relative_to(source)
|
||||
if item.is_dir():
|
||||
|
@ -185,6 +185,25 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# Rate limiting for auth endpoints
|
||||
|
||||
|
||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||
if _rate_limit_window_seconds_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
RATE_LIMIT_MAX_REQUESTS: int | None = None
|
||||
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
|
||||
if _rate_limit_max_requests_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
|
@ -74,6 +74,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from onyx.server.middleware.rate_limiting import close_limiter
|
||||
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
from onyx.server.middleware.rate_limiting import setup_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
@ -153,6 +156,20 @@ def include_router_with_global_prefix_prepended(
|
||||
application.include_router(router, **final_kwargs)
|
||||
|
||||
|
||||
def include_auth_router_with_prefix(
|
||||
application: FastAPI, router: APIRouter, prefix: str, tags: list[str] | None = None
|
||||
) -> None:
|
||||
"""Wrapper function to include an 'auth' router with prefix + rate-limiting dependencies."""
|
||||
final_tags = tags or ["auth"]
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
router,
|
||||
prefix=prefix,
|
||||
tags=final_tags,
|
||||
dependencies=get_auth_rate_limiters(),
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
# Set recursion limit
|
||||
@ -194,8 +211,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
|
||||
# Set up rate limiter
|
||||
await setup_limiter()
|
||||
|
||||
yield
|
||||
|
||||
# Close rate limiter
|
||||
await close_limiter()
|
||||
|
||||
|
||||
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
@ -283,42 +307,37 @@ def get_application() -> FastAPI:
|
||||
pass
|
||||
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@ -330,15 +349,13 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
application.add_exception_handler(
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
@ -5,6 +6,7 @@ from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import redis
|
||||
from redis import asyncio as aioredis
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
@ -196,3 +198,33 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
# redis_client.set('key', 'value')
|
||||
# value = redis_client.get('key')
|
||||
# print(value.decode()) # Output: 'value'
|
||||
|
||||
_async_redis_connection = None
|
||||
_async_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_async_redis_connection() -> aioredis.Redis:
|
||||
"""
|
||||
Provides a shared async Redis connection, using the same configs (host, port, SSL, etc.).
|
||||
Ensures that the connection is created only once (lazily) and reused for all future calls.
|
||||
"""
|
||||
global _async_redis_connection
|
||||
|
||||
# If we haven't yet created an async Redis connection, we need to create one
|
||||
if _async_redis_connection is None:
|
||||
# Acquire the lock to ensure that only one coroutine attempts to create the connection
|
||||
async with _async_lock:
|
||||
# Double-check inside the lock to avoid race conditions
|
||||
if _async_redis_connection is None:
|
||||
scheme = "rediss" if REDIS_SSL else "redis"
|
||||
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
|
||||
|
||||
# Create a new Redis connection (or connection pool) from the URL
|
||||
_async_redis_connection = aioredis.from_url(
|
||||
url,
|
||||
password=REDIS_PASSWORD,
|
||||
max_connections=REDIS_POOL_MAX_CONNECTIONS,
|
||||
)
|
||||
|
||||
# Return the established connection (or pool) for all future operations
|
||||
return _async_redis_connection
|
||||
|
47
backend/onyx/server/middleware/rate_limiting.py
Normal file
47
backend/onyx/server/middleware/rate_limiting.py
Normal file
@ -0,0 +1,47 @@
|
||||
from collections.abc import Callable
|
||||
from typing import List
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Request
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
|
||||
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
|
||||
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
|
||||
|
||||
async def setup_limiter() -> None:
|
||||
# Use the centralized async Redis connection
|
||||
redis = await get_async_redis_connection()
|
||||
await FastAPILimiter.init(redis)
|
||||
|
||||
|
||||
async def close_limiter() -> None:
|
||||
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
|
||||
await FastAPILimiter.close()
|
||||
|
||||
|
||||
async def rate_limit_key(request: Request) -> str:
|
||||
# Uses both IP and User-Agent to make collisions less likely if IP is behind NAT.
|
||||
# If request.client is None, a fallback is used to avoid completely unknown keys.
|
||||
# This helps ensure we have a unique key for each 'user' in simple scenarios.
|
||||
ip_part = request.client.host if request.client else "unknown"
|
||||
ua_part = request.headers.get("user-agent", "none").replace(" ", "_")
|
||||
return f"{ip_part}-{ua_part}"
|
||||
|
||||
|
||||
def get_auth_rate_limiters() -> List[Callable]:
|
||||
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
|
||||
return []
|
||||
|
||||
return [
|
||||
Depends(
|
||||
RateLimiter(
|
||||
times=RATE_LIMIT_MAX_REQUESTS,
|
||||
seconds=RATE_LIMIT_WINDOW_SECONDS,
|
||||
# Use the custom key function to distinguish users
|
||||
identifier=rate_limit_key,
|
||||
)
|
||||
)
|
||||
]
|
@ -81,4 +81,5 @@ stripe==10.12.0
|
||||
urllib3==2.2.3
|
||||
mistune==0.8.4
|
||||
sentry-sdk==2.14.0
|
||||
prometheus_client==0.21.0
|
||||
prometheus_client==0.21.0
|
||||
fastapi-limiter==0.1.6
|
@ -59,10 +59,14 @@ export function EmailPasswordForm({
|
||||
errorMsg =
|
||||
"An account already exists with the specified email.";
|
||||
}
|
||||
if (response.status === 429) {
|
||||
errorMsg = "Too many requests. Please try again later.";
|
||||
}
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to sign up - ${errorMsg}`,
|
||||
});
|
||||
setIsWorking(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -89,6 +93,9 @@ export function EmailPasswordForm({
|
||||
} else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") {
|
||||
errorMsg = "Create an account to set a password";
|
||||
}
|
||||
if (loginResponse.status === 429) {
|
||||
errorMsg = "Too many requests. Please try again later.";
|
||||
}
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to login - ${errorMsg}`,
|
||||
|
Loading…
x
Reference in New Issue
Block a user