From 831fe9f509d0f7438369edda9497e761e3494acf Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 6 Aug 2024 10:15:29 +0100 Subject: [PATCH 01/10] cleanup --- backend/apps/ollama/main.py | 63 ++++++++++++++----------------------- 1 file changed, 24 insertions(+), 39 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 442d99ff2..88b0bc9f2 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,27 +1,21 @@ from fastapi import ( FastAPI, Request, - Response, HTTPException, Depends, - status, UploadFile, File, - BackgroundTasks, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel, ConfigDict import os import re -import copy import random import requests import json -import uuid import aiohttp import asyncio import logging @@ -32,11 +26,8 @@ from typing import Optional, List, Union from starlette.background import BackgroundTask from apps.webui.models.models import Models -from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( - decode_token, - get_current_user, get_verified_user, get_admin_user, ) @@ -183,7 +174,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True): res = await r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -238,7 +229,7 @@ async def get_all_models(): async def get_ollama_tags( url_idx: Optional[int] = None, user=Depends(get_verified_user) ): - if url_idx == None: + if url_idx is None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: @@ -269,7 +260,7 @@ async def get_ollama_tags( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -282,8 +273,7 @@ async def get_ollama_tags( @app.get("/api/version/{url_idx}") async def get_ollama_versions(url_idx: Optional[int] = None): if app.state.config.ENABLE_OLLAMA_API: - if url_idx == None: - + if url_idx is None: # returns lowest version tasks = [ fetch_url(f"{url}/api/version") @@ -323,7 +313,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -367,7 +357,7 @@ async def push_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: if form_data.name in app.state.MODELS: url_idx = app.state.MODELS[form_data.name]["urls"][0] else: @@ -417,7 +407,7 @@ async def copy_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: if form_data.source in app.state.MODELS: url_idx = app.state.MODELS[form_data.source]["urls"][0] else: @@ -448,7 +438,7 @@ async def copy_model( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -464,7 +454,7 @@ async def delete_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: if form_data.name in app.state.MODELS: url_idx = app.state.MODELS[form_data.name]["urls"][0] else: @@ -495,7 +485,7 @@ async def delete_model( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -533,7 +523,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -556,7 +546,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx == None: + if url_idx is None: model = form_data.model if ":" not in model: @@ -590,7 +580,7 @@ async def generate_embeddings( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -603,10 +593,9 @@ def generate_ollama_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, ): - log.info(f"generate_ollama_embeddings {form_data}") - if url_idx == None: + if url_idx is None: model = form_data.model if ":" not in model: @@ -638,7 +627,7 @@ def generate_ollama_embeddings( if "embedding" in data: return data["embedding"] else: - raise "Something went wrong :/" + raise Exception("Something went wrong :/") except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" @@ -647,10 +636,10 @@ def generate_ollama_embeddings( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" - raise error_detail + raise Exception(error_detail) class GenerateCompletionForm(BaseModel): @@ -674,8 +663,7 @@ async def generate_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - - if url_idx == None: + if url_idx is None: model = form_data.model if ":" not in model: @@ -720,7 +708,6 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - log.debug( "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( form_data.model_dump_json(exclude_none=True).encode() @@ -906,7 +893,7 @@ async def generate_chat_completion( system, payload["messages"] ) - if url_idx == None: + if url_idx is None: if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" @@ -1016,7 +1003,7 @@ async def generate_openai_chat_completion( }, ) - if url_idx == None: + if url_idx is None: if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" @@ -1044,7 +1031,7 @@ async def get_openai_models( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx == None: + if url_idx is None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: @@ -1099,7 +1086,7 @@ async def get_openai_models( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -1125,7 +1112,6 @@ def parse_huggingface_url(hf_url): path_components = parsed_url.path.split("/") # Extract the desired output - user_repo = "/".join(path_components[1:3]) model_file = path_components[-1] return model_file @@ -1190,7 +1176,6 @@ async def download_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - allowed_hosts = ["https://huggingface.co/", "https://github.com/"] if not any(form_data.url.startswith(host) for host in allowed_hosts): @@ -1199,7 +1184,7 @@ async def download_model( detail="Invalid file_url. Only URLs from allowed hosts are permitted.", ) - if url_idx == None: + if url_idx is None: url_idx = 0 url = app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -1222,7 +1207,7 @@ def upload_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: url_idx = 0 ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] From 44c781f414007a7968d7c38560668a457c0f1900 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 6 Aug 2024 10:50:22 +0100 Subject: [PATCH 02/10] cleanup --- backend/apps/ollama/main.py | 173 +++++++++++++++--------------------- 1 file changed, 70 insertions(+), 103 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 88b0bc9f2..f1544c80b 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -336,8 +336,6 @@ async def pull_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = None - # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} @@ -418,13 +416,13 @@ async def copy_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/copy", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/copy", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() log.debug(f"r.text: {r.text}") @@ -466,12 +464,12 @@ async def delete_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="DELETE", - url=f"{url}/api/delete", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() log.debug(f"r.text: {r.text}") @@ -506,12 +504,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/show", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/show", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() return r.json() @@ -563,12 +561,12 @@ async def generate_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() return r.json() @@ -612,12 +610,12 @@ def generate_ollama_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() data = r.json() @@ -727,152 +725,121 @@ async def generate_chat_completion( if model_info.base_model_id: payload["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() - if model_info.params: + if params: if payload.get("options") is None: payload["options"] = {} if ( - model_info.params.get("mirostat", None) + params.get("mirostat", None) and payload["options"].get("mirostat") is None ): - payload["options"]["mirostat"] = model_info.params.get("mirostat", None) + payload["options"]["mirostat"] = params.get("mirostat", None) if ( - model_info.params.get("mirostat_eta", None) + params.get("mirostat_eta", None) and payload["options"].get("mirostat_eta") is None ): - payload["options"]["mirostat_eta"] = model_info.params.get( - "mirostat_eta", None - ) + payload["options"]["mirostat_eta"] = params.get("mirostat_eta", None) if ( - model_info.params.get("mirostat_tau", None) + params.get("mirostat_tau", None) and payload["options"].get("mirostat_tau") is None ): - payload["options"]["mirostat_tau"] = model_info.params.get( - "mirostat_tau", None - ) + payload["options"]["mirostat_tau"] = params.get("mirostat_tau", None) if ( - model_info.params.get("num_ctx", None) + params.get("num_ctx", None) and payload["options"].get("num_ctx") is None ): - payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) + payload["options"]["num_ctx"] = params.get("num_ctx", None) if ( - model_info.params.get("num_batch", None) + params.get("num_batch", None) and payload["options"].get("num_batch") is None ): - payload["options"]["num_batch"] = model_info.params.get( - "num_batch", None - ) + payload["options"]["num_batch"] = params.get("num_batch", None) if ( - model_info.params.get("num_keep", None) + params.get("num_keep", None) and payload["options"].get("num_keep") is None ): - payload["options"]["num_keep"] = model_info.params.get("num_keep", None) + payload["options"]["num_keep"] = params.get("num_keep", None) if ( - model_info.params.get("repeat_last_n", None) + params.get("repeat_last_n", None) and payload["options"].get("repeat_last_n") is None ): - payload["options"]["repeat_last_n"] = model_info.params.get( - "repeat_last_n", None - ) + payload["options"]["repeat_last_n"] = params.get("repeat_last_n", None) if ( - model_info.params.get("frequency_penalty", None) + params.get("frequency_penalty", None) and payload["options"].get("frequency_penalty") is None ): - payload["options"]["repeat_penalty"] = model_info.params.get( + payload["options"]["repeat_penalty"] = params.get( "frequency_penalty", None ) if ( - model_info.params.get("temperature", None) is not None + params.get("temperature", None) is not None and payload["options"].get("temperature") is None ): - payload["options"]["temperature"] = model_info.params.get( - "temperature", None - ) + payload["options"]["temperature"] = params.get("temperature", None) if ( - model_info.params.get("seed", None) is not None + params.get("seed", None) is not None and payload["options"].get("seed") is None ): - payload["options"]["seed"] = model_info.params.get("seed", None) + payload["options"]["seed"] = params.get("seed", None) - if ( - model_info.params.get("stop", None) - and payload["options"].get("stop") is None - ): + if params.get("stop", None) and payload["options"].get("stop") is None: payload["options"]["stop"] = ( [ bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] + for stop in params["stop"] ] - if model_info.params.get("stop", None) + if params.get("stop", None) else None ) - if ( - model_info.params.get("tfs_z", None) - and payload["options"].get("tfs_z") is None - ): - payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) + if params.get("tfs_z", None) and payload["options"].get("tfs_z") is None: + payload["options"]["tfs_z"] = params.get("tfs_z", None) if ( - model_info.params.get("max_tokens", None) + params.get("max_tokens", None) and payload["options"].get("max_tokens") is None ): - payload["options"]["num_predict"] = model_info.params.get( - "max_tokens", None - ) + payload["options"]["num_predict"] = params.get("max_tokens", None) + + if params.get("top_k", None) and payload["options"].get("top_k") is None: + payload["options"]["top_k"] = params.get("top_k", None) + + if params.get("top_p", None) and payload["options"].get("top_p") is None: + payload["options"]["top_p"] = params.get("top_p", None) + + if params.get("min_p", None) and payload["options"].get("min_p") is None: + payload["options"]["min_p"] = params.get("min_p", None) if ( - model_info.params.get("top_k", None) - and payload["options"].get("top_k") is None - ): - payload["options"]["top_k"] = model_info.params.get("top_k", None) - - if ( - model_info.params.get("top_p", None) - and payload["options"].get("top_p") is None - ): - payload["options"]["top_p"] = model_info.params.get("top_p", None) - - if ( - model_info.params.get("min_p", None) - and payload["options"].get("min_p") is None - ): - payload["options"]["min_p"] = model_info.params.get("min_p", None) - - if ( - model_info.params.get("use_mmap", None) + params.get("use_mmap", None) and payload["options"].get("use_mmap") is None ): - payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None) + payload["options"]["use_mmap"] = params.get("use_mmap", None) if ( - model_info.params.get("use_mlock", None) + params.get("use_mlock", None) and payload["options"].get("use_mlock") is None ): - payload["options"]["use_mlock"] = model_info.params.get( - "use_mlock", None - ) + payload["options"]["use_mlock"] = params.get("use_mlock", None) if ( - model_info.params.get("num_thread", None) + params.get("num_thread", None) and payload["options"].get("num_thread") is None ): - payload["options"]["num_thread"] = model_info.params.get( - "num_thread", None - ) + payload["options"]["num_thread"] = params.get("num_thread", None) - system = model_info.params.get("system", None) + system = params.get("system", None) if system: system = prompt_template( system, From fc31267a54c4ba20a4a0252b05f46d878791eb55 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 6 Aug 2024 11:31:45 +0100 Subject: [PATCH 03/10] refac: re-use utils.misc --- backend/apps/ollama/main.py | 252 ++++++------------------------------ backend/apps/openai/main.py | 7 +- backend/apps/webui/main.py | 4 +- backend/utils/misc.py | 51 ++++++-- 4 files changed, 85 insertions(+), 229 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index f1544c80b..19d914c4b 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -44,7 +44,13 @@ from config import ( UPLOAD_DIR, AppConfig, ) -from utils.misc import calculate_sha256, add_or_update_system_message +from utils.misc import ( + apply_model_params_to_body_ollama, + calculate_sha256, + add_or_update_system_message, + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) @@ -699,6 +705,18 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None +def get_ollama_url(url_idx: Optional[int], model: str): + if url_idx is None: + if model not in app.state.MODELS: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), + ) + url_idx = random.choice(app.state.MODELS[model]["urls"]) + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + return url + + @app.post("/api/chat") @app.post("/api/chat/{url_idx}") async def generate_chat_completion( @@ -706,17 +724,12 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - log.debug( - "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( - form_data.model_dump_json(exclude_none=True).encode() - ) - ) + log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") payload = { **form_data.model_dump(exclude_none=True, exclude=["metadata"]), } - if "metadata" in payload: - del payload["metadata"] + payload.pop("metadata") model_id = form_data.model model_info = Models.get_model_by_id(model_id) @@ -731,148 +744,15 @@ async def generate_chat_completion( if payload.get("options") is None: payload["options"] = {} - if ( - params.get("mirostat", None) - and payload["options"].get("mirostat") is None - ): - payload["options"]["mirostat"] = params.get("mirostat", None) - - if ( - params.get("mirostat_eta", None) - and payload["options"].get("mirostat_eta") is None - ): - payload["options"]["mirostat_eta"] = params.get("mirostat_eta", None) - - if ( - params.get("mirostat_tau", None) - and payload["options"].get("mirostat_tau") is None - ): - payload["options"]["mirostat_tau"] = params.get("mirostat_tau", None) - - if ( - params.get("num_ctx", None) - and payload["options"].get("num_ctx") is None - ): - payload["options"]["num_ctx"] = params.get("num_ctx", None) - - if ( - params.get("num_batch", None) - and payload["options"].get("num_batch") is None - ): - payload["options"]["num_batch"] = params.get("num_batch", None) - - if ( - params.get("num_keep", None) - and payload["options"].get("num_keep") is None - ): - payload["options"]["num_keep"] = params.get("num_keep", None) - - if ( - params.get("repeat_last_n", None) - and payload["options"].get("repeat_last_n") is None - ): - payload["options"]["repeat_last_n"] = params.get("repeat_last_n", None) - - if ( - params.get("frequency_penalty", None) - and payload["options"].get("frequency_penalty") is None - ): - payload["options"]["repeat_penalty"] = params.get( - "frequency_penalty", None - ) - - if ( - params.get("temperature", None) is not None - and payload["options"].get("temperature") is None - ): - payload["options"]["temperature"] = params.get("temperature", None) - - if ( - params.get("seed", None) is not None - and payload["options"].get("seed") is None - ): - payload["options"]["seed"] = params.get("seed", None) - - if params.get("stop", None) and payload["options"].get("stop") is None: - payload["options"]["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in params["stop"] - ] - if params.get("stop", None) - else None - ) - - if params.get("tfs_z", None) and payload["options"].get("tfs_z") is None: - payload["options"]["tfs_z"] = params.get("tfs_z", None) - - if ( - params.get("max_tokens", None) - and payload["options"].get("max_tokens") is None - ): - payload["options"]["num_predict"] = params.get("max_tokens", None) - - if params.get("top_k", None) and payload["options"].get("top_k") is None: - payload["options"]["top_k"] = params.get("top_k", None) - - if params.get("top_p", None) and payload["options"].get("top_p") is None: - payload["options"]["top_p"] = params.get("top_p", None) - - if params.get("min_p", None) and payload["options"].get("min_p") is None: - payload["options"]["min_p"] = params.get("min_p", None) - - if ( - params.get("use_mmap", None) - and payload["options"].get("use_mmap") is None - ): - payload["options"]["use_mmap"] = params.get("use_mmap", None) - - if ( - params.get("use_mlock", None) - and payload["options"].get("use_mlock") is None - ): - payload["options"]["use_mlock"] = params.get("use_mlock", None) - - if ( - params.get("num_thread", None) - and payload["options"].get("num_thread") is None - ): - payload["options"]["num_thread"] = params.get("num_thread", None) - - system = params.get("system", None) - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), + payload["options"] = apply_model_params_to_body_ollama( + params, payload["options"] ) + payload = apply_model_system_prompt_to_body(params, payload, user) - if payload.get("messages"): - payload["messages"] = add_or_update_system_message( - system, payload["messages"] - ) + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if url_idx is None: - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - if payload["model"] in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(payload) @@ -906,83 +786,27 @@ async def generate_openai_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - form_data = OpenAIChatCompletionForm(**form_data) - payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} + completion_form = OpenAIChatCompletionForm(**form_data) + payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} + payload.pop("metadata") - if "metadata" in payload: - del payload["metadata"] - - model_id = form_data.model + model_id = completion_form.model model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() - if model_info.params: - payload["temperature"] = model_info.params.get("temperature", None) - payload["top_p"] = model_info.params.get("top_p", None) - payload["max_tokens"] = model_info.params.get("max_tokens", None) - payload["frequency_penalty"] = model_info.params.get( - "frequency_penalty", None - ) - payload["seed"] = model_info.params.get("seed", None) - payload["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) + if params: + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_model_system_prompt_to_body(params, payload, user) - system = model_info.params.get("system", None) + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), - ) - # Check if the payload already has a system message - # If not, add a system message to the payload - if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) - - if url_idx is None: - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - if payload["model"] in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") return await post_streaming_url( diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index a0d8f3750..1313d2091 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -17,7 +17,10 @@ from utils.utils import ( get_verified_user, get_admin_user, ) -from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body +from utils.misc import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) from config import ( SRC_LOG_LEVELS, @@ -366,7 +369,7 @@ async def generate_chat_completion( payload["model"] = model_info.base_model_id params = model_info.params.model_dump() - payload = apply_model_params_to_body(params, payload) + payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) model = app.state.MODELS[payload.get("model")] diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index a0b9f5008..6848fdd4d 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id from utils.misc import ( openai_chat_chunk_message_template, openai_chat_completion_message_template, - apply_model_params_to_body, + apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) @@ -289,7 +289,7 @@ async def generate_function_chat_completion(form_data, user): form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - form_data = apply_model_params_to_body(params, form_data) + form_data = apply_model_params_to_body_openai(params, form_data) form_data = apply_model_system_prompt_to_body(params, form_data, user) pipe_id = get_pipe_id(form_data) diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 25dd4dd5b..ffe6a6e53 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -2,7 +2,7 @@ from pathlib import Path import hashlib import re from datetime import timedelta -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, Callable import uuid import time @@ -135,19 +135,12 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di # inplace function: form_data is modified -def apply_model_params_to_body(params: dict, form_data: dict) -> dict: +def apply_model_params_to_body( + params: dict, form_data: dict, mappings: dict[str, Callable] +) -> dict: if not params: return form_data - mappings = { - "temperature": float, - "top_p": int, - "max_tokens": int, - "frequency_penalty": int, - "seed": lambda x: x, - "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], - } - for key, cast_func in mappings.items(): if (value := params.get(key)) is not None: form_data[key] = cast_func(value) @@ -155,6 +148,42 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict: return form_data +OPENAI_MAPPINGS = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], +} + + +# inplace function: form_data is modified +def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: + return apply_model_params_to_body(params, form_data, OPENAI_MAPPINGS) + + +def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: + opts = [ + "mirostat", + "mirostat_eta", + "mirostat_tau", + "num_ctx", + "num_batch", + "num_keep", + "repeat_last_n", + "tfs_z", + "top_k", + "min_p", + "use_mmap", + "use_mlock", + "num_thread", + ] + mappings = {i: lambda x: x for i in opts} + mappings = {**mappings, **OPENAI_MAPPINGS} + return apply_model_params_to_body(params, form_data, mappings) + + def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters From ed205d82e8799a06ea9db4e6334228d99fad6678 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 6 Aug 2024 12:25:00 +0100 Subject: [PATCH 04/10] fix: pop --- backend/apps/ollama/main.py | 4 ++-- backend/apps/openai/main.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 19d914c4b..79a2773ba 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -729,7 +729,7 @@ async def generate_chat_completion( payload = { **form_data.model_dump(exclude_none=True, exclude=["metadata"]), } - payload.pop("metadata") + payload.pop("metadata", None) model_id = form_data.model model_info = Models.get_model_by_id(model_id) @@ -788,7 +788,7 @@ async def generate_openai_chat_completion( ): completion_form = OpenAIChatCompletionForm(**form_data) payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} - payload.pop("metadata") + payload.pop("metadata", None) model_id = completion_form.model model_info = Models.get_model_by_id(model_id) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 1313d2091..44b3151e0 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -359,7 +359,7 @@ async def generate_chat_completion( ): idx = 0 payload = {**form_data} - payload.pop("metadata") + payload.pop("metadata", None) model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) From e6bbce439d81443cf37da9573df1fc9bc4d813e8 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 8 Aug 2024 10:52:09 +0100 Subject: [PATCH 05/10] fix: repeat_penalty --- backend/apps/ollama/main.py | 3 --- backend/utils/misc.py | 7 ++++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 79a2773ba..3d86b852a 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -31,8 +31,6 @@ from utils.utils import ( get_verified_user, get_admin_user, ) -from utils.task import prompt_template - from config import ( SRC_LOG_LEVELS, @@ -47,7 +45,6 @@ from config import ( from utils.misc import ( apply_model_params_to_body_ollama, calculate_sha256, - add_or_update_system_message, apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) diff --git a/backend/utils/misc.py b/backend/utils/misc.py index ffe6a6e53..993aa9f60 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -181,7 +181,12 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: ] mappings = {i: lambda x: x for i in opts} mappings = {**mappings, **OPENAI_MAPPINGS} - return apply_model_params_to_body(params, form_data, mappings) + form_data = apply_model_params_to_body(params, form_data, mappings) + + # only param that changes name + if (param := params.get("frequency_penalty", None)) is not None: + form_data["repeat_penalty"] = param + return form_data def get_gravatar_url(email): From 8cdf9814bde8049a31c09961263f3ef91d4607c0 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 8 Aug 2024 11:01:00 +0100 Subject: [PATCH 06/10] fix: name differences --- backend/utils/misc.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 993aa9f60..3dc1cf7ee 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -148,23 +148,24 @@ def apply_model_params_to_body( return form_data -OPENAI_MAPPINGS = { - "temperature": float, - "top_p": int, - "max_tokens": int, - "frequency_penalty": int, - "seed": lambda x: x, - "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], -} - - # inplace function: form_data is modified def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: - return apply_model_params_to_body(params, form_data, OPENAI_MAPPINGS) + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + } + return apply_model_params_to_body(params, form_data, mappings) def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: opts = [ + "temperature", + "top_p", + "seed", "mirostat", "mirostat_eta", "mirostat_tau", @@ -180,12 +181,18 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: "num_thread", ] mappings = {i: lambda x: x for i in opts} - mappings = {**mappings, **OPENAI_MAPPINGS} form_data = apply_model_params_to_body(params, form_data, mappings) - # only param that changes name - if (param := params.get("frequency_penalty", None)) is not None: - form_data["repeat_penalty"] = param + name_differences = { + "max_tokens": "num_predict", + "frequency_penalty": "repeat_penalty", + } + + for key, value in name_differences.items(): + if (param := params.get(key, None)) is not None: + form_data[value] = param + + print(form_data) return form_data From a725801e5559956f33291706b89af774cbefb472 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 8 Aug 2024 11:30:13 +0100 Subject: [PATCH 07/10] fix: formatting test errors, remove print, merge dev --- backend/apps/images/main.py | 2 +- backend/apps/images/utils/comfyui.py | 11 ++++++----- backend/apps/openai/main.py | 4 ++-- backend/utils/misc.py | 1 - 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 4239f3f45..a418f2693 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -514,7 +514,7 @@ async def image_generations( data = ImageGenerationPayload(**data) - res = comfyui_generate_image( + res = await comfyui_generate_image( app.state.config.MODEL, data, user.id, diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index 6c37f0c49..ec0f8c59e 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -1,3 +1,4 @@ +import asyncio import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) import uuid import json @@ -328,7 +329,7 @@ class ImageGenerationPayload(BaseModel): flux_fp8_clip: Optional[bool] = None -def comfyui_generate_image( +async def comfyui_generate_image( model: str, payload: ImageGenerationPayload, client_id, base_url ): ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") @@ -377,9 +378,9 @@ def comfyui_generate_image( comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype if payload.flux_fp8_clip: - comfyui_prompt["11"]["inputs"][ - "clip_name2" - ] = "t5xxl_fp8_e4m3fn.safetensors" + comfyui_prompt["11"]["inputs"]["clip_name2"] = ( + "t5xxl_fp8_e4m3fn.safetensors" + ) comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n comfyui_prompt["5"]["inputs"]["width"] = payload.width @@ -397,7 +398,7 @@ def comfyui_generate_image( return None try: - images = get_images(ws, comfyui_prompt, client_id, base_url) + images = await asyncio.to_thread(get_images, ws, comfyui_prompt, client_id, base_url) except Exception as e: log.exception(f"Error while receiving images: {e}") images = None diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 831da783b..50de53a53 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -359,10 +359,10 @@ async def generate_chat_completion( ): idx = 0 payload = {**form_data} - + if "metadata" in payload: del payload["metadata"] - + model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 3dc1cf7ee..9de19d3f6 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -192,7 +192,6 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: if (param := params.get(key, None)) is not None: form_data[value] = param - print(form_data) return form_data From 309cd645f13de63d36650293f1424195955168d5 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 8 Aug 2024 12:30:07 +0100 Subject: [PATCH 08/10] undo del --- backend/apps/ollama/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 3d86b852a..d5ef82942 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -726,7 +726,8 @@ async def generate_chat_completion( payload = { **form_data.model_dump(exclude_none=True, exclude=["metadata"]), } - payload.pop("metadata", None) + if "metadata" in payload: + del payload["metadata"] model_id = form_data.model model_info = Models.get_model_by_id(model_id) @@ -785,7 +786,8 @@ async def generate_openai_chat_completion( ): completion_form = OpenAIChatCompletionForm(**form_data) payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} - payload.pop("metadata", None) + if "metadata" in payload: + del payload["metadata"] model_id = completion_form.model model_info = Models.get_model_by_id(model_id) From fa4d1d42a53bd6df6e62467bc825c866cab23451 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 8 Aug 2024 12:41:41 +0100 Subject: [PATCH 09/10] fix: backend format test --- backend/apps/images/utils/comfyui.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index ec0f8c59e..ab6f4e407 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -1,6 +1,5 @@ import asyncio import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) -import uuid import json import urllib.request import urllib.parse @@ -398,7 +397,9 @@ async def comfyui_generate_image( return None try: - images = await asyncio.to_thread(get_images, ws, comfyui_prompt, client_id, base_url) + images = await asyncio.to_thread( + get_images, ws, comfyui_prompt, client_id, base_url + ) except Exception as e: log.exception(f"Error while receiving images: {e}") images = None From 204a4fbe7a3cf77f83f5b310fdd75b9853f5361d Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 8 Aug 2024 12:45:23 +0100 Subject: [PATCH 10/10] fix: backend format test --- backend/apps/images/utils/comfyui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index ab6f4e407..94875d959 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -377,9 +377,9 @@ async def comfyui_generate_image( comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype if payload.flux_fp8_clip: - comfyui_prompt["11"]["inputs"]["clip_name2"] = ( - "t5xxl_fp8_e4m3fn.safetensors" - ) + comfyui_prompt["11"]["inputs"][ + "clip_name2" + ] = "t5xxl_fp8_e4m3fn.safetensors" comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n comfyui_prompt["5"]["inputs"]["width"] = payload.width