mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-26 20:08:38 +02:00
@@ -195,7 +195,6 @@ REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
|
||||
# Rate limiting for auth endpoints
|
||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||
@@ -213,6 +212,7 @@ if _rate_limit_max_requests_str is not None:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
|
@@ -30,6 +30,7 @@ from onyx.auth.users import fastapi_users
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
@@ -74,9 +75,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from onyx.server.middleware.rate_limiting import close_limiter
|
||||
from onyx.server.middleware.rate_limiting import close_auth_limiter
|
||||
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
from onyx.server.middleware.rate_limiting import setup_limiter
|
||||
from onyx.server.middleware.rate_limiting import setup_auth_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
@@ -174,7 +175,7 @@ def include_auth_router_with_prefix(
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
||||
@@ -216,13 +217,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
|
||||
# Set up rate limiter
|
||||
await setup_limiter()
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await setup_auth_limiter()
|
||||
|
||||
yield
|
||||
|
||||
# Close rate limiter
|
||||
await close_limiter()
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await close_auth_limiter()
|
||||
|
||||
|
||||
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import ssl
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -194,10 +195,6 @@ class RedisPool:
|
||||
redis_pool = RedisPool()
|
||||
|
||||
|
||||
def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
return redis_pool.get_client(tenant_id)
|
||||
|
||||
|
||||
# # Usage example
|
||||
# redis_pool = RedisPool()
|
||||
# redis_client = redis_pool.get_client()
|
||||
@@ -207,6 +204,18 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
# value = redis_client.get('key')
|
||||
# print(value.decode()) # Output: 'value'
|
||||
|
||||
|
||||
def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
return redis_pool.get_client(tenant_id)
|
||||
|
||||
|
||||
SSL_CERT_REQS_MAP = {
|
||||
"none": ssl.CERT_NONE,
|
||||
"optional": ssl.CERT_OPTIONAL,
|
||||
"required": ssl.CERT_REQUIRED,
|
||||
}
|
||||
|
||||
|
||||
_async_redis_connection: aioredis.Redis | None = None
|
||||
_async_lock = asyncio.Lock()
|
||||
|
||||
@@ -224,15 +233,35 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
||||
async with _async_lock:
|
||||
# Double-check inside the lock to avoid race conditions
|
||||
if _async_redis_connection is None:
|
||||
scheme = "rediss" if REDIS_SSL else "redis"
|
||||
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
|
||||
# Load env vars or your config variables
|
||||
|
||||
# Create a new Redis connection (or connection pool) from the URL
|
||||
_async_redis_connection = aioredis.from_url(
|
||||
url,
|
||||
password=REDIS_PASSWORD,
|
||||
max_connections=REDIS_POOL_MAX_CONNECTIONS,
|
||||
)
|
||||
connection_kwargs: dict[str, Any] = {
|
||||
"host": REDIS_HOST,
|
||||
"port": REDIS_PORT,
|
||||
"db": REDIS_DB_NUMBER,
|
||||
"password": REDIS_PASSWORD,
|
||||
"max_connections": REDIS_POOL_MAX_CONNECTIONS,
|
||||
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
|
||||
"socket_keepalive": True,
|
||||
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
}
|
||||
|
||||
if REDIS_SSL:
|
||||
ssl_context = ssl.create_default_context()
|
||||
|
||||
if REDIS_SSL_CA_CERTS:
|
||||
ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS)
|
||||
ssl_context.check_hostname = False
|
||||
|
||||
# Map your string to the proper ssl.CERT_* constant
|
||||
ssl_context.verify_mode = SSL_CERT_REQS_MAP.get(
|
||||
REDIS_SSL_CERT_REQS, ssl.CERT_NONE
|
||||
)
|
||||
|
||||
connection_kwargs["ssl"] = ssl_context
|
||||
|
||||
# Create a new Redis connection (or connection pool) with SSL configuration
|
||||
_async_redis_connection = aioredis.Redis(**connection_kwargs)
|
||||
|
||||
# Return the established connection (or pool) for all future operations
|
||||
return _async_redis_connection
|
||||
|
@@ -6,18 +6,19 @@ from fastapi import Request
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
|
||||
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
|
||||
|
||||
async def setup_limiter() -> None:
|
||||
async def setup_auth_limiter() -> None:
|
||||
# Use the centralized async Redis connection
|
||||
redis = await get_async_redis_connection()
|
||||
await FastAPILimiter.init(redis)
|
||||
|
||||
|
||||
async def close_limiter() -> None:
|
||||
async def close_auth_limiter() -> None:
|
||||
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
|
||||
await FastAPILimiter.close()
|
||||
|
||||
@@ -32,14 +33,14 @@ async def rate_limit_key(request: Request) -> str:
|
||||
|
||||
|
||||
def get_auth_rate_limiters() -> List[Callable]:
|
||||
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
|
||||
if not AUTH_RATE_LIMITING_ENABLED:
|
||||
return []
|
||||
|
||||
return [
|
||||
Depends(
|
||||
RateLimiter(
|
||||
times=RATE_LIMIT_MAX_REQUESTS,
|
||||
seconds=RATE_LIMIT_WINDOW_SECONDS,
|
||||
times=RATE_LIMIT_MAX_REQUESTS or 100,
|
||||
seconds=RATE_LIMIT_WINDOW_SECONDS or 60,
|
||||
# Use the custom key function to distinguish users
|
||||
identifier=rate_limit_key,
|
||||
)
|
||||
|
Reference in New Issue
Block a user