Auth specific rate limiting (#3463)

* k

* v1

* fully functional

* finalize

* nit

* nit

* nit

* clean up with wrapper + comments

* k

* update

* minor clean
This commit is contained in:
pablonyx
2024-12-29 18:34:23 -05:00
committed by GitHub
parent d14ef431a7
commit 27acd3387a
8 changed files with 149 additions and 24 deletions

View File

@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType from onyx.configs.constants import AuthType
from onyx.main import get_application as get_application_base from onyx.main import get_application as get_application_base
from onyx.main import include_auth_router_with_prefix
from onyx.main import include_router_with_global_prefix_prepended from onyx.main import include_router_with_global_prefix_prepended
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version from onyx.utils.variable_functionality import global_version
@ -62,7 +63,7 @@ def get_application() -> FastAPI:
if AUTH_TYPE == AuthType.CLOUD: if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
create_onyx_oauth_router( create_onyx_oauth_router(
oauth_client, oauth_client,
@ -74,19 +75,17 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback", redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
), ),
prefix="/auth/oauth", prefix="/auth/oauth",
tags=["auth"],
) )
# Need basic auth router for `logout` endpoint # Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
fastapi_users.get_logout_router(auth_backend), fastapi_users.get_logout_router(auth_backend),
prefix="/auth", prefix="/auth",
tags=["auth"],
) )
if AUTH_TYPE == AuthType.OIDC: if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
create_onyx_oauth_router( create_onyx_oauth_router(
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL), OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
@ -97,19 +96,21 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback", redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
), ),
prefix="/auth/oidc", prefix="/auth/oidc",
tags=["auth"],
) )
# need basic auth router for `logout` endpoint # need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
fastapi_users.get_auth_router(auth_backend), fastapi_users.get_auth_router(auth_backend),
prefix="/auth", prefix="/auth",
tags=["auth"],
) )
elif AUTH_TYPE == AuthType.SAML: elif AUTH_TYPE == AuthType.SAML:
include_router_with_global_prefix_prepended(application, saml_router) include_auth_router_with_prefix(
application,
saml_router,
prefix="/auth/saml",
)
# RBAC / group access control # RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router) include_router_with_global_prefix_prepended(application, user_group_router)

View File

@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
the files in the existing huggingface cache that don't exist in the temp the files in the existing huggingface cache that don't exist in the temp
huggingface cache. huggingface cache.
""" """
for item in source.iterdir(): for item in source.iterdir():
target_path = dest / item.relative_to(source) target_path = dest / item.relative_to(source)
if item.is_dir(): if item.is_dir():

View File

@ -185,6 +185,25 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or "" REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
if _rate_limit_window_seconds_str is not None:
try:
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
except ValueError:
pass
RATE_LIMIT_MAX_REQUESTS: int | None = None
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
if _rate_limit_max_requests_str is not None:
try:
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
except ValueError:
pass
# Used for general redis things # Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0)) REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))

View File

@ -74,6 +74,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router from onyx.server.manage.users import router as user_router
from onyx.server.middleware.latency_logging import add_latency_logging_middleware from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import ( from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router, get_full_openai_assistants_api_router,
@ -153,6 +156,20 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs) application.include_router(router, **final_kwargs)
def include_auth_router_with_prefix(
application: FastAPI, router: APIRouter, prefix: str, tags: list[str] | None = None
) -> None:
"""Wrapper function to include an 'auth' router with prefix + rate-limiting dependencies."""
final_tags = tags or ["auth"]
include_router_with_global_prefix_prepended(
application,
router,
prefix=prefix,
tags=final_tags,
dependencies=get_auth_rate_limiters(),
)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator: async def lifespan(app: FastAPI) -> AsyncGenerator:
# Set recursion limit # Set recursion limit
@ -194,8 +211,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
setup_multitenant_onyx() setup_multitenant_onyx()
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
# Set up rate limiter
await setup_limiter()
yield yield
# Close rate limiter
await close_limiter()
def log_http_error(_: Request, exc: Exception) -> JSONResponse: def log_http_error(_: Request, exc: Exception) -> JSONResponse:
status_code = getattr(exc, "status_code", 500) status_code = getattr(exc, "status_code", 500)
@ -283,42 +307,37 @@ def get_application() -> FastAPI:
pass pass
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD: if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
fastapi_users.get_auth_router(auth_backend), fastapi_users.get_auth_router(auth_backend),
prefix="/auth", prefix="/auth",
tags=["auth"],
) )
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
fastapi_users.get_register_router(UserRead, UserCreate), fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth", prefix="/auth",
tags=["auth"],
) )
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
fastapi_users.get_reset_password_router(), fastapi_users.get_reset_password_router(),
prefix="/auth", prefix="/auth",
tags=["auth"],
) )
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
fastapi_users.get_verify_router(UserRead), fastapi_users.get_verify_router(UserRead),
prefix="/auth", prefix="/auth",
tags=["auth"],
) )
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
fastapi_users.get_users_router(UserRead, UserUpdate), fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users", prefix="/users",
tags=["users"],
) )
if AUTH_TYPE == AuthType.GOOGLE_OAUTH: if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
create_onyx_oauth_router( create_onyx_oauth_router(
oauth_client, oauth_client,
@ -330,15 +349,13 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback", redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
), ),
prefix="/auth/oauth", prefix="/auth/oauth",
tags=["auth"],
) )
# Need basic auth router for `logout` endpoint # Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended( include_auth_router_with_prefix(
application, application,
fastapi_users.get_logout_router(auth_backend), fastapi_users.get_logout_router(auth_backend),
prefix="/auth", prefix="/auth",
tags=["auth"],
) )
application.add_exception_handler( application.add_exception_handler(

View File

@ -1,3 +1,4 @@
import asyncio
import functools import functools
import threading import threading
from collections.abc import Callable from collections.abc import Callable
@ -5,6 +6,7 @@ from typing import Any
from typing import Optional from typing import Optional
import redis import redis
from redis import asyncio as aioredis
from redis.client import Redis from redis.client import Redis
from onyx.configs.app_configs import REDIS_DB_NUMBER from onyx.configs.app_configs import REDIS_DB_NUMBER
@ -196,3 +198,33 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
# redis_client.set('key', 'value') # redis_client.set('key', 'value')
# value = redis_client.get('key') # value = redis_client.get('key')
# print(value.decode()) # Output: 'value' # print(value.decode()) # Output: 'value'
_async_redis_connection = None
_async_lock = asyncio.Lock()
async def get_async_redis_connection() -> aioredis.Redis:
"""
Provides a shared async Redis connection, using the same configs (host, port, SSL, etc.).
Ensures that the connection is created only once (lazily) and reused for all future calls.
"""
global _async_redis_connection
# If we haven't yet created an async Redis connection, we need to create one
if _async_redis_connection is None:
# Acquire the lock to ensure that only one coroutine attempts to create the connection
async with _async_lock:
# Double-check inside the lock to avoid race conditions
if _async_redis_connection is None:
scheme = "rediss" if REDIS_SSL else "redis"
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
# Create a new Redis connection (or connection pool) from the URL
_async_redis_connection = aioredis.from_url(
url,
password=REDIS_PASSWORD,
max_connections=REDIS_POOL_MAX_CONNECTIONS,
)
# Return the established connection (or pool) for all future operations
return _async_redis_connection

View File

@ -0,0 +1,47 @@
from collections.abc import Callable
from typing import List
from fastapi import Depends
from fastapi import Request
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
from onyx.redis.redis_pool import get_async_redis_connection
async def setup_limiter() -> None:
# Use the centralized async Redis connection
redis = await get_async_redis_connection()
await FastAPILimiter.init(redis)
async def close_limiter() -> None:
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
await FastAPILimiter.close()
async def rate_limit_key(request: Request) -> str:
# Uses both IP and User-Agent to make collisions less likely if IP is behind NAT.
# If request.client is None, a fallback is used to avoid completely unknown keys.
# This helps ensure we have a unique key for each 'user' in simple scenarios.
ip_part = request.client.host if request.client else "unknown"
ua_part = request.headers.get("user-agent", "none").replace(" ", "_")
return f"{ip_part}-{ua_part}"
def get_auth_rate_limiters() -> List[Callable]:
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
return []
return [
Depends(
RateLimiter(
times=RATE_LIMIT_MAX_REQUESTS,
seconds=RATE_LIMIT_WINDOW_SECONDS,
# Use the custom key function to distinguish users
identifier=rate_limit_key,
)
)
]

View File

@ -81,4 +81,5 @@ stripe==10.12.0
urllib3==2.2.3 urllib3==2.2.3
mistune==0.8.4 mistune==0.8.4
sentry-sdk==2.14.0 sentry-sdk==2.14.0
prometheus_client==0.21.0 prometheus_client==0.21.0
fastapi-limiter==0.1.6

View File

@ -59,10 +59,14 @@ export function EmailPasswordForm({
errorMsg = errorMsg =
"An account already exists with the specified email."; "An account already exists with the specified email.";
} }
if (response.status === 429) {
errorMsg = "Too many requests. Please try again later.";
}
setPopup({ setPopup({
type: "error", type: "error",
message: `Failed to sign up - ${errorMsg}`, message: `Failed to sign up - ${errorMsg}`,
}); });
setIsWorking(false);
return; return;
} }
} }
@ -89,6 +93,9 @@ export function EmailPasswordForm({
} else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") { } else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") {
errorMsg = "Create an account to set a password"; errorMsg = "Create an account to set a password";
} }
if (loginResponse.status === 429) {
errorMsg = "Too many requests. Please try again later.";
}
setPopup({ setPopup({
type: "error", type: "error",
message: `Failed to login - ${errorMsg}`, message: `Failed to login - ${errorMsg}`,