Minor update to llm image ability tracking (#2423)

* minor update to llm image ability tracking

* quick robustification
This commit is contained in:
pablodanswer
2024-09-13 10:24:51 -07:00
committed by GitHub
parent 2fe49e5efb
commit 566f44fcd6
3 changed files with 30 additions and 51 deletions

View File

@@ -132,11 +132,6 @@ export function AssistantEditor({
const [isIconDropdownOpen, setIsIconDropdownOpen] = useState(false); const [isIconDropdownOpen, setIsIconDropdownOpen] = useState(false);
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [finalPrompt, setFinalPrompt] = useState<string | null>(""); const [finalPrompt, setFinalPrompt] = useState<string | null>("");
const [finalPromptError, setFinalPromptError] = useState<string>(""); const [finalPromptError, setFinalPromptError] = useState<string>("");
const [removePersonaImage, setRemovePersonaImage] = useState(false); const [removePersonaImage, setRemovePersonaImage] = useState(false);
@@ -172,7 +167,7 @@ export function AssistantEditor({
const defaultProvider = llmProviders.find( const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider (llmProvider) => llmProvider.is_default_provider
); );
const defaultProviderName = defaultProvider?.provider;
const defaultModelName = defaultProvider?.default_model_name; const defaultModelName = defaultProvider?.default_model_name;
const providerDisplayNameToProviderName = new Map<string, string>(); const providerDisplayNameToProviderName = new Map<string, string>();
llmProviders.forEach((llmProvider) => { llmProviders.forEach((llmProvider) => {
@@ -193,7 +188,8 @@ export function AssistantEditor({
modelOptionsByProvider.set(llmProvider.name, providerOptions); modelOptionsByProvider.set(llmProvider.name, providerOptions);
}); });
const providerSupportingImageGenerationExists = llmProviders.some( const providerSupportingImageGenerationExists = llmProviders.some(
(provider) => provider.provider === "openai" (provider) =>
provider.provider === "openai" || provider.provider === "anthropic"
); );
const personaCurrentToolIds = const personaCurrentToolIds =
@@ -347,11 +343,6 @@ export function AssistantEditor({
if (imageGenerationToolEnabled) { if (imageGenerationToolEnabled) {
if ( if (
!checkLLMSupportsImageInput( !checkLLMSupportsImageInput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) ||
defaultProviderName ||
"",
values.llm_model_version_override || defaultModelName || "" values.llm_model_version_override || defaultModelName || ""
) )
) { ) {
@@ -767,9 +758,6 @@ export function AssistantEditor({
<div <div
className={`w-fit ${ className={`w-fit ${
!checkLLMSupportsImageInput( !checkLLMSupportsImageInput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) || "",
values.llm_model_version_override || "" values.llm_model_version_override || ""
) )
? "opacity-70 cursor-not-allowed" ? "opacity-70 cursor-not-allowed"
@@ -785,9 +773,6 @@ export function AssistantEditor({
}} }}
disabled={ disabled={
!checkLLMSupportsImageInput( !checkLLMSupportsImageInput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) || "",
values.llm_model_version_override || "" values.llm_model_version_override || ""
) )
} }
@@ -795,9 +780,6 @@ export function AssistantEditor({
</div> </div>
</TooltipTrigger> </TooltipTrigger>
{!checkLLMSupportsImageInput( {!checkLLMSupportsImageInput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) || "",
values.llm_model_version_override || "" values.llm_model_version_override || ""
) && ( ) && (
<TooltipContent side="top" align="center"> <TooltipContent side="top" align="center">

View File

@@ -1432,13 +1432,13 @@ export function ChatPage({
}; };
const handleImageUpload = (acceptedFiles: File[]) => { const handleImageUpload = (acceptedFiles: File[]) => {
const llmAcceptsImages = checkLLMSupportsImageInput( const [_, llmModel] = getFinalLLM(
...getFinalLLM( llmProviders,
llmProviders, liveAssistant,
liveAssistant, llmOverrideManager.llmOverride
llmOverrideManager.llmOverride
)
); );
const llmAcceptsImages = checkLLMSupportsImageInput(llmModel);
const imageFiles = acceptedFiles.filter((file) => const imageFiles = acceptedFiles.filter((file) =>
file.type.startsWith("image/") file.type.startsWith("image/")
); );

View File

@@ -62,33 +62,30 @@ export function getLLMProviderOverrideForPersona(
return null; return null;
} }
const MODELS_SUPPORTING_IMAGES = [ const MODEL_NAMES_SUPPORTING_IMAGES = [
["openai", "gpt-4o"], "gpt-4o",
["openai", "gpt-4o-mini"], "gpt-4o-mini",
["openai", "gpt-4-vision-preview"], "gpt-4-vision-preview",
["openai", "gpt-4-turbo"], "gpt-4-turbo",
["openai", "gpt-4-1106-vision-preview"], "gpt-4-1106-vision-preview",
["azure", "gpt-4o"], "gpt-4o",
["azure", "gpt-4o-mini"], "gpt-4o-mini",
["azure", "gpt-4-vision-preview"], "gpt-4-vision-preview",
["azure", "gpt-4-turbo"], "gpt-4-turbo",
["azure", "gpt-4-1106-vision-preview"], "gpt-4-1106-vision-preview",
["anthropic", "claude-3-5-sonnet-20240620"], "claude-3-5-sonnet-20240620",
["anthropic", "claude-3-opus-20240229"], "claude-3-opus-20240229",
["anthropic", "claude-3-sonnet-20240229"], "claude-3-sonnet-20240229",
["anthropic", "claude-3-haiku-20240307"], "claude-3-haiku-20240307",
["bedrock", "anthropic.claude-3-opus-20240229-v1:0"], "anthropic.claude-3-opus-20240229-v1:0",
["bedrock", "anthropic.claude-3-sonnet-20240229-v1:0"], "anthropic.claude-3-sonnet-20240229-v1:0",
["bedrock", "anthropic.claude-3-haiku-20240307-v1:0"], "anthropic.claude-3-haiku-20240307-v1:0",
["bedrock", "anthropic.claude-3-5-sonnet-20240620-v1:0"], "anthropic.claude-3-5-sonnet-20240620-v1:0",
]; ];
export function checkLLMSupportsImageInput(provider: string, model: string) { export function checkLLMSupportsImageInput(model: string) {
return MODELS_SUPPORTING_IMAGES.some( return MODEL_NAMES_SUPPORTING_IMAGES.some((modelName) => modelName === model);
([p, m]) => p === provider && m === model
);
} }
export const structureValue = ( export const structureValue = (
name: string, name: string,
provider: string, provider: string,