Merge pull request from df-cgdm/main

feat: Add  X-OpenWebUI when forwarding to ollama servers
This commit is contained in:
Timothy Jaeryang Baek 2025-02-24 11:55:04 -08:00 committed by GitHub
commit d8bc3098db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 175 additions and 33 deletions
backend/open_webui

@ -919,7 +919,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
models = await get_all_models(request)
models = await get_all_models(request, user=user)
# Filter out filter pipelines
models = [
@ -959,7 +959,7 @@ async def chat_completion(
user=Depends(get_verified_user),
):
if not request.app.state.MODELS:
await get_all_models(request)
await get_all_models(request, user=user)
model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None)

@ -14,6 +14,11 @@ from urllib.parse import urlparse
import aiohttp
from aiocache import cached
import requests
from open_webui.models.users import UserModel
from open_webui.env import (
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from fastapi import (
Depends,
@ -66,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
##########################################
async def send_get_request(url, key=None):
async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
url,
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) as response:
return await response.json()
except Exception as e:
@ -96,6 +115,7 @@ async def send_post_request(
stream: bool = True,
key: Optional[str] = None,
content_type: Optional[str] = None,
user: UserModel = None,
):
r = None
@ -110,6 +130,16 @@ async def send_post_request(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@ -191,7 +221,19 @@ async def verify_connection(
try:
async with session.get(
f"{url}/api/version",
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) as r:
if r.status != 200:
detail = f"HTTP Error: {r.status}"
@ -254,7 +296,7 @@ async def update_config(
@cached(ttl=3)
async def get_all_models(request: Request):
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = []
@ -262,7 +304,7 @@ async def get_all_models(request: Request):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/tags"))
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
@ -275,7 +317,9 @@ async def get_all_models(request: Request):
key = api_config.get("key", None)
if enable:
request_tasks.append(send_get_request(f"{url}/api/tags", key))
request_tasks.append(
send_get_request(f"{url}/api/tags", key, user=user)
)
else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
@ -360,7 +404,7 @@ async def get_ollama_tags(
models = []
if url_idx is None:
models = await get_all_models(request)
models = await get_all_models(request, user=user)
else:
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@ -370,7 +414,19 @@ async def get_ollama_tags(
r = requests.request(
method="GET",
url=f"{url}/api/tags",
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@ -477,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
url, {}
), # Legacy support
).get("key", None),
user=user,
)
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
]
@ -509,6 +566,7 @@ async def pull_model(
url=f"{url}/api/pull",
payload=json.dumps(payload),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -527,7 +585,7 @@ async def push_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name in models:
@ -545,6 +603,7 @@ async def push_model(
url=f"{url}/api/push",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -571,6 +630,7 @@ async def create_model(
url=f"{url}/api/create",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -588,7 +648,7 @@ async def copy_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.source in models:
@ -609,6 +669,16 @@ async def copy_model(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@ -643,7 +713,7 @@ async def delete_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name in models:
@ -665,6 +735,16 @@ async def delete_model(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@ -693,7 +773,7 @@ async def delete_model(
async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name not in models:
@ -714,6 +794,16 @@ async def show_model_info(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@ -757,7 +847,7 @@ async def embed(
log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@ -783,6 +873,16 @@ async def embed(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@ -826,7 +926,7 @@ async def embeddings(
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@ -852,6 +952,16 @@ async def embeddings(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@ -901,7 +1011,7 @@ async def generate_completion(
user=Depends(get_verified_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@ -931,6 +1041,7 @@ async def generate_completion(
url=f"{url}/api/generate",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -1060,6 +1171,7 @@ async def generate_chat_completion(
stream=form_data.stream,
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
content_type="application/x-ndjson",
user=user,
)
@ -1162,6 +1274,7 @@ async def generate_openai_completion(
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -1240,6 +1353,7 @@ async def generate_openai_chat_completion(
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -1253,7 +1367,7 @@ async def get_openai_models(
models = []
if url_idx is None:
model_list = await get_all_models(request)
model_list = await get_all_models(request, user=user)
models = [
{
"id": model["model"],

@ -26,6 +26,7 @@ from open_webui.env import (
ENABLE_FORWARD_USER_INFO_HEADERS,
BYPASS_MODEL_ACCESS_CONTROL,
)
from open_webui.models.users import UserModel
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
##########################################
async def send_get_request(url, key=None):
async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
url,
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as response:
return await response.json()
except Exception as e:
@ -247,7 +261,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def get_all_models_responses(request: Request) -> list:
async def get_all_models_responses(request: Request, user: UserModel) -> list:
if not request.app.state.config.ENABLE_OPENAI_API:
return []
@ -271,7 +285,9 @@ async def get_all_models_responses(request: Request) -> list:
):
request_tasks.append(
send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
)
)
else:
@ -291,6 +307,7 @@ async def get_all_models_responses(request: Request) -> list:
send_get_request(
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
)
)
else:
@ -352,13 +369,13 @@ async def get_filtered_models(models, user):
@cached(ttl=3)
async def get_all_models(request: Request) -> dict[str, list]:
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")
if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []}
responses = await get_all_models_responses(request)
responses = await get_all_models_responses(request, user=user)
def extract_data(response):
if response and "data" in response:
@ -418,7 +435,7 @@ async def get_models(
}
if url_idx is None:
models = await get_all_models(request)
models = await get_all_models(request, user=user)
else:
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
@ -515,6 +532,16 @@ async def verify_connection(
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as r:
if r.status != 200:
@ -587,7 +614,7 @@ async def generate_chat_completion(
detail="Model not found",
)
await get_all_models(request)
await get_all_models(request, user=user)
model = request.app.state.OPENAI_MODELS.get(model_id)
if model:
idx = model["urlIdx"]

@ -285,7 +285,7 @@ chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS:
await get_all_models(request)
await get_all_models(request, user=user)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
@ -351,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
raise Exception(f"Action not found: {action_id}")
if not request.app.state.MODELS:
await get_all_models(request)
await get_all_models(request, user=user)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {

@ -22,6 +22,7 @@ from open_webui.config import (
)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
from open_webui.models.users import UserModel
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
@ -29,17 +30,17 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def get_all_base_models(request: Request):
async def get_all_base_models(request: Request, user: UserModel = None):
function_models = []
openai_models = []
ollama_models = []
if request.app.state.config.ENABLE_OPENAI_API:
openai_models = await openai.get_all_models(request)
openai_models = await openai.get_all_models(request, user=user)
openai_models = openai_models["data"]
if request.app.state.config.ENABLE_OLLAMA_API:
ollama_models = await ollama.get_all_models(request)
ollama_models = await ollama.get_all_models(request, user=user)
ollama_models = [
{
"id": model["model"],
@ -58,8 +59,8 @@ async def get_all_base_models(request: Request):
return models
async def get_all_models(request):
models = await get_all_base_models(request)
async def get_all_models(request, user: UserModel = None):
models = await get_all_base_models(request, user=user)
# If there are no models, return an empty list
if len(models) == 0: