From 5cd28c04b8330001f45504b7c0a3c94adfbc090b Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 11 Jun 2024 15:29:46 -0700 Subject: [PATCH] feat: model tools assignment --- src/lib/components/chat/Chat.svelte | 21 ++++--- src/lib/components/chat/MessageInput.svelte | 16 +++++- .../workspace/Models/ToolsSelector.svelte | 57 +++++++++++++++++++ .../workspace/models/create/+page.svelte | 18 +++++- .../(app)/workspace/models/edit/+page.svelte | 20 ++++++- 5 files changed, 116 insertions(+), 16 deletions(-) create mode 100644 src/lib/components/workspace/Models/ToolsSelector.svelte diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index af38d1665..359056cfd 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -74,6 +74,9 @@ let selectedModels = ['']; let atSelectedModel: Model | undefined; + let selectedModelIds = []; + $: selectedModelIds = atSelectedModel !== undefined ? [atSelectedModel.id] : selectedModels; + let selectedToolIds = []; let webSearchEnabled = false; @@ -1281,17 +1284,13 @@ bind:selectedToolIds bind:webSearchEnabled bind:atSelectedModel - availableTools={$user.role === 'admin' - ? $tools.reduce((a, e, i, arr) => { - a[e.id] = { - name: e.name, - description: e.meta.description, - enabled: false - }; - - return a; - }, {}) - : {}} + availableToolIds={selectedModelIds.reduce((a, e, i, arr) => { + const model = $models.find((m) => m.id === e); + if (model?.info?.meta?.toolIds ?? false) { + return [...new Set([...a, ...model.info.meta.toolIds])]; + } + return a; + }, [])} {selectedModels} {messages} {submitPrompt} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 871025e63..50414634b 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -9,7 +9,8 @@ models, config, showCallOverlay, - tools + tools, + user as _user } from '$lib/stores'; import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils'; @@ -59,7 +60,7 @@ export let files = []; - export let availableTools = {}; + export let availableToolIds = []; export let selectedToolIds = []; export let webSearchEnabled = false; @@ -657,7 +658,16 @@ { + if (availableToolIds.includes(e.id) || ($_user?.role ?? 'user') === 'admin') { + a[e.id] = { + name: e.name, + description: e.meta.description, + enabled: false + }; + } + return a; + }, {})} uploadFilesHandler={() => { filesInputElement.click(); }} diff --git a/src/lib/components/workspace/Models/ToolsSelector.svelte b/src/lib/components/workspace/Models/ToolsSelector.svelte new file mode 100644 index 000000000..584e737c5 --- /dev/null +++ b/src/lib/components/workspace/Models/ToolsSelector.svelte @@ -0,0 +1,57 @@ + + +
+
+
{$i18n.t('Tools')}
+
+ +
+ {$i18n.t('To select toolkits here, add them to the "Tools" workspace first.')} +
+ +
+ {#if tools.length > 0} +
+ {#each Object.keys(_tools) as tool, toolIdx} +
+
+ { + _tools[tool].selected = e.detail === 'checked'; + selectedToolIds = Object.keys(_tools).filter((t) => _tools[t].selected); + }} + /> +
+ +
+ {_tools[tool].name} +
+
+ {/each} +
+ {/if} +
+
diff --git a/src/routes/(app)/workspace/models/create/+page.svelte b/src/routes/(app)/workspace/models/create/+page.svelte index 8f1215dde..130818c6a 100644 --- a/src/routes/(app)/workspace/models/create/+page.svelte +++ b/src/routes/(app)/workspace/models/create/+page.svelte @@ -2,7 +2,7 @@ import { v4 as uuidv4 } from 'uuid'; import { toast } from 'svelte-sonner'; import { goto } from '$app/navigation'; - import { settings, user, config, models } from '$lib/stores'; + import { settings, user, config, models, tools } from '$lib/stores'; import { onMount, tick, getContext } from 'svelte'; import { addNewModel, getModelById, getModelInfos } from '$lib/apis/models'; @@ -12,6 +12,8 @@ import Checkbox from '$lib/components/common/Checkbox.svelte'; import Tags from '$lib/components/common/Tags.svelte'; import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte'; + import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte'; + import { stringify } from 'postcss'; const i18n = getContext('i18n'); @@ -54,6 +56,7 @@ vision: true }; + let toolIds = []; let knowledge = []; $: if (name) { @@ -88,6 +91,14 @@ } } + if (toolIds.length > 0) { + info.meta.toolIds = toolIds; + } else { + if (info.meta.toolIds) { + delete info.meta.toolIds; + } + } + info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null; Object.keys(info.params).forEach((key) => { if (info.params[key] === '' || info.params[key] === null) { @@ -154,6 +165,7 @@ params.stop = params?.stop ? (params?.stop ?? []).join(',') : null; capabilities = { ...capabilities, ...(model?.info?.meta?.capabilities ?? {}) }; + toolIds = model?.info?.meta?.toolIds ?? []; info = { ...info, @@ -554,6 +566,10 @@ +
+ +
+
{$i18n.t('Capabilities')}
diff --git a/src/routes/(app)/workspace/models/edit/+page.svelte b/src/routes/(app)/workspace/models/edit/+page.svelte index 970be7a22..95505d3ee 100644 --- a/src/routes/(app)/workspace/models/edit/+page.svelte +++ b/src/routes/(app)/workspace/models/edit/+page.svelte @@ -5,7 +5,7 @@ import { onMount, getContext } from 'svelte'; import { page } from '$app/stores'; - import { settings, user, config, models } from '$lib/stores'; + import { settings, user, config, models, tools } from '$lib/stores'; import { splitStream } from '$lib/utils'; import { getModelInfos, updateModelById } from '$lib/apis/models'; @@ -15,6 +15,7 @@ import Checkbox from '$lib/components/common/Checkbox.svelte'; import Tags from '$lib/components/common/Tags.svelte'; import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte'; + import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte'; const i18n = getContext('i18n'); @@ -60,6 +61,7 @@ }; let knowledge = []; + let toolIds = []; const updateHandler = async () => { loading = true; @@ -76,6 +78,14 @@ } } + if (toolIds.length > 0) { + info.meta.toolIds = toolIds; + } else { + if (info.meta.toolIds) { + delete info.meta.toolIds; + } + } + info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null; Object.keys(info.params).forEach((key) => { if (info.params[key] === '' || info.params[key] === null) { @@ -133,6 +143,10 @@ knowledge = [...model?.info?.meta?.knowledge]; } + if (model?.info?.meta?.toolIds) { + toolIds = [...model?.info?.meta?.toolIds]; + } + if (model?.owned_by === 'openai') { capabilities.usage = false; } @@ -515,6 +529,10 @@
+
+ +
+
{$i18n.t('Capabilities')}