diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index d153c7dda..900645c83 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -19,6 +19,8 @@ from open_webui.env import ( DATABASE_URL, ENV, REDIS_URL, + SENTINEL_PORT, + SENTINEL_HOSTS, FRONTEND_BUILD_DIR, OFFLINE_MODE, OPEN_WEBUI_DIR, @@ -28,7 +30,7 @@ from open_webui.env import ( log, ) from open_webui.internal.db import Base, get_db - +from open_webui.utils.redis import get_redis_connection class EndpointFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: @@ -252,11 +254,11 @@ class AppConfig: _state: dict[str, PersistentConfig] _redis: Optional[redis.Redis] = None - def __init__(self, redis_url: Optional[str] = None): + def __init__(self, redis_url: Optional[str] = None, sentinels: Optional[list] = []): super().__setattr__("_state", {}) if redis_url: super().__setattr__( - "_redis", redis.Redis.from_url(redis_url, decode_responses=True) + "_redis", get_redis_connection(redis_url, sentinels, decode_responses=True) ) def __setattr__(self, key, value): diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 274cd9245..1f7f58445 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -323,6 +323,8 @@ ENABLE_REALTIME_CHAT_SAVE = ( #################################### REDIS_URL = os.environ.get("REDIS_URL", "") +SENTINEL_HOSTS = os.environ.get("SENTINEL_HOSTS", "") +SENTINEL_PORT = os.environ.get("SENTINEL_PORT", "26379") #################################### # WEBUI_AUTH (Required for security) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 1ea79aa26..eeac2b3c7 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -315,6 +315,8 @@ from open_webui.env import ( AUDIT_LOG_LEVEL, CHANGELOG, REDIS_URL, + SENTINEL_HOSTS, + SENTINEL_PORT, GLOBAL_LOG_LEVEL, MAX_BODY_LOG_SIZE, SAFE_MODE, @@ -358,6 +360,9 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.tasks import stop_task, list_tasks # Import from tasks.py +from open_webui.utils.redis import get_sentinels_from_env + + if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() @@ -421,7 +426,7 @@ app = FastAPI( oauth_manager = OAuthManager(app) -app.state.config = AppConfig(redis_url=REDIS_URL) +app.state.config = AppConfig(redis_url=REDIS_URL, sentinels=get_sentinels_from_env(SENTINEL_HOSTS, SENTINEL_PORT)) app.state.WEBUI_NAME = WEBUI_NAME app.state.LICENSE_METADATA = None diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 925e28fbe..b681d7e3b 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -8,7 +8,7 @@ from redis import asyncio as aioredis from open_webui.models.users import Users, UserNameResponse from open_webui.models.channels import Channels from open_webui.models.chats import Chats -from open_webui.utils.redis import parse_redis_sentinel_url, AsyncRedisSentinelManager +from open_webui.utils.redis import parse_redis_sentinel_url, get_sentinels_from_env, AsyncRedisSentinelManager from open_webui.env import ( ENABLE_WEBSOCKET_SUPPORT, @@ -64,9 +64,7 @@ TIMEOUT_DURATION = 3 if WEBSOCKET_MANAGER == "redis": log.debug("Using Redis to manage websockets.") - sentinel_hosts=WEBSOCKET_SENTINEL_HOSTS.split(',') - sentinel_port=int(WEBSOCKET_SENTINEL_PORT) - sentinels=[(host, sentinel_port) for host in sentinel_hosts] + sentinels=get_sentinels_from_env(WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT) SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL, sentinels=sentinels) USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL, sentinels=sentinels) USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL, sentinels=sentinels) diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py index 07512fe13..ff39f96ac 100644 --- a/backend/open_webui/utils/redis.py +++ b/backend/open_webui/utils/redis.py @@ -33,7 +33,12 @@ def get_redis_connection(redis_url, sentinels, decode_responses=True): else: # Standard Redis connection return redis.Redis.from_url(redis_url, decode_responses=decode_responses) - + +def get_sentinels_from_env(SENTINEL_HOSTS, SENTINEL_PORT): + sentinel_hosts=SENTINEL_HOSTS.split(',') + sentinel_port=int(SENTINEL_PORT) + return [(host, sentinel_port) for host in sentinel_hosts] + class AsyncRedisSentinelManager(socketio.AsyncRedisManager): def __init__(self, sentinel_hosts, sentinel_port=26379, redis_port=6379, service="mymaster", db=0, username=None, password=None, channel='socketio', write_only=False, logger=None, redis_options=None): @@ -81,4 +86,4 @@ class AsyncRedisSentinelManager(socketio.AsyncRedisManager): ) self.redis = sentinel.master_for(self._service) - self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) \ No newline at end of file + self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)