diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 05d7c68006..5c5a2dcd90 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -1,6 +1,9 @@ from typing import Optional import io import base64 +import json +import asyncio +import logging from open_webui.models.models import ( ModelForm, @@ -12,7 +15,14 @@ from open_webui.models.models import ( from pydantic import BaseModel from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status, Response +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + status, + Response, +) from fastapi.responses import FileResponse, StreamingResponse @@ -20,6 +30,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR +log = logging.getLogger(__name__) + router = APIRouter() @@ -93,6 +105,50 @@ async def export_models(user=Depends(get_admin_user)): return Models.get_models() +############################ +# ImportModels +############################ + + +class ModelsImportForm(BaseModel): + models: list[dict] + + +@router.post("/import", response_model=bool) +async def import_models( + user: str = Depends(get_admin_user), form_data: ModelsImportForm = (...) +): + try: + data = form_data.models + if isinstance(data, list): + for model_data in data: + # Here, you can add logic to validate model_data if needed + model_id = model_data.get("id") + if model_id: + existing_model = Models.get_model_by_id(model_id) + if existing_model: + # Update existing model + model_data["meta"] = model_data.get("meta", {}) + model_data["params"] = model_data.get("params", {}) + + updated_model = ModelForm( + **{**existing_model.model_dump(), **model_data} + ) + Models.update_model_by_id(model_id, updated_model) + else: + # Insert new model + model_data["meta"] = model_data.get("meta", {}) + model_data["params"] = model_data.get("params", {}) + new_model = ModelForm(**model_data) + Models.insert_new_model(user_id=user.id, form_data=new_model) + return True + else: + raise HTTPException(status_code=400, detail="Invalid JSON format") + except Exception as e: + log.exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + ############################ # SyncModels ############################ diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 3e6e0d0c0b..d324fa9173 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -31,6 +31,34 @@ export const getModels = async (token: string = '') => { return res; }; +export const importModels = async (token: string, models: object[]) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/import`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ models: models }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getBaseModels = async (token: string = '') => { let error = null; diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index f3df30377f..1f0e33d1dd 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -12,7 +12,8 @@ deleteAllModels, getBaseModels, toggleModelById, - updateModelById + updateModelById, + importModels } from '$lib/apis/models'; import { copyToClipboard } from '$lib/utils'; import { page } from '$app/stores'; @@ -40,6 +41,7 @@ let shiftKey = false; +let modelsImportInProgress = false; let importFiles; let modelsImportInputElement: HTMLInputElement; @@ -464,47 +466,41 @@ accept=".json" hidden on:change={() => { - console.log(importFiles); + if (importFiles.length > 0) { + const reader = new FileReader(); + reader.onload = async (event) => { + try { + const models = JSON.parse(String(event.target.result)); + modelsImportInProgress = true; + const res = await importModels(localStorage.token, models); + modelsImportInProgress = false; - let reader = new FileReader(); - reader.onload = async (event) => { - let savedModels = JSON.parse(event.target.result); - console.log(savedModels); - - for (const model of savedModels) { - if (Object.keys(model).includes('base_model_id')) { - if (model.base_model_id === null) { - upsertModelHandler(model); - } - } else { - if (model?.info ?? false) { - if (model.info.base_model_id === null) { - upsertModelHandler(model.info); - } + if (res) { + toast.success($i18n.t('Models imported successfully')); + await init(); + } else { + toast.error($i18n.t('Failed to import models')); } + } catch (e) { + toast.error($i18n.t('Invalid JSON file')); + console.error(e); } - } - - await _models.set( - await getModels( - localStorage.token, - $config?.features?.enable_direct_connections && - ($settings?.directConnections ?? null) - ) - ); - init(); - }; - - reader.readAsText(importFiles[0]); + }; + reader.readAsText(importFiles[0]); + } }} />