Auth specific rate limiting (#3463)

* k

* v1

* fully functional

* finalize

* nit

* nit

* nit

* clean up with wrapper + comments

* k

* update

* minor clean
This commit is contained in:
pablonyx 2024-12-29 18:34:23 -05:00 committed by GitHub
parent d14ef431a7
commit 27acd3387a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 149 additions and 24 deletions

View File

@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.main import get_application as get_application_base
from onyx.main import include_auth_router_with_prefix
from onyx.main import include_router_with_global_prefix_prepended
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
@ -62,7 +63,7 @@ def get_application() -> FastAPI:
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
oauth_client,
@ -74,19 +75,17 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
@ -97,19 +96,21 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
),
prefix="/auth/oidc",
tags=["auth"],
)
# need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
elif AUTH_TYPE == AuthType.SAML:
include_router_with_global_prefix_prepended(application, saml_router)
include_auth_router_with_prefix(
application,
saml_router,
prefix="/auth/saml",
)
# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)

View File

@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
the files in the existing huggingface cache that don't exist in the temp
huggingface cache.
"""
for item in source.iterdir():
target_path = dest / item.relative_to(source)
if item.is_dir():

View File

@ -185,6 +185,25 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
if _rate_limit_window_seconds_str is not None:
try:
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
except ValueError:
pass
RATE_LIMIT_MAX_REQUESTS: int | None = None
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
if _rate_limit_max_requests_str is not None:
try:
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
except ValueError:
pass
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))

View File

@ -74,6 +74,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
@ -153,6 +156,20 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs)
def include_auth_router_with_prefix(
application: FastAPI, router: APIRouter, prefix: str, tags: list[str] | None = None
) -> None:
"""Wrapper function to include an 'auth' router with prefix + rate-limiting dependencies."""
final_tags = tags or ["auth"]
include_router_with_global_prefix_prepended(
application,
router,
prefix=prefix,
tags=final_tags,
dependencies=get_auth_rate_limiters(),
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
# Set recursion limit
@ -194,8 +211,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
setup_multitenant_onyx()
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
# Set up rate limiter
await setup_limiter()
yield
# Close rate limiter
await close_limiter()
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
status_code = getattr(exc, "status_code", 500)
@ -283,42 +307,37 @@ def get_application() -> FastAPI:
pass
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
tags=["users"],
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
oauth_client,
@ -330,15 +349,13 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
application.add_exception_handler(

View File

@ -1,3 +1,4 @@
import asyncio
import functools
import threading
from collections.abc import Callable
@ -5,6 +6,7 @@ from typing import Any
from typing import Optional
import redis
from redis import asyncio as aioredis
from redis.client import Redis
from onyx.configs.app_configs import REDIS_DB_NUMBER
@ -196,3 +198,33 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
# redis_client.set('key', 'value')
# value = redis_client.get('key')
# print(value.decode()) # Output: 'value'
_async_redis_connection = None
_async_lock = asyncio.Lock()
async def get_async_redis_connection() -> aioredis.Redis:
"""
Provides a shared async Redis connection, using the same configs (host, port, SSL, etc.).
Ensures that the connection is created only once (lazily) and reused for all future calls.
"""
global _async_redis_connection
# If we haven't yet created an async Redis connection, we need to create one
if _async_redis_connection is None:
# Acquire the lock to ensure that only one coroutine attempts to create the connection
async with _async_lock:
# Double-check inside the lock to avoid race conditions
if _async_redis_connection is None:
scheme = "rediss" if REDIS_SSL else "redis"
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
# Create a new Redis connection (or connection pool) from the URL
_async_redis_connection = aioredis.from_url(
url,
password=REDIS_PASSWORD,
max_connections=REDIS_POOL_MAX_CONNECTIONS,
)
# Return the established connection (or pool) for all future operations
return _async_redis_connection

View File

@ -0,0 +1,47 @@
from collections.abc import Callable
from typing import List
from fastapi import Depends
from fastapi import Request
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
from onyx.redis.redis_pool import get_async_redis_connection
async def setup_limiter() -> None:
# Use the centralized async Redis connection
redis = await get_async_redis_connection()
await FastAPILimiter.init(redis)
async def close_limiter() -> None:
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
await FastAPILimiter.close()
async def rate_limit_key(request: Request) -> str:
# Uses both IP and User-Agent to make collisions less likely if IP is behind NAT.
# If request.client is None, a fallback is used to avoid completely unknown keys.
# This helps ensure we have a unique key for each 'user' in simple scenarios.
ip_part = request.client.host if request.client else "unknown"
ua_part = request.headers.get("user-agent", "none").replace(" ", "_")
return f"{ip_part}-{ua_part}"
def get_auth_rate_limiters() -> List[Callable]:
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
return []
return [
Depends(
RateLimiter(
times=RATE_LIMIT_MAX_REQUESTS,
seconds=RATE_LIMIT_WINDOW_SECONDS,
# Use the custom key function to distinguish users
identifier=rate_limit_key,
)
)
]

View File

@ -81,4 +81,5 @@ stripe==10.12.0
urllib3==2.2.3
mistune==0.8.4
sentry-sdk==2.14.0
prometheus_client==0.21.0
prometheus_client==0.21.0
fastapi-limiter==0.1.6

View File

@ -59,10 +59,14 @@ export function EmailPasswordForm({
errorMsg =
"An account already exists with the specified email.";
}
if (response.status === 429) {
errorMsg = "Too many requests. Please try again later.";
}
setPopup({
type: "error",
message: `Failed to sign up - ${errorMsg}`,
});
setIsWorking(false);
return;
}
}
@ -89,6 +93,9 @@ export function EmailPasswordForm({
} else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") {
errorMsg = "Create an account to set a password";
}
if (loginResponse.status === 429) {
errorMsg = "Too many requests. Please try again later.";
}
setPopup({
type: "error",
message: `Failed to login - ${errorMsg}`,