diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 5a85d0879..438666133 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -32,7 +32,7 @@ assignees: '' **Confirmation:** - [ ] I have read and followed all the instructions provided in the README.md. -- [ ] I have reviewed the troubleshooting.md document. +- [ ] I am on the latest version of both Open WebUI and Ollama. - [ ] I have included the browser console logs. - [ ] I have included the Docker container logs. diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 7965ff325..d4d1e91a6 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -3,16 +3,23 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool +from pydantic import BaseModel, ConfigDict + +import random import requests import json import uuid -from pydantic import BaseModel +import aiohttp +import asyncio from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import decode_token, get_current_user, get_admin_user from config import OLLAMA_BASE_URL, WEBUI_AUTH +from typing import Optional, List, Union + + app = FastAPI() app.add_middleware( CORSMiddleware, @@ -23,26 +30,44 @@ app.add_middleware( ) app.state.OLLAMA_BASE_URL = OLLAMA_BASE_URL - -# TARGET_SERVER_URL = OLLAMA_API_BASE_URL +app.state.OLLAMA_BASE_URLS = [OLLAMA_BASE_URL] +app.state.MODELS = {} REQUEST_POOL = [] -@app.get("/url") -async def get_ollama_api_url(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL} +# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. +# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, +# least connections, or least response time for better resource utilization and performance optimization. + + +@app.middleware("http") +async def check_url(request: Request, call_next): + if len(app.state.MODELS) == 0: + await get_all_models() + else: + pass + + response = await call_next(request) + return response + + +@app.get("/urls") +async def get_ollama_api_urls(user=Depends(get_admin_user)): + return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} class UrlUpdateForm(BaseModel): - url: str + urls: List[str] -@app.post("/url/update") +@app.post("/urls/update") async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.OLLAMA_BASE_URL = form_data.url - return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL} + app.state.OLLAMA_BASE_URLS = form_data.urls + + print(app.state.OLLAMA_BASE_URLS) + return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} @app.get("/cancel/{request_id}") @@ -55,9 +80,806 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +async def fetch_url(url): + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + return await response.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + return None + + +def merge_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + for model in model_list: + digest = model["digest"] + if digest not in merged_models: + model["urls"] = [idx] + merged_models[digest] = model + else: + merged_models[digest]["urls"].append(idx) + + return list(merged_models.values()) + + +# user=Depends(get_current_user) + + +async def get_all_models(): + print("get_all_models") + tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] + responses = await asyncio.gather(*tasks) + responses = list(filter(lambda x: x is not None, responses)) + + models = { + "models": merge_models_lists( + map(lambda response: response["models"], responses) + ) + } + app.state.MODELS = {model["model"]: model for model in models["models"]} + + return models + + +@app.get("/api/tags") +@app.get("/api/tags/{url_idx}") +async def get_ollama_tags( + url_idx: Optional[int] = None, user=Depends(get_current_user) +): + + if url_idx == None: + return await get_all_models() + else: + url = app.state.OLLAMA_BASE_URLS[url_idx] + try: + r = requests.request(method="GET", url=f"{url}/api/tags") + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +@app.get("/api/version") +@app.get("/api/version/{url_idx}") +async def get_ollama_versions(url_idx: Optional[int] = None): + + if url_idx == None: + + # returns lowest version + tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] + responses = await asyncio.gather(*tasks) + responses = list(filter(lambda x: x is not None, responses)) + + lowest_version = min( + responses, key=lambda x: tuple(map(int, x["version"].split("."))) + ) + + return {"version": lowest_version["version"]} + else: + url = app.state.OLLAMA_BASE_URLS[url_idx] + try: + r = requests.request(method="GET", url=f"{url}/api/version") + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class ModelNameForm(BaseModel): + name: str + + +@app.post("/api/pull") +@app.post("/api/pull/{url_idx}") +async def pull_model( + form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) +): + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal url + nonlocal r + try: + + def stream_content(): + for chunk in r.iter_content(chunk_size=8192): + yield chunk + + r = requests.request( + method="POST", + url=f"{url}/api/pull", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class PushModelForm(BaseModel): + name: str + insecure: Optional[bool] = None + stream: Optional[bool] = None + + +@app.delete("/api/push") +@app.delete("/api/push/{url_idx}") +async def push_model( + form_data: PushModelForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx == None: + if form_data.name in app.state.MODELS: + url_idx = app.state.MODELS[form_data.name]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal url + nonlocal r + try: + + def stream_content(): + for chunk in r.iter_content(chunk_size=8192): + yield chunk + + r = requests.request( + method="POST", + url=f"{url}/api/push", + data=form_data.model_dump_json(exclude_none=True), + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class CreateModelForm(BaseModel): + name: str + modelfile: Optional[str] = None + stream: Optional[bool] = None + path: Optional[str] = None + + +@app.post("/api/create") +@app.post("/api/create/{url_idx}") +async def create_model( + form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) +): + print(form_data) + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal url + nonlocal r + try: + + def stream_content(): + for chunk in r.iter_content(chunk_size=8192): + yield chunk + + r = requests.request( + method="POST", + url=f"{url}/api/create", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + print(r) + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class CopyModelForm(BaseModel): + source: str + destination: str + + +@app.post("/api/copy") +@app.post("/api/copy/{url_idx}") +async def copy_model( + form_data: CopyModelForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx == None: + if form_data.source in app.state.MODELS: + url_idx = app.state.MODELS[form_data.source]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/copy", + data=form_data.model_dump_json(exclude_none=True), + ) + r.raise_for_status() + + print(r.text) + + return True + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +@app.delete("/api/delete") +@app.delete("/api/delete/{url_idx}") +async def delete_model( + form_data: ModelNameForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx == None: + if form_data.name in app.state.MODELS: + url_idx = app.state.MODELS[form_data.name]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + try: + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True), + ) + r.raise_for_status() + + print(r.text) + + return True + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +@app.post("/api/show") +async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)): + if form_data.name not in app.state.MODELS: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/show", + data=form_data.model_dump_json(exclude_none=True), + ) + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class GenerateEmbeddingsForm(BaseModel): + model: str + prompt: str + options: Optional[dict] = None + keep_alive: Optional[Union[int, str]] = None + + +@app.post("/api/embeddings") +@app.post("/api/embeddings/{url_idx}") +async def generate_embeddings( + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True), + ) + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class GenerateCompletionForm(BaseModel): + model: str + prompt: str + images: Optional[List[str]] = None + format: Optional[str] = None + options: Optional[dict] = None + system: Optional[str] = None + template: Optional[str] = None + context: Optional[str] = None + stream: Optional[bool] = True + raw: Optional[bool] = None + keep_alive: Optional[Union[int, str]] = None + + +@app.post("/api/generate") +@app.post("/api/generate/{url_idx}") +async def generate_completion( + form_data: GenerateCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail="error_detail", + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps({"id": request_id, "done": False}) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/api/generate", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class ChatMessage(BaseModel): + role: str + content: str + images: Optional[List[str]] = None + + +class GenerateChatCompletionForm(BaseModel): + model: str + messages: List[ChatMessage] + format: Optional[str] = None + options: Optional[dict] = None + template: Optional[str] = None + stream: Optional[bool] = True + keep_alive: Optional[Union[int, str]] = None + + +@app.post("/api/chat") +@app.post("/api/chat/{url_idx}") +async def generate_chat_completion( + form_data: GenerateChatCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + print(form_data.model_dump_json(exclude_none=True)) + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps({"id": request_id, "done": False}) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/api/chat", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +# TODO: we should update this part once Ollama supports other types +class OpenAIChatMessage(BaseModel): + role: str + content: str + + model_config = ConfigDict(extra="allow") + + +class OpenAIChatCompletionForm(BaseModel): + model: str + messages: List[OpenAIChatMessage] + + model_config = ConfigDict(extra="allow") + + +@app.post("/v1/chat/completions") +@app.post("/v1/chat/completions/{url_idx}") +async def generate_openai_chat_completion( + form_data: OpenAIChatCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps( + {"request_id": request_id, "done": False} + ) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/v1/chat/completions", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_current_user)): - target_url = f"{app.state.OLLAMA_BASE_URL}/{path}" +async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): + url = app.state.OLLAMA_BASE_URLS[0] + target_url = f"{url}/{path}" body = await request.body() headers = dict(request.headers) diff --git a/backend/apps/ollama/old_main.py b/backend/apps/ollama/old_main.py deleted file mode 100644 index 5e5b88111..000000000 --- a/backend/apps/ollama/old_main.py +++ /dev/null @@ -1,127 +0,0 @@ -from fastapi import FastAPI, Request, Response, HTTPException, Depends -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse - -import requests -import json -from pydantic import BaseModel - -from apps.web.models.users import Users -from constants import ERROR_MESSAGES -from utils.utils import decode_token, get_current_user -from config import OLLAMA_API_BASE_URL, WEBUI_AUTH - -import aiohttp - -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL - -# TARGET_SERVER_URL = OLLAMA_API_BASE_URL - - -@app.get("/url") -async def get_ollama_api_url(user=Depends(get_current_user)): - if user and user.role == "admin": - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - -class UrlUpdateForm(BaseModel): - url: str - - -@app.post("/url/update") -async def update_ollama_api_url( - form_data: UrlUpdateForm, user=Depends(get_current_user) -): - if user and user.role == "admin": - app.state.OLLAMA_API_BASE_URL = form_data.url - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - -# async def fetch_sse(method, target_url, body, headers): -# async with aiohttp.ClientSession() as session: -# try: -# async with session.request( -# method, target_url, data=body, headers=headers -# ) as response: -# print(response.status) -# async for line in response.content: -# yield line -# except Exception as e: -# print(e) -# error_detail = "Open WebUI: Server Connection Error" -# yield json.dumps({"error": error_detail, "message": str(e)}).encode() - - -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_current_user)): - target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" - print(target_url) - - body = await request.body() - headers = dict(request.headers) - - if user.role in ["user", "admin"]: - if path in ["pull", "delete", "push", "copy", "create"]: - if user.role != "admin": - raise HTTPException( - status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - headers.pop("Host", None) - headers.pop("Authorization", None) - headers.pop("Origin", None) - headers.pop("Referer", None) - - session = aiohttp.ClientSession() - response = None - try: - response = await session.request( - request.method, target_url, data=body, headers=headers - ) - - print(response) - if not response.ok: - data = await response.json() - print(data) - response.raise_for_status() - - async def generate(): - async for line in response.content: - print(line) - yield line - await session.close() - - return StreamingResponse(generate(), response.status) - - except Exception as e: - print(e) - error_detail = "Open WebUI: Server Connection Error" - - if response is not None: - try: - res = await response.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - await session.close() - raise HTTPException( - status_code=response.status if response else 500, - detail=error_detail, - ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 2a8b2a49e..99aa69594 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -108,7 +108,7 @@ class StoreWebForm(CollectionNameForm): url: str -def store_data_in_vector_db(data, collection_name) -> bool: +def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP ) @@ -118,6 +118,12 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: + if overwrite: + for collection in CHROMA_CLIENT.list_collections(): + if collection_name == collection.name: + print(f"deleting existing collection {collection_name}") + CHROMA_CLIENT.delete_collection(name=collection_name) + collection = CHROMA_CLIENT.create_collection( name=collection_name, embedding_function=app.state.sentence_transformer_ef, @@ -355,7 +361,7 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): if collection_name == "": collection_name = calculate_sha256_string(form_data.url)[:63] - store_data_in_vector_db(data, collection_name) + store_data_in_vector_db(data, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, diff --git a/backend/constants.py b/backend/constants.py index 006fa7bbe..b2bbe9aae 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -48,3 +48,5 @@ class ERROR_MESSAGES(str, Enum): lambda err="": f"Invalid format. Please use the correct format{err if err else ''}" ) RATE_LIMIT_EXCEEDED = "API rate limit exceeded" + + MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" diff --git a/backend/main.py b/backend/main.py index 94938b249..5f6b44410 100644 --- a/backend/main.py +++ b/backend/main.py @@ -104,7 +104,7 @@ async def auth_middleware(request: Request, call_next): app.mount("/api/v1", webui_app) app.mount("/litellm/api", litellm_app) -app.mount("/ollama/api", ollama_app) +app.mount("/ollama", ollama_app) app.mount("/openai/api", openai_app) app.mount("/images/api/v1", images_app) @@ -125,6 +125,14 @@ async def get_app_config(): } +@app.get("/api/version") +async def get_app_config(): + + return { + "version": VERSION, + } + + @app.get("/api/changelog") async def get_app_changelog(): return CHANGELOG diff --git a/backend/requirements.txt b/backend/requirements.txt index 0cacacd80..6d3d044dc 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -22,6 +22,7 @@ google-generativeai langchain langchain-community +fake_useragent chromadb sentence_transformers pypdf diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 0c96b2ab9..2047fedef 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,9 +1,9 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; -export const getOllamaAPIUrl = async (token: string = '') => { +export const getOllamaUrls = async (token: string = '') => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/url`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/urls`, { method: 'GET', headers: { Accept: 'application/json', @@ -29,13 +29,13 @@ export const getOllamaAPIUrl = async (token: string = '') => { throw error; } - return res.OLLAMA_BASE_URL; + return res.OLLAMA_BASE_URLS; }; -export const updateOllamaAPIUrl = async (token: string = '', url: string) => { +export const updateOllamaUrls = async (token: string = '', urls: string[]) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/url/update`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/urls/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -43,7 +43,7 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => { ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - url: url + urls: urls }) }) .then(async (res) => { @@ -64,7 +64,7 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => { throw error; } - return res.OLLAMA_BASE_URL; + return res.OLLAMA_BASE_URLS; }; export const getOllamaVersion = async (token: string = '') => { @@ -151,7 +151,8 @@ export const generateTitle = async ( const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -189,7 +190,8 @@ export const generatePrompt = async (token: string = '', model: string, conversa const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -223,7 +225,8 @@ export const generateTextCompletion = async (token: string = '', model: string, const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -251,7 +254,8 @@ export const generateChatCompletion = async (token: string = '', body: object) = signal: controller.signal, method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify(body) @@ -294,7 +298,8 @@ export const createModel = async (token: string, tagName: string, content: strin const res = await fetch(`${OLLAMA_API_BASE_URL}/api/create`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -313,19 +318,23 @@ export const createModel = async (token: string, tagName: string, content: strin return res; }; -export const deleteModel = async (token: string, tagName: string) => { +export const deleteModel = async (token: string, tagName: string, urlIdx: string | null = null) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/api/delete`, { - method: 'DELETE', - headers: { - 'Content-Type': 'text/event-stream', - Authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - name: tagName - }) - }) + const res = await fetch( + `${OLLAMA_API_BASE_URL}/api/delete${urlIdx !== null ? `/${urlIdx}` : ''}`, + { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + name: tagName + }) + } + ) .then(async (res) => { if (!res.ok) throw await res.json(); return res.json(); @@ -336,7 +345,12 @@ export const deleteModel = async (token: string, tagName: string) => { }) .catch((err) => { console.log(err); - error = err.error; + error = err; + + if ('detail' in err) { + error = err.detail; + } + return null; }); @@ -347,13 +361,14 @@ export const deleteModel = async (token: string, tagName: string) => { return res; }; -export const pullModel = async (token: string, tagName: string) => { +export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 7e98a706c..077bafd93 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -21,7 +21,7 @@ export let suggestionPrompts = []; export let autoScroll = true; - + let chatTextAreaElement:HTMLTextAreaElement let filesInputElement; let promptsElement; @@ -45,11 +45,9 @@ let speechRecognition; $: if (prompt) { - const chatInput = document.getElementById('chat-textarea'); - - if (chatInput) { - chatInput.style.height = ''; - chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px'; + if (chatTextAreaElement) { + chatTextAreaElement.style.height = ''; + chatTextAreaElement.style.height = Math.min(chatTextAreaElement.scrollHeight, 200) + 'px'; } } @@ -88,9 +86,7 @@ if (res) { prompt = res.text; await tick(); - - const inputElement = document.getElementById('chat-textarea'); - inputElement?.focus(); + chatTextAreaElement?.focus(); if (prompt !== '' && $settings?.speechAutoSend === true) { submitPrompt(prompt, user); @@ -193,8 +189,7 @@ prompt = `${prompt}${transcript}`; await tick(); - const inputElement = document.getElementById('chat-textarea'); - inputElement?.focus(); + chatTextAreaElement?.focus(); // Restart the inactivity timeout timeoutId = setTimeout(() => { @@ -296,8 +291,7 @@ }; onMount(() => { - const chatInput = document.getElementById('chat-textarea'); - window.setTimeout(() => chatInput?.focus(), 0); + window.setTimeout(() => chatTextAreaElement?.focus(), 0); const dropZone = document.querySelector('body'); @@ -671,6 +665,7 @@