mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-29 05:15:12 +02:00
@@ -195,7 +195,6 @@ REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
|||||||
|
|
||||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||||
|
|
||||||
|
|
||||||
# Rate limiting for auth endpoints
|
# Rate limiting for auth endpoints
|
||||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||||
@@ -213,6 +212,7 @@ if _rate_limit_max_requests_str is not None:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS
|
||||||
# Used for general redis things
|
# Used for general redis things
|
||||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||||
|
|
||||||
|
@@ -30,6 +30,7 @@ from onyx.auth.users import fastapi_users
|
|||||||
from onyx.configs.app_configs import APP_API_PREFIX
|
from onyx.configs.app_configs import APP_API_PREFIX
|
||||||
from onyx.configs.app_configs import APP_HOST
|
from onyx.configs.app_configs import APP_HOST
|
||||||
from onyx.configs.app_configs import APP_PORT
|
from onyx.configs.app_configs import APP_PORT
|
||||||
|
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||||
from onyx.configs.app_configs import AUTH_TYPE
|
from onyx.configs.app_configs import AUTH_TYPE
|
||||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||||
@@ -74,9 +75,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
|
|||||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||||
from onyx.server.manage.users import router as user_router
|
from onyx.server.manage.users import router as user_router
|
||||||
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
|
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
|
||||||
from onyx.server.middleware.rate_limiting import close_limiter
|
from onyx.server.middleware.rate_limiting import close_auth_limiter
|
||||||
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||||
from onyx.server.middleware.rate_limiting import setup_limiter
|
from onyx.server.middleware.rate_limiting import setup_auth_limiter
|
||||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||||
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
|
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
|
||||||
get_full_openai_assistants_api_router,
|
get_full_openai_assistants_api_router,
|
||||||
@@ -174,7 +175,7 @@ def include_auth_router_with_prefix(
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
# Set recursion limit
|
# Set recursion limit
|
||||||
if SYSTEM_RECURSION_LIMIT is not None:
|
if SYSTEM_RECURSION_LIMIT is not None:
|
||||||
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
||||||
@@ -216,13 +217,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
|
|
||||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||||
|
|
||||||
# Set up rate limiter
|
if AUTH_RATE_LIMITING_ENABLED:
|
||||||
await setup_limiter()
|
await setup_auth_limiter()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Close rate limiter
|
if AUTH_RATE_LIMITING_ENABLED:
|
||||||
await close_limiter()
|
await close_auth_limiter()
|
||||||
|
|
||||||
|
|
||||||
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
|
import ssl
|
||||||
import threading
|
import threading
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -194,10 +195,6 @@ class RedisPool:
|
|||||||
redis_pool = RedisPool()
|
redis_pool = RedisPool()
|
||||||
|
|
||||||
|
|
||||||
def get_redis_client(*, tenant_id: str | None) -> Redis:
|
|
||||||
return redis_pool.get_client(tenant_id)
|
|
||||||
|
|
||||||
|
|
||||||
# # Usage example
|
# # Usage example
|
||||||
# redis_pool = RedisPool()
|
# redis_pool = RedisPool()
|
||||||
# redis_client = redis_pool.get_client()
|
# redis_client = redis_pool.get_client()
|
||||||
@@ -207,6 +204,18 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
|
|||||||
# value = redis_client.get('key')
|
# value = redis_client.get('key')
|
||||||
# print(value.decode()) # Output: 'value'
|
# print(value.decode()) # Output: 'value'
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||||
|
return redis_pool.get_client(tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
SSL_CERT_REQS_MAP = {
|
||||||
|
"none": ssl.CERT_NONE,
|
||||||
|
"optional": ssl.CERT_OPTIONAL,
|
||||||
|
"required": ssl.CERT_REQUIRED,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
_async_redis_connection: aioredis.Redis | None = None
|
_async_redis_connection: aioredis.Redis | None = None
|
||||||
_async_lock = asyncio.Lock()
|
_async_lock = asyncio.Lock()
|
||||||
|
|
||||||
@@ -224,15 +233,35 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
|||||||
async with _async_lock:
|
async with _async_lock:
|
||||||
# Double-check inside the lock to avoid race conditions
|
# Double-check inside the lock to avoid race conditions
|
||||||
if _async_redis_connection is None:
|
if _async_redis_connection is None:
|
||||||
scheme = "rediss" if REDIS_SSL else "redis"
|
# Load env vars or your config variables
|
||||||
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
|
|
||||||
|
|
||||||
# Create a new Redis connection (or connection pool) from the URL
|
connection_kwargs: dict[str, Any] = {
|
||||||
_async_redis_connection = aioredis.from_url(
|
"host": REDIS_HOST,
|
||||||
url,
|
"port": REDIS_PORT,
|
||||||
password=REDIS_PASSWORD,
|
"db": REDIS_DB_NUMBER,
|
||||||
max_connections=REDIS_POOL_MAX_CONNECTIONS,
|
"password": REDIS_PASSWORD,
|
||||||
)
|
"max_connections": REDIS_POOL_MAX_CONNECTIONS,
|
||||||
|
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
|
||||||
|
"socket_keepalive": True,
|
||||||
|
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||||
|
}
|
||||||
|
|
||||||
|
if REDIS_SSL:
|
||||||
|
ssl_context = ssl.create_default_context()
|
||||||
|
|
||||||
|
if REDIS_SSL_CA_CERTS:
|
||||||
|
ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS)
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
|
||||||
|
# Map your string to the proper ssl.CERT_* constant
|
||||||
|
ssl_context.verify_mode = SSL_CERT_REQS_MAP.get(
|
||||||
|
REDIS_SSL_CERT_REQS, ssl.CERT_NONE
|
||||||
|
)
|
||||||
|
|
||||||
|
connection_kwargs["ssl"] = ssl_context
|
||||||
|
|
||||||
|
# Create a new Redis connection (or connection pool) with SSL configuration
|
||||||
|
_async_redis_connection = aioredis.Redis(**connection_kwargs)
|
||||||
|
|
||||||
# Return the established connection (or pool) for all future operations
|
# Return the established connection (or pool) for all future operations
|
||||||
return _async_redis_connection
|
return _async_redis_connection
|
||||||
|
@@ -6,18 +6,19 @@ from fastapi import Request
|
|||||||
from fastapi_limiter import FastAPILimiter
|
from fastapi_limiter import FastAPILimiter
|
||||||
from fastapi_limiter.depends import RateLimiter
|
from fastapi_limiter.depends import RateLimiter
|
||||||
|
|
||||||
|
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||||
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
|
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
|
||||||
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
|
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
|
||||||
from onyx.redis.redis_pool import get_async_redis_connection
|
from onyx.redis.redis_pool import get_async_redis_connection
|
||||||
|
|
||||||
|
|
||||||
async def setup_limiter() -> None:
|
async def setup_auth_limiter() -> None:
|
||||||
# Use the centralized async Redis connection
|
# Use the centralized async Redis connection
|
||||||
redis = await get_async_redis_connection()
|
redis = await get_async_redis_connection()
|
||||||
await FastAPILimiter.init(redis)
|
await FastAPILimiter.init(redis)
|
||||||
|
|
||||||
|
|
||||||
async def close_limiter() -> None:
|
async def close_auth_limiter() -> None:
|
||||||
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
|
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
|
||||||
await FastAPILimiter.close()
|
await FastAPILimiter.close()
|
||||||
|
|
||||||
@@ -32,14 +33,14 @@ async def rate_limit_key(request: Request) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_auth_rate_limiters() -> List[Callable]:
|
def get_auth_rate_limiters() -> List[Callable]:
|
||||||
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
|
if not AUTH_RATE_LIMITING_ENABLED:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [
|
return [
|
||||||
Depends(
|
Depends(
|
||||||
RateLimiter(
|
RateLimiter(
|
||||||
times=RATE_LIMIT_MAX_REQUESTS,
|
times=RATE_LIMIT_MAX_REQUESTS or 100,
|
||||||
seconds=RATE_LIMIT_WINDOW_SECONDS,
|
seconds=RATE_LIMIT_WINDOW_SECONDS or 60,
|
||||||
# Use the custom key function to distinguish users
|
# Use the custom key function to distinguish users
|
||||||
identifier=rate_limit_key,
|
identifier=rate_limit_key,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user