Make it impossible to switch to non-image (#2440)

* make it impossible to switch to non-image

* revert ports

* proper provider support

* remove unused imports

* minor rename

* simplify interface

* remove logs
This commit is contained in:
pablodanswer
2024-09-16 11:35:40 -07:00
committed by GitHub
parent 66cf67d04d
commit 96b98fbc4a
6 changed files with 140 additions and 68 deletions

View File

@@ -25,10 +25,8 @@ import { usePopup } from "@/components/admin/connectors/Popup";
import { getDisplayNameForModel } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import { useUserGroups } from "@/lib/hooks";
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
import { checkLLMSupportsImageOutput, destructureValue } from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { checkUserIsNoAuthUser } from "@/lib/user";
@@ -47,7 +45,12 @@ import { FullLLMProvider } from "../configuration/llm/interfaces";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { Persona, StarterMessage } from "./interfaces";
import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
import {
buildFinalPrompt,
createPersona,
providersContainImageGeneratingSupport,
updatePersona,
} from "./lib";
import { Popover } from "@/components/popover/Popover";
import {
CameraIcon,
@@ -167,7 +170,7 @@ export function AssistantEditor({
const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
);
const defaultProviderName = defaultProvider?.provider;
const defaultModelName = defaultProvider?.default_model_name;
const providerDisplayNameToProviderName = new Map<string, string>();
llmProviders.forEach((llmProvider) => {
@@ -187,10 +190,9 @@ export function AssistantEditor({
});
modelOptionsByProvider.set(llmProvider.name, providerOptions);
});
const providerSupportingImageGenerationExists = llmProviders.some(
(provider) =>
provider.provider === "openai" || provider.provider === "anthropic"
);
const providerSupportingImageGenerationExists =
providersContainImageGeneratingSupport(llmProviders);
const personaCurrentToolIds =
existingPersona?.tools.map((tool) => tool.id) || [];
@@ -342,7 +344,12 @@ export function AssistantEditor({
if (imageGenerationToolEnabled) {
if (
!checkLLMSupportsImageInput(
!checkLLMSupportsImageOutput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) ||
defaultProviderName ||
"",
values.llm_model_version_override || defaultModelName || ""
)
) {
@@ -453,6 +460,15 @@ export function AssistantEditor({
: false;
}
const currentLLMSupportsImageOutput = checkLLMSupportsImageOutput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) ||
defaultProviderName ||
"",
values.llm_model_version_override || defaultModelName || ""
);
return (
<Form className="w-full text-text-950">
<div className="w-full flex gap-x-2 justify-center">
@@ -757,9 +773,7 @@ export function AssistantEditor({
<TooltipTrigger asChild>
<div
className={`w-fit ${
!checkLLMSupportsImageInput(
values.llm_model_version_override || ""
)
!currentLLMSupportsImageOutput
? "opacity-70 cursor-not-allowed"
: ""
}`}
@@ -771,17 +785,11 @@ export function AssistantEditor({
onChange={() => {
toggleToolInValues(imageGenerationTool.id);
}}
disabled={
!checkLLMSupportsImageInput(
values.llm_model_version_override || ""
)
}
disabled={!currentLLMSupportsImageOutput}
/>
</div>
</TooltipTrigger>
{!checkLLMSupportsImageInput(
values.llm_model_version_override || ""
) && (
{!currentLLMSupportsImageOutput && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
To use Image Generation, select GPT-4o or another
@@ -1051,15 +1059,15 @@ export function AssistantEditor({
<Field
name={`starter_messages[${index}].name`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
@@ -1081,15 +1089,15 @@ export function AssistantEditor({
<Field
name={`starter_messages.${index}.description`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
@@ -1112,15 +1120,15 @@ export function AssistantEditor({
<Field
name={`starter_messages[${index}].message`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
as="textarea"
autoComplete="off"
/>

View File

@@ -1,3 +1,4 @@
import { FullLLMProvider } from "../configuration/llm/interfaces";
import { Persona, Prompt, StarterMessage } from "./interfaces";
interface PersonaCreationRequest {
@@ -318,3 +319,18 @@ export function personaComparator(a: Persona, b: Persona) {
return closerToZeroNegativesFirstComparator(a.id, b.id);
}
export function checkPersonaRequiresImageGeneration(persona: Persona) {
for (const tool of persona.tools) {
if (tool.name === "ImageGenerationTool") {
return true;
}
}
return false;
}
export function providersContainImageGeneratingSupport(
providers: FullLLMProvider[]
) {
return providers.some((provider) => provider.provider === "openai");
}

View File

@@ -549,6 +549,7 @@ export function ChatInputBar({
tab
content={(close, ref) => (
<LlmTab
currentAssistant={alternativeAssistant || selectedAssistant}
openModelSettings={openModelSettings}
currentLlm={
llmOverrideManager.llmOverride.modelName ||

View File

@@ -4,10 +4,15 @@ import React, { forwardRef, useCallback, useState } from "react";
import { debounce } from "lodash";
import { Text } from "@tremor/react";
import { Persona } from "@/app/admin/assistants/interfaces";
import { destructureValue, structureValue } from "@/lib/llm/utils";
import {
checkLLMSupportsImageInput,
destructureValue,
structureValue,
} from "@/lib/llm/utils";
import { updateModelOverrideForChatSession } from "../../lib";
import { GearIcon } from "@/components/icons/icons";
import { LlmList } from "@/components/llm/LLMList";
import { checkPersonaRequiresImageGeneration } from "@/app/admin/assistants/lib";
interface LlmTabProps {
llmOverrideManager: LlmOverrideManager;
@@ -15,13 +20,24 @@ interface LlmTabProps {
openModelSettings: () => void;
chatSessionId?: number;
close: () => void;
currentAssistant: Persona;
}
export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
(
{ llmOverrideManager, chatSessionId, currentLlm, close, openModelSettings },
{
llmOverrideManager,
chatSessionId,
currentLlm,
close,
openModelSettings,
currentAssistant,
},
ref
) => {
const requiresImageGeneration =
checkPersonaRequiresImageGeneration(currentAssistant);
const { llmProviders } = useChatContext();
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager;
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
@@ -55,6 +71,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
</button>
</div>
<LlmList
requiresImageGeneration={requiresImageGeneration}
llmProviders={llmProviders}
currentLlm={currentLlm}
onSelect={(value: string | null) => {

View File

@@ -1,6 +1,6 @@
import React from "react";
import { getDisplayNameForModel } from "@/lib/hooks";
import { structureValue } from "@/lib/llm/utils";
import { checkLLMSupportsImageInput, structureValue } from "@/lib/llm/utils";
import {
getProviderIcon,
LLMProviderDescriptor,
@@ -13,6 +13,7 @@ interface LlmListProps {
userDefault?: string | null;
scrollable?: boolean;
hideProviderIcon?: boolean;
requiresImageGeneration?: boolean;
}
export const LlmList: React.FC<LlmListProps> = ({
@@ -21,6 +22,7 @@ export const LlmList: React.FC<LlmListProps> = ({
onSelect,
userDefault,
scrollable,
requiresImageGeneration,
}) => {
const llmOptionsByProvider: {
[provider: string]: {
@@ -76,21 +78,26 @@ export const LlmList: React.FC<LlmListProps> = ({
User Default (currently {getDisplayNameForModel(userDefault)})
</button>
)}
{llmOptions.map(({ name, icon, value }, index) => (
<button
type="button"
key={index}
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
currentLlm == name
? "bg-background-200"
: "bg-background hover:bg-background-100"
} text-left rounded`}
onClick={() => onSelect(value)}
>
{icon({ size: 16 })}
{getDisplayNameForModel(name)}
</button>
))}
{llmOptions.map(({ name, icon, value }, index) => {
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
return (
<button
type="button"
key={index}
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
currentLlm == name
? "bg-background-200"
: "bg-background hover:bg-background-100"
} text-left rounded`}
onClick={() => onSelect(value)}
>
{icon({ size: 16 })}
{getDisplayNameForModel(name)}
</button>
);
}
})}
</div>
);
};

View File

@@ -62,7 +62,7 @@ export function getLLMProviderOverrideForPersona(
return null;
}
const MODEL_NAMES_SUPPORTING_IMAGES = [
const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"gpt-4o",
"gpt-4o-mini",
"gpt-4-vision-preview",
@@ -84,8 +84,31 @@ const MODEL_NAMES_SUPPORTING_IMAGES = [
];
export function checkLLMSupportsImageInput(model: string) {
return MODEL_NAMES_SUPPORTING_IMAGES.some((modelName) => modelName === model);
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
(modelName) => modelName === model
);
}
const MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT = [
["openai", "gpt-4o"],
["openai", "gpt-4o-mini"],
["openai", "gpt-4-vision-preview"],
["openai", "gpt-4-turbo"],
["openai", "gpt-4-1106-vision-preview"],
["azure", "gpt-4o"],
["azure", "gpt-4o-mini"],
["azure", "gpt-4-vision-preview"],
["azure", "gpt-4-turbo"],
["azure", "gpt-4-1106-vision-preview"],
];
export function checkLLMSupportsImageOutput(provider: string, model: string) {
return MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT.some(
(modelProvider) =>
modelProvider[0] === provider && modelProvider[1] === model
);
}
export const structureValue = (
name: string,
provider: string,