mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
Allow all LLMs for image generation assistants (#3730)
* Allow all LLMs for image generation assistants * ensure pushed * update color + assistant -> model * update prompt * fix silly conditional
This commit is contained in:
@@ -15,6 +15,7 @@ from onyx.llm.models import PreviousMessage
|
|||||||
from onyx.llm.utils import build_content_with_imgs
|
from onyx.llm.utils import build_content_with_imgs
|
||||||
from onyx.llm.utils import check_message_tokens
|
from onyx.llm.utils import check_message_tokens
|
||||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||||
|
from onyx.llm.utils import model_supports_image_input
|
||||||
from onyx.natural_language_processing.utils import get_tokenizer
|
from onyx.natural_language_processing.utils import get_tokenizer
|
||||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||||
@@ -90,6 +91,7 @@ class AnswerPromptBuilder:
|
|||||||
provider_type=llm_config.model_provider,
|
provider_type=llm_config.model_provider,
|
||||||
model_name=llm_config.model_name,
|
model_name=llm_config.model_name,
|
||||||
)
|
)
|
||||||
|
self.llm_config = llm_config
|
||||||
self.llm_tokenizer_encode_func = cast(
|
self.llm_tokenizer_encode_func = cast(
|
||||||
Callable[[str], list[int]], llm_tokenizer.encode
|
Callable[[str], list[int]], llm_tokenizer.encode
|
||||||
)
|
)
|
||||||
@@ -98,12 +100,21 @@ class AnswerPromptBuilder:
|
|||||||
(
|
(
|
||||||
self.message_history,
|
self.message_history,
|
||||||
self.history_token_cnts,
|
self.history_token_cnts,
|
||||||
) = translate_history_to_basemessages(message_history)
|
) = translate_history_to_basemessages(
|
||||||
|
message_history,
|
||||||
|
exclude_images=not model_supports_image_input(
|
||||||
|
self.llm_config.model_name,
|
||||||
|
self.llm_config.model_provider,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||||
self.user_message_and_token_cnt = (
|
self.user_message_and_token_cnt = (
|
||||||
user_message,
|
user_message,
|
||||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
check_message_tokens(
|
||||||
|
user_message,
|
||||||
|
self.llm_tokenizer_encode_func,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||||
|
@@ -11,6 +11,7 @@ from onyx.llm.utils import build_content_with_imgs
|
|||||||
|
|
||||||
def translate_onyx_msg_to_langchain(
|
def translate_onyx_msg_to_langchain(
|
||||||
msg: ChatMessage | PreviousMessage,
|
msg: ChatMessage | PreviousMessage,
|
||||||
|
exclude_images: bool = False,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
files: list[InMemoryChatFile] = []
|
files: list[InMemoryChatFile] = []
|
||||||
|
|
||||||
@@ -18,7 +19,9 @@ def translate_onyx_msg_to_langchain(
|
|||||||
# attached. Just ignore them for now.
|
# attached. Just ignore them for now.
|
||||||
if not isinstance(msg, ChatMessage):
|
if not isinstance(msg, ChatMessage):
|
||||||
files = msg.files
|
files = msg.files
|
||||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
content = build_content_with_imgs(
|
||||||
|
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
|
||||||
|
)
|
||||||
|
|
||||||
if msg.message_type == MessageType.SYSTEM:
|
if msg.message_type == MessageType.SYSTEM:
|
||||||
raise ValueError("System messages are not currently part of history")
|
raise ValueError("System messages are not currently part of history")
|
||||||
@@ -32,9 +35,12 @@ def translate_onyx_msg_to_langchain(
|
|||||||
|
|
||||||
def translate_history_to_basemessages(
|
def translate_history_to_basemessages(
|
||||||
history: list[ChatMessage] | list["PreviousMessage"],
|
history: list[ChatMessage] | list["PreviousMessage"],
|
||||||
|
exclude_images: bool = False,
|
||||||
) -> tuple[list[BaseMessage], list[int]]:
|
) -> tuple[list[BaseMessage], list[int]]:
|
||||||
history_basemessages = [
|
history_basemessages = [
|
||||||
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
|
translate_onyx_msg_to_langchain(msg, exclude_images)
|
||||||
|
for msg in history
|
||||||
|
if msg.token_count != 0
|
||||||
]
|
]
|
||||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||||
return history_basemessages, history_token_counts
|
return history_basemessages, history_token_counts
|
||||||
|
@@ -142,6 +142,7 @@ def build_content_with_imgs(
|
|||||||
img_urls: list[str] | None = None,
|
img_urls: list[str] | None = None,
|
||||||
b64_imgs: list[str] | None = None,
|
b64_imgs: list[str] | None = None,
|
||||||
message_type: MessageType = MessageType.USER,
|
message_type: MessageType = MessageType.USER,
|
||||||
|
exclude_images: bool = False,
|
||||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||||
files = files or []
|
files = files or []
|
||||||
|
|
||||||
@@ -157,7 +158,7 @@ def build_content_with_imgs(
|
|||||||
|
|
||||||
message_main_content = _build_content(message, files)
|
message_main_content = _build_content(message, files)
|
||||||
|
|
||||||
if not img_files and not img_urls:
|
if exclude_images or (not img_files and not img_urls):
|
||||||
return message_main_content
|
return message_main_content
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
@@ -382,9 +383,19 @@ def _strip_colon_from_model_name(model_name: str) -> str:
|
|||||||
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||||
|
|
||||||
|
|
||||||
def _find_model_obj(
|
def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None:
|
||||||
model_map: dict, provider: str, model_names: list[str | None]
|
stripped_model_name = _strip_extra_provider_from_model_name(model_name)
|
||||||
) -> dict | None:
|
|
||||||
|
model_names = [
|
||||||
|
model_name,
|
||||||
|
_strip_extra_provider_from_model_name(model_name),
|
||||||
|
# Remove leading extra provider. Usually for cases where user has a
|
||||||
|
# customer model proxy which appends another prefix
|
||||||
|
# remove :XXXX from the end, if present. Needed for ollama.
|
||||||
|
_strip_colon_from_model_name(model_name),
|
||||||
|
_strip_colon_from_model_name(stripped_model_name),
|
||||||
|
]
|
||||||
|
|
||||||
# Filter out None values and deduplicate model names
|
# Filter out None values and deduplicate model names
|
||||||
filtered_model_names = [name for name in model_names if name]
|
filtered_model_names = [name for name in model_names if name]
|
||||||
|
|
||||||
@@ -417,21 +428,10 @@ def get_llm_max_tokens(
|
|||||||
return GEN_AI_MAX_TOKENS
|
return GEN_AI_MAX_TOKENS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
|
||||||
model_name
|
|
||||||
)
|
|
||||||
model_obj = _find_model_obj(
|
model_obj = _find_model_obj(
|
||||||
model_map,
|
model_map,
|
||||||
model_provider,
|
model_provider,
|
||||||
[
|
|
||||||
model_name,
|
model_name,
|
||||||
# Remove leading extra provider. Usually for cases where user has a
|
|
||||||
# customer model proxy which appends another prefix
|
|
||||||
extra_provider_stripped_model_name,
|
|
||||||
# remove :XXXX from the end, if present. Needed for ollama.
|
|
||||||
_strip_colon_from_model_name(model_name),
|
|
||||||
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
if not model_obj:
|
if not model_obj:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -523,3 +523,23 @@ def get_max_input_tokens(
|
|||||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||||
|
|
||||||
return input_toks
|
return input_toks
|
||||||
|
|
||||||
|
|
||||||
|
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||||
|
model_map = get_model_map()
|
||||||
|
try:
|
||||||
|
model_obj = _find_model_obj(
|
||||||
|
model_map,
|
||||||
|
model_provider,
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
if not model_obj:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No litellm entry found for {model_provider}/{model_name}"
|
||||||
|
)
|
||||||
|
return model_obj.get("supports_vision", False)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to get model object for {model_provider}/{model_name}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
@@ -16,6 +16,7 @@ from onyx.llm.interfaces import LLM
|
|||||||
from onyx.llm.models import PreviousMessage
|
from onyx.llm.models import PreviousMessage
|
||||||
from onyx.llm.utils import build_content_with_imgs
|
from onyx.llm.utils import build_content_with_imgs
|
||||||
from onyx.llm.utils import message_to_string
|
from onyx.llm.utils import message_to_string
|
||||||
|
from onyx.llm.utils import model_supports_image_input
|
||||||
from onyx.prompts.constants import GENERAL_SEP_PAT
|
from onyx.prompts.constants import GENERAL_SEP_PAT
|
||||||
from onyx.tools.message import ToolCallSummary
|
from onyx.tools.message import ToolCallSummary
|
||||||
from onyx.tools.models import ToolResponse
|
from onyx.tools.models import ToolResponse
|
||||||
@@ -316,12 +317,22 @@ class ImageGenerationTool(Tool):
|
|||||||
for img in img_generation_response
|
for img in img_generation_response
|
||||||
if img.image_data is not None
|
if img.image_data is not None
|
||||||
]
|
]
|
||||||
prompt_builder.update_user_prompt(
|
|
||||||
build_image_generation_user_prompt(
|
user_prompt = build_image_generation_user_prompt(
|
||||||
query=prompt_builder.get_user_message_content(),
|
query=prompt_builder.get_user_message_content(),
|
||||||
|
supports_image_input=model_supports_image_input(
|
||||||
|
prompt_builder.llm_config.model_name,
|
||||||
|
prompt_builder.llm_config.model_provider,
|
||||||
|
),
|
||||||
|
prompts=[
|
||||||
|
prompt
|
||||||
|
for response in img_generation_response
|
||||||
|
for prompt in response.revised_prompt
|
||||||
|
],
|
||||||
img_urls=img_urls,
|
img_urls=img_urls,
|
||||||
b64_imgs=b64_imgs,
|
b64_imgs=b64_imgs,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
prompt_builder.update_user_prompt(user_prompt)
|
||||||
|
|
||||||
return prompt_builder
|
return prompt_builder
|
||||||
|
@@ -9,12 +9,24 @@ You have just created the attached images in response to the following query: "{
|
|||||||
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES = """
|
||||||
|
You have generated images based on the following query: "{query}".
|
||||||
|
The prompts used to create these images were: {prompts}
|
||||||
|
|
||||||
|
Describe the two images you generated, summarizing the key elements and content in a sentence or two.
|
||||||
|
Be specific about what was generated and respond as if you have seen them,
|
||||||
|
without including any disclaimers or speculations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_image_generation_user_prompt(
|
def build_image_generation_user_prompt(
|
||||||
query: str,
|
query: str,
|
||||||
|
supports_image_input: bool,
|
||||||
img_urls: list[str] | None = None,
|
img_urls: list[str] | None = None,
|
||||||
b64_imgs: list[str] | None = None,
|
b64_imgs: list[str] | None = None,
|
||||||
|
prompts: list[str] | None = None,
|
||||||
) -> HumanMessage:
|
) -> HumanMessage:
|
||||||
|
if supports_image_input:
|
||||||
return HumanMessage(
|
return HumanMessage(
|
||||||
content=build_content_with_imgs(
|
content=build_content_with_imgs(
|
||||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||||
@@ -22,3 +34,9 @@ def build_image_generation_user_prompt(
|
|||||||
img_urls=img_urls,
|
img_urls=img_urls,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
return HumanMessage(
|
||||||
|
content=IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES.format(
|
||||||
|
query=query, prompts=prompts
|
||||||
|
).strip()
|
||||||
|
)
|
||||||
|
@@ -444,26 +444,10 @@ export function AssistantEditor({
|
|||||||
let enabledTools = Object.keys(values.enabled_tools_map)
|
let enabledTools = Object.keys(values.enabled_tools_map)
|
||||||
.map((toolId) => Number(toolId))
|
.map((toolId) => Number(toolId))
|
||||||
.filter((toolId) => values.enabled_tools_map[toolId]);
|
.filter((toolId) => values.enabled_tools_map[toolId]);
|
||||||
|
|
||||||
const searchToolEnabled = searchTool
|
const searchToolEnabled = searchTool
|
||||||
? enabledTools.includes(searchTool.id)
|
? enabledTools.includes(searchTool.id)
|
||||||
: false;
|
: false;
|
||||||
const imageGenerationToolEnabled = imageGenerationTool
|
|
||||||
? enabledTools.includes(imageGenerationTool.id)
|
|
||||||
: false;
|
|
||||||
|
|
||||||
if (imageGenerationToolEnabled) {
|
|
||||||
if (
|
|
||||||
// model must support image input for image generation
|
|
||||||
// to work
|
|
||||||
!checkLLMSupportsImageInput(
|
|
||||||
values.llm_model_version_override || defaultModelName || ""
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
enabledTools = enabledTools.filter(
|
|
||||||
(toolId) => toolId !== imageGenerationTool!.id
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if disable_retrieval is set, set num_chunks to 0
|
// if disable_retrieval is set, set num_chunks to 0
|
||||||
// to tell the backend to not fetch any documents
|
// to tell the backend to not fetch any documents
|
||||||
@@ -914,25 +898,20 @@ export function AssistantEditor({
|
|||||||
id={`enabled_tools_map.${imageGenerationTool.id}`}
|
id={`enabled_tools_map.${imageGenerationTool.id}`}
|
||||||
name={`enabled_tools_map.${imageGenerationTool.id}`}
|
name={`enabled_tools_map.${imageGenerationTool.id}`}
|
||||||
onCheckedChange={() => {
|
onCheckedChange={() => {
|
||||||
if (
|
if (isImageGenerationAvailable) {
|
||||||
currentLLMSupportsImageOutput &&
|
|
||||||
isImageGenerationAvailable
|
|
||||||
) {
|
|
||||||
toggleToolInValues(
|
toggleToolInValues(
|
||||||
imageGenerationTool.id
|
imageGenerationTool.id
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
className={
|
className={
|
||||||
!currentLLMSupportsImageOutput ||
|
|
||||||
!isImageGenerationAvailable
|
!isImageGenerationAvailable
|
||||||
? "opacity-50 cursor-not-allowed"
|
? "opacity-50 cursor-not-allowed"
|
||||||
: ""
|
: ""
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
</TooltipTrigger>
|
</TooltipTrigger>
|
||||||
{(!currentLLMSupportsImageOutput ||
|
{!isImageGenerationAvailable && (
|
||||||
!isImageGenerationAvailable) && (
|
|
||||||
<TooltipContent side="top" align="center">
|
<TooltipContent side="top" align="center">
|
||||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||||
{!currentLLMSupportsImageOutput
|
{!currentLLMSupportsImageOutput
|
||||||
|
@@ -49,6 +49,7 @@ import {
|
|||||||
useContext,
|
useContext,
|
||||||
useEffect,
|
useEffect,
|
||||||
useLayoutEffect,
|
useLayoutEffect,
|
||||||
|
useMemo,
|
||||||
useRef,
|
useRef,
|
||||||
useState,
|
useState,
|
||||||
} from "react";
|
} from "react";
|
||||||
@@ -1623,7 +1624,7 @@ export function ChatPage({
|
|||||||
setPopup({
|
setPopup({
|
||||||
type: "error",
|
type: "error",
|
||||||
message:
|
message:
|
||||||
"The current Assistant does not support image input. Please select an assistant with Vision support.",
|
"The current model does not support image input. Please select a model with Vision support.",
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -1841,6 +1842,14 @@ export function ChatPage({
|
|||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [messageHistory]);
|
}, [messageHistory]);
|
||||||
|
|
||||||
|
const imageFileInMessageHistory = useMemo(() => {
|
||||||
|
return messageHistory
|
||||||
|
.filter((message) => message.type === "user")
|
||||||
|
.some((message) =>
|
||||||
|
message.files.some((file) => file.type === ChatFileType.IMAGE)
|
||||||
|
);
|
||||||
|
}, [messageHistory]);
|
||||||
|
|
||||||
const currentVisibleRange = visibleRange.get(currentSessionId()) || {
|
const currentVisibleRange = visibleRange.get(currentSessionId()) || {
|
||||||
start: 0,
|
start: 0,
|
||||||
end: 0,
|
end: 0,
|
||||||
@@ -1921,6 +1930,10 @@ export function ChatPage({
|
|||||||
handleSlackChatRedirect();
|
handleSlackChatRedirect();
|
||||||
}, [searchParams, router]);
|
}, [searchParams, router]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
llmOverrideManager.updateImageFilesPresent(imageFileInMessageHistory);
|
||||||
|
}, [imageFileInMessageHistory]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handleKeyDown = (event: KeyboardEvent) => {
|
const handleKeyDown = (event: KeyboardEvent) => {
|
||||||
if (event.metaKey || event.ctrlKey) {
|
if (event.metaKey || event.ctrlKey) {
|
||||||
|
@@ -5,7 +5,6 @@ import {
|
|||||||
PopoverTrigger,
|
PopoverTrigger,
|
||||||
} from "@/components/ui/popover";
|
} from "@/components/ui/popover";
|
||||||
import { ChatInputOption } from "./ChatInputOption";
|
import { ChatInputOption } from "./ChatInputOption";
|
||||||
import { AnthropicSVG } from "@/components/icons/icons";
|
|
||||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
import { getDisplayNameForModel } from "@/lib/hooks";
|
||||||
import {
|
import {
|
||||||
checkLLMSupportsImageInput,
|
checkLLMSupportsImageInput,
|
||||||
@@ -19,6 +18,14 @@ import {
|
|||||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||||
import { LlmOverrideManager } from "@/lib/hooks";
|
import { LlmOverrideManager } from "@/lib/hooks";
|
||||||
|
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipProvider,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/ui/tooltip";
|
||||||
|
import { FiAlertTriangle } from "react-icons/fi";
|
||||||
|
|
||||||
interface LLMPopoverProps {
|
interface LLMPopoverProps {
|
||||||
llmProviders: LLMProviderDescriptor[];
|
llmProviders: LLMProviderDescriptor[];
|
||||||
llmOverrideManager: LlmOverrideManager;
|
llmOverrideManager: LlmOverrideManager;
|
||||||
@@ -139,6 +146,22 @@ export default function LLMPopover({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
})()}
|
})()}
|
||||||
|
{llmOverrideManager.imageFilesPresent &&
|
||||||
|
!checkLLMSupportsImageInput(name) && (
|
||||||
|
<TooltipProvider>
|
||||||
|
<Tooltip delayDuration={0}>
|
||||||
|
<TooltipTrigger className="my-auto flex items-center ml-auto">
|
||||||
|
<FiAlertTriangle className="text-alert" size={16} />
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent>
|
||||||
|
<p className="text-xs">
|
||||||
|
This LLM is not vision-capable and cannot process
|
||||||
|
image files present in your chat session.
|
||||||
|
</p>
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</TooltipProvider>
|
||||||
|
)}
|
||||||
</button>
|
</button>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@@ -1,7 +1,8 @@
|
|||||||
import { useRef, useState } from "react";
|
import { useRef, useState } from "react";
|
||||||
import { cva, type VariantProps } from "class-variance-authority";
|
import { cva, type VariantProps } from "class-variance-authority";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { CheckCircle, XCircle } from "lucide-react";
|
import { Check, CheckCircle, XCircle } from "lucide-react";
|
||||||
|
import { Warning } from "@phosphor-icons/react";
|
||||||
const popupVariants = cva(
|
const popupVariants = cva(
|
||||||
"fixed bottom-4 left-4 p-4 rounded-lg shadow-xl text-white z-[10000] flex items-center space-x-3 transition-all duration-300 ease-in-out",
|
"fixed bottom-4 left-4 p-4 rounded-lg shadow-xl text-white z-[10000] flex items-center space-x-3 transition-all duration-300 ease-in-out",
|
||||||
{
|
{
|
||||||
@@ -26,9 +27,9 @@ export interface PopupSpec extends VariantProps<typeof popupVariants> {
|
|||||||
export const Popup: React.FC<PopupSpec> = ({ message, type }) => (
|
export const Popup: React.FC<PopupSpec> = ({ message, type }) => (
|
||||||
<div className={cn(popupVariants({ type }))}>
|
<div className={cn(popupVariants({ type }))}>
|
||||||
{type === "success" ? (
|
{type === "success" ? (
|
||||||
<CheckCircle className="w-6 h-6 animate-pulse" />
|
<Check className="w-6 h-6" />
|
||||||
) : type === "error" ? (
|
) : type === "error" ? (
|
||||||
<XCircle className="w-6 h-6 animate-pulse" />
|
<Warning className="w-6 h-6 " />
|
||||||
) : type === "info" ? (
|
) : type === "info" ? (
|
||||||
<svg
|
<svg
|
||||||
className="w-6 h-6"
|
className="w-6 h-6"
|
||||||
|
@@ -360,6 +360,8 @@ export interface LlmOverrideManager {
|
|||||||
temperature: number | null;
|
temperature: number | null;
|
||||||
updateTemperature: (temperature: number | null) => void;
|
updateTemperature: (temperature: number | null) => void;
|
||||||
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||||
|
imageFilesPresent: boolean;
|
||||||
|
updateImageFilesPresent: (present: boolean) => void;
|
||||||
}
|
}
|
||||||
export function useLlmOverride(
|
export function useLlmOverride(
|
||||||
llmProviders: LLMProviderDescriptor[],
|
llmProviders: LLMProviderDescriptor[],
|
||||||
@@ -383,6 +385,11 @@ export function useLlmOverride(
|
|||||||
}
|
}
|
||||||
return { name: "", provider: "", modelName: "" };
|
return { name: "", provider: "", modelName: "" };
|
||||||
};
|
};
|
||||||
|
const [imageFilesPresent, setImageFilesPresent] = useState(false);
|
||||||
|
|
||||||
|
const updateImageFilesPresent = (present: boolean) => {
|
||||||
|
setImageFilesPresent(present);
|
||||||
|
};
|
||||||
|
|
||||||
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
|
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
|
||||||
getValidLlmOverride(globalModel)
|
getValidLlmOverride(globalModel)
|
||||||
@@ -447,6 +454,8 @@ export function useLlmOverride(
|
|||||||
setGlobalDefault,
|
setGlobalDefault,
|
||||||
temperature,
|
temperature,
|
||||||
updateTemperature,
|
updateTemperature,
|
||||||
|
imageFilesPresent,
|
||||||
|
updateImageFilesPresent,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user