mirror of
https://github.com/open-webui/open-webui.git
synced 2025-10-04 19:02:41 +02:00
Merge pull request #17871 from silentoplayz/backend-json-model-import
feat: move JSON model import to backend for massive speedup
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from typing import Optional
|
||||
import io
|
||||
import base64
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from open_webui.models.models import (
|
||||
ModelForm,
|
||||
@@ -12,7 +15,14 @@ from open_webui.models.models import (
|
||||
|
||||
from pydantic import BaseModel
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status, Response
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
Response,
|
||||
)
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
@@ -20,6 +30,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -93,6 +105,50 @@ async def export_models(user=Depends(get_admin_user)):
|
||||
return Models.get_models()
|
||||
|
||||
|
||||
############################
|
||||
# ImportModels
|
||||
############################
|
||||
|
||||
|
||||
class ModelsImportForm(BaseModel):
|
||||
models: list[dict]
|
||||
|
||||
|
||||
@router.post("/import", response_model=bool)
|
||||
async def import_models(
|
||||
user: str = Depends(get_admin_user), form_data: ModelsImportForm = (...)
|
||||
):
|
||||
try:
|
||||
data = form_data.models
|
||||
if isinstance(data, list):
|
||||
for model_data in data:
|
||||
# Here, you can add logic to validate model_data if needed
|
||||
model_id = model_data.get("id")
|
||||
if model_id:
|
||||
existing_model = Models.get_model_by_id(model_id)
|
||||
if existing_model:
|
||||
# Update existing model
|
||||
model_data["meta"] = model_data.get("meta", {})
|
||||
model_data["params"] = model_data.get("params", {})
|
||||
|
||||
updated_model = ModelForm(
|
||||
**{**existing_model.model_dump(), **model_data}
|
||||
)
|
||||
Models.update_model_by_id(model_id, updated_model)
|
||||
else:
|
||||
# Insert new model
|
||||
model_data["meta"] = model_data.get("meta", {})
|
||||
model_data["params"] = model_data.get("params", {})
|
||||
new_model = ModelForm(**model_data)
|
||||
Models.insert_new_model(user_id=user.id, form_data=new_model)
|
||||
return True
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
############################
|
||||
# SyncModels
|
||||
############################
|
||||
|
@@ -31,6 +31,34 @@ export const getModels = async (token: string = '') => {
|
||||
return res;
|
||||
};
|
||||
|
||||
export const importModels = async (token: string, models: object[]) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/models/import`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({ models: models })
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
error = err;
|
||||
console.error(err);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getBaseModels = async (token: string = '') => {
|
||||
let error = null;
|
||||
|
||||
|
@@ -12,7 +12,8 @@
|
||||
deleteAllModels,
|
||||
getBaseModels,
|
||||
toggleModelById,
|
||||
updateModelById
|
||||
updateModelById,
|
||||
importModels
|
||||
} from '$lib/apis/models';
|
||||
import { copyToClipboard } from '$lib/utils';
|
||||
import { page } from '$app/stores';
|
||||
@@ -40,6 +41,7 @@
|
||||
|
||||
let shiftKey = false;
|
||||
|
||||
let modelsImportInProgress = false;
|
||||
let importFiles;
|
||||
let modelsImportInputElement: HTMLInputElement;
|
||||
|
||||
@@ -464,47 +466,41 @@
|
||||
accept=".json"
|
||||
hidden
|
||||
on:change={() => {
|
||||
console.log(importFiles);
|
||||
if (importFiles.length > 0) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = async (event) => {
|
||||
try {
|
||||
const models = JSON.parse(String(event.target.result));
|
||||
modelsImportInProgress = true;
|
||||
const res = await importModels(localStorage.token, models);
|
||||
modelsImportInProgress = false;
|
||||
|
||||
let reader = new FileReader();
|
||||
reader.onload = async (event) => {
|
||||
let savedModels = JSON.parse(event.target.result);
|
||||
console.log(savedModels);
|
||||
|
||||
for (const model of savedModels) {
|
||||
if (Object.keys(model).includes('base_model_id')) {
|
||||
if (model.base_model_id === null) {
|
||||
upsertModelHandler(model);
|
||||
}
|
||||
} else {
|
||||
if (model?.info ?? false) {
|
||||
if (model.info.base_model_id === null) {
|
||||
upsertModelHandler(model.info);
|
||||
}
|
||||
if (res) {
|
||||
toast.success($i18n.t('Models imported successfully'));
|
||||
await init();
|
||||
} else {
|
||||
toast.error($i18n.t('Failed to import models'));
|
||||
}
|
||||
} catch (e) {
|
||||
toast.error($i18n.t('Invalid JSON file'));
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
|
||||
await _models.set(
|
||||
await getModels(
|
||||
localStorage.token,
|
||||
$config?.features?.enable_direct_connections &&
|
||||
($settings?.directConnections ?? null)
|
||||
)
|
||||
);
|
||||
init();
|
||||
};
|
||||
|
||||
reader.readAsText(importFiles[0]);
|
||||
};
|
||||
reader.readAsText(importFiles[0]);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<button
|
||||
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
|
||||
disabled={modelsImportInProgress}
|
||||
on:click={() => {
|
||||
modelsImportInputElement.click();
|
||||
}}
|
||||
>
|
||||
{#if modelsImportInProgress}
|
||||
<Spinner className="size-3" />
|
||||
{/if}
|
||||
<div class=" self-center mr-2 font-medium line-clamp-1">
|
||||
{$i18n.t('Import Presets')}
|
||||
</div>
|
||||
|
Reference in New Issue
Block a user