Make it impossible to switch to non-image (#2440)

* make it impossible to switch to non-image

* revert ports

* proper provider support

* remove unused imports

* minor rename

* simplify interface

* remove logs
This commit is contained in:
pablodanswer
2024-09-16 11:35:40 -07:00
committed by GitHub
parent 66cf67d04d
commit 96b98fbc4a
6 changed files with 140 additions and 68 deletions

View File

@@ -25,10 +25,8 @@ import { usePopup } from "@/components/admin/connectors/Popup";
import { getDisplayNameForModel } from "@/lib/hooks"; import { getDisplayNameForModel } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown"; import { Option } from "@/components/Dropdown";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences"; import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import { useUserGroups } from "@/lib/hooks"; import { checkLLMSupportsImageOutput, destructureValue } from "@/lib/llm/utils";
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces"; import { ToolSnapshot } from "@/lib/tools/interfaces";
import { checkUserIsNoAuthUser } from "@/lib/user"; import { checkUserIsNoAuthUser } from "@/lib/user";
@@ -47,7 +45,12 @@ import { FullLLMProvider } from "../configuration/llm/interfaces";
import CollapsibleSection from "./CollapsibleSection"; import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums"; import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { Persona, StarterMessage } from "./interfaces"; import { Persona, StarterMessage } from "./interfaces";
import { buildFinalPrompt, createPersona, updatePersona } from "./lib"; import {
buildFinalPrompt,
createPersona,
providersContainImageGeneratingSupport,
updatePersona,
} from "./lib";
import { Popover } from "@/components/popover/Popover"; import { Popover } from "@/components/popover/Popover";
import { import {
CameraIcon, CameraIcon,
@@ -167,7 +170,7 @@ export function AssistantEditor({
const defaultProvider = llmProviders.find( const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider (llmProvider) => llmProvider.is_default_provider
); );
const defaultProviderName = defaultProvider?.provider;
const defaultModelName = defaultProvider?.default_model_name; const defaultModelName = defaultProvider?.default_model_name;
const providerDisplayNameToProviderName = new Map<string, string>(); const providerDisplayNameToProviderName = new Map<string, string>();
llmProviders.forEach((llmProvider) => { llmProviders.forEach((llmProvider) => {
@@ -187,10 +190,9 @@ export function AssistantEditor({
}); });
modelOptionsByProvider.set(llmProvider.name, providerOptions); modelOptionsByProvider.set(llmProvider.name, providerOptions);
}); });
const providerSupportingImageGenerationExists = llmProviders.some(
(provider) => const providerSupportingImageGenerationExists =
provider.provider === "openai" || provider.provider === "anthropic" providersContainImageGeneratingSupport(llmProviders);
);
const personaCurrentToolIds = const personaCurrentToolIds =
existingPersona?.tools.map((tool) => tool.id) || []; existingPersona?.tools.map((tool) => tool.id) || [];
@@ -342,7 +344,12 @@ export function AssistantEditor({
if (imageGenerationToolEnabled) { if (imageGenerationToolEnabled) {
if ( if (
!checkLLMSupportsImageInput( !checkLLMSupportsImageOutput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) ||
defaultProviderName ||
"",
values.llm_model_version_override || defaultModelName || "" values.llm_model_version_override || defaultModelName || ""
) )
) { ) {
@@ -453,6 +460,15 @@ export function AssistantEditor({
: false; : false;
} }
const currentLLMSupportsImageOutput = checkLLMSupportsImageOutput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) ||
defaultProviderName ||
"",
values.llm_model_version_override || defaultModelName || ""
);
return ( return (
<Form className="w-full text-text-950"> <Form className="w-full text-text-950">
<div className="w-full flex gap-x-2 justify-center"> <div className="w-full flex gap-x-2 justify-center">
@@ -757,9 +773,7 @@ export function AssistantEditor({
<TooltipTrigger asChild> <TooltipTrigger asChild>
<div <div
className={`w-fit ${ className={`w-fit ${
!checkLLMSupportsImageInput( !currentLLMSupportsImageOutput
values.llm_model_version_override || ""
)
? "opacity-70 cursor-not-allowed" ? "opacity-70 cursor-not-allowed"
: "" : ""
}`} }`}
@@ -771,17 +785,11 @@ export function AssistantEditor({
onChange={() => { onChange={() => {
toggleToolInValues(imageGenerationTool.id); toggleToolInValues(imageGenerationTool.id);
}} }}
disabled={ disabled={!currentLLMSupportsImageOutput}
!checkLLMSupportsImageInput(
values.llm_model_version_override || ""
)
}
/> />
</div> </div>
</TooltipTrigger> </TooltipTrigger>
{!checkLLMSupportsImageInput( {!currentLLMSupportsImageOutput && (
values.llm_model_version_override || ""
) && (
<TooltipContent side="top" align="center"> <TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white"> <p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
To use Image Generation, select GPT-4o or another To use Image Generation, select GPT-4o or another
@@ -1051,15 +1059,15 @@ export function AssistantEditor({
<Field <Field
name={`starter_messages[${index}].name`} name={`starter_messages[${index}].name`}
className={` className={`
border border
border-border border-border
bg-background bg-background
rounded rounded
w-full w-full
py-2 py-2
px-3 px-3
mr-4 mr-4
`} `}
autoComplete="off" autoComplete="off"
/> />
<ErrorMessage <ErrorMessage
@@ -1081,15 +1089,15 @@ export function AssistantEditor({
<Field <Field
name={`starter_messages.${index}.description`} name={`starter_messages.${index}.description`}
className={` className={`
border border
border-border border-border
bg-background bg-background
rounded rounded
w-full w-full
py-2 py-2
px-3 px-3
mr-4 mr-4
`} `}
autoComplete="off" autoComplete="off"
/> />
<ErrorMessage <ErrorMessage
@@ -1112,15 +1120,15 @@ export function AssistantEditor({
<Field <Field
name={`starter_messages[${index}].message`} name={`starter_messages[${index}].message`}
className={` className={`
border border
border-border border-border
bg-background bg-background
rounded rounded
w-full w-full
py-2 py-2
px-3 px-3
mr-4 mr-4
`} `}
as="textarea" as="textarea"
autoComplete="off" autoComplete="off"
/> />

View File

@@ -1,3 +1,4 @@
import { FullLLMProvider } from "../configuration/llm/interfaces";
import { Persona, Prompt, StarterMessage } from "./interfaces"; import { Persona, Prompt, StarterMessage } from "./interfaces";
interface PersonaCreationRequest { interface PersonaCreationRequest {
@@ -318,3 +319,18 @@ export function personaComparator(a: Persona, b: Persona) {
return closerToZeroNegativesFirstComparator(a.id, b.id); return closerToZeroNegativesFirstComparator(a.id, b.id);
} }
export function checkPersonaRequiresImageGeneration(persona: Persona) {
for (const tool of persona.tools) {
if (tool.name === "ImageGenerationTool") {
return true;
}
}
return false;
}
export function providersContainImageGeneratingSupport(
providers: FullLLMProvider[]
) {
return providers.some((provider) => provider.provider === "openai");
}

View File

@@ -549,6 +549,7 @@ export function ChatInputBar({
tab tab
content={(close, ref) => ( content={(close, ref) => (
<LlmTab <LlmTab
currentAssistant={alternativeAssistant || selectedAssistant}
openModelSettings={openModelSettings} openModelSettings={openModelSettings}
currentLlm={ currentLlm={
llmOverrideManager.llmOverride.modelName || llmOverrideManager.llmOverride.modelName ||

View File

@@ -4,10 +4,15 @@ import React, { forwardRef, useCallback, useState } from "react";
import { debounce } from "lodash"; import { debounce } from "lodash";
import { Text } from "@tremor/react"; import { Text } from "@tremor/react";
import { Persona } from "@/app/admin/assistants/interfaces"; import { Persona } from "@/app/admin/assistants/interfaces";
import { destructureValue, structureValue } from "@/lib/llm/utils"; import {
checkLLMSupportsImageInput,
destructureValue,
structureValue,
} from "@/lib/llm/utils";
import { updateModelOverrideForChatSession } from "../../lib"; import { updateModelOverrideForChatSession } from "../../lib";
import { GearIcon } from "@/components/icons/icons"; import { GearIcon } from "@/components/icons/icons";
import { LlmList } from "@/components/llm/LLMList"; import { LlmList } from "@/components/llm/LLMList";
import { checkPersonaRequiresImageGeneration } from "@/app/admin/assistants/lib";
interface LlmTabProps { interface LlmTabProps {
llmOverrideManager: LlmOverrideManager; llmOverrideManager: LlmOverrideManager;
@@ -15,13 +20,24 @@ interface LlmTabProps {
openModelSettings: () => void; openModelSettings: () => void;
chatSessionId?: number; chatSessionId?: number;
close: () => void; close: () => void;
currentAssistant: Persona;
} }
export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>( export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
( (
{ llmOverrideManager, chatSessionId, currentLlm, close, openModelSettings }, {
llmOverrideManager,
chatSessionId,
currentLlm,
close,
openModelSettings,
currentAssistant,
},
ref ref
) => { ) => {
const requiresImageGeneration =
checkPersonaRequiresImageGeneration(currentAssistant);
const { llmProviders } = useChatContext(); const { llmProviders } = useChatContext();
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager; const { setLlmOverride, temperature, setTemperature } = llmOverrideManager;
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false); const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
@@ -55,6 +71,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
</button> </button>
</div> </div>
<LlmList <LlmList
requiresImageGeneration={requiresImageGeneration}
llmProviders={llmProviders} llmProviders={llmProviders}
currentLlm={currentLlm} currentLlm={currentLlm}
onSelect={(value: string | null) => { onSelect={(value: string | null) => {

View File

@@ -1,6 +1,6 @@
import React from "react"; import React from "react";
import { getDisplayNameForModel } from "@/lib/hooks"; import { getDisplayNameForModel } from "@/lib/hooks";
import { structureValue } from "@/lib/llm/utils"; import { checkLLMSupportsImageInput, structureValue } from "@/lib/llm/utils";
import { import {
getProviderIcon, getProviderIcon,
LLMProviderDescriptor, LLMProviderDescriptor,
@@ -13,6 +13,7 @@ interface LlmListProps {
userDefault?: string | null; userDefault?: string | null;
scrollable?: boolean; scrollable?: boolean;
hideProviderIcon?: boolean; hideProviderIcon?: boolean;
requiresImageGeneration?: boolean;
} }
export const LlmList: React.FC<LlmListProps> = ({ export const LlmList: React.FC<LlmListProps> = ({
@@ -21,6 +22,7 @@ export const LlmList: React.FC<LlmListProps> = ({
onSelect, onSelect,
userDefault, userDefault,
scrollable, scrollable,
requiresImageGeneration,
}) => { }) => {
const llmOptionsByProvider: { const llmOptionsByProvider: {
[provider: string]: { [provider: string]: {
@@ -76,21 +78,26 @@ export const LlmList: React.FC<LlmListProps> = ({
User Default (currently {getDisplayNameForModel(userDefault)}) User Default (currently {getDisplayNameForModel(userDefault)})
</button> </button>
)} )}
{llmOptions.map(({ name, icon, value }, index) => (
<button {llmOptions.map(({ name, icon, value }, index) => {
type="button" if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
key={index} return (
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${ <button
currentLlm == name type="button"
? "bg-background-200" key={index}
: "bg-background hover:bg-background-100" className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
} text-left rounded`} currentLlm == name
onClick={() => onSelect(value)} ? "bg-background-200"
> : "bg-background hover:bg-background-100"
{icon({ size: 16 })} } text-left rounded`}
{getDisplayNameForModel(name)} onClick={() => onSelect(value)}
</button> >
))} {icon({ size: 16 })}
{getDisplayNameForModel(name)}
</button>
);
}
})}
</div> </div>
); );
}; };

View File

@@ -62,7 +62,7 @@ export function getLLMProviderOverrideForPersona(
return null; return null;
} }
const MODEL_NAMES_SUPPORTING_IMAGES = [ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"gpt-4o", "gpt-4o",
"gpt-4o-mini", "gpt-4o-mini",
"gpt-4-vision-preview", "gpt-4-vision-preview",
@@ -84,8 +84,31 @@ const MODEL_NAMES_SUPPORTING_IMAGES = [
]; ];
export function checkLLMSupportsImageInput(model: string) { export function checkLLMSupportsImageInput(model: string) {
return MODEL_NAMES_SUPPORTING_IMAGES.some((modelName) => modelName === model); return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
(modelName) => modelName === model
);
} }
const MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT = [
["openai", "gpt-4o"],
["openai", "gpt-4o-mini"],
["openai", "gpt-4-vision-preview"],
["openai", "gpt-4-turbo"],
["openai", "gpt-4-1106-vision-preview"],
["azure", "gpt-4o"],
["azure", "gpt-4o-mini"],
["azure", "gpt-4-vision-preview"],
["azure", "gpt-4-turbo"],
["azure", "gpt-4-1106-vision-preview"],
];
export function checkLLMSupportsImageOutput(provider: string, model: string) {
return MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT.some(
(modelProvider) =>
modelProvider[0] === provider && modelProvider[1] === model
);
}
export const structureValue = ( export const structureValue = (
name: string, name: string,
provider: string, provider: string,