Async Redis (#3618)

* k

* update configs for clarity

* typing

* update
This commit is contained in:
pablonyx
2025-01-07 11:34:57 -08:00
committed by GitHub
parent d9e9c6973d
commit 5b5c1166ca
4 changed files with 56 additions and 25 deletions

View File

@@ -195,7 +195,6 @@ REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:" REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
# Rate limiting for auth endpoints # Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS") _rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
@@ -213,6 +212,7 @@ if _rate_limit_max_requests_str is not None:
except ValueError: except ValueError:
pass pass
AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS
# Used for general redis things # Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0)) REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))

View File

@@ -30,6 +30,7 @@ from onyx.auth.users import fastapi_users
from onyx.configs.app_configs import APP_API_PREFIX from onyx.configs.app_configs import APP_API_PREFIX
from onyx.configs.app_configs import APP_HOST from onyx.configs.app_configs import APP_HOST
from onyx.configs.app_configs import APP_PORT from onyx.configs.app_configs import APP_PORT
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
@@ -74,9 +75,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router from onyx.server.manage.users import router as user_router
from onyx.server.middleware.latency_logging import add_latency_logging_middleware from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_limiter from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_limiter from onyx.server.middleware.rate_limiting import setup_auth_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import ( from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router, get_full_openai_assistants_api_router,
@@ -174,7 +175,7 @@ def include_auth_router_with_prefix(
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator: async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Set recursion limit # Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None: if SYSTEM_RECURSION_LIMIT is not None:
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT) sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
@@ -216,13 +217,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
# Set up rate limiter if AUTH_RATE_LIMITING_ENABLED:
await setup_limiter() await setup_auth_limiter()
yield yield
# Close rate limiter if AUTH_RATE_LIMITING_ENABLED:
await close_limiter() await close_auth_limiter()
def log_http_error(_: Request, exc: Exception) -> JSONResponse: def log_http_error(_: Request, exc: Exception) -> JSONResponse:

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import functools import functools
import json import json
import ssl
import threading import threading
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@@ -194,10 +195,6 @@ class RedisPool:
redis_pool = RedisPool() redis_pool = RedisPool()
def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
# # Usage example # # Usage example
# redis_pool = RedisPool() # redis_pool = RedisPool()
# redis_client = redis_pool.get_client() # redis_client = redis_pool.get_client()
@@ -207,6 +204,18 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
# value = redis_client.get('key') # value = redis_client.get('key')
# print(value.decode()) # Output: 'value' # print(value.decode()) # Output: 'value'
def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
"required": ssl.CERT_REQUIRED,
}
_async_redis_connection: aioredis.Redis | None = None _async_redis_connection: aioredis.Redis | None = None
_async_lock = asyncio.Lock() _async_lock = asyncio.Lock()
@@ -224,15 +233,35 @@ async def get_async_redis_connection() -> aioredis.Redis:
async with _async_lock: async with _async_lock:
# Double-check inside the lock to avoid race conditions # Double-check inside the lock to avoid race conditions
if _async_redis_connection is None: if _async_redis_connection is None:
scheme = "rediss" if REDIS_SSL else "redis" # Load env vars or your config variables
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
# Create a new Redis connection (or connection pool) from the URL connection_kwargs: dict[str, Any] = {
_async_redis_connection = aioredis.from_url( "host": REDIS_HOST,
url, "port": REDIS_PORT,
password=REDIS_PASSWORD, "db": REDIS_DB_NUMBER,
max_connections=REDIS_POOL_MAX_CONNECTIONS, "password": REDIS_PASSWORD,
) "max_connections": REDIS_POOL_MAX_CONNECTIONS,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}
if REDIS_SSL:
ssl_context = ssl.create_default_context()
if REDIS_SSL_CA_CERTS:
ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS)
ssl_context.check_hostname = False
# Map your string to the proper ssl.CERT_* constant
ssl_context.verify_mode = SSL_CERT_REQS_MAP.get(
REDIS_SSL_CERT_REQS, ssl.CERT_NONE
)
connection_kwargs["ssl"] = ssl_context
# Create a new Redis connection (or connection pool) with SSL configuration
_async_redis_connection = aioredis.Redis(**connection_kwargs)
# Return the established connection (or pool) for all future operations # Return the established connection (or pool) for all future operations
return _async_redis_connection return _async_redis_connection

View File

@@ -6,18 +6,19 @@ from fastapi import Request
from fastapi_limiter import FastAPILimiter from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter from fastapi_limiter.depends import RateLimiter
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
from onyx.redis.redis_pool import get_async_redis_connection from onyx.redis.redis_pool import get_async_redis_connection
async def setup_limiter() -> None: async def setup_auth_limiter() -> None:
# Use the centralized async Redis connection # Use the centralized async Redis connection
redis = await get_async_redis_connection() redis = await get_async_redis_connection()
await FastAPILimiter.init(redis) await FastAPILimiter.init(redis)
async def close_limiter() -> None: async def close_auth_limiter() -> None:
# This closes the FastAPILimiter connection so we don't leave open connections to Redis. # This closes the FastAPILimiter connection so we don't leave open connections to Redis.
await FastAPILimiter.close() await FastAPILimiter.close()
@@ -32,14 +33,14 @@ async def rate_limit_key(request: Request) -> str:
def get_auth_rate_limiters() -> List[Callable]: def get_auth_rate_limiters() -> List[Callable]:
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS): if not AUTH_RATE_LIMITING_ENABLED:
return [] return []
return [ return [
Depends( Depends(
RateLimiter( RateLimiter(
times=RATE_LIMIT_MAX_REQUESTS, times=RATE_LIMIT_MAX_REQUESTS or 100,
seconds=RATE_LIMIT_WINDOW_SECONDS, seconds=RATE_LIMIT_WINDOW_SECONDS or 60,
# Use the custom key function to distinguish users # Use the custom key function to distinguish users
identifier=rate_limit_key, identifier=rate_limit_key,
) )