diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 22a30474e..455dc89a5 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -53,7 +53,7 @@ from config import ( UPLOAD_DIR, AppConfig, ) -from utils.misc import calculate_sha256 +from utils.misc import calculate_sha256, add_or_update_system_message log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) @@ -834,18 +834,9 @@ async def generate_chat_completion( ) if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) + payload["messages"] = add_or_update_system_message( + system, payload["messages"] + ) if url_idx == None: if ":" not in payload["model"]: diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index c60c52fad..302dd8d98 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -432,7 +432,12 @@ async def generate_chat_completion( idx = model["urlIdx"] if "pipeline" in model and model.get("pipeline"): - payload["user"] = {"name": user.name, "id": user.id} + payload["user"] = { + "name": user.name, + "id": user.id, + "email": user.email, + "role": user.role, + } # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # This is a workaround until OpenAI fixes the issue with this model diff --git a/backend/apps/webui/internal/migrations/015_add_functions.py b/backend/apps/webui/internal/migrations/015_add_functions.py new file mode 100644 index 000000000..8316a9333 --- /dev/null +++ b/backend/apps/webui/internal/migrations/015_add_functions.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Function(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + + name = pw.TextField() + type = pw.TextField() + + content = pw.TextField() + meta = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "function" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("function") diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index bdc6ec4f4..ce58047ed 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -13,7 +13,11 @@ from apps.webui.routers import ( memories, utils, files, + functions, ) +from apps.webui.models.functions import Functions +from apps.webui.utils import load_function_module_by_id + from config import ( WEBUI_BUILD_HASH, SHOW_ADMIN_DETAILS, @@ -60,7 +64,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.MODELS = {} app.state.TOOLS = {} - +app.state.FUNCTIONS = {} app.add_middleware( CORSMiddleware, @@ -70,19 +74,22 @@ app.add_middleware( allow_headers=["*"], ) + +app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) -app.include_router(tools.router, prefix="/tools", tags=["tools"]) app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) -app.include_router(memories.router, prefix="/memories", tags=["memories"]) -app.include_router(configs.router, prefix="/configs", tags=["configs"]) -app.include_router(utils.router, prefix="/utils", tags=["utils"]) +app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(files.router, prefix="/files", tags=["files"]) +app.include_router(tools.router, prefix="/tools", tags=["tools"]) +app.include_router(functions.router, prefix="/functions", tags=["functions"]) + +app.include_router(utils.router, prefix="/utils", tags=["utils"]) @app.get("/") @@ -93,3 +100,58 @@ async def get_status(): "default_models": app.state.config.DEFAULT_MODELS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, } + + +async def get_pipe_models(): + pipes = Functions.get_functions_by_type("pipe") + pipe_models = [] + + for pipe in pipes: + # Check if function is already loaded + if pipe.id not in app.state.FUNCTIONS: + function_module, function_type = load_function_module_by_id(pipe.id) + app.state.FUNCTIONS[pipe.id] = function_module + else: + function_module = app.state.FUNCTIONS[pipe.id] + + # Check if function is a manifold + if hasattr(function_module, "type"): + if function_module.type == "manifold": + manifold_pipes = [] + + # Check if pipes is a function or a list + if callable(function_module.pipes): + manifold_pipes = function_module.pipes() + else: + manifold_pipes = function_module.pipes + + for p in manifold_pipes: + manifold_pipe_id = f'{pipe.id}.{p["id"]}' + manifold_pipe_name = p["name"] + + if hasattr(function_module, "name"): + manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}" + + pipe_models.append( + { + "id": manifold_pipe_id, + "name": manifold_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": {"type": pipe.type}, + } + ) + else: + pipe_models.append( + { + "id": pipe.id, + "name": pipe.name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": {"type": "pipe"}, + } + ) + + return pipe_models diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index cd877434d..f5fab34db 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -55,6 +55,7 @@ class FunctionModel(BaseModel): class FunctionResponse(BaseModel): id: str user_id: str + type: str name: str meta: FunctionMeta updated_at: int # timestamp in epoch @@ -64,23 +65,23 @@ class FunctionResponse(BaseModel): class FunctionForm(BaseModel): id: str name: str - type: str content: str meta: FunctionMeta -class ToolsTable: +class FunctionsTable: def __init__(self, db): self.db = db self.db.create_tables([Function]) def insert_new_function( - self, user_id: str, form_data: FunctionForm + self, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: function = FunctionModel( **{ **form_data.model_dump(), "user_id": user_id, + "type": type, "updated_at": int(time.time()), "created_at": int(time.time()), } @@ -137,4 +138,4 @@ class ToolsTable: return False -Tools = ToolsTable(DB) +Functions = FunctionsTable(DB) diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py new file mode 100644 index 000000000..ea5fde336 --- /dev/null +++ b/backend/apps/webui/routers/functions.py @@ -0,0 +1,180 @@ +from fastapi import Depends, FastAPI, HTTPException, status, Request +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + +from apps.webui.models.functions import ( + Functions, + FunctionForm, + FunctionModel, + FunctionResponse, +) +from apps.webui.utils import load_function_module_by_id +from utils.utils import get_verified_user, get_admin_user +from constants import ERROR_MESSAGES + +from importlib import util +import os +from pathlib import Path + +from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR + + +router = APIRouter() + +############################ +# GetFunctions +############################ + + +@router.get("/", response_model=List[FunctionResponse]) +async def get_functions(user=Depends(get_verified_user)): + return Functions.get_functions() + + +############################ +# ExportFunctions +############################ + + +@router.get("/export", response_model=List[FunctionModel]) +async def get_functions(user=Depends(get_admin_user)): + return Functions.get_functions() + + +############################ +# CreateNewFunction +############################ + + +@router.post("/create", response_model=Optional[FunctionResponse]) +async def create_new_function( + request: Request, form_data: FunctionForm, user=Depends(get_admin_user) +): + if not form_data.id.isidentifier(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only alphanumeric characters and underscores are allowed in the id", + ) + + form_data.id = form_data.id.lower() + + function = Functions.get_function_by_id(form_data.id) + if function == None: + function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py") + try: + with open(function_path, "w") as function_file: + function_file.write(form_data.content) + + function_module, function_type = load_function_module_by_id(form_data.id) + + FUNCTIONS = request.app.state.FUNCTIONS + FUNCTIONS[form_data.id] = function_module + + function = Functions.insert_new_function(user.id, function_type, form_data) + + function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id + function_cache_dir.mkdir(parents=True, exist_ok=True) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating function"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ID_TAKEN, + ) + + +############################ +# GetFunctionById +############################ + + +@router.get("/id/{id}", response_model=Optional[FunctionModel]) +async def get_function_by_id(id: str, user=Depends(get_admin_user)): + function = Functions.get_function_by_id(id) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateFunctionById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[FunctionModel]) +async def update_toolkit_by_id( + request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) +): + function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") + + try: + with open(function_path, "w") as function_file: + function_file.write(form_data.content) + + function_module, function_type = load_function_module_by_id(id) + + FUNCTIONS = request.app.state.FUNCTIONS + FUNCTIONS[id] = function_module + + updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} + print(updated) + + function = Functions.update_function_by_id(id, updated) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteFunctionById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_function_by_id( + request: Request, id: str, user=Depends(get_admin_user) +): + result = Functions.delete_function_by_id(id) + + if result: + FUNCTIONS = request.app.state.FUNCTIONS + if id in FUNCTIONS: + del FUNCTIONS[id] + + # delete the function file + function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") + os.remove(function_path) + + return result diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py index 19a8615bc..3e075a8a8 100644 --- a/backend/apps/webui/utils.py +++ b/backend/apps/webui/utils.py @@ -1,7 +1,7 @@ from importlib import util import os -from config import TOOLS_DIR +from config import TOOLS_DIR, FUNCTIONS_DIR def load_toolkit_module_by_id(toolkit_id): @@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id): # Move the file to the error folder os.rename(toolkit_path, f"{toolkit_path}.error") raise e + + +def load_function_module_by_id(function_id): + function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py") + + spec = util.spec_from_file_location(function_id, function_path) + module = util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + print(f"Loaded module: {module.__name__}") + if hasattr(module, "Pipe"): + return module.Pipe(), "pipe" + elif hasattr(module, "Filter"): + return module.Filter(), "filter" + else: + raise Exception("No Function class found") + except Exception as e: + print(f"Error loading module: {function_id}") + # Move the file to the error folder + os.rename(function_path, f"{function_path}.error") + raise e diff --git a/backend/config.py b/backend/config.py index 01ce060a3..842cea1ba 100644 --- a/backend/config.py +++ b/backend/config.py @@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) +#################################### +# Functions DIR +#################################### + +FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions") +Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # LITELLM_CONFIG #################################### diff --git a/backend/main.py b/backend/main.py index 0a0587159..47078b681 100644 --- a/backend/main.py +++ b/backend/main.py @@ -15,6 +15,7 @@ import uuid import inspect import asyncio +from fastapi.concurrency import run_in_threadpool from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse @@ -42,15 +43,17 @@ from apps.openai.main import ( from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app -from apps.webui.main import app as webui_app +from apps.webui.main import app as webui_app, get_pipe_models from pydantic import BaseModel -from typing import List, Optional +from typing import List, Optional, Iterator, Generator, Union from apps.webui.models.models import Models, ModelModel from apps.webui.models.tools import Tools -from apps.webui.utils import load_toolkit_module_by_id +from apps.webui.models.functions import Functions + +from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id from utils.utils import ( @@ -64,7 +67,11 @@ from utils.task import ( search_query_generation_template, tools_function_calling_generation_template, ) -from utils.misc import get_last_user_message, add_or_update_system_message +from utils.misc import ( + get_last_user_message, + add_or_update_system_message, + stream_message_template, +) from apps.rag.utils import get_rag_context, rag_template @@ -170,6 +177,13 @@ app.state.MODELS = {} origins = ["*"] +################################## +# +# ChatCompletion Middleware +# +################################## + + async def get_function_call_response( messages, files, tool_id, template, task_model_id, user ): @@ -309,41 +323,72 @@ async def get_function_call_response( class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - return_citations = False + data_items = [] - if request.method == "POST" and ( - "/ollama/api/chat" in request.url.path - or "/chat/completions" in request.url.path + if request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] ): log.debug(f"request.url.path: {request.url.path}") # Read the original request body body = await request.body() - # Decode body to string body_str = body.decode("utf-8") - # Parse string to JSON data = json.loads(body_str) if body_str else {} user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), ) + # Flag to skip RAG completions if file_handler is present in tools/functions + skip_files = False - # Remove the citations from the body - return_citations = data.get("citations", False) - if "citations" in data: - del data["citations"] - - # Set the task model - task_model_id = data["model"] - if task_model_id not in app.state.MODELS: + model_id = data["model"] + if model_id not in app.state.MODELS: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) + model = app.state.MODELS[model_id] - # Check if the user has a custom task model - # If the user has a custom task model, use that model + # Check if the model has any filters + if "info" in model and "meta" in model["info"]: + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if getattr(function_module, "file_handler"): + skip_files = True + + try: + if hasattr(function_module, "inlet"): + data = function_module.inlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + # Set the task model + task_model_id = data["model"] + # Check if the user has a custom task model and use that model if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if ( app.state.config.TASK_MODEL @@ -361,8 +406,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): context = "" # If tool_ids field is present, call the functions - - skip_files = False if "tool_ids" in data: print(data["tool_ids"]) for tool_id in data["tool_ids"]: @@ -408,18 +451,22 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): context += ("\n" if context != "" else "") + rag_context log.debug(f"rag_context: {rag_context}, citations: {citations}") - else: - return_citations = False + + if citations and data.get("citations"): + data_items.append({"citations": citations}) del data["files"] + if data.get("citations"): + del data["citations"] + if context != "": system_prompt = rag_template( rag_app.state.config.RAG_TEMPLATE, context, prompt ) print(system_prompt) data["messages"] = add_or_update_system_message( - f"\n{system_prompt}", data["messages"] + system_prompt, data["messages"] ) modified_body_bytes = json.dumps(data).encode("utf-8") @@ -435,40 +482,51 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ], ] - response = await call_next(request) - - if return_citations: - # Inject the citations into the response + response = await call_next(request) if isinstance(response, StreamingResponse): # If it's a streaming response, inject it as SSE event or NDJSON line content_type = response.headers.get("Content-Type") if "text/event-stream" in content_type: return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, citations), + self.openai_stream_wrapper(response.body_iterator, data_items), ) if "application/x-ndjson" in content_type: return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, citations), + self.ollama_stream_wrapper(response.body_iterator, data_items), ) + else: + return response + # If it's not a chat completion request, just pass it through + response = await call_next(request) return response async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} - async def openai_stream_wrapper(self, original_generator, citations): - yield f"data: {json.dumps({'citations': citations})}\n\n" + async def openai_stream_wrapper(self, original_generator, data_items): + for item in data_items: + yield f"data: {json.dumps(item)}\n\n" + async for data in original_generator: yield data - async def ollama_stream_wrapper(self, original_generator, citations): - yield f"{json.dumps({'citations': citations})}\n" + async def ollama_stream_wrapper(self, original_generator, data_items): + for item in data_items: + yield f"{json.dumps(item)}\n" + async for data in original_generator: yield data app.add_middleware(ChatCompletionMiddleware) +################################## +# +# Pipeline Middleware +# +################################## + def filter_pipeline(payload, user): user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} @@ -628,7 +686,6 @@ async def update_embedding_function(request: Request, call_next): app.mount("/ws", socket_app) - app.mount("/ollama", ollama_app) app.mount("/openai", openai_app) @@ -642,17 +699,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION async def get_all_models(): + pipe_models = [] openai_models = [] ollama_models = [] + pipe_models = await get_pipe_models() + if app.state.config.ENABLE_OPENAI_API: openai_models = await get_openai_models() - openai_models = openai_models["data"] if app.state.config.ENABLE_OLLAMA_API: ollama_models = await get_ollama_models() - ollama_models = [ { "id": model["model"], @@ -665,9 +723,9 @@ async def get_all_models(): for model in ollama_models["models"] ] - models = openai_models + ollama_models - custom_models = Models.get_all_models() + models = pipe_models + openai_models + ollama_models + custom_models = Models.get_all_models() for custom_model in custom_models: if custom_model.base_model_id == None: for model in models: @@ -730,6 +788,234 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.post("/api/chat/completions") +async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = app.state.MODELS[model_id] + print(model) + + pipe = model.get("pipe") + if pipe: + form_data["user"] = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + def job(): + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, sub_pipe_id = pipe_id.split(".", 1) + print(pipe_id) + + pipe = webui_app.state.FUNCTIONS[pipe_id].pipe + if form_data["stream"]: + + def stream_content(): + res = pipe(body=form_data) + + if isinstance(res, str): + message = stream_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" + + if isinstance(res, Iterator): + for line in res: + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + try: + line = line.decode("utf-8") + except: + pass + + if line.startswith("data:"): + yield f"{line}\n\n" + else: + line = stream_message_template(form_data["model"], line) + yield f"data: {json.dumps(line)}\n\n" + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": form_data["model"], + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + yield f"data: {json.dumps(finish_message)}\n\n" + yield f"data: [DONE]" + + return StreamingResponse( + stream_content(), media_type="text/event-stream" + ) + else: + res = pipe(body=form_data) + + if isinstance(res, dict): + return res + elif isinstance(res, BaseModel): + return res.model_dump() + else: + message = "" + if isinstance(res, str): + message = res + if isinstance(res, Generator): + for stream in res: + message = f"{message}{stream}" + + return { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "object": "chat.completion", + "created": int(time.time()), + "model": form_data["model"], + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + return await run_in_threadpool(job) + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion(form_data, user=user) + else: + return await generate_openai_chat_completion(form_data, user=user) + + +@app.post("/api/chat/completed") +async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + data = form_data + model_id = data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + model = app.state.MODELS[model_id] + + filters = [ + model + for model in app.state.MODELS.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers=headers, + json={ + "user": {"id": user.id, "name": user.name, "role": user.role}, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except: + pass + + else: + pass + + # Check if the model has any filters + if "info" in model and "meta" in model["info"]: + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module + + try: + if hasattr(function_module, "outlet"): + data = function_module.outlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + +################################## +# +# Task Endpoints +# +################################## + + +# TODO: Refactor task API endpoints below into a separate file + + @app.get("/api/task/config") async def get_task_config(user=Depends(get_verified_user)): return { @@ -1015,92 +1301,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ) -@app.post("/api/chat/completions") -async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = app.state.MODELS[model_id] - print(model) - - if model["owned_by"] == "ollama": - return await generate_ollama_chat_completion(form_data, user=user) - else: - return await generate_openai_chat_completion(form_data, user=user) +################################## +# +# Pipelines Endpoints +# +################################## -@app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - data = form_data - model_id = data["model"] - - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - - print(model_id) - - if model_id in app.state.MODELS: - model = app.state.MODELS[model_id] - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": {"id": user.id, "name": user.name, "role": user.role}, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except: - pass - - else: - pass - - return data +# TODO: Refactor pipelines API endpoints below into a separate file @app.get("/api/pipelines/list") @@ -1423,6 +1631,13 @@ async def update_pipeline_valves( ) +################################## +# +# Config Endpoints +# +################################## + + @app.get("/api/config") async def get_app_config(): # Checking and Handling the Absence of 'ui' in CONFIG_DATA @@ -1486,6 +1701,9 @@ async def update_model_filter_config( } +# TODO: webhook endpoint should be under config endpoints + + @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 41fbdcc75..b4e499df8 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -4,6 +4,8 @@ import json import re from datetime import timedelta from typing import Optional, List, Tuple +import uuid +import time def get_last_user_message(messages: List[dict]) -> str: @@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]): return messages +def stream_message_template(model: str, message: str): + return { + "id": f"{model}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": message}, + "logprobs": None, + "finish_reason": None, + } + ], + } + + def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters diff --git a/src/lib/apis/functions/index.ts b/src/lib/apis/functions/index.ts new file mode 100644 index 000000000..e035ef1c1 --- /dev/null +++ b/src/lib/apis/functions/index.ts @@ -0,0 +1,193 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewFunction = async (token: string, func: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...func + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getFunctions = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const exportFunctions = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/export`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getFunctionById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateFunctionById = async (token: string, id: string, func: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...func + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteFunctionById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index b33b26fa3..d83eb3cb2 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -278,7 +278,9 @@ })), chat_id: $chatId }).catch((error) => { - console.error(error); + toast.error(error); + messages.at(-1).error = { content: error }; + return null; }); @@ -323,6 +325,13 @@ } else if (messages.length != 0 && messages.at(-1).done != true) { // Response not done console.log('wait'); + } else if (messages.length != 0 && messages.at(-1).error) { + // Error in response + toast.error( + $i18n.t( + `Oops! There was an error in the previous response. Please try again or contact admin.` + ) + ); } else if ( files.length > 0 && files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0 @@ -630,7 +639,7 @@ keep_alive: $settings.keepAlive ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - citations: files.length > 0, + citations: files.length > 0 ? true : undefined, chat_id: $chatId }); @@ -928,10 +937,11 @@ max_tokens: $settings?.params?.max_tokens ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - citations: files.length > 0, + citations: files.length > 0 ? true : undefined, + chat_id: $chatId }, - `${OPENAI_API_BASE_URL}` + `${WEBUI_BASE_URL}/api` ); // Wait until history/message have been updated diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index f00e9ad2f..35e308220 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -3,25 +3,27 @@ import fileSaver from 'file-saver'; const { saveAs } = fileSaver; + import { WEBUI_NAME, functions, models } from '$lib/stores'; import { onMount, getContext } from 'svelte'; - import { WEBUI_NAME, prompts, tools } from '$lib/stores'; import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { goto } from '$app/navigation'; import { - createNewTool, - deleteToolById, - exportTools, - getToolById, - getTools - } from '$lib/apis/tools'; + createNewFunction, + deleteFunctionById, + exportFunctions, + getFunctionById, + getFunctions + } from '$lib/apis/functions'; + import ArrowDownTray from '../icons/ArrowDownTray.svelte'; import Tooltip from '../common/Tooltip.svelte'; import ConfirmDialog from '../common/ConfirmDialog.svelte'; + import { getModels } from '$lib/apis'; const i18n = getContext('i18n'); - let toolsImportInputElement: HTMLInputElement; + let functionsImportInputElement: HTMLInputElement; let importFiles; let showConfirm = false; @@ -64,7 +66,7 @@
- {#each $tools.filter((t) => query === '' || t.name + {#each $functions.filter((f) => query === '' || f.name .toLowerCase() - .includes(query.toLowerCase()) || t.id.toLowerCase().includes(query.toLowerCase())) as tool} + .includes(query.toLowerCase()) || f.id.toLowerCase().includes(query.toLowerCase())) as func}
diff --git a/src/lib/components/workspace/Functions/FunctionEditor.svelte b/src/lib/components/workspace/Functions/FunctionEditor.svelte index e69de29bb..6e30013cc 100644 --- a/src/lib/components/workspace/Functions/FunctionEditor.svelte +++ b/src/lib/components/workspace/Functions/FunctionEditor.svelte @@ -0,0 +1,235 @@ + + +
+
+
{ + if (edit) { + submitHandler(); + } else { + showConfirm = true; + } + }} + > +
+ +
+ + +
+
+
+ + { + submitHandler(); + }} +> +
+
+
Please carefully review the following warnings:
+ +
    +
  • Functions allow arbitrary code execution.
  • +
  • Do not install functions from sources you do not fully trust.
  • +
+
+ +
+ I acknowledge that I have read and I understand the implications of my action. I am aware of + the risks associated with executing arbitrary code and I have verified the trustworthiness of + the source. +
+
+
diff --git a/src/lib/components/workspace/Models/FiltersSelector.svelte b/src/lib/components/workspace/Models/FiltersSelector.svelte new file mode 100644 index 000000000..92f64c2cf --- /dev/null +++ b/src/lib/components/workspace/Models/FiltersSelector.svelte @@ -0,0 +1,60 @@ + + +
+
+
{$i18n.t('Filters')}
+
+ +
+ {$i18n.t('To select filters here, add them to the "Functions" workspace first.')} +
+ + +
+ {#if filters.length > 0} +
+ {#each Object.keys(_filters) as filter, filterIdx} +
+
+ { + _filters[filter].selected = e.detail === 'checked'; + selectedFilterIds = Object.keys(_filters).filter((t) => _filters[t].selected); + }} + /> +
+ +
+ + {_filters[filter].name} + +
+
+ {/each} +
+ {/if} +
+
diff --git a/src/lib/stores/index.ts b/src/lib/stores/index.ts index b0f0061f7..894565ef3 100644 --- a/src/lib/stores/index.ts +++ b/src/lib/stores/index.ts @@ -27,7 +27,9 @@ export const tags = writable([]); export const models: Writable = writable([]); export const prompts: Writable = writable([]); export const documents: Writable = writable([]); + export const tools = writable([]); +export const functions = writable([]); export const banners: Writable = writable([]); diff --git a/src/routes/(app)/workspace/+layout.svelte b/src/routes/(app)/workspace/+layout.svelte index d26f5812c..46e0f63c4 100644 --- a/src/routes/(app)/workspace/+layout.svelte +++ b/src/routes/(app)/workspace/+layout.svelte @@ -1,11 +1,16 @@ diff --git a/src/routes/(app)/workspace/functions/create/+page.svelte b/src/routes/(app)/workspace/functions/create/+page.svelte index c785c74cd..0f73cf94e 100644 --- a/src/routes/(app)/workspace/functions/create/+page.svelte +++ b/src/routes/(app)/workspace/functions/create/+page.svelte @@ -1,18 +1,20 @@ {#if mounted} - { saveHandler(e.detail); diff --git a/src/routes/(app)/workspace/functions/edit/+page.svelte b/src/routes/(app)/workspace/functions/edit/+page.svelte index b8ca32507..21fc5acb6 100644 --- a/src/routes/(app)/workspace/functions/edit/+page.svelte +++ b/src/routes/(app)/workspace/functions/edit/+page.svelte @@ -1,18 +1,21 @@ -{#if tool} - { saveHandler(e.detail); }} diff --git a/src/routes/(app)/workspace/models/edit/+page.svelte b/src/routes/(app)/workspace/models/edit/+page.svelte index ddbe9f682..ef2ed0558 100644 --- a/src/routes/(app)/workspace/models/edit/+page.svelte +++ b/src/routes/(app)/workspace/models/edit/+page.svelte @@ -5,7 +5,7 @@ import { onMount, getContext } from 'svelte'; import { page } from '$app/stores'; - import { settings, user, config, models, tools } from '$lib/stores'; + import { settings, user, config, models, tools, functions } from '$lib/stores'; import { splitStream } from '$lib/utils'; import { getModelInfos, updateModelById } from '$lib/apis/models'; @@ -16,6 +16,7 @@ import Tags from '$lib/components/common/Tags.svelte'; import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte'; import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte'; + import FiltersSelector from '$lib/components/workspace/Models/FiltersSelector.svelte'; const i18n = getContext('i18n'); @@ -62,6 +63,7 @@ let knowledge = []; let toolIds = []; + let filterIds = []; const updateHandler = async () => { loading = true; @@ -86,6 +88,14 @@ } } + if (filterIds.length > 0) { + info.meta.filterIds = filterIds; + } else { + if (info.meta.filterIds) { + delete info.meta.filterIds; + } + } + info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null; Object.keys(info.params).forEach((key) => { if (info.params[key] === '' || info.params[key] === null) { @@ -147,6 +157,10 @@ toolIds = [...model?.info?.meta?.toolIds]; } + if (model?.info?.meta?.filterIds) { + filterIds = [...model?.info?.meta?.filterIds]; + } + if (model?.owned_by === 'openai') { capabilities.usage = false; } @@ -534,6 +548,13 @@
+
+ func.type === 'filter')} + /> +
+
{$i18n.t('Capabilities')}