mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 20:38:32 +02:00
Make it impossible to switch to non-image (#2440)
* make it impossible to switch to non-image * revert ports * proper provider support * remove unused imports * minor rename * simplify interface * remove logs
This commit is contained in:
@@ -25,10 +25,8 @@ import { usePopup } from "@/components/admin/connectors/Popup";
|
|||||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
import { getDisplayNameForModel } from "@/lib/hooks";
|
||||||
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
||||||
import { Option } from "@/components/Dropdown";
|
import { Option } from "@/components/Dropdown";
|
||||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
|
||||||
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
|
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
|
||||||
import { useUserGroups } from "@/lib/hooks";
|
import { checkLLMSupportsImageOutput, destructureValue } from "@/lib/llm/utils";
|
||||||
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
|
|
||||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||||
import { checkUserIsNoAuthUser } from "@/lib/user";
|
import { checkUserIsNoAuthUser } from "@/lib/user";
|
||||||
|
|
||||||
@@ -47,7 +45,12 @@ import { FullLLMProvider } from "../configuration/llm/interfaces";
|
|||||||
import CollapsibleSection from "./CollapsibleSection";
|
import CollapsibleSection from "./CollapsibleSection";
|
||||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||||
import { Persona, StarterMessage } from "./interfaces";
|
import { Persona, StarterMessage } from "./interfaces";
|
||||||
import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
|
import {
|
||||||
|
buildFinalPrompt,
|
||||||
|
createPersona,
|
||||||
|
providersContainImageGeneratingSupport,
|
||||||
|
updatePersona,
|
||||||
|
} from "./lib";
|
||||||
import { Popover } from "@/components/popover/Popover";
|
import { Popover } from "@/components/popover/Popover";
|
||||||
import {
|
import {
|
||||||
CameraIcon,
|
CameraIcon,
|
||||||
@@ -167,7 +170,7 @@ export function AssistantEditor({
|
|||||||
const defaultProvider = llmProviders.find(
|
const defaultProvider = llmProviders.find(
|
||||||
(llmProvider) => llmProvider.is_default_provider
|
(llmProvider) => llmProvider.is_default_provider
|
||||||
);
|
);
|
||||||
|
const defaultProviderName = defaultProvider?.provider;
|
||||||
const defaultModelName = defaultProvider?.default_model_name;
|
const defaultModelName = defaultProvider?.default_model_name;
|
||||||
const providerDisplayNameToProviderName = new Map<string, string>();
|
const providerDisplayNameToProviderName = new Map<string, string>();
|
||||||
llmProviders.forEach((llmProvider) => {
|
llmProviders.forEach((llmProvider) => {
|
||||||
@@ -187,10 +190,9 @@ export function AssistantEditor({
|
|||||||
});
|
});
|
||||||
modelOptionsByProvider.set(llmProvider.name, providerOptions);
|
modelOptionsByProvider.set(llmProvider.name, providerOptions);
|
||||||
});
|
});
|
||||||
const providerSupportingImageGenerationExists = llmProviders.some(
|
|
||||||
(provider) =>
|
const providerSupportingImageGenerationExists =
|
||||||
provider.provider === "openai" || provider.provider === "anthropic"
|
providersContainImageGeneratingSupport(llmProviders);
|
||||||
);
|
|
||||||
|
|
||||||
const personaCurrentToolIds =
|
const personaCurrentToolIds =
|
||||||
existingPersona?.tools.map((tool) => tool.id) || [];
|
existingPersona?.tools.map((tool) => tool.id) || [];
|
||||||
@@ -342,7 +344,12 @@ export function AssistantEditor({
|
|||||||
|
|
||||||
if (imageGenerationToolEnabled) {
|
if (imageGenerationToolEnabled) {
|
||||||
if (
|
if (
|
||||||
!checkLLMSupportsImageInput(
|
!checkLLMSupportsImageOutput(
|
||||||
|
providerDisplayNameToProviderName.get(
|
||||||
|
values.llm_model_provider_override || ""
|
||||||
|
) ||
|
||||||
|
defaultProviderName ||
|
||||||
|
"",
|
||||||
values.llm_model_version_override || defaultModelName || ""
|
values.llm_model_version_override || defaultModelName || ""
|
||||||
)
|
)
|
||||||
) {
|
) {
|
||||||
@@ -453,6 +460,15 @@ export function AssistantEditor({
|
|||||||
: false;
|
: false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const currentLLMSupportsImageOutput = checkLLMSupportsImageOutput(
|
||||||
|
providerDisplayNameToProviderName.get(
|
||||||
|
values.llm_model_provider_override || ""
|
||||||
|
) ||
|
||||||
|
defaultProviderName ||
|
||||||
|
"",
|
||||||
|
values.llm_model_version_override || defaultModelName || ""
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Form className="w-full text-text-950">
|
<Form className="w-full text-text-950">
|
||||||
<div className="w-full flex gap-x-2 justify-center">
|
<div className="w-full flex gap-x-2 justify-center">
|
||||||
@@ -757,9 +773,7 @@ export function AssistantEditor({
|
|||||||
<TooltipTrigger asChild>
|
<TooltipTrigger asChild>
|
||||||
<div
|
<div
|
||||||
className={`w-fit ${
|
className={`w-fit ${
|
||||||
!checkLLMSupportsImageInput(
|
!currentLLMSupportsImageOutput
|
||||||
values.llm_model_version_override || ""
|
|
||||||
)
|
|
||||||
? "opacity-70 cursor-not-allowed"
|
? "opacity-70 cursor-not-allowed"
|
||||||
: ""
|
: ""
|
||||||
}`}
|
}`}
|
||||||
@@ -771,17 +785,11 @@ export function AssistantEditor({
|
|||||||
onChange={() => {
|
onChange={() => {
|
||||||
toggleToolInValues(imageGenerationTool.id);
|
toggleToolInValues(imageGenerationTool.id);
|
||||||
}}
|
}}
|
||||||
disabled={
|
disabled={!currentLLMSupportsImageOutput}
|
||||||
!checkLLMSupportsImageInput(
|
|
||||||
values.llm_model_version_override || ""
|
|
||||||
)
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</TooltipTrigger>
|
</TooltipTrigger>
|
||||||
{!checkLLMSupportsImageInput(
|
{!currentLLMSupportsImageOutput && (
|
||||||
values.llm_model_version_override || ""
|
|
||||||
) && (
|
|
||||||
<TooltipContent side="top" align="center">
|
<TooltipContent side="top" align="center">
|
||||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||||
To use Image Generation, select GPT-4o or another
|
To use Image Generation, select GPT-4o or another
|
||||||
@@ -1051,15 +1059,15 @@ export function AssistantEditor({
|
|||||||
<Field
|
<Field
|
||||||
name={`starter_messages[${index}].name`}
|
name={`starter_messages[${index}].name`}
|
||||||
className={`
|
className={`
|
||||||
border
|
border
|
||||||
border-border
|
border-border
|
||||||
bg-background
|
bg-background
|
||||||
rounded
|
rounded
|
||||||
w-full
|
w-full
|
||||||
py-2
|
py-2
|
||||||
px-3
|
px-3
|
||||||
mr-4
|
mr-4
|
||||||
`}
|
`}
|
||||||
autoComplete="off"
|
autoComplete="off"
|
||||||
/>
|
/>
|
||||||
<ErrorMessage
|
<ErrorMessage
|
||||||
@@ -1081,15 +1089,15 @@ export function AssistantEditor({
|
|||||||
<Field
|
<Field
|
||||||
name={`starter_messages.${index}.description`}
|
name={`starter_messages.${index}.description`}
|
||||||
className={`
|
className={`
|
||||||
border
|
border
|
||||||
border-border
|
border-border
|
||||||
bg-background
|
bg-background
|
||||||
rounded
|
rounded
|
||||||
w-full
|
w-full
|
||||||
py-2
|
py-2
|
||||||
px-3
|
px-3
|
||||||
mr-4
|
mr-4
|
||||||
`}
|
`}
|
||||||
autoComplete="off"
|
autoComplete="off"
|
||||||
/>
|
/>
|
||||||
<ErrorMessage
|
<ErrorMessage
|
||||||
@@ -1112,15 +1120,15 @@ export function AssistantEditor({
|
|||||||
<Field
|
<Field
|
||||||
name={`starter_messages[${index}].message`}
|
name={`starter_messages[${index}].message`}
|
||||||
className={`
|
className={`
|
||||||
border
|
border
|
||||||
border-border
|
border-border
|
||||||
bg-background
|
bg-background
|
||||||
rounded
|
rounded
|
||||||
w-full
|
w-full
|
||||||
py-2
|
py-2
|
||||||
px-3
|
px-3
|
||||||
mr-4
|
mr-4
|
||||||
`}
|
`}
|
||||||
as="textarea"
|
as="textarea"
|
||||||
autoComplete="off"
|
autoComplete="off"
|
||||||
/>
|
/>
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||||
import { Persona, Prompt, StarterMessage } from "./interfaces";
|
import { Persona, Prompt, StarterMessage } from "./interfaces";
|
||||||
|
|
||||||
interface PersonaCreationRequest {
|
interface PersonaCreationRequest {
|
||||||
@@ -318,3 +319,18 @@ export function personaComparator(a: Persona, b: Persona) {
|
|||||||
|
|
||||||
return closerToZeroNegativesFirstComparator(a.id, b.id);
|
return closerToZeroNegativesFirstComparator(a.id, b.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function checkPersonaRequiresImageGeneration(persona: Persona) {
|
||||||
|
for (const tool of persona.tools) {
|
||||||
|
if (tool.name === "ImageGenerationTool") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function providersContainImageGeneratingSupport(
|
||||||
|
providers: FullLLMProvider[]
|
||||||
|
) {
|
||||||
|
return providers.some((provider) => provider.provider === "openai");
|
||||||
|
}
|
||||||
|
@@ -549,6 +549,7 @@ export function ChatInputBar({
|
|||||||
tab
|
tab
|
||||||
content={(close, ref) => (
|
content={(close, ref) => (
|
||||||
<LlmTab
|
<LlmTab
|
||||||
|
currentAssistant={alternativeAssistant || selectedAssistant}
|
||||||
openModelSettings={openModelSettings}
|
openModelSettings={openModelSettings}
|
||||||
currentLlm={
|
currentLlm={
|
||||||
llmOverrideManager.llmOverride.modelName ||
|
llmOverrideManager.llmOverride.modelName ||
|
||||||
|
@@ -4,10 +4,15 @@ import React, { forwardRef, useCallback, useState } from "react";
|
|||||||
import { debounce } from "lodash";
|
import { debounce } from "lodash";
|
||||||
import { Text } from "@tremor/react";
|
import { Text } from "@tremor/react";
|
||||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||||
import { destructureValue, structureValue } from "@/lib/llm/utils";
|
import {
|
||||||
|
checkLLMSupportsImageInput,
|
||||||
|
destructureValue,
|
||||||
|
structureValue,
|
||||||
|
} from "@/lib/llm/utils";
|
||||||
import { updateModelOverrideForChatSession } from "../../lib";
|
import { updateModelOverrideForChatSession } from "../../lib";
|
||||||
import { GearIcon } from "@/components/icons/icons";
|
import { GearIcon } from "@/components/icons/icons";
|
||||||
import { LlmList } from "@/components/llm/LLMList";
|
import { LlmList } from "@/components/llm/LLMList";
|
||||||
|
import { checkPersonaRequiresImageGeneration } from "@/app/admin/assistants/lib";
|
||||||
|
|
||||||
interface LlmTabProps {
|
interface LlmTabProps {
|
||||||
llmOverrideManager: LlmOverrideManager;
|
llmOverrideManager: LlmOverrideManager;
|
||||||
@@ -15,13 +20,24 @@ interface LlmTabProps {
|
|||||||
openModelSettings: () => void;
|
openModelSettings: () => void;
|
||||||
chatSessionId?: number;
|
chatSessionId?: number;
|
||||||
close: () => void;
|
close: () => void;
|
||||||
|
currentAssistant: Persona;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
||||||
(
|
(
|
||||||
{ llmOverrideManager, chatSessionId, currentLlm, close, openModelSettings },
|
{
|
||||||
|
llmOverrideManager,
|
||||||
|
chatSessionId,
|
||||||
|
currentLlm,
|
||||||
|
close,
|
||||||
|
openModelSettings,
|
||||||
|
currentAssistant,
|
||||||
|
},
|
||||||
ref
|
ref
|
||||||
) => {
|
) => {
|
||||||
|
const requiresImageGeneration =
|
||||||
|
checkPersonaRequiresImageGeneration(currentAssistant);
|
||||||
|
|
||||||
const { llmProviders } = useChatContext();
|
const { llmProviders } = useChatContext();
|
||||||
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager;
|
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager;
|
||||||
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
|
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
|
||||||
@@ -55,6 +71,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
|||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
<LlmList
|
<LlmList
|
||||||
|
requiresImageGeneration={requiresImageGeneration}
|
||||||
llmProviders={llmProviders}
|
llmProviders={llmProviders}
|
||||||
currentLlm={currentLlm}
|
currentLlm={currentLlm}
|
||||||
onSelect={(value: string | null) => {
|
onSelect={(value: string | null) => {
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import React from "react";
|
import React from "react";
|
||||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
import { getDisplayNameForModel } from "@/lib/hooks";
|
||||||
import { structureValue } from "@/lib/llm/utils";
|
import { checkLLMSupportsImageInput, structureValue } from "@/lib/llm/utils";
|
||||||
import {
|
import {
|
||||||
getProviderIcon,
|
getProviderIcon,
|
||||||
LLMProviderDescriptor,
|
LLMProviderDescriptor,
|
||||||
@@ -13,6 +13,7 @@ interface LlmListProps {
|
|||||||
userDefault?: string | null;
|
userDefault?: string | null;
|
||||||
scrollable?: boolean;
|
scrollable?: boolean;
|
||||||
hideProviderIcon?: boolean;
|
hideProviderIcon?: boolean;
|
||||||
|
requiresImageGeneration?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const LlmList: React.FC<LlmListProps> = ({
|
export const LlmList: React.FC<LlmListProps> = ({
|
||||||
@@ -21,6 +22,7 @@ export const LlmList: React.FC<LlmListProps> = ({
|
|||||||
onSelect,
|
onSelect,
|
||||||
userDefault,
|
userDefault,
|
||||||
scrollable,
|
scrollable,
|
||||||
|
requiresImageGeneration,
|
||||||
}) => {
|
}) => {
|
||||||
const llmOptionsByProvider: {
|
const llmOptionsByProvider: {
|
||||||
[provider: string]: {
|
[provider: string]: {
|
||||||
@@ -76,21 +78,26 @@ export const LlmList: React.FC<LlmListProps> = ({
|
|||||||
User Default (currently {getDisplayNameForModel(userDefault)})
|
User Default (currently {getDisplayNameForModel(userDefault)})
|
||||||
</button>
|
</button>
|
||||||
)}
|
)}
|
||||||
{llmOptions.map(({ name, icon, value }, index) => (
|
|
||||||
<button
|
{llmOptions.map(({ name, icon, value }, index) => {
|
||||||
type="button"
|
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
|
||||||
key={index}
|
return (
|
||||||
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
|
<button
|
||||||
currentLlm == name
|
type="button"
|
||||||
? "bg-background-200"
|
key={index}
|
||||||
: "bg-background hover:bg-background-100"
|
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
|
||||||
} text-left rounded`}
|
currentLlm == name
|
||||||
onClick={() => onSelect(value)}
|
? "bg-background-200"
|
||||||
>
|
: "bg-background hover:bg-background-100"
|
||||||
{icon({ size: 16 })}
|
} text-left rounded`}
|
||||||
{getDisplayNameForModel(name)}
|
onClick={() => onSelect(value)}
|
||||||
</button>
|
>
|
||||||
))}
|
{icon({ size: 16 })}
|
||||||
|
{getDisplayNameForModel(name)}
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
})}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@@ -62,7 +62,7 @@ export function getLLMProviderOverrideForPersona(
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const MODEL_NAMES_SUPPORTING_IMAGES = [
|
const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
|
||||||
"gpt-4o",
|
"gpt-4o",
|
||||||
"gpt-4o-mini",
|
"gpt-4o-mini",
|
||||||
"gpt-4-vision-preview",
|
"gpt-4-vision-preview",
|
||||||
@@ -84,8 +84,31 @@ const MODEL_NAMES_SUPPORTING_IMAGES = [
|
|||||||
];
|
];
|
||||||
|
|
||||||
export function checkLLMSupportsImageInput(model: string) {
|
export function checkLLMSupportsImageInput(model: string) {
|
||||||
return MODEL_NAMES_SUPPORTING_IMAGES.some((modelName) => modelName === model);
|
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
|
||||||
|
(modelName) => modelName === model
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT = [
|
||||||
|
["openai", "gpt-4o"],
|
||||||
|
["openai", "gpt-4o-mini"],
|
||||||
|
["openai", "gpt-4-vision-preview"],
|
||||||
|
["openai", "gpt-4-turbo"],
|
||||||
|
["openai", "gpt-4-1106-vision-preview"],
|
||||||
|
["azure", "gpt-4o"],
|
||||||
|
["azure", "gpt-4o-mini"],
|
||||||
|
["azure", "gpt-4-vision-preview"],
|
||||||
|
["azure", "gpt-4-turbo"],
|
||||||
|
["azure", "gpt-4-1106-vision-preview"],
|
||||||
|
];
|
||||||
|
|
||||||
|
export function checkLLMSupportsImageOutput(provider: string, model: string) {
|
||||||
|
return MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT.some(
|
||||||
|
(modelProvider) =>
|
||||||
|
modelProvider[0] === provider && modelProvider[1] === model
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export const structureValue = (
|
export const structureValue = (
|
||||||
name: string,
|
name: string,
|
||||||
provider: string,
|
provider: string,
|
||||||
|
Reference in New Issue
Block a user