Better support for image generation capable models (#2725)

This commit is contained in:
Chris Weaver
2024-10-08 12:41:14 -07:00
committed by GitHub
parent aa69fe762b
commit 21a3921790
2 changed files with 28 additions and 37 deletions

View File

@@ -26,7 +26,7 @@ import { getDisplayNameForModel } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown"; import { Option } from "@/components/Dropdown";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences"; import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import { checkLLMSupportsImageOutput, destructureValue } from "@/lib/llm/utils"; import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces"; import { ToolSnapshot } from "@/lib/tools/interfaces";
import { checkUserIsNoAuthUser } from "@/lib/user"; import { checkUserIsNoAuthUser } from "@/lib/user";
@@ -349,12 +349,9 @@ export function AssistantEditor({
if (imageGenerationToolEnabled) { if (imageGenerationToolEnabled) {
if ( if (
!checkLLMSupportsImageOutput( // model must support image input for image generation
providerDisplayNameToProviderName.get( // to work
values.llm_model_provider_override || "" !checkLLMSupportsImageInput(
) ||
defaultProviderName ||
"",
values.llm_model_version_override || defaultModelName || "" values.llm_model_version_override || defaultModelName || ""
) )
) { ) {
@@ -469,12 +466,9 @@ export function AssistantEditor({
: false; : false;
} }
const currentLLMSupportsImageOutput = checkLLMSupportsImageOutput( // model must support image input for image generation
providerDisplayNameToProviderName.get( // to work
values.llm_model_provider_override || "" const currentLLMSupportsImageOutput = checkLLMSupportsImageInput(
) ||
defaultProviderName ||
"",
values.llm_model_version_override || defaultModelName || "" values.llm_model_version_override || defaultModelName || ""
); );

View File

@@ -68,15 +68,17 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"gpt-4-vision-preview", "gpt-4-vision-preview",
"gpt-4-turbo", "gpt-4-turbo",
"gpt-4-1106-vision-preview", "gpt-4-1106-vision-preview",
"gpt-4o", // standard claude names
"gpt-4o-mini",
"gpt-4-vision-preview",
"gpt-4-turbo",
"gpt-4-1106-vision-preview",
"claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
// claude names with AWS Bedrock Suffix
"claude-3-opus-20240229-v1:0",
"claude-3-sonnet-20240229-v1:0",
"claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620-v1:0",
// claude names with full AWS Bedrock names
"anthropic.claude-3-opus-20240229-v1:0", "anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-haiku-20240307-v1:0",
@@ -84,29 +86,24 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
]; ];
export function checkLLMSupportsImageInput(model: string) { export function checkLLMSupportsImageInput(model: string) {
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some( // Original exact match check
const exactMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
(modelName) => modelName === model (modelName) => modelName === model
); );
if (exactMatch) {
return true;
} }
const MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT = [ // Additional check for the last part of the model name
["openai", "gpt-4o"], const modelParts = model.split(/[/.]/);
["openai", "gpt-4o-mini"], const lastPart = modelParts[modelParts.length - 1];
["openai", "gpt-4-vision-preview"],
["openai", "gpt-4-turbo"],
["openai", "gpt-4-1106-vision-preview"],
["azure", "gpt-4o"],
["azure", "gpt-4o-mini"],
["azure", "gpt-4-vision-preview"],
["azure", "gpt-4-turbo"],
["azure", "gpt-4-1106-vision-preview"],
];
export function checkLLMSupportsImageOutput(provider: string, model: string) { return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
return MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT.some( const modelNameParts = modelName.split(/[/.]/);
(modelProvider) => const modelNameLastPart = modelNameParts[modelNameParts.length - 1];
modelProvider[0] === provider && modelProvider[1] === model return modelNameLastPart === lastPart;
); });
} }
export const structureValue = ( export const structureValue = (