mirror of
https://github.com/open-webui/open-webui.git
synced 2025-10-05 11:25:12 +02:00
This refactors the model import functionality to improve performance and user experience by centralizing the logic on the backend. Previously, the frontend would parse an imported JSON file and send an individual API request for each model, which was slow and inefficient. This change introduces a new backend endpoint, `/api/v1/models/import`, that accepts a list of model objects. The frontend now reads the selected JSON file, parses it, and sends the entire payload to the backend in a single request. The backend then processes this list, creating or updating models as necessary. This commit also includes the following fixes: - Handles cases where the imported JSON contains models without `meta` or `params` fields by providing default empty values.
324 lines
9.2 KiB
Python
324 lines
9.2 KiB
Python
from typing import Optional
|
|
import io
|
|
import base64
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
|
|
from open_webui.models.models import (
|
|
ModelForm,
|
|
ModelModel,
|
|
ModelResponse,
|
|
ModelUserResponse,
|
|
Models,
|
|
)
|
|
|
|
from pydantic import BaseModel
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
HTTPException,
|
|
Request,
|
|
status,
|
|
Response,
|
|
)
|
|
from fastapi.responses import FileResponse, StreamingResponse
|
|
|
|
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.utils.access_control import has_access, has_permission
|
|
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
###########################
|
|
# GetModels
|
|
###########################
|
|
|
|
|
|
@router.get("/", response_model=list[ModelUserResponse])
|
|
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
|
return Models.get_models()
|
|
else:
|
|
return Models.get_models_by_user_id(user.id)
|
|
|
|
|
|
###########################
|
|
# GetBaseModels
|
|
###########################
|
|
|
|
|
|
@router.get("/base", response_model=list[ModelResponse])
|
|
async def get_base_models(user=Depends(get_admin_user)):
|
|
return Models.get_base_models()
|
|
|
|
|
|
############################
|
|
# CreateNewModel
|
|
############################
|
|
|
|
|
|
@router.post("/create", response_model=Optional[ModelModel])
|
|
async def create_new_model(
|
|
request: Request,
|
|
form_data: ModelForm,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
if user.role != "admin" and not has_permission(
|
|
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
|
|
model = Models.get_model_by_id(form_data.id)
|
|
if model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
|
)
|
|
|
|
else:
|
|
model = Models.insert_new_model(form_data, user.id)
|
|
if model:
|
|
return model
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.DEFAULT(),
|
|
)
|
|
|
|
|
|
############################
|
|
# ExportModels
|
|
############################
|
|
|
|
|
|
@router.get("/export", response_model=list[ModelModel])
|
|
async def export_models(user=Depends(get_admin_user)):
|
|
return Models.get_models()
|
|
|
|
|
|
############################
|
|
# ImportModels
|
|
############################
|
|
|
|
|
|
class ModelsImportForm(BaseModel):
|
|
models: list[dict]
|
|
|
|
|
|
@router.post("/import", response_model=bool)
|
|
async def import_models(
|
|
user: str = Depends(get_admin_user), form_data: ModelsImportForm = (...)
|
|
):
|
|
try:
|
|
data = form_data.models
|
|
if isinstance(data, list):
|
|
for model_data in data:
|
|
# Here, you can add logic to validate model_data if needed
|
|
model_id = model_data.get("id")
|
|
if model_id:
|
|
existing_model = Models.get_model_by_id(model_id)
|
|
if existing_model:
|
|
# Update existing model
|
|
model_data["meta"] = model_data.get("meta", {})
|
|
model_data["params"] = model_data.get("params", {})
|
|
|
|
updated_model = ModelForm(
|
|
**{**existing_model.model_dump(), **model_data}
|
|
)
|
|
Models.update_model_by_id(model_id, updated_model)
|
|
else:
|
|
# Insert new model
|
|
model_data["meta"] = model_data.get("meta", {})
|
|
model_data["params"] = model_data.get("params", {})
|
|
new_model = ModelForm(**model_data)
|
|
Models.insert_new_model(user_id=user.id, form_data=new_model)
|
|
return True
|
|
else:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
############################
|
|
# SyncModels
|
|
############################
|
|
|
|
|
|
class SyncModelsForm(BaseModel):
|
|
models: list[ModelModel] = []
|
|
|
|
|
|
@router.post("/sync", response_model=list[ModelModel])
|
|
async def sync_models(
|
|
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
|
):
|
|
return Models.sync_models(user.id, form_data.models)
|
|
|
|
|
|
###########################
|
|
# GetModelById
|
|
###########################
|
|
|
|
|
|
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
|
@router.get("/model", response_model=Optional[ModelResponse])
|
|
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
|
model = Models.get_model_by_id(id)
|
|
if model:
|
|
if (
|
|
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
|
or model.user_id == user.id
|
|
or has_access(user.id, "read", model.access_control)
|
|
):
|
|
return model
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
|
|
###########################
|
|
# GetModelById
|
|
###########################
|
|
|
|
|
|
@router.get("/model/profile/image")
|
|
async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
|
|
model = Models.get_model_by_id(id)
|
|
if model:
|
|
if model.meta.profile_image_url:
|
|
if model.meta.profile_image_url.startswith("http"):
|
|
return Response(
|
|
status_code=status.HTTP_302_FOUND,
|
|
headers={"Location": model.meta.profile_image_url},
|
|
)
|
|
elif model.meta.profile_image_url.startswith("data:image"):
|
|
try:
|
|
header, base64_data = model.meta.profile_image_url.split(",", 1)
|
|
image_data = base64.b64decode(base64_data)
|
|
image_buffer = io.BytesIO(image_data)
|
|
|
|
return StreamingResponse(
|
|
image_buffer,
|
|
media_type="image/png",
|
|
headers={"Content-Disposition": "inline; filename=image.png"},
|
|
)
|
|
except Exception as e:
|
|
pass
|
|
return FileResponse(f"{STATIC_DIR}/favicon.png")
|
|
else:
|
|
return FileResponse(f"{STATIC_DIR}/favicon.png")
|
|
|
|
|
|
############################
|
|
# ToggleModelById
|
|
############################
|
|
|
|
|
|
@router.post("/model/toggle", response_model=Optional[ModelResponse])
|
|
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
|
model = Models.get_model_by_id(id)
|
|
if model:
|
|
if (
|
|
user.role == "admin"
|
|
or model.user_id == user.id
|
|
or has_access(user.id, "write", model.access_control)
|
|
):
|
|
model = Models.toggle_model_by_id(id)
|
|
|
|
if model:
|
|
return model
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
|
|
############################
|
|
# UpdateModelById
|
|
############################
|
|
|
|
|
|
@router.post("/model/update", response_model=Optional[ModelModel])
|
|
async def update_model_by_id(
|
|
id: str,
|
|
form_data: ModelForm,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
model = Models.get_model_by_id(id)
|
|
|
|
if not model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
if (
|
|
model.user_id != user.id
|
|
and not has_access(user.id, "write", model.access_control)
|
|
and user.role != "admin"
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
)
|
|
|
|
model = Models.update_model_by_id(id, form_data)
|
|
return model
|
|
|
|
|
|
############################
|
|
# DeleteModelById
|
|
############################
|
|
|
|
|
|
@router.delete("/model/delete", response_model=bool)
|
|
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
|
model = Models.get_model_by_id(id)
|
|
if not model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
if (
|
|
user.role != "admin"
|
|
and model.user_id != user.id
|
|
and not has_access(user.id, "write", model.access_control)
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
|
|
result = Models.delete_model_by_id(id)
|
|
return result
|
|
|
|
|
|
@router.delete("/delete/all", response_model=bool)
|
|
async def delete_all_models(user=Depends(get_admin_user)):
|
|
result = Models.delete_all_models()
|
|
return result
|