diff --git a/README.md b/README.md index 444002fbc..5f6e4550b 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature- - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query. -- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, and `Serply` and inject the results directly into your chat experience. +- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo` and `TavilySearch` and inject the results directly into your chat experience. - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index 8e8f89da0..9bf242381 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -18,6 +18,10 @@ If you're experiencing connection issues, it’s often due to the WebUI docker c docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main ``` +### Error on Slow Reponses for Ollama + +Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds. + ### General Connection Errors **Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates. diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 144755418..118c688d3 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -46,6 +46,7 @@ from config import ( SRC_LOG_LEVELS, OLLAMA_BASE_URLS, ENABLE_OLLAMA_API, + AIOHTTP_CLIENT_TIMEOUT, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, @@ -154,7 +155,9 @@ async def cleanup_response( async def post_streaming_url(url: str, payload: str): r = None try: - session = aiohttp.ClientSession(trust_env=True) + session = aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) r = await session.post(url, data=payload) r.raise_for_status() @@ -751,6 +754,14 @@ async def generate_chat_completion( if model_info.params.get("num_ctx", None): payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) + if model_info.params.get("num_batch", None): + payload["options"]["num_batch"] = model_info.params.get( + "num_batch", None + ) + + if model_info.params.get("num_keep", None): + payload["options"]["num_keep"] = model_info.params.get("num_keep", None) + if model_info.params.get("repeat_last_n", None): payload["options"]["repeat_last_n"] = model_info.params.get( "repeat_last_n", None diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 37da4db5a..62be56aa4 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -73,6 +73,7 @@ from apps.rag.search.serper import search_serper from apps.rag.search.serpstack import search_serpstack from apps.rag.search.serply import search_serply from apps.rag.search.duckduckgo import search_duckduckgo +from apps.rag.search.tavily import search_tavily from utils.misc import ( calculate_sha256, @@ -120,6 +121,7 @@ from config import ( SERPSTACK_HTTPS, SERPER_API_KEY, SERPLY_API_KEY, + TAVILY_API_KEY, RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_EMBEDDING_OPENAI_BATCH_SIZE, @@ -174,6 +176,7 @@ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPER_API_KEY = SERPER_API_KEY app.state.config.SERPLY_API_KEY = SERPLY_API_KEY +app.state.config.TAVILY_API_KEY = TAVILY_API_KEY app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS @@ -402,6 +405,7 @@ async def get_rag_config(user=Depends(get_admin_user)): "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, "serply_api_key": app.state.config.SERPLY_API_KEY, + "tavily_api_key": app.state.config.TAVILY_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -430,6 +434,7 @@ class WebSearchConfig(BaseModel): serpstack_https: Optional[bool] = None serper_api_key: Optional[str] = None serply_api_key: Optional[str] = None + tavily_api_key: Optional[str] = None result_count: Optional[int] = None concurrent_requests: Optional[int] = None @@ -481,6 +486,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key + app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests @@ -510,6 +516,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, "serply_api_key": app.state.config.SERPLY_API_KEY, + "tavily_api_key": app.state.config.TAVILY_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -758,7 +765,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: - SERPSTACK_API_KEY - SERPER_API_KEY - SERPLY_API_KEY - + - TAVILY_API_KEY Args: query (str): The query to search for """ @@ -833,6 +840,15 @@ def search_web(engine: str, query: str) -> list[SearchResult]: raise Exception("No SERPLY_API_KEY found in environment variables") elif engine == "duckduckgo": return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS) + elif engine == "tavily": + if app.state.config.TAVILY_API_KEY: + return search_tavily( + app.state.config.TAVILY_API_KEY, + query, + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + ) + else: + raise Exception("No TAVILY_API_KEY found in environment variables") else: raise Exception("No search engine API key found in environment variables") diff --git a/backend/apps/rag/search/tavily.py b/backend/apps/rag/search/tavily.py new file mode 100644 index 000000000..b15d6ef9d --- /dev/null +++ b/backend/apps/rag/search/tavily.py @@ -0,0 +1,39 @@ +import logging + +import requests + +from apps.rag.search.main import SearchResult +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: + """Search using Tavily's Search API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A Tavily Search API key + query (str): The query to search for + + Returns: + List[SearchResult]: A list of search results + """ + url = "https://api.tavily.com/search" + data = {"query": query, "api_key": api_key} + + response = requests.post(url, json=data) + response.raise_for_status() + + json_response = response.json() + + raw_search_results = json_response.get("results", []) + + return [ + SearchResult( + link=result["url"], + title=result.get("title", ""), + snippet=result.get("content"), + ) + for result in raw_search_results[:count] + ] diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 70e5577e9..ef63674ab 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -65,6 +65,20 @@ class MemoriesTable: else: return None + def update_memory_by_id( + self, + id: str, + content: str, + ) -> Optional[MemoryModel]: + try: + memory = Memory.get(Memory.id == id) + memory.content = content + memory.updated_at = int(time.time()) + memory.save() + return MemoryModel(**model_to_dict(memory)) + except: + return None + def get_memories(self) -> List[MemoryModel]: try: memories = Memory.select() diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index 6448ebe1e..3832fe9a1 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -44,6 +44,10 @@ class AddMemoryForm(BaseModel): content: str +class MemoryUpdateModel(BaseModel): + content: Optional[str] = None + + @router.post("/add", response_model=Optional[MemoryModel]) async def add_memory( request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user) @@ -62,6 +66,34 @@ async def add_memory( return memory +@router.post("/{memory_id}/update", response_model=Optional[MemoryModel]) +async def update_memory_by_id( + memory_id: str, + request: Request, + form_data: MemoryUpdateModel, + user=Depends(get_verified_user), +): + memory = Memories.update_memory_by_id(memory_id, form_data.content) + if memory is None: + raise HTTPException(status_code=404, detail="Memory not found") + + if form_data.content is not None: + memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) + collection = CHROMA_CLIENT.get_or_create_collection( + name=f"user-memory-{user.id}" + ) + collection.upsert( + documents=[form_data.content], + ids=[memory.id], + embeddings=[memory_embedding], + metadatas=[ + {"created_at": memory.created_at, "updated_at": memory.updated_at} + ], + ) + + return memory + + ############################ # QueryMemory ############################ diff --git a/backend/config.py b/backend/config.py index 6d145465a..e0190f645 100644 --- a/backend/config.py +++ b/backend/config.py @@ -425,6 +425,7 @@ OLLAMA_API_BASE_URL = os.environ.get( ) OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") +AIOHTTP_CLIENT_TIMEOUT = int(os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300")) K8S_FLAG = os.environ.get("K8S_FLAG", "") USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") @@ -951,6 +952,11 @@ SERPLY_API_KEY = PersistentConfig( os.getenv("SERPLY_API_KEY", ""), ) +TAVILY_API_KEY = PersistentConfig( + "TAVILY_API_KEY", + "rag.web.search.tavily_api_key", + os.getenv("TAVILY_API_KEY", ""), +) RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( "RAG_WEB_SEARCH_RESULT_COUNT", diff --git a/backend/main.py b/backend/main.py index de8827d12..e42c4ed9c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -494,6 +494,9 @@ def filter_pipeline(payload, user): if "title" in payload: del payload["title"] + if "task" in payload: + del payload["task"] + return payload @@ -835,6 +838,71 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, + "task": True, + } + + print(payload) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + +@app.post("/api/task/emoji/completions") +async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): + print("generate_emoji") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + model = app.state.MODELS[model_id] + + template = ''' +Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). + +Message: """{{prompt}}""" +''' + + content = title_generation_template( + template, form_data["prompt"], user.model_dump() + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 4, + "chat_id": form_data.get("chat_id", None), + "task": True, } print(payload) diff --git a/src/app.css b/src/app.css index da1d961e5..baf620845 100644 --- a/src/app.css +++ b/src/app.css @@ -28,6 +28,10 @@ math { @apply rounded-lg; } +.markdown a { + @apply underline; +} + ol > li { counter-increment: list-number; display: block; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index c40815611..9558e98f5 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -205,6 +205,54 @@ export const generateTitle = async ( return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat'; }; +export const generateEmoji = async ( + token: string = '', + model: string, + prompt: string, + chat_id?: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/emoji/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + ...(chat_id && { chat_id: chat_id }) + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + const response = res?.choices[0]?.message?.content.replace(/["']/g, '') ?? null; + + if (response) { + if (/\p{Extended_Pictographic}/u.test(response)) { + return response.match(/\p{Extended_Pictographic}/gu)[0]; + } + } + + return null; +}; + export const generateSearchQuery = async ( token: string = '', model: string, diff --git a/src/lib/apis/memories/index.ts b/src/lib/apis/memories/index.ts index 6cbb89f14..c3c122adf 100644 --- a/src/lib/apis/memories/index.ts +++ b/src/lib/apis/memories/index.ts @@ -3,7 +3,7 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; export const getMemories = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/memories`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/memories/`, { method: 'GET', headers: { Accept: 'application/json', @@ -59,6 +59,37 @@ export const addNewMemory = async (token: string, content: string) => { return res; }; +export const updateMemoryById = async (token: string, id: string, content: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/memories/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + content: content + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const queryMemory = async (token: string, content: string) => { let error = null; diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index ab8996d92..af2dfcdc3 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -13,6 +13,8 @@ getRAGConfig, updateRAGConfig } from '$lib/apis/rag'; + import ResetUploadDirConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; + import ResetVectorDBConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import { documents, models } from '$lib/stores'; import { onMount, getContext } from 'svelte'; @@ -213,6 +215,34 @@ }); + { + const res = resetUploadDir(localStorage.token).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + toast.success($i18n.t('Success')); + } + }} +/> + + { + const res = resetVectorDB(localStorage.token).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + toast.success($i18n.t('Success')); + } + }} +/> +
{ @@ -640,199 +670,56 @@
- {#if showResetUploadDirConfirm} -
-
- - - - - {$i18n.t('Are you sure?')} -
- -
- - -
+ - {/if} +
{$i18n.t('Reset Upload Directory')}
+ - {#if showResetConfirm} -
-
- - - - {$i18n.t('Are you sure?')} -
- -
- - -
+ - {/if} +
{$i18n.t('Reset Vector Storage')}
+
diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index a943c6fb0..57d0be135 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -1,5 +1,10 @@ + { + deleteModelHandler(); + }} +/> +
{#if ollamaEnabled} @@ -763,7 +773,7 @@
+ {:else if webConfig.search.engine === 'tavily'} +
+
+ {$i18n.t('Tavily API Key')} +
+ +
+
+ +
+
+
{/if}
{/if} diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 359056cfd..73b480796 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -30,6 +30,7 @@ import { convertMessagesToHistory, copyToClipboard, + extractSentencesForAudio, promptTemplate, splitStream } from '$lib/utils'; @@ -64,6 +65,8 @@ export let chatIdProp = ''; let loaded = false; + const eventTarget = new EventTarget(); + let stopResponseFlag = false; let autoScroll = true; let processing = ''; @@ -108,7 +111,8 @@ $: if (chatIdProp) { (async () => { - if (await loadChat()) { + console.log(chatIdProp); + if (chatIdProp && (await loadChat())) { await tick(); loaded = true; @@ -123,7 +127,11 @@ onMount(async () => { if (!$chatId) { - await initNewChat(); + chatId.subscribe(async (value) => { + if (!value) { + await initNewChat(); + } + }); } else { if (!($settings.saveChatHistory ?? true)) { await goto('/'); @@ -300,7 +308,7 @@ // Chat functions ////////////////////////// - const submitPrompt = async (userPrompt, _user = null) => { + const submitPrompt = async (userPrompt, { _raw = false } = {}) => { let _responses = []; console.log('submitPrompt', $chatId); @@ -344,7 +352,6 @@ parentId: messages.length !== 0 ? messages.at(-1).id : null, childrenIds: [], role: 'user', - user: _user ?? undefined, content: userPrompt, files: _files.length > 0 ? _files : undefined, timestamp: Math.floor(Date.now() / 1000), // Unix epoch @@ -362,15 +369,13 @@ // Wait until history/message have been updated await tick(); - - // Send prompt - _responses = await sendPrompt(userPrompt, userMessageId); + _responses = await sendPrompt(userPrompt, userMessageId, { newChat: true }); } return _responses; }; - const sendPrompt = async (prompt, parentId, modelId = null, newChat = true) => { + const sendPrompt = async (prompt, parentId, { modelId = null, newChat = false } = {}) => { let _responses = []; // If modelId is provided, use it, else use selected model @@ -490,7 +495,6 @@ responseMessage.userContext = userContext; const chatEventEmitter = await getChatEventEmitter(model.id, _chatId); - if (webSearchEnabled) { await getWebSearchResults(model.id, parentId, responseMessageId); } @@ -503,8 +507,6 @@ } _responses.push(_response); - console.log('chatEventEmitter', chatEventEmitter); - if (chatEventEmitter) clearInterval(chatEventEmitter); } else { toast.error($i18n.t(`Model {{modelId}} not found`, { modelId })); @@ -513,88 +515,9 @@ ); await chats.set(await getChatList(localStorage.token)); - return _responses; }; - const getWebSearchResults = async (model: string, parentId: string, responseId: string) => { - const responseMessage = history.messages[responseId]; - - responseMessage.statusHistory = [ - { - done: false, - action: 'web_search', - description: $i18n.t('Generating search query') - } - ]; - messages = messages; - - const prompt = history.messages[parentId].content; - let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch( - (error) => { - console.log(error); - return prompt; - } - ); - - if (!searchQuery) { - toast.warning($i18n.t('No search query generated')); - responseMessage.statusHistory.push({ - done: true, - error: true, - action: 'web_search', - description: 'No search query generated' - }); - - messages = messages; - } - - responseMessage.statusHistory.push({ - done: false, - action: 'web_search', - description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery }) - }); - messages = messages; - - const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => { - console.log(error); - toast.error(error); - - return null; - }); - - if (results) { - responseMessage.statusHistory.push({ - done: true, - action: 'web_search', - description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }), - query: searchQuery, - urls: results.filenames - }); - - if (responseMessage?.files ?? undefined === undefined) { - responseMessage.files = []; - } - - responseMessage.files.push({ - collection_name: results.collection_name, - name: searchQuery, - type: 'web_search_results', - urls: results.filenames - }); - - messages = messages; - } else { - responseMessage.statusHistory.push({ - done: true, - error: true, - action: 'web_search', - description: 'No search results found' - }); - messages = messages; - } - }; - const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => { let _response = null; @@ -676,6 +599,16 @@ array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index ); + eventTarget.dispatchEvent( + new CustomEvent('chat:start', { + detail: { + id: responseMessageId + } + }) + ); + + await tick(); + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model.id, messages: messagesBody, @@ -745,6 +678,23 @@ continue; } else { responseMessage.content += data.message.content; + + const sentences = extractSentencesForAudio(responseMessage.content); + sentences.pop(); + + // dispatch only last sentence and make sure it hasn't been dispatched before + if ( + sentences.length > 0 && + sentences[sentences.length - 1] !== responseMessage.lastSentence + ) { + responseMessage.lastSentence = sentences[sentences.length - 1]; + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { id: responseMessageId, content: sentences[sentences.length - 1] } + }) + ); + } + messages = messages; } } else { @@ -771,21 +721,13 @@ messages = messages; if ($settings.notificationEnabled && !document.hasFocus()) { - const notification = new Notification( - selectedModelfile - ? `${ - selectedModelfile.title.charAt(0).toUpperCase() + - selectedModelfile.title.slice(1) - }` - : `${model.id}`, - { - body: responseMessage.content, - icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png` - } - ); + const notification = new Notification(`${model.id}`, { + body: responseMessage.content, + icon: `${WEBUI_BASE_URL}/static/favicon.png` + }); } - if ($settings.responseAutoCopy) { + if ($settings?.responseAutoCopy ?? false) { copyToClipboard(responseMessage.content); } @@ -847,6 +789,23 @@ stopResponseFlag = false; await tick(); + let lastSentence = extractSentencesForAudio(responseMessage.content)?.at(-1) ?? ''; + if (lastSentence) { + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { id: responseMessageId, content: lastSentence } + }) + ); + } + eventTarget.dispatchEvent( + new CustomEvent('chat:finish', { + detail: { + id: responseMessageId, + content: responseMessage.content + } + }) + ); + if (autoScroll) { scrollToBottom(); } @@ -887,6 +846,15 @@ scrollToBottom(); + eventTarget.dispatchEvent( + new CustomEvent('chat:start', { + detail: { + id: responseMessageId + } + }) + ); + await tick(); + try { const [res, controller] = await generateOpenAIChatCompletion( localStorage.token, @@ -1007,6 +975,23 @@ continue; } else { responseMessage.content += value; + + const sentences = extractSentencesForAudio(responseMessage.content); + sentences.pop(); + + // dispatch only last sentence and make sure it hasn't been dispatched before + if ( + sentences.length > 0 && + sentences[sentences.length - 1] !== responseMessage.lastSentence + ) { + responseMessage.lastSentence = sentences[sentences.length - 1]; + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { id: responseMessageId, content: sentences[sentences.length - 1] } + }) + ); + } + messages = messages; } @@ -1057,6 +1042,24 @@ stopResponseFlag = false; await tick(); + let lastSentence = extractSentencesForAudio(responseMessage.content)?.at(-1) ?? ''; + if (lastSentence) { + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { id: responseMessageId, content: lastSentence } + }) + ); + } + + eventTarget.dispatchEvent( + new CustomEvent('chat:finish', { + detail: { + id: responseMessageId, + content: responseMessage.content + } + }) + ); + if (autoScroll) { scrollToBottom(); } @@ -1123,9 +1126,12 @@ let userPrompt = userMessage.content; if ((userMessage?.models ?? [...selectedModels]).length == 1) { - await sendPrompt(userPrompt, userMessage.id, undefined, false); + // If user message has only one model selected, sendPrompt automatically selects it for regeneration + await sendPrompt(userPrompt, userMessage.id); } else { - await sendPrompt(userPrompt, userMessage.id, message.model, false); + // If there are multiple models selected, use the model of the response message for regeneration + // e.g. many model chat + await sendPrompt(userPrompt, userMessage.id, { modelId: message.model }); } } }; @@ -1191,6 +1197,84 @@ } }; + const getWebSearchResults = async (model: string, parentId: string, responseId: string) => { + const responseMessage = history.messages[responseId]; + + responseMessage.statusHistory = [ + { + done: false, + action: 'web_search', + description: $i18n.t('Generating search query') + } + ]; + messages = messages; + + const prompt = history.messages[parentId].content; + let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch( + (error) => { + console.log(error); + return prompt; + } + ); + + if (!searchQuery) { + toast.warning($i18n.t('No search query generated')); + responseMessage.statusHistory.push({ + done: true, + error: true, + action: 'web_search', + description: 'No search query generated' + }); + + messages = messages; + } + + responseMessage.statusHistory.push({ + done: false, + action: 'web_search', + description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery }) + }); + messages = messages; + + const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => { + console.log(error); + toast.error(error); + + return null; + }); + + if (results) { + responseMessage.statusHistory.push({ + done: true, + action: 'web_search', + description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }), + query: searchQuery, + urls: results.filenames + }); + + if (responseMessage?.files ?? undefined === undefined) { + responseMessage.files = []; + } + + responseMessage.files.push({ + collection_name: results.collection_name, + name: searchQuery, + type: 'web_search_results', + urls: results.filenames + }); + + messages = messages; + } else { + responseMessage.statusHistory.push({ + done: true, + error: true, + action: 'web_search', + description: 'No search results found' + }); + messages = messages; + } + }; + const getTags = async () => { return await getTagsById(localStorage.token, $chatId).catch(async (error) => { return []; @@ -1206,7 +1290,18 @@ - +