diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 5aa7ddc3cffa..821949a5b793 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -25,10 +25,8 @@ import { usePopup } from "@/components/admin/connectors/Popup"; import { getDisplayNameForModel } from "@/lib/hooks"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import { Option } from "@/components/Dropdown"; -import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences"; -import { useUserGroups } from "@/lib/hooks"; -import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils"; +import { checkLLMSupportsImageOutput, destructureValue } from "@/lib/llm/utils"; import { ToolSnapshot } from "@/lib/tools/interfaces"; import { checkUserIsNoAuthUser } from "@/lib/user"; @@ -47,7 +45,12 @@ import { FullLLMProvider } from "../configuration/llm/interfaces"; import CollapsibleSection from "./CollapsibleSection"; import { SuccessfulPersonaUpdateRedirectType } from "./enums"; import { Persona, StarterMessage } from "./interfaces"; -import { buildFinalPrompt, createPersona, updatePersona } from "./lib"; +import { + buildFinalPrompt, + createPersona, + providersContainImageGeneratingSupport, + updatePersona, +} from "./lib"; import { Popover } from "@/components/popover/Popover"; import { CameraIcon, @@ -167,7 +170,7 @@ export function AssistantEditor({ const defaultProvider = llmProviders.find( (llmProvider) => llmProvider.is_default_provider ); - + const defaultProviderName = defaultProvider?.provider; const defaultModelName = defaultProvider?.default_model_name; const providerDisplayNameToProviderName = new Map(); llmProviders.forEach((llmProvider) => { @@ -187,10 +190,9 @@ export function AssistantEditor({ }); modelOptionsByProvider.set(llmProvider.name, providerOptions); }); - const providerSupportingImageGenerationExists = llmProviders.some( - (provider) => - provider.provider === "openai" || provider.provider === "anthropic" - ); + + const providerSupportingImageGenerationExists = + providersContainImageGeneratingSupport(llmProviders); const personaCurrentToolIds = existingPersona?.tools.map((tool) => tool.id) || []; @@ -342,7 +344,12 @@ export function AssistantEditor({ if (imageGenerationToolEnabled) { if ( - !checkLLMSupportsImageInput( + !checkLLMSupportsImageOutput( + providerDisplayNameToProviderName.get( + values.llm_model_provider_override || "" + ) || + defaultProviderName || + "", values.llm_model_version_override || defaultModelName || "" ) ) { @@ -453,6 +460,15 @@ export function AssistantEditor({ : false; } + const currentLLMSupportsImageOutput = checkLLMSupportsImageOutput( + providerDisplayNameToProviderName.get( + values.llm_model_provider_override || "" + ) || + defaultProviderName || + "", + values.llm_model_version_override || defaultModelName || "" + ); + return (
@@ -757,9 +773,7 @@ export function AssistantEditor({
{ toggleToolInValues(imageGenerationTool.id); }} - disabled={ - !checkLLMSupportsImageInput( - values.llm_model_version_override || "" - ) - } + disabled={!currentLLMSupportsImageOutput} />
- {!checkLLMSupportsImageInput( - values.llm_model_version_override || "" - ) && ( + {!currentLLMSupportsImageOutput && (

To use Image Generation, select GPT-4o or another @@ -1051,15 +1059,15 @@ export function AssistantEditor({ diff --git a/web/src/app/admin/assistants/lib.ts b/web/src/app/admin/assistants/lib.ts index 613f98145f17..7b5d5b21dbe5 100644 --- a/web/src/app/admin/assistants/lib.ts +++ b/web/src/app/admin/assistants/lib.ts @@ -1,3 +1,4 @@ +import { FullLLMProvider } from "../configuration/llm/interfaces"; import { Persona, Prompt, StarterMessage } from "./interfaces"; interface PersonaCreationRequest { @@ -318,3 +319,18 @@ export function personaComparator(a: Persona, b: Persona) { return closerToZeroNegativesFirstComparator(a.id, b.id); } + +export function checkPersonaRequiresImageGeneration(persona: Persona) { + for (const tool of persona.tools) { + if (tool.name === "ImageGenerationTool") { + return true; + } + } + return false; +} + +export function providersContainImageGeneratingSupport( + providers: FullLLMProvider[] +) { + return providers.some((provider) => provider.provider === "openai"); +} diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index 6ea2ce868a58..77b85509dbc4 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -549,6 +549,7 @@ export function ChatInputBar({ tab content={(close, ref) => ( void; chatSessionId?: number; close: () => void; + currentAssistant: Persona; } export const LlmTab = forwardRef( ( - { llmOverrideManager, chatSessionId, currentLlm, close, openModelSettings }, + { + llmOverrideManager, + chatSessionId, + currentLlm, + close, + openModelSettings, + currentAssistant, + }, ref ) => { + const requiresImageGeneration = + checkPersonaRequiresImageGeneration(currentAssistant); + const { llmProviders } = useChatContext(); const { setLlmOverride, temperature, setTemperature } = llmOverrideManager; const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false); @@ -55,6 +71,7 @@ export const LlmTab = forwardRef(

{ diff --git a/web/src/components/llm/LLMList.tsx b/web/src/components/llm/LLMList.tsx index 191a89485800..ca2d3f577a18 100644 --- a/web/src/components/llm/LLMList.tsx +++ b/web/src/components/llm/LLMList.tsx @@ -1,6 +1,6 @@ import React from "react"; import { getDisplayNameForModel } from "@/lib/hooks"; -import { structureValue } from "@/lib/llm/utils"; +import { checkLLMSupportsImageInput, structureValue } from "@/lib/llm/utils"; import { getProviderIcon, LLMProviderDescriptor, @@ -13,6 +13,7 @@ interface LlmListProps { userDefault?: string | null; scrollable?: boolean; hideProviderIcon?: boolean; + requiresImageGeneration?: boolean; } export const LlmList: React.FC = ({ @@ -21,6 +22,7 @@ export const LlmList: React.FC = ({ onSelect, userDefault, scrollable, + requiresImageGeneration, }) => { const llmOptionsByProvider: { [provider: string]: { @@ -76,21 +78,26 @@ export const LlmList: React.FC = ({ User Default (currently {getDisplayNameForModel(userDefault)}) )} - {llmOptions.map(({ name, icon, value }, index) => ( - - ))} + + {llmOptions.map(({ name, icon, value }, index) => { + if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) { + return ( + + ); + } + })} ); }; diff --git a/web/src/lib/llm/utils.ts b/web/src/lib/llm/utils.ts index 92e75cf4664b..b020854d4fe4 100644 --- a/web/src/lib/llm/utils.ts +++ b/web/src/lib/llm/utils.ts @@ -62,7 +62,7 @@ export function getLLMProviderOverrideForPersona( return null; } -const MODEL_NAMES_SUPPORTING_IMAGES = [ +const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [ "gpt-4o", "gpt-4o-mini", "gpt-4-vision-preview", @@ -84,8 +84,31 @@ const MODEL_NAMES_SUPPORTING_IMAGES = [ ]; export function checkLLMSupportsImageInput(model: string) { - return MODEL_NAMES_SUPPORTING_IMAGES.some((modelName) => modelName === model); + return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some( + (modelName) => modelName === model + ); } + +const MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT = [ + ["openai", "gpt-4o"], + ["openai", "gpt-4o-mini"], + ["openai", "gpt-4-vision-preview"], + ["openai", "gpt-4-turbo"], + ["openai", "gpt-4-1106-vision-preview"], + ["azure", "gpt-4o"], + ["azure", "gpt-4o-mini"], + ["azure", "gpt-4-vision-preview"], + ["azure", "gpt-4-turbo"], + ["azure", "gpt-4-1106-vision-preview"], +]; + +export function checkLLMSupportsImageOutput(provider: string, model: string) { + return MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT.some( + (modelProvider) => + modelProvider[0] === provider && modelProvider[1] === model + ); +} + export const structureValue = ( name: string, provider: string,