diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index c7fd78819..d1be9c210 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,6 +1,11 @@ import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; +import { getOpenAIModelsDirect } from './openai'; -export const getModels = async (token: string = '', base: boolean = false) => { +export const getModels = async ( + token: string = '', + connections: object | null = null, + base: boolean = false +) => { let error = null; const res = await fetch(`${WEBUI_BASE_URL}/api/models${base ? '/base' : ''}`, { method: 'GET', @@ -25,6 +30,76 @@ export const getModels = async (token: string = '', base: boolean = false) => { } let models = res?.data ?? []; + + if (connections && !base) { + let localModels = []; + + if (connections) { + const OPENAI_API_BASE_URLS = connections.OPENAI_API_BASE_URLS; + const OPENAI_API_KEYS = connections.OPENAI_API_KEYS; + const OPENAI_API_CONFIGS = connections.OPENAI_API_CONFIGS; + + const requests = []; + for (const idx in OPENAI_API_BASE_URLS) { + const url = OPENAI_API_BASE_URLS[idx]; + + if (idx.toString() in OPENAI_API_CONFIGS) { + const apiConfig = OPENAI_API_CONFIGS[idx.toString()] ?? {}; + + const enable = apiConfig?.enable ?? true; + const modelIds = apiConfig?.model_ids ?? []; + + if (enable) { + if (modelIds.length > 0) { + const modelList = { + object: 'list', + data: modelIds.map((modelId) => ({ + id: modelId, + name: modelId, + owned_by: 'openai', + openai: { id: modelId }, + urlIdx: idx + })) + }; + + requests.push(() => modelList); + } else { + requests.push(getOpenAIModelsDirect(url, OPENAI_API_KEYS[idx])); + } + } else { + requests.push(() => {}); + } + } + } + const responses = await Promise.all(requests); + + for (const idx in responses) { + const response = responses[idx]; + const apiConfig = OPENAI_API_CONFIGS[idx.toString()] ?? {}; + + let models = Array.isArray(response) ? response : (response?.data ?? []); + models = models.map((model) => ({ ...model, openai: { id: model.id }, urlIdx: idx })); + + const prefixId = apiConfig.prefix_id; + if (prefixId) { + for (const model of models) { + model.id = `${prefixId}.${model.id}`; + } + } + + localModels = localModels.concat(models); + } + } + + models = models.concat( + localModels.map((model) => ({ + ...model, + name: model?.name ?? model?.id, + direct: true + })) + ); + } + return models; }; diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 53f369e01..bab2d6e36 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -208,6 +208,33 @@ export const updateOpenAIKeys = async (token: string = '', keys: string[]) => { return res.OPENAI_API_KEYS; }; +export const getOpenAIModelsDirect = async (url: string, key: string) => { + let error = null; + + const res = await fetch(`${url}/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(key && { authorization: `Bearer ${key}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getOpenAIModels = async (token: string, urlIdx?: number) => { let error = null; diff --git a/src/lib/components/admin/Functions.svelte b/src/lib/components/admin/Functions.svelte index f7814ce91..67bfc0499 100644 --- a/src/lib/components/admin/Functions.svelte +++ b/src/lib/components/admin/Functions.svelte @@ -3,7 +3,7 @@ import fileSaver from 'file-saver'; const { saveAs } = fileSaver; - import { WEBUI_NAME, config, functions, models } from '$lib/stores'; + import { WEBUI_NAME, config, functions, models, settings } from '$lib/stores'; import { onMount, getContext, tick } from 'svelte'; import { goto } from '$app/navigation'; @@ -126,7 +126,7 @@ toast.success($i18n.t('Function deleted successfully')); functions.set(await getFunctions(localStorage.token)); - models.set(await getModels(localStorage.token)); + models.set(await getModels(localStorage.token, $settings?.directConnections ?? null)); } }; @@ -147,7 +147,7 @@ } functions.set(await getFunctions(localStorage.token)); - models.set(await getModels(localStorage.token)); + models.set(await getModels(localStorage.token, $settings?.directConnections ?? null)); } }; @@ -359,7 +359,9 @@ bind:state={func.is_active} on:change={async (e) => { toggleFunctionById(localStorage.token, func.id); - models.set(await getModels(localStorage.token)); + models.set( + await getModels(localStorage.token, $settings?.directConnections ?? null) + ); }} /> @@ -496,7 +498,7 @@ id={selectedFunction?.id ?? null} on:save={async () => { await tick(); - models.set(await getModels(localStorage.token)); + models.set(await getModels(localStorage.token, $settings?.directConnections ?? null)); }} /> @@ -517,7 +519,7 @@ toast.success($i18n.t('Functions imported successfully')); functions.set(await getFunctions(localStorage.token)); - models.set(await getModels(localStorage.token)); + models.set(await getModels(localStorage.token, $settings?.directConnections ?? null)); }; reader.readAsText(importFiles[0]); diff --git a/src/lib/components/admin/Settings/Audio.svelte b/src/lib/components/admin/Settings/Audio.svelte index 69dcb55fa..ca4401029 100644 --- a/src/lib/components/admin/Settings/Audio.svelte +++ b/src/lib/components/admin/Settings/Audio.svelte @@ -10,7 +10,7 @@ getModels as _getModels, getVoices as _getVoices } from '$lib/apis/audio'; - import { config } from '$lib/stores'; + import { config, settings } from '$lib/stores'; import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; @@ -51,9 +51,11 @@ if (TTS_ENGINE === '') { models = []; } else { - const res = await _getModels(localStorage.token).catch((e) => { - toast.error(`${e}`); - }); + const res = await _getModels(localStorage.token, $settings?.directConnections ?? null).catch( + (e) => { + toast.error(`${e}`); + } + ); if (res) { console.log(res); diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte index 893f45602..a4254ae4c 100644 --- a/src/lib/components/admin/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -9,7 +9,7 @@ import { getModels as _getModels } from '$lib/apis'; import { getDirectConnectionsConfig, setDirectConnectionsConfig } from '$lib/apis/configs'; - import { models, user } from '$lib/stores'; + import { models, settings, user } from '$lib/stores'; import Switch from '$lib/components/common/Switch.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; @@ -23,7 +23,7 @@ const i18n = getContext('i18n'); const getModels = async () => { - const models = await _getModels(localStorage.token); + const models = await _getModels(localStorage.token, $settings?.directConnections ?? null); return models; }; diff --git a/src/lib/components/admin/Settings/Evaluations.svelte b/src/lib/components/admin/Settings/Evaluations.svelte index c0d1b4f32..41c76d94a 100644 --- a/src/lib/components/admin/Settings/Evaluations.svelte +++ b/src/lib/components/admin/Settings/Evaluations.svelte @@ -1,6 +1,6 @@