diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 7c7d922b6..d3085cf54 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -19,6 +19,8 @@ from open_webui.env import ( DATABASE_URL, ENV, REDIS_URL, + REDIS_SENTINEL_HOSTS, + REDIS_SENTINEL_PORT, FRONTEND_BUILD_DIR, OFFLINE_MODE, OPEN_WEBUI_DIR, @@ -28,7 +30,7 @@ from open_webui.env import ( log, ) from open_webui.internal.db import Base, get_db - +from open_webui.utils.redis import get_redis_connection class EndpointFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: @@ -252,11 +254,11 @@ class AppConfig: _state: dict[str, PersistentConfig] _redis: Optional[redis.Redis] = None - def __init__(self, redis_url: Optional[str] = None): + def __init__(self, redis_url: Optional[str] = None, redis_sentinels: Optional[list] = []): super().__setattr__("_state", {}) if redis_url: super().__setattr__( - "_redis", redis.Redis.from_url(redis_url, decode_responses=True) + "_redis", get_redis_connection(redis_url, redis_sentinels, decode_responses=True) ) def __setattr__(self, key, value): diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 2a327aa5d..e3819fdc5 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -323,6 +323,8 @@ ENABLE_REALTIME_CHAT_SAVE = ( #################################### REDIS_URL = os.environ.get("REDIS_URL", "") +REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "") +REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379") #################################### # WEBUI_AUTH (Required for security) @@ -379,6 +381,10 @@ WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60) +WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") + +WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379") + AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") if AIOHTTP_CLIENT_TIMEOUT == "": diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7fc4e3983..34cf68069 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -317,6 +317,8 @@ from open_webui.env import ( AUDIT_LOG_LEVEL, CHANGELOG, REDIS_URL, + REDIS_SENTINEL_HOSTS, + REDIS_SENTINEL_PORT, GLOBAL_LOG_LEVEL, MAX_BODY_LOG_SIZE, SAFE_MODE, @@ -360,6 +362,9 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.tasks import stop_task, list_tasks # Import from tasks.py +from open_webui.utils.redis import get_sentinels_from_env + + if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() @@ -423,7 +428,7 @@ app = FastAPI( oauth_manager = OAuthManager(app) -app.state.config = AppConfig(redis_url=REDIS_URL) +app.state.config = AppConfig(redis_url=REDIS_URL, redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT)) app.state.WEBUI_NAME = WEBUI_NAME app.state.LICENSE_METADATA = None diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 3bf964e5b..f92b7b5e3 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -3,16 +3,20 @@ import socketio import logging import sys import time +from redis import asyncio as aioredis from open_webui.models.users import Users, UserNameResponse from open_webui.models.channels import Channels from open_webui.models.chats import Chats +from open_webui.utils.redis import parse_redis_sentinel_url, get_sentinels_from_env, AsyncRedisSentinelManager from open_webui.env import ( ENABLE_WEBSOCKET_SUPPORT, WEBSOCKET_MANAGER, WEBSOCKET_REDIS_URL, WEBSOCKET_REDIS_LOCK_TIMEOUT, + WEBSOCKET_SENTINEL_PORT, + WEBSOCKET_SENTINEL_HOSTS, ) from open_webui.utils.auth import decode_token from open_webui.socket.utils import RedisDict, RedisLock @@ -29,7 +33,12 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"]) if WEBSOCKET_MANAGER == "redis": - mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL) + if WEBSOCKET_SENTINEL_HOSTS: + redis_config = parse_redis_sentinel_url(WEBSOCKET_REDIS_URL) + mgr = AsyncRedisSentinelManager(WEBSOCKET_SENTINEL_HOSTS.split(','), sentinel_port=int(WEBSOCKET_SENTINEL_PORT), redis_port=redis_config["port"], + service=redis_config["service"], db=redis_config["db"], username=redis_config["username"], password=redis_config["password"]) + else: + mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL) sio = socketio.AsyncServer( cors_allowed_origins=[], async_mode="asgi", @@ -55,14 +64,16 @@ TIMEOUT_DURATION = 3 if WEBSOCKET_MANAGER == "redis": log.debug("Using Redis to manage websockets.") - SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL) - USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL) - USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL) + redis_sentinels=get_sentinels_from_env(WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT) + SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels) + USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels) + USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels) clean_up_lock = RedisLock( redis_url=WEBSOCKET_REDIS_URL, lock_name="usage_cleanup_lock", timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT, + redis_sentinels=redis_sentinels, ) aquire_func = clean_up_lock.aquire_lock renew_func = clean_up_lock.renew_lock diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index 46fafbb9e..a63815c02 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -1,15 +1,14 @@ import json -import redis import uuid - +from open_webui.utils.redis import get_redis_connection class RedisLock: - def __init__(self, redis_url, lock_name, timeout_secs): + def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]): self.lock_name = lock_name self.lock_id = str(uuid.uuid4()) self.timeout_secs = timeout_secs self.lock_obtained = False - self.redis = redis.Redis.from_url(redis_url, decode_responses=True) + self.redis = get_redis_connection(redis_url, redis_sentinels, decode_responses=True) def aquire_lock(self): # nx=True will only set this key if it _hasn't_ already been set @@ -31,9 +30,9 @@ class RedisLock: class RedisDict: - def __init__(self, name, redis_url): + def __init__(self, name, redis_url, redis_sentinels=[]): self.name = name - self.redis = redis.Redis.from_url(redis_url, decode_responses=True) + self.redis = get_redis_connection(redis_url, redis_sentinels, decode_responses=True) def __setitem__(self, key, value): serialized_value = json.dumps(value) diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py new file mode 100644 index 000000000..fa90a26db --- /dev/null +++ b/backend/open_webui/utils/redis.py @@ -0,0 +1,91 @@ +import socketio +import redis +from redis import asyncio as aioredis +from urllib.parse import urlparse + +def parse_redis_sentinel_url(redis_url): + parsed_url = urlparse(redis_url) + if parsed_url.scheme != "redis": + raise ValueError("Invalid Redis URL scheme. Must be 'redis'.") + + return { + "username": parsed_url.username or None, + "password": parsed_url.password or None, + "service": parsed_url.hostname or 'mymaster', + "port": parsed_url.port or 6379, + "db": int(parsed_url.path.lstrip("/") or 0), + } + +def get_redis_connection(redis_url, redis_sentinels, decode_responses=True): + if redis_sentinels: + redis_config = parse_redis_sentinel_url(redis_url) + sentinel = redis.sentinel.Sentinel( + redis_sentinels, + port=redis_config['port'], + db=redis_config['db'], + username=redis_config['username'], + password=redis_config['password'], + decode_responses=decode_responses + ) + + # Get a master connection from Sentinel + return sentinel.master_for(redis_config['service']) + else: + # Standard Redis connection + return redis.Redis.from_url(redis_url, decode_responses=decode_responses) + +def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env): + if sentinel_hosts_env: + sentinel_hosts=sentinel_hosts_env.split(',') + sentinel_port=int(sentinel_port_env) + return [(host, sentinel_port) for host in sentinel_hosts] + return [] + +class AsyncRedisSentinelManager(socketio.AsyncRedisManager): + def __init__(self, sentinel_hosts, sentinel_port=26379, redis_port=6379, service="mymaster", db=0, + username=None, password=None, channel='socketio', write_only=False, logger=None, redis_options=None): + """ + Initialize the Redis Sentinel Manager. + This implementation mostly replicates the __init__ of AsyncRedisManager and + overrides _redis_connect() with a version that uses Redis Sentinel + + :param sentinel_hosts: List of Sentinel hosts + :param sentinel_port: Sentinel Port + :param redis_port: Redis Port (currently unsupported by aioredis!) + :param service: Master service name in Sentinel + :param db: Redis database to use + :param username: Redis username (if any) (currently unsupported by aioredis!) + :param password: Redis password (if any) + :param channel: The channel name on which the server sends and receives + notifications. Must be the same in all the servers. + :param write_only: If set to ``True``, only initialize to emit events. The + default of ``False`` initializes the class for emitting + and receiving. + :param redis_options: additional keyword arguments to be passed to + ``aioredis.from_url()``. + """ + self._sentinels = [(host, sentinel_port) for host in sentinel_hosts] + self._redis_port=redis_port + self._service = service + self._db = db + self._username = username + self._password = password + self._channel = channel + self.redis_options = redis_options or {} + + # connect and call grandparent constructor + self._redis_connect() + super(socketio.AsyncRedisManager, self).__init__(channel=channel, write_only=write_only, logger=logger) + + def _redis_connect(self): + """Establish connections to Redis through Sentinel.""" + sentinel = aioredis.sentinel.Sentinel( + self._sentinels, + port=self._redis_port, + db=self._db, + password=self._password, + **self.redis_options + ) + + self.redis = sentinel.master_for(self._service) + self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)