Async Redis (#3618)

* k

* update configs for clarity

* typing

* update
This commit is contained in:
pablonyx
2025-01-07 11:34:57 -08:00
committed by GitHub
parent d9e9c6973d
commit 5b5c1166ca
4 changed files with 56 additions and 25 deletions

View File

@@ -195,7 +195,6 @@ REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
@@ -213,6 +212,7 @@ if _rate_limit_max_requests_str is not None:
except ValueError:
pass
AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))

View File

@@ -30,6 +30,7 @@ from onyx.auth.users import fastapi_users
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.configs.app_configs import APP_HOST
from onyx.configs.app_configs import APP_PORT
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
@@ -74,9 +75,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_limiter
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_limiter
from onyx.server.middleware.rate_limiting import setup_auth_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
@@ -174,7 +175,7 @@ def include_auth_router_with_prefix(
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
@@ -216,13 +217,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
# Set up rate limiter
await setup_limiter()
if AUTH_RATE_LIMITING_ENABLED:
await setup_auth_limiter()
yield
# Close rate limiter
await close_limiter()
if AUTH_RATE_LIMITING_ENABLED:
await close_auth_limiter()
def log_http_error(_: Request, exc: Exception) -> JSONResponse:

View File

@@ -1,6 +1,7 @@
import asyncio
import functools
import json
import ssl
import threading
from collections.abc import Callable
from typing import Any
@@ -194,10 +195,6 @@ class RedisPool:
redis_pool = RedisPool()
def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
# # Usage example
# redis_pool = RedisPool()
# redis_client = redis_pool.get_client()
@@ -207,6 +204,18 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
# value = redis_client.get('key')
# print(value.decode()) # Output: 'value'
def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
"required": ssl.CERT_REQUIRED,
}
_async_redis_connection: aioredis.Redis | None = None
_async_lock = asyncio.Lock()
@@ -224,16 +233,36 @@ async def get_async_redis_connection() -> aioredis.Redis:
async with _async_lock:
# Double-check inside the lock to avoid race conditions
if _async_redis_connection is None:
scheme = "rediss" if REDIS_SSL else "redis"
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
# Load env vars or your config variables
# Create a new Redis connection (or connection pool) from the URL
_async_redis_connection = aioredis.from_url(
url,
password=REDIS_PASSWORD,
max_connections=REDIS_POOL_MAX_CONNECTIONS,
connection_kwargs: dict[str, Any] = {
"host": REDIS_HOST,
"port": REDIS_PORT,
"db": REDIS_DB_NUMBER,
"password": REDIS_PASSWORD,
"max_connections": REDIS_POOL_MAX_CONNECTIONS,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}
if REDIS_SSL:
ssl_context = ssl.create_default_context()
if REDIS_SSL_CA_CERTS:
ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS)
ssl_context.check_hostname = False
# Map your string to the proper ssl.CERT_* constant
ssl_context.verify_mode = SSL_CERT_REQS_MAP.get(
REDIS_SSL_CERT_REQS, ssl.CERT_NONE
)
connection_kwargs["ssl"] = ssl_context
# Create a new Redis connection (or connection pool) with SSL configuration
_async_redis_connection = aioredis.Redis(**connection_kwargs)
# Return the established connection (or pool) for all future operations
return _async_redis_connection

View File

@@ -6,18 +6,19 @@ from fastapi import Request
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
from onyx.redis.redis_pool import get_async_redis_connection
async def setup_limiter() -> None:
async def setup_auth_limiter() -> None:
# Use the centralized async Redis connection
redis = await get_async_redis_connection()
await FastAPILimiter.init(redis)
async def close_limiter() -> None:
async def close_auth_limiter() -> None:
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
await FastAPILimiter.close()
@@ -32,14 +33,14 @@ async def rate_limit_key(request: Request) -> str:
def get_auth_rate_limiters() -> List[Callable]:
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
if not AUTH_RATE_LIMITING_ENABLED:
return []
return [
Depends(
RateLimiter(
times=RATE_LIMIT_MAX_REQUESTS,
seconds=RATE_LIMIT_WINDOW_SECONDS,
times=RATE_LIMIT_MAX_REQUESTS or 100,
seconds=RATE_LIMIT_WINDOW_SECONDS or 60,
# Use the custom key function to distinguish users
identifier=rate_limit_key,
)