From 6d62e71c3431f3305a3802970bf5a18055a7fe39 Mon Sep 17 00:00:00 2001 From: Didier FOURNOUT Date: Thu, 13 Feb 2025 15:29:26 +0000 Subject: [PATCH 1/2] Add x-Open-Webui headers for ollama + more for openai --- backend/open_webui/main.py | 4 +- backend/open_webui/routers/ollama.py | 144 ++++++++++++++++++++++++--- backend/open_webui/routers/openai.py | 43 ++++++-- backend/open_webui/utils/chat.py | 4 +- backend/open_webui/utils/models.py | 11 +- 5 files changed, 173 insertions(+), 33 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 88b5b3f69..3e5f20cee 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -858,7 +858,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)): return filtered_models - models = await get_all_models(request) + models = await get_all_models(request, user=user) # Filter out filter pipelines models = [ @@ -898,7 +898,7 @@ async def chat_completion( user=Depends(get_verified_user), ): if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) model_item = form_data.pop("model_item", {}) tasks = form_data.pop("background_tasks", None) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 64373c616..e825848d4 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -14,6 +14,11 @@ from urllib.parse import urlparse import aiohttp from aiocache import cached import requests +from open_webui.models.users import UserModel + +from open_webui.env import ( + ENABLE_FORWARD_USER_INFO_HEADERS, +) from fastapi import ( Depends, @@ -66,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) ########################################## -async def send_get_request(url, key=None): +async def send_get_request(url, key=None, user: UserModel = None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + url, + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) as response: return await response.json() except Exception as e: @@ -96,6 +115,7 @@ async def send_post_request( stream: bool = True, key: Optional[str] = None, content_type: Optional[str] = None, + user: UserModel = None ): r = None @@ -110,6 +130,16 @@ async def send_post_request( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, ) r.raise_for_status() @@ -191,7 +221,19 @@ async def verify_connection( try: async with session.get( f"{url}/api/version", - headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) as r: if r.status != 200: detail = f"HTTP Error: {r.status}" @@ -254,7 +296,7 @@ async def update_config( @cached(ttl=3) -async def get_all_models(request: Request): +async def get_all_models(request: Request, user: UserModel=None): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: request_tasks = [] @@ -262,7 +304,7 @@ async def get_all_models(request: Request): if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support ): - request_tasks.append(send_get_request(f"{url}/api/tags")) + request_tasks.append(send_get_request(f"{url}/api/tags", user=user)) else: api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(idx), @@ -275,7 +317,7 @@ async def get_all_models(request: Request): key = api_config.get("key", None) if enable: - request_tasks.append(send_get_request(f"{url}/api/tags", key)) + request_tasks.append(send_get_request(f"{url}/api/tags", key, user=user)) else: request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) @@ -360,7 +402,7 @@ async def get_ollama_tags( models = [] if url_idx is None: - models = await get_all_models(request) + models = await get_all_models(request, user=user) else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) @@ -370,7 +412,19 @@ async def get_ollama_tags( r = requests.request( method="GET", url=f"{url}/api/tags", - headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) r.raise_for_status() @@ -477,6 +531,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u url, {} ), # Legacy support ).get("key", None), + user=user ) for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) ] @@ -509,6 +564,7 @@ async def pull_model( url=f"{url}/api/pull", payload=json.dumps(payload), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -527,7 +583,7 @@ async def push_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -545,6 +601,7 @@ async def push_model( url=f"{url}/api/push", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -571,6 +628,7 @@ async def create_model( url=f"{url}/api/create", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -588,7 +646,7 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.source in models: @@ -609,6 +667,16 @@ async def copy_model( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -643,7 +711,7 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -665,6 +733,16 @@ async def delete_model( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, ) r.raise_for_status() @@ -693,7 +771,7 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name not in models: @@ -714,6 +792,16 @@ async def show_model_info( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -757,7 +845,7 @@ async def embed( log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -783,6 +871,16 @@ async def embed( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -826,7 +924,7 @@ async def embeddings( log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -852,6 +950,16 @@ async def embeddings( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -901,7 +1009,7 @@ async def generate_completion( user=Depends(get_verified_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -931,6 +1039,7 @@ async def generate_completion( url=f"{url}/api/generate", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1047,6 +1156,7 @@ async def generate_chat_completion( stream=form_data.stream, key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", + user=user, ) @@ -1149,6 +1259,7 @@ async def generate_openai_completion( payload=json.dumps(payload), stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1227,6 +1338,7 @@ async def generate_openai_chat_completion( payload=json.dumps(payload), stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1240,7 +1352,7 @@ async def get_openai_models( models = [] if url_idx is None: - model_list = await get_all_models(request) + model_list = await get_all_models(request, user=user) models = [ { "id": model["model"], diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index afda36237..f0d5d81dd 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -26,6 +26,7 @@ from open_webui.env import ( ENABLE_FORWARD_USER_INFO_HEADERS, BYPASS_MODEL_ACCESS_CONTROL, ) +from open_webui.models.users import UserModel from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS @@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"]) ########################################## -async def send_get_request(url, key=None): +async def send_get_request(url, key=None, user: UserModel=None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + url, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + } ) as response: return await response.json() except Exception as e: @@ -247,7 +261,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def get_all_models_responses(request: Request) -> list: +async def get_all_models_responses(request: Request, user: UserModel) -> list: if not request.app.state.config.ENABLE_OPENAI_API: return [] @@ -271,7 +285,9 @@ async def get_all_models_responses(request: Request) -> list: ): request_tasks.append( send_get_request( - f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] + f"{url}/models", + request.app.state.config.OPENAI_API_KEYS[idx], + user=user, ) ) else: @@ -291,6 +307,7 @@ async def get_all_models_responses(request: Request) -> list: send_get_request( f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx], + user=user, ) ) else: @@ -352,13 +369,13 @@ async def get_filtered_models(models, user): @cached(ttl=3) -async def get_all_models(request: Request) -> dict[str, list]: +async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: log.info("get_all_models()") if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} - responses = await get_all_models_responses(request) + responses = await get_all_models_responses(request, user=user) def extract_data(response): if response and "data" in response: @@ -418,7 +435,7 @@ async def get_models( } if url_idx is None: - models = await get_all_models(request) + models = await get_all_models(request, user=user) else: url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx] @@ -515,6 +532,16 @@ async def verify_connection( headers={ "Authorization": f"Bearer {key}", "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), }, ) as r: if r.status != 200: @@ -587,7 +614,7 @@ async def generate_chat_completion( detail="Model not found", ) - await get_all_models(request) + await get_all_models(request, user=user) model = request.app.state.OPENAI_MODELS.get(model_id) if model: idx = model["urlIdx"] diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 253eaedfb..d8f44590b 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -285,7 +285,7 @@ chat_completion = generate_chat_completion async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { @@ -351,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A raise Exception(f"Action not found: {action_id}") if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 975f8cb09..00f8fd666 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -22,6 +22,7 @@ from open_webui.config import ( ) from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL +from open_webui.models.users import UserModel logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -29,17 +30,17 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -async def get_all_base_models(request: Request): +async def get_all_base_models(request: Request, user: UserModel=None): function_models = [] openai_models = [] ollama_models = [] if request.app.state.config.ENABLE_OPENAI_API: - openai_models = await openai.get_all_models(request) + openai_models = await openai.get_all_models(request, user=user) openai_models = openai_models["data"] if request.app.state.config.ENABLE_OLLAMA_API: - ollama_models = await ollama.get_all_models(request) + ollama_models = await ollama.get_all_models(request, user=user) ollama_models = [ { "id": model["model"], @@ -58,8 +59,8 @@ async def get_all_base_models(request: Request): return models -async def get_all_models(request): - models = await get_all_base_models(request) +async def get_all_models(request, user: UserModel=None): + models = await get_all_base_models(request, user=user) # If there are no models, return an empty list if len(models) == 0: From 06062568c7f24a6c3a7e76c9070c11941b4018d5 Mon Sep 17 00:00:00 2001 From: Didier FOURNOUT Date: Thu, 13 Feb 2025 16:12:46 +0000 Subject: [PATCH 2/2] black formatting --- backend/open_webui/routers/ollama.py | 10 ++++++---- backend/open_webui/routers/openai.py | 4 ++-- backend/open_webui/utils/models.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index e825848d4..a3d506449 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -115,7 +115,7 @@ async def send_post_request( stream: bool = True, key: Optional[str] = None, content_type: Optional[str] = None, - user: UserModel = None + user: UserModel = None, ): r = None @@ -296,7 +296,7 @@ async def update_config( @cached(ttl=3) -async def get_all_models(request: Request, user: UserModel=None): +async def get_all_models(request: Request, user: UserModel = None): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: request_tasks = [] @@ -317,7 +317,9 @@ async def get_all_models(request: Request, user: UserModel=None): key = api_config.get("key", None) if enable: - request_tasks.append(send_get_request(f"{url}/api/tags", key, user=user)) + request_tasks.append( + send_get_request(f"{url}/api/tags", key, user=user) + ) else: request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) @@ -531,7 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u url, {} ), # Legacy support ).get("key", None), - user=user + user=user, ) for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) ] diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index f0d5d81dd..1ef913df4 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -52,7 +52,7 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"]) ########################################## -async def send_get_request(url, key=None, user: UserModel=None): +async def send_get_request(url, key=None, user: UserModel = None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: @@ -70,7 +70,7 @@ async def send_get_request(url, key=None, user: UserModel=None): if ENABLE_FORWARD_USER_INFO_HEADERS else {} ), - } + }, ) as response: return await response.json() except Exception as e: diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 00f8fd666..872049f0f 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -30,7 +30,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -async def get_all_base_models(request: Request, user: UserModel=None): +async def get_all_base_models(request: Request, user: UserModel = None): function_models = [] openai_models = [] ollama_models = [] @@ -59,7 +59,7 @@ async def get_all_base_models(request: Request, user: UserModel=None): return models -async def get_all_models(request, user: UserModel=None): +async def get_all_models(request, user: UserModel = None): models = await get_all_base_models(request, user=user) # If there are no models, return an empty list