import asyncio import functools import json import ssl import threading from collections.abc import Callable from typing import Any from typing import cast from typing import Optional import redis from fastapi import Request from redis import asyncio as aioredis from redis.client import Redis from redis.lock import Lock as RedisLock from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX from onyx.configs.app_configs import REDIS_DB_NUMBER from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL from onyx.configs.app_configs import REDIS_HOST from onyx.configs.app_configs import REDIS_PASSWORD from onyx.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS from onyx.configs.app_configs import REDIS_PORT from onyx.configs.app_configs import REDIS_REPLICA_HOST from onyx.configs.app_configs import REDIS_SSL from onyx.configs.app_configs import REDIS_SSL_CA_CERTS from onyx.configs.app_configs import REDIS_SSL_CERT_REQS from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS from onyx.utils.logger import setup_logger from shared_configs.configs import DEFAULT_REDIS_PREFIX from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() SCAN_ITER_COUNT_DEFAULT = 4096 class TenantRedis(redis.Redis): def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.tenant_id: str = tenant_id def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview: prefix: str = f"{self.tenant_id}:" if isinstance(key, str): if key.startswith(prefix): return key else: return prefix + key elif isinstance(key, bytes): prefix_bytes = prefix.encode() if key.startswith(prefix_bytes): return key else: return prefix_bytes + key elif isinstance(key, memoryview): key_bytes = key.tobytes() prefix_bytes = prefix.encode() if key_bytes.startswith(prefix_bytes): return key else: return memoryview(prefix_bytes + key_bytes) else: raise TypeError(f"Unsupported key type: {type(key)}") def _prefix_method(self, method: Callable) -> Callable: @functools.wraps(method) def wrapper(*args: Any, **kwargs: Any) -> Any: if "name" in kwargs: kwargs["name"] = self._prefixed(kwargs["name"]) elif len(args) > 0: args = (self._prefixed(args[0]),) + args[1:] return method(*args, **kwargs) return wrapper def _prefix_scan_iter(self, method: Callable) -> Callable: @functools.wraps(method) def wrapper(*args: Any, **kwargs: Any) -> Any: # Prefix the match pattern if provided if "match" in kwargs: kwargs["match"] = self._prefixed(kwargs["match"]) elif len(args) > 0: args = (self._prefixed(args[0]),) + args[1:] # Get the iterator iterator = method(*args, **kwargs) # Remove prefix from returned keys prefix = f"{self.tenant_id}:".encode() prefix_len = len(prefix) for key in iterator: if isinstance(key, bytes) and key.startswith(prefix): yield key[prefix_len:] else: yield key return wrapper def __getattribute__(self, item: str) -> Any: original_attr = super().__getattribute__(item) methods_to_wrap = [ "lock", "unlock", "get", "set", "delete", "exists", "incrby", "hset", "hget", "getset", "owned", "reacquire", "create_lock", "startswith", "smembers", "sismember", "sadd", "srem", "scard", "hexists", "hset", "hdel", "ttl", ] # Regular methods that need simple prefixing if item == "scan_iter" or item == "sscan_iter": return self._prefix_scan_iter(original_attr) elif item in methods_to_wrap and callable(original_attr): return self._prefix_method(original_attr) return original_attr class RedisPool: _instance: Optional["RedisPool"] = None _lock: threading.Lock = threading.Lock() _pool: redis.BlockingConnectionPool _replica_pool: redis.BlockingConnectionPool def __new__(cls) -> "RedisPool": if not cls._instance: with cls._lock: if not cls._instance: cls._instance = super(RedisPool, cls).__new__(cls) cls._instance._init_pools() return cls._instance def _init_pools(self) -> None: self._pool = RedisPool.create_pool(ssl=REDIS_SSL) self._replica_pool = RedisPool.create_pool( host=REDIS_REPLICA_HOST, ssl=REDIS_SSL ) def get_client(self, tenant_id: str) -> Redis: return TenantRedis(tenant_id, connection_pool=self._pool) def get_replica_client(self, tenant_id: str) -> Redis: return TenantRedis(tenant_id, connection_pool=self._replica_pool) @staticmethod def create_pool( host: str = REDIS_HOST, port: int = REDIS_PORT, db: int = REDIS_DB_NUMBER, password: str = REDIS_PASSWORD, max_connections: int = REDIS_POOL_MAX_CONNECTIONS, ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS, ssl_cert_reqs: str = REDIS_SSL_CERT_REQS, ssl: bool = False, ) -> redis.BlockingConnectionPool: """We use BlockingConnectionPool because it will block and wait for a connection rather than error if max_connections is reached. This is far more deterministic behavior and aligned with how we want to use Redis.""" # Using ConnectionPool is not well documented. # Useful examples: https://github.com/redis/redis-py/issues/780 if ssl: return redis.BlockingConnectionPool( host=host, port=port, db=db, password=password, max_connections=max_connections, timeout=None, health_check_interval=REDIS_HEALTH_CHECK_INTERVAL, socket_keepalive=True, socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS, connection_class=redis.SSLConnection, ssl_ca_certs=ssl_ca_certs, ssl_cert_reqs=ssl_cert_reqs, ) return redis.BlockingConnectionPool( host=host, port=port, db=db, password=password, max_connections=max_connections, timeout=None, health_check_interval=REDIS_HEALTH_CHECK_INTERVAL, socket_keepalive=True, socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS, ) redis_pool = RedisPool() # # Usage example # redis_pool = RedisPool() # redis_client = redis_pool.get_client() # # Example of setting and getting a value # redis_client.set('key', 'value') # value = redis_client.get('key') # print(value.decode()) # Output: 'value' def get_redis_client( *, # This argument will be deprecated in the future tenant_id: str | None = None, ) -> Redis: if tenant_id is None: tenant_id = get_current_tenant_id() return redis_pool.get_client(tenant_id) def get_redis_replica_client( *, # this argument will be deprecated in the future tenant_id: str | None = None, ) -> Redis: if tenant_id is None: tenant_id = get_current_tenant_id() return redis_pool.get_replica_client(tenant_id) def get_shared_redis_client() -> Redis: return redis_pool.get_client(DEFAULT_REDIS_PREFIX) def get_shared_redis_replica_client() -> Redis: return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX) SSL_CERT_REQS_MAP = { "none": ssl.CERT_NONE, "optional": ssl.CERT_OPTIONAL, "required": ssl.CERT_REQUIRED, } _async_redis_connection: aioredis.Redis | None = None _async_lock = asyncio.Lock() async def get_async_redis_connection() -> aioredis.Redis: """ Provides a shared async Redis connection, using the same configs (host, port, SSL, etc.). Ensures that the connection is created only once (lazily) and reused for all future calls. """ global _async_redis_connection # If we haven't yet created an async Redis connection, we need to create one if _async_redis_connection is None: # Acquire the lock to ensure that only one coroutine attempts to create the connection async with _async_lock: # Double-check inside the lock to avoid race conditions if _async_redis_connection is None: # Load env vars or your config variables connection_kwargs: dict[str, Any] = { "host": REDIS_HOST, "port": REDIS_PORT, "db": REDIS_DB_NUMBER, "password": REDIS_PASSWORD, "max_connections": REDIS_POOL_MAX_CONNECTIONS, "health_check_interval": REDIS_HEALTH_CHECK_INTERVAL, "socket_keepalive": True, "socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS, } if REDIS_SSL: ssl_context = ssl.create_default_context() if REDIS_SSL_CA_CERTS: ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS) ssl_context.check_hostname = False # Map your string to the proper ssl.CERT_* constant ssl_context.verify_mode = SSL_CERT_REQS_MAP.get( REDIS_SSL_CERT_REQS, ssl.CERT_NONE ) connection_kwargs["ssl"] = ssl_context # Create a new Redis connection (or connection pool) with SSL configuration _async_redis_connection = aioredis.Redis(**connection_kwargs) # Return the established connection (or pool) for all future operations return _async_redis_connection async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None: token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME) if not token: logger.debug("No auth token cookie found") return None try: redis = await get_async_redis_connection() redis_key = REDIS_AUTH_KEY_PREFIX + token token_data_str = await redis.get(redis_key) if not token_data_str: logger.debug(f"Token key {redis_key} not found or expired in Redis") return None return json.loads(token_data_str) except json.JSONDecodeError: logger.error("Error decoding token data from Redis") return None except Exception as e: logger.error( f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}" ) raise ValueError( f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}" ) def redis_lock_dump(lock: RedisLock, r: Redis) -> None: # diagnostic logging for lock errors name = lock.name ttl = r.ttl(name) locked = lock.locked() owned = lock.owned() local_token: str | None = lock.local.token # type: ignore remote_token_raw = r.get(lock.name) if remote_token_raw: remote_token_bytes = cast(bytes, remote_token_raw) remote_token = remote_token_bytes.decode("utf-8") else: remote_token = None logger.warning( f"RedisLock diagnostic: " f"name={name} " f"locked={locked} " f"owned={owned} " f"local_token={local_token} " f"remote_token={remote_token} " f"ttl={ttl}" )