Merge pull request #3321 from open-webui/functions

feat: functions
This commit is contained in:
Timothy Jaeryang Baek 2024-06-20 04:52:32 -07:00 committed by GitHub
commit 09a81eb225
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1365 additions and 248 deletions

View File

@ -53,7 +53,7 @@ from config import (
UPLOAD_DIR,
AppConfig,
)
from utils.misc import calculate_sha256
from utils.misc import calculate_sha256, add_or_update_system_message
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
@ -834,18 +834,9 @@ async def generate_chat_completion(
)
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
payload["messages"] = add_or_update_system_message(
system, payload["messages"]
)
if url_idx == None:
if ":" not in payload["model"]:

View File

@ -432,7 +432,12 @@ async def generate_chat_completion(
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
payload["user"] = {
"name": user.name,
"id": user.id,
"email": user.email,
"role": user.role,
}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model

View File

@ -0,0 +1,61 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Function(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
type = pw.TextField()
content = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "function"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("function")

View File

@ -13,7 +13,11 @@ from apps.webui.routers import (
memories,
utils,
files,
functions,
)
from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
@ -60,7 +64,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
app.add_middleware(
CORSMiddleware,
@ -70,19 +74,22 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
@app.get("/")
@ -93,3 +100,58 @@ async def get_status():
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
}
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe")
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type = load_function_module_by_id(pipe.id)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
# Check if pipes is a function or a list
if callable(function_module.pipes):
manifold_pipes = function_module.pipes()
else:
manifold_pipes = function_module.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(function_module, "name"):
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": pipe.type},
}
)
else:
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": "pipe"},
}
)
return pipe_models

View File

@ -55,6 +55,7 @@ class FunctionModel(BaseModel):
class FunctionResponse(BaseModel):
id: str
user_id: str
type: str
name: str
meta: FunctionMeta
updated_at: int # timestamp in epoch
@ -64,23 +65,23 @@ class FunctionResponse(BaseModel):
class FunctionForm(BaseModel):
id: str
name: str
type: str
content: str
meta: FunctionMeta
class ToolsTable:
class FunctionsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Function])
def insert_new_function(
self, user_id: str, form_data: FunctionForm
self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]:
function = FunctionModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"type": type,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
@ -137,4 +138,4 @@ class ToolsTable:
return False
Tools = ToolsTable(DB)
Functions = FunctionsTable(DB)

View File

@ -0,0 +1,180 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.functions import (
Functions,
FunctionForm,
FunctionModel,
FunctionResponse,
)
from apps.webui.utils import load_function_module_by_id
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
from importlib import util
import os
from pathlib import Path
from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
router = APIRouter()
############################
# GetFunctions
############################
@router.get("/", response_model=List[FunctionResponse])
async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions()
############################
# ExportFunctions
############################
@router.get("/export", response_model=List[FunctionModel])
async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions()
############################
# CreateNewFunction
############################
@router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id)
if function == None:
function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
try:
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(form_data.id)
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetFunctionById
############################
@router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateFunctionById
############################
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
):
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
try:
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(id)
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[id] = function_module
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated)
function = Functions.update_function_by_id(id, updated)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteFunctionById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
result = Functions.delete_function_by_id(id)
if result:
FUNCTIONS = request.app.state.FUNCTIONS
if id in FUNCTIONS:
del FUNCTIONS[id]
# delete the function file
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
os.remove(function_path)
return result

View File

@ -1,7 +1,7 @@
from importlib import util
import os
from config import TOOLS_DIR
from config import TOOLS_DIR, FUNCTIONS_DIR
def load_toolkit_module_by_id(toolkit_id):
@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id):
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e
def load_function_module_by_id(function_id):
function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
spec = util.spec_from_file_location(function_id, function_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipe"):
return module.Pipe(), "pipe"
elif hasattr(module, "Filter"):
return module.Filter(), "filter"
else:
raise Exception("No Function class found")
except Exception as e:
print(f"Error loading module: {function_id}")
# Move the file to the error folder
os.rename(function_path, f"{function_path}.error")
raise e

View File

@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Functions DIR
####################################
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# LITELLM_CONFIG
####################################

View File

@ -15,6 +15,7 @@ import uuid
import inspect
import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
@ -42,15 +43,17 @@ from apps.openai.main import (
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app
from apps.webui.main import app as webui_app, get_pipe_models
from pydantic import BaseModel
from typing import List, Optional
from typing import List, Optional, Iterator, Generator, Union
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
from apps.webui.utils import load_toolkit_module_by_id
from apps.webui.models.functions import Functions
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from utils.utils import (
@ -64,7 +67,11 @@ from utils.task import (
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
)
from apps.rag.utils import get_rag_context, rag_template
@ -170,6 +177,13 @@ app.state.MODELS = {}
origins = ["*"]
##################################
#
# ChatCompletion Middleware
#
##################################
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
):
@ -309,41 +323,72 @@ async def get_function_call_response(
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
data_items = []
if request.method == "POST" and (
"/ollama/api/chat" in request.url.path
or "/chat/completions" in request.url.path
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
# Remove the citations from the body
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Set the task model
task_model_id = data["model"]
if task_model_id not in app.state.MODELS:
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
# Check if the user has a custom task model
# If the user has a custom task model, use that model
# Check if the model has any filters
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if getattr(function_module, "file_handler"):
skip_files = True
try:
if hasattr(function_module, "inlet"):
data = function_module.inlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Set the task model
task_model_id = data["model"]
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
@ -361,8 +406,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context = ""
# If tool_ids field is present, call the functions
skip_files = False
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
@ -408,18 +451,22 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}")
else:
return_citations = False
if citations and data.get("citations"):
data_items.append({"citations": citations})
del data["files"]
if data.get("citations"):
del data["citations"]
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
system_prompt, data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
@ -435,40 +482,51 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
],
]
response = await call_next(request)
if return_citations:
# Inject the citations into the response
response = await call_next(request)
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations),
self.openai_stream_wrapper(response.body_iterator, data_items),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations),
self.ollama_stream_wrapper(response.body_iterator, data_items),
)
else:
return response
# If it's not a chat completion request, just pass it through
response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations):
yield f"data: {json.dumps({'citations': citations})}\n\n"
async def openai_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"data: {json.dumps(item)}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, citations):
yield f"{json.dumps({'citations': citations})}\n"
async def ollama_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"{json.dumps(item)}\n"
async for data in original_generator:
yield data
app.add_middleware(ChatCompletionMiddleware)
##################################
#
# Pipeline Middleware
#
##################################
def filter_pipeline(payload, user):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
@ -628,7 +686,6 @@ async def update_embedding_function(request: Request, call_next):
app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app)
@ -642,17 +699,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models():
pipe_models = []
openai_models = []
ollama_models = []
pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models()
openai_models = openai_models["data"]
if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await get_ollama_models()
ollama_models = [
{
"id": model["model"],
@ -665,9 +723,9 @@ async def get_all_models():
for model in ollama_models["models"]
]
models = openai_models + ollama_models
custom_models = Models.get_all_models()
models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id == None:
for model in models:
@ -730,6 +788,234 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models}
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
pipe = model.get("pipe")
if pipe:
form_data["user"] = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
if form_data["stream"]:
def stream_content():
res = pipe(body=form_data)
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(
stream_content(), media_type="text/event-stream"
)
else:
res = pipe(body=form_data)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
# Check if the model has any filters
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
try:
if hasattr(function_module, "outlet"):
data = function_module.outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
##################################
#
# Task Endpoints
#
##################################
# TODO: Refactor task API endpoints below into a separate file
@app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)):
return {
@ -1015,92 +1301,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
)
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
##################################
#
# Pipelines Endpoints
#
##################################
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
print(model_id)
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
return data
# TODO: Refactor pipelines API endpoints below into a separate file
@app.get("/api/pipelines/list")
@ -1423,6 +1631,13 @@ async def update_pipeline_valves(
)
##################################
#
# Config Endpoints
#
##################################
@app.get("/api/config")
async def get_app_config():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
@ -1486,6 +1701,9 @@ async def update_model_filter_config(
}
# TODO: webhook endpoint should be under config endpoints
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
return {

View File

@ -4,6 +4,8 @@ import json
import re
from datetime import timedelta
from typing import Optional, List, Tuple
import uuid
import time
def get_last_user_message(messages: List[dict]) -> str:
@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages
def stream_message_template(model: str, message: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters

View File

@ -0,0 +1,193 @@
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewFunction = async (token: string, func: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/create`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...func
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctions = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const exportFunctions = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/export`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFunctionById = async (token: string, id: string, func: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...func
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/delete`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};

View File

@ -278,7 +278,9 @@
})),
chat_id: $chatId
}).catch((error) => {
console.error(error);
toast.error(error);
messages.at(-1).error = { content: error };
return null;
});
@ -323,6 +325,13 @@
} else if (messages.length != 0 && messages.at(-1).done != true) {
// Response not done
console.log('wait');
} else if (messages.length != 0 && messages.at(-1).error) {
// Error in response
toast.error(
$i18n.t(
`Oops! There was an error in the previous response. Please try again or contact admin.`
)
);
} else if (
files.length > 0 &&
files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0
@ -630,7 +639,7 @@
keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
citations: files.length > 0,
citations: files.length > 0 ? true : undefined,
chat_id: $chatId
});
@ -928,10 +937,11 @@
max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
citations: files.length > 0,
citations: files.length > 0 ? true : undefined,
chat_id: $chatId
},
`${OPENAI_API_BASE_URL}`
`${WEBUI_BASE_URL}/api`
);
// Wait until history/message have been updated

View File

@ -3,25 +3,27 @@
import fileSaver from 'file-saver';
const { saveAs } = fileSaver;
import { WEBUI_NAME, functions, models } from '$lib/stores';
import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, prompts, tools } from '$lib/stores';
import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
import { goto } from '$app/navigation';
import {
createNewTool,
deleteToolById,
exportTools,
getToolById,
getTools
} from '$lib/apis/tools';
createNewFunction,
deleteFunctionById,
exportFunctions,
getFunctionById,
getFunctions
} from '$lib/apis/functions';
import ArrowDownTray from '../icons/ArrowDownTray.svelte';
import Tooltip from '../common/Tooltip.svelte';
import ConfirmDialog from '../common/ConfirmDialog.svelte';
import { getModels } from '$lib/apis';
const i18n = getContext('i18n');
let toolsImportInputElement: HTMLInputElement;
let functionsImportInputElement: HTMLInputElement;
let importFiles;
let showConfirm = false;
@ -64,7 +66,7 @@
<div>
<a
class=" px-2 py-2 rounded-xl border border-gray-200 dark:border-gray-600 dark:border-0 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 transition font-medium text-sm flex items-center space-x-1"
href="/workspace/tools/create"
href="/workspace/functions/create"
>
<svg
xmlns="http://www.w3.org/2000/svg"
@ -82,30 +84,40 @@
<hr class=" dark:border-gray-850 my-2.5" />
<div class="my-3 mb-5">
{#each $tools.filter((t) => query === '' || t.name
{#each $functions.filter((f) => query === '' || f.name
.toLowerCase()
.includes(query.toLowerCase()) || t.id.toLowerCase().includes(query.toLowerCase())) as tool}
.includes(query.toLowerCase()) || f.id.toLowerCase().includes(query.toLowerCase())) as func}
<button
class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
type="button"
on:click={() => {
goto(`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`);
goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
}}
>
<div class=" flex flex-1 space-x-4 cursor-pointer w-full">
<a
href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
class="flex items-center text-left"
>
<div class=" flex-1 self-center pl-5">
<div class=" flex-1 self-center pl-1">
<div class=" font-semibold flex items-center gap-1.5">
<div>
{tool.name}
<div
class=" text-xs font-black px-1 rounded uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
>
{func.type}
</div>
<div>
{func.name}
</div>
<div class=" text-gray-500 text-xs font-medium">{tool.id}</div>
</div>
<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
{tool.meta.description}
<div class="flex gap-1.5 px-1">
<div class=" text-gray-500 text-xs font-medium">{func.id}</div>
<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
{func.meta.description}
</div>
</div>
</div>
</a>
@ -115,7 +127,7 @@
<a
class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button"
href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
>
<svg
xmlns="http://www.w3.org/2000/svg"
@ -141,18 +153,20 @@
on:click={async (e) => {
e.stopPropagation();
const _tool = await getToolById(localStorage.token, tool.id).catch((error) => {
toast.error(error);
return null;
});
const _function = await getFunctionById(localStorage.token, func.id).catch(
(error) => {
toast.error(error);
return null;
}
);
if (_tool) {
sessionStorage.tool = JSON.stringify({
..._tool,
id: `${_tool.id}_clone`,
name: `${_tool.name} (Clone)`
if (_function) {
sessionStorage.function = JSON.stringify({
..._function,
id: `${_function.id}_clone`,
name: `${_function.name} (Clone)`
});
goto('/workspace/tools/create');
goto('/workspace/functions/create');
}
}}
>
@ -180,16 +194,18 @@
on:click={async (e) => {
e.stopPropagation();
const _tool = await getToolById(localStorage.token, tool.id).catch((error) => {
toast.error(error);
return null;
});
const _function = await getFunctionById(localStorage.token, func.id).catch(
(error) => {
toast.error(error);
return null;
}
);
if (_tool) {
let blob = new Blob([JSON.stringify([_tool])], {
if (_function) {
let blob = new Blob([JSON.stringify([_function])], {
type: 'application/json'
});
saveAs(blob, `tool-${_tool.id}-export-${Date.now()}.json`);
saveAs(blob, `function-${_function.id}-export-${Date.now()}.json`);
}
}}
>
@ -204,14 +220,16 @@
on:click={async (e) => {
e.stopPropagation();
const res = await deleteToolById(localStorage.token, tool.id).catch((error) => {
const res = await deleteFunctionById(localStorage.token, func.id).catch((error) => {
toast.error(error);
return null;
});
if (res) {
toast.success('Tool deleted successfully');
tools.set(await getTools(localStorage.token));
toast.success('Function deleted successfully');
functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
}
}}
>
@ -246,7 +264,7 @@
<div class="flex space-x-2">
<input
id="documents-import-input"
bind:this={toolsImportInputElement}
bind:this={functionsImportInputElement}
bind:files={importFiles}
type="file"
accept=".json"
@ -260,7 +278,7 @@
<button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
on:click={() => {
toolsImportInputElement.click();
functionsImportInputElement.click();
}}
>
<div class=" self-center mr-2 font-medium">{$i18n.t('Import Functions')}</div>
@ -284,16 +302,16 @@
<button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
on:click={async () => {
const _tools = await exportTools(localStorage.token).catch((error) => {
const _functions = await exportFunctions(localStorage.token).catch((error) => {
toast.error(error);
return null;
});
if (_tools) {
let blob = new Blob([JSON.stringify(_tools)], {
if (_functions) {
let blob = new Blob([JSON.stringify(_functions)], {
type: 'application/json'
});
saveAs(blob, `tools-export-${Date.now()}.json`);
saveAs(blob, `functions-export-${Date.now()}.json`);
}
}}
>
@ -322,18 +340,19 @@
on:confirm={() => {
const reader = new FileReader();
reader.onload = async (event) => {
const _tools = JSON.parse(event.target.result);
console.log(_tools);
const _functions = JSON.parse(event.target.result);
console.log(_functions);
for (const tool of _tools) {
const res = await createNewTool(localStorage.token, tool).catch((error) => {
for (const func of _functions) {
const res = await createNewFunction(localStorage.token, func).catch((error) => {
toast.error(error);
return null;
});
}
toast.success('Tool imported successfully');
tools.set(await getTools(localStorage.token));
toast.success('Functions imported successfully');
functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
};
reader.readAsText(importFiles[0]);
@ -344,8 +363,8 @@
<div>Please carefully review the following warnings:</div>
<ul class=" mt-1 list-disc pl-4 text-xs">
<li>Tools have a function calling system that allows arbitrary code execution.</li>
<li>Do not install tools from sources you do not fully trust.</li>
<li>Functions allow arbitrary code execution.</li>
<li>Do not install functions from sources you do not fully trust.</li>
</ul>
</div>

View File

@ -0,0 +1,235 @@
<script>
import { getContext, createEventDispatcher, onMount } from 'svelte';
import { goto } from '$app/navigation';
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
import CodeEditor from '$lib/components/common/CodeEditor.svelte';
import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
let formElement = null;
let loading = false;
let showConfirm = false;
export let edit = false;
export let clone = false;
export let id = '';
export let name = '';
export let meta = {
description: ''
};
export let content = '';
$: if (name && !edit && !clone) {
id = name.replace(/\s+/g, '_').toLowerCase();
}
let codeEditor;
let boilerplate = `from pydantic import BaseModel
from typing import Optional
class Filter:
class Valves(BaseModel):
max_turns: int = 4
pass
def __init__(self):
# Indicates custom file handling logic. This flag helps disengage default routines in favor of custom
# implementations, informing the WebUI to defer file-related operations to designated methods within this class.
# Alternatively, you can remove the files directly from the body in from the inlet hook
self.file_handler = True
# Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
# which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
self.valves = self.Valves(**{"max_turns": 2})
pass
def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
# Modify the request body or validate it before processing by the chat completion API.
# This function is the pre-processor for the API where various checks on the input can be performed.
# It can also modify the request before sending it to the API.
print(f"inlet:{__name__}")
print(f"inlet:body:{body}")
print(f"inlet:user:{user}")
if user.get("role", "admin") in ["user", "admin"]:
messages = body.get("messages", [])
if len(messages) > self.valves.max_turns:
raise Exception(
f"Conversation turn limit exceeded. Max turns: {self.valves.max_turns}"
)
return body
def outlet(self, body: dict, user: Optional[dict] = None) -> dict:
# Modify or analyze the response body after processing by the API.
# This function is the post-processor for the API, which can be used to modify the response
# or perform additional checks and analytics.
print(f"outlet:{__name__}")
print(f"outlet:body:{body}")
print(f"outlet:user:{user}")
messages = [
{
**message,
"content": f"{message['content']} - @@Modified from Filter Outlet",
}
for message in body.get("messages", [])
]
return {"messages": messages}
`;
const saveHandler = async () => {
loading = true;
dispatch('save', {
id,
name,
meta,
content
});
};
const submitHandler = async () => {
if (codeEditor) {
const res = await codeEditor.formatPythonCodeHandler();
if (res) {
console.log('Code formatted successfully');
saveHandler();
}
}
};
</script>
<div class=" flex flex-col justify-between w-full overflow-y-auto h-full">
<div class="mx-auto w-full md:px-0 h-full">
<form
bind:this={formElement}
class=" flex flex-col max-h-[100dvh] h-full"
on:submit|preventDefault={() => {
if (edit) {
submitHandler();
} else {
showConfirm = true;
}
}}
>
<div class="mb-2.5">
<button
class="flex space-x-1"
on:click={() => {
goto('/workspace/functions');
}}
type="button"
>
<div class=" self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M17 10a.75.75 0 01-.75.75H5.612l4.158 3.96a.75.75 0 11-1.04 1.08l-5.5-5.25a.75.75 0 010-1.08l5.5-5.25a.75.75 0 111.04 1.08L5.612 9.25H16.25A.75.75 0 0117 10z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class=" self-center font-medium text-sm">{$i18n.t('Back')}</div>
</button>
</div>
<div class="flex flex-col flex-1 overflow-auto h-0 rounded-lg">
<div class="w-full mb-2 flex flex-col gap-1.5">
<div class="flex gap-2 w-full">
<input
class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Function Name (e.g. My Filter)"
bind:value={name}
required
/>
<input
class="w-full px-3 py-2 text-sm font-medium disabled:text-gray-300 dark:disabled:text-gray-700 bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Function ID (e.g. my_filter)"
bind:value={id}
required
disabled={edit}
/>
</div>
<input
class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Function Description (e.g. A filter to remove profanity from text)"
bind:value={meta.description}
required
/>
</div>
<div class="mb-2 flex-1 overflow-auto h-0 rounded-lg">
<CodeEditor
bind:value={content}
bind:this={codeEditor}
{boilerplate}
on:save={() => {
if (formElement) {
formElement.requestSubmit();
}
}}
/>
</div>
<div class="pb-3 flex justify-between">
<div class="flex-1 pr-3">
<div class="text-xs text-gray-500 line-clamp-2">
<span class=" font-semibold dark:text-gray-200">Warning:</span> Functions allow
arbitrary code execution <br />
<span class=" font-medium dark:text-gray-400"
>don't install random functions from sources you don't trust.</span
>
</div>
</div>
<button
class="px-3 py-1.5 text-sm font-medium bg-emerald-600 hover:bg-emerald-700 text-gray-50 transition rounded-lg"
type="submit"
>
{$i18n.t('Save')}
</button>
</div>
</div>
</form>
</div>
</div>
<ConfirmDialog
bind:show={showConfirm}
on:confirm={() => {
submitHandler();
}}
>
<div class="text-sm text-gray-500">
<div class=" bg-yellow-500/20 text-yellow-700 dark:text-yellow-200 rounded-lg px-4 py-3">
<div>Please carefully review the following warnings:</div>
<ul class=" mt-1 list-disc pl-4 text-xs">
<li>Functions allow arbitrary code execution.</li>
<li>Do not install functions from sources you do not fully trust.</li>
</ul>
</div>
<div class="my-3">
I acknowledge that I have read and I understand the implications of my action. I am aware of
the risks associated with executing arbitrary code and I have verified the trustworthiness of
the source.
</div>
</div>
</ConfirmDialog>

View File

@ -0,0 +1,60 @@
<script lang="ts">
import { getContext, onMount } from 'svelte';
import Checkbox from '$lib/components/common/Checkbox.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
const i18n = getContext('i18n');
export let filters = [];
export let selectedFilterIds = [];
let _filters = {};
onMount(() => {
_filters = filters.reduce((acc, filter) => {
acc[filter.id] = {
...filter,
selected: selectedFilterIds.includes(filter.id)
};
return acc;
}, {});
});
</script>
<div>
<div class="flex w-full justify-between mb-1">
<div class=" self-center text-sm font-semibold">{$i18n.t('Filters')}</div>
</div>
<div class=" text-xs dark:text-gray-500">
{$i18n.t('To select filters here, add them to the "Functions" workspace first.')}
</div>
<!-- TODO: Filer order matters -->
<div class="flex flex-col">
{#if filters.length > 0}
<div class=" flex items-center mt-2 flex-wrap">
{#each Object.keys(_filters) as filter, filterIdx}
<div class=" flex items-center gap-2 mr-3">
<div class="self-center flex items-center">
<Checkbox
state={_filters[filter].selected ? 'checked' : 'unchecked'}
on:change={(e) => {
_filters[filter].selected = e.detail === 'checked';
selectedFilterIds = Object.keys(_filters).filter((t) => _filters[t].selected);
}}
/>
</div>
<div class=" py-0.5 text-sm w-full capitalize font-medium">
<Tooltip content={_filters[filter].meta.description}>
{_filters[filter].name}
</Tooltip>
</div>
</div>
{/each}
</div>
{/if}
</div>
</div>

View File

@ -27,7 +27,9 @@ export const tags = writable([]);
export const models: Writable<Model[]> = writable([]);
export const prompts: Writable<Prompt[]> = writable([]);
export const documents: Writable<Document[]> = writable([]);
export const tools = writable([]);
export const functions = writable([]);
export const banners: Writable<Banner[]> = writable([]);

View File

@ -1,11 +1,16 @@
<script lang="ts">
import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, showSidebar } from '$lib/stores';
import { WEBUI_NAME, showSidebar, functions } from '$lib/stores';
import MenuLines from '$lib/components/icons/MenuLines.svelte';
import { page } from '$app/stores';
import { getFunctions } from '$lib/apis/functions';
const i18n = getContext('i18n');
onMount(async () => {
functions.set(await getFunctions(localStorage.token));
});
</script>
<svelte:head>

View File

@ -1,18 +1,20 @@
<script>
import { goto } from '$app/navigation';
import { createNewTool, getTools } from '$lib/apis/tools';
import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
import { tools } from '$lib/stores';
import { onMount } from 'svelte';
import { toast } from 'svelte-sonner';
import { onMount } from 'svelte';
import { goto } from '$app/navigation';
import { functions, models } from '$lib/stores';
import { createNewFunction, getFunctions } from '$lib/apis/functions';
import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
import { getModels } from '$lib/apis';
let mounted = false;
let clone = false;
let tool = null;
let func = null;
const saveHandler = async (data) => {
console.log(data);
const res = await createNewTool(localStorage.token, {
const res = await createNewFunction(localStorage.token, {
id: data.id,
name: data.name,
meta: data.meta,
@ -23,19 +25,20 @@
});
if (res) {
toast.success('Tool created successfully');
tools.set(await getTools(localStorage.token));
toast.success('Function created successfully');
functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
await goto('/workspace/tools');
await goto('/workspace/functions');
}
};
onMount(() => {
if (sessionStorage.tool) {
tool = JSON.parse(sessionStorage.tool);
sessionStorage.removeItem('tool');
if (sessionStorage.function) {
func = JSON.parse(sessionStorage.function);
sessionStorage.removeItem('function');
console.log(tool);
console.log(func);
clone = true;
}
@ -44,11 +47,11 @@
</script>
{#if mounted}
<ToolkitEditor
id={tool?.id ?? ''}
name={tool?.name ?? ''}
meta={tool?.meta ?? { description: '' }}
content={tool?.content ?? ''}
<FunctionEditor
id={func?.id ?? ''}
name={func?.name ?? ''}
meta={func?.meta ?? { description: '' }}
content={func?.content ?? ''}
{clone}
on:save={(e) => {
saveHandler(e.detail);

View File

@ -1,18 +1,21 @@
<script>
import { toast } from 'svelte-sonner';
import { onMount } from 'svelte';
import { goto } from '$app/navigation';
import { page } from '$app/stores';
import { getToolById, getTools, updateToolById } from '$lib/apis/tools';
import Spinner from '$lib/components/common/Spinner.svelte';
import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
import { tools } from '$lib/stores';
import { onMount } from 'svelte';
import { toast } from 'svelte-sonner';
import { functions, models } from '$lib/stores';
import { updateFunctionById, getFunctions, getFunctionById } from '$lib/apis/functions';
let tool = null;
import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import { getModels } from '$lib/apis';
let func = null;
const saveHandler = async (data) => {
console.log(data);
const res = await updateToolById(localStorage.token, tool.id, {
const res = await updateFunctionById(localStorage.token, func.id, {
id: data.id,
name: data.name,
meta: data.meta,
@ -23,10 +26,9 @@
});
if (res) {
toast.success('Tool updated successfully');
tools.set(await getTools(localStorage.token));
// await goto('/workspace/tools');
toast.success('Function updated successfully');
functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
}
};
@ -35,24 +37,24 @@
const id = $page.url.searchParams.get('id');
if (id) {
tool = await getToolById(localStorage.token, id).catch((error) => {
func = await getFunctionById(localStorage.token, id).catch((error) => {
toast.error(error);
goto('/workspace/tools');
goto('/workspace/functions');
return null;
});
console.log(tool);
console.log(func);
}
});
</script>
{#if tool}
<ToolkitEditor
{#if func}
<FunctionEditor
edit={true}
id={tool.id}
name={tool.name}
meta={tool.meta}
content={tool.content}
id={func.id}
name={func.name}
meta={func.meta}
content={func.content}
on:save={(e) => {
saveHandler(e.detail);
}}

View File

@ -5,7 +5,7 @@
import { onMount, getContext } from 'svelte';
import { page } from '$app/stores';
import { settings, user, config, models, tools } from '$lib/stores';
import { settings, user, config, models, tools, functions } from '$lib/stores';
import { splitStream } from '$lib/utils';
import { getModelInfos, updateModelById } from '$lib/apis/models';
@ -16,6 +16,7 @@
import Tags from '$lib/components/common/Tags.svelte';
import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte';
import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte';
import FiltersSelector from '$lib/components/workspace/Models/FiltersSelector.svelte';
const i18n = getContext('i18n');
@ -62,6 +63,7 @@
let knowledge = [];
let toolIds = [];
let filterIds = [];
const updateHandler = async () => {
loading = true;
@ -86,6 +88,14 @@
}
}
if (filterIds.length > 0) {
info.meta.filterIds = filterIds;
} else {
if (info.meta.filterIds) {
delete info.meta.filterIds;
}
}
info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null;
Object.keys(info.params).forEach((key) => {
if (info.params[key] === '' || info.params[key] === null) {
@ -147,6 +157,10 @@
toolIds = [...model?.info?.meta?.toolIds];
}
if (model?.info?.meta?.filterIds) {
filterIds = [...model?.info?.meta?.filterIds];
}
if (model?.owned_by === 'openai') {
capabilities.usage = false;
}
@ -534,6 +548,13 @@
<ToolsSelector bind:selectedToolIds={toolIds} tools={$tools} />
</div>
<div class="my-2">
<FiltersSelector
bind:selectedFilterIds={filterIds}
filters={$functions.filter((func) => func.type === 'filter')}
/>
</div>
<div class="my-2">
<div class="flex w-full justify-between mb-1">
<div class=" self-center text-sm font-semibold">{$i18n.t('Capabilities')}</div>