mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-29 15:15:03 +02:00
Single source of truth for image capability (#4612)
* Single source of truth for image capability * Update web/src/app/admin/assistants/AssistantEditor.tsx Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Fix tests --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
class ModelConfiguration(BaseModel):
|
||||
class _SimpleModelConfiguration(BaseModel):
|
||||
# Configure model to read from attributes
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -82,7 +82,7 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
model_configurations = [
|
||||
ModelConfiguration.model_validate(model_configuration)
|
||||
_SimpleModelConfiguration.model_validate(model_configuration)
|
||||
for model_configuration in connection.execute(
|
||||
sa.select(
|
||||
model_configuration_table.c.id,
|
||||
|
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -152,6 +153,7 @@ class ModelConfigurationView(BaseModel):
|
||||
name: str
|
||||
is_visible: bool | None = False
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -166,6 +168,10 @@ class ModelConfigurationView(BaseModel):
|
||||
or get_max_input_tokens(
|
||||
model_name=model_configuration_model.name, model_provider=provider_name
|
||||
),
|
||||
supports_image_input=model_supports_image_input(
|
||||
model_name=model_configuration_model.name,
|
||||
model_provider=provider_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,19 +1,18 @@
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests.models import Response
|
||||
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
|
||||
|
||||
|
||||
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
|
||||
"""Utility function to fetch an LLM provider by ID"""
|
||||
response = requests.get(
|
||||
@@ -40,10 +39,10 @@ def assert_response_is_equivalent(
|
||||
|
||||
assert provider_data["default_model_name"] == default_model_name
|
||||
|
||||
def fill_max_input_tokens_if_none(
|
||||
def fill_max_input_tokens_and_supports_image_input(
|
||||
req: ModelConfigurationUpsertRequest,
|
||||
) -> ModelConfigurationUpsertRequest:
|
||||
return ModelConfigurationUpsertRequest(
|
||||
) -> dict[str, Any]:
|
||||
filled_with_max_input_tokens = ModelConfigurationUpsertRequest(
|
||||
name=req.name,
|
||||
is_visible=req.is_visible,
|
||||
max_input_tokens=req.max_input_tokens
|
||||
@@ -51,13 +50,21 @@ def assert_response_is_equivalent(
|
||||
model_name=req.name, model_provider=default_model_name
|
||||
),
|
||||
)
|
||||
return {
|
||||
**filled_with_max_input_tokens.model_dump(),
|
||||
"supports_image_input": model_supports_image_input(
|
||||
req.name, created_provider["provider"]
|
||||
),
|
||||
}
|
||||
|
||||
actual = set(
|
||||
tuple(model_configuration.items())
|
||||
for model_configuration in provider_data["model_configurations"]
|
||||
)
|
||||
expected = set(
|
||||
tuple(fill_max_input_tokens_if_none(model_configuration).dict().items())
|
||||
tuple(
|
||||
fill_max_input_tokens_and_supports_image_input(model_configuration).items()
|
||||
)
|
||||
for model_configuration in model_configurations
|
||||
)
|
||||
assert actual == expected
|
||||
@@ -150,7 +157,7 @@ def test_create_llm_provider(
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": default_model_name,
|
||||
"model_configurations": [
|
||||
model_configuration.dict()
|
||||
model_configuration.model_dump()
|
||||
for model_configuration in model_configurations
|
||||
],
|
||||
"is_public": True,
|
||||
|
@@ -25,8 +25,8 @@ import { getDisplayNameForModel, useLabels } from "@/lib/hooks";
|
||||
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
||||
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
|
||||
import {
|
||||
checkLLMSupportsImageInput,
|
||||
destructureValue,
|
||||
modelSupportsImageInput,
|
||||
structureValue,
|
||||
} from "@/lib/llm/utils";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
@@ -139,6 +139,7 @@ export function AssistantEditor({
|
||||
admin?: boolean;
|
||||
}) {
|
||||
const { refreshAssistants, isImageGenerationAvailable } = useAssistants();
|
||||
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const isAdminPage = searchParams?.get("admin") === "true";
|
||||
@@ -643,7 +644,8 @@ export function AssistantEditor({
|
||||
|
||||
// model must support image input for image generation
|
||||
// to work
|
||||
const currentLLMSupportsImageOutput = checkLLMSupportsImageInput(
|
||||
const currentLLMSupportsImageOutput = modelSupportsImageInput(
|
||||
llmProviders,
|
||||
values.llm_model_version_override || defaultModelName || ""
|
||||
);
|
||||
|
||||
|
@@ -71,6 +71,7 @@ export interface ModelConfiguration {
|
||||
name: string;
|
||||
is_visible: boolean;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
}
|
||||
|
||||
export interface VisionProvider extends LLMProviderView {
|
||||
|
@@ -90,8 +90,8 @@ import { buildFilters } from "@/lib/search/utils";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import Dropzone from "react-dropzone";
|
||||
import {
|
||||
checkLLMSupportsImageInput,
|
||||
getFinalLLM,
|
||||
modelSupportsImageInput,
|
||||
structureValue,
|
||||
} from "@/lib/llm/utils";
|
||||
import { ChatInputBar } from "./input/ChatInputBar";
|
||||
@@ -1952,7 +1952,7 @@ export function ChatPage({
|
||||
liveAssistant,
|
||||
llmManager.currentLlm
|
||||
);
|
||||
const llmAcceptsImages = checkLLMSupportsImageInput(llmModel);
|
||||
const llmAcceptsImages = modelSupportsImageInput(llmProviders, llmModel);
|
||||
|
||||
const imageFiles = acceptedFiles.filter((file) =>
|
||||
file.type.startsWith("image/")
|
||||
|
@@ -6,7 +6,7 @@ import {
|
||||
} from "@/components/ui/popover";
|
||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
||||
import {
|
||||
checkLLMSupportsImageInput,
|
||||
modelSupportsImageInput,
|
||||
destructureValue,
|
||||
structureValue,
|
||||
} from "@/lib/llm/utils";
|
||||
@@ -175,7 +175,10 @@ export default function LLMPopover({
|
||||
>
|
||||
<div className="flex-grow max-h-[300px] default-scrollbar overflow-y-auto">
|
||||
{llmOptions.map(({ name, icon, value }, index) => {
|
||||
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
|
||||
if (
|
||||
!requiresImageGeneration ||
|
||||
modelSupportsImageInput(llmProviders, name)
|
||||
) {
|
||||
return (
|
||||
<button
|
||||
key={index}
|
||||
@@ -206,7 +209,7 @@ export default function LLMPopover({
|
||||
}
|
||||
})()}
|
||||
{llmManager.imageFilesPresent &&
|
||||
!checkLLMSupportsImageInput(name) && (
|
||||
!modelSupportsImageInput(llmProviders, name) && (
|
||||
<TooltipProvider>
|
||||
<Tooltip delayDuration={0}>
|
||||
<TooltipTrigger className="my-auto flex items-center ml-auto">
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import React from "react";
|
||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
||||
import {
|
||||
checkLLMSupportsImageInput,
|
||||
destructureValue,
|
||||
modelSupportsImageInput,
|
||||
structureValue,
|
||||
} from "@/lib/llm/utils";
|
||||
import {
|
||||
@@ -96,7 +96,7 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
|
||||
{llmOptions.map((option) => {
|
||||
if (
|
||||
!requiresImageGeneration ||
|
||||
checkLLMSupportsImageInput(option.name)
|
||||
modelSupportsImageInput(llmProviders, option.name)
|
||||
) {
|
||||
return (
|
||||
<SelectItem key={option.value} value={option.value}>
|
||||
|
@@ -2,7 +2,7 @@ import { fetchSS } from "@/lib/utilsSS";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs";
|
||||
import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS";
|
||||
import { checkLLMSupportsImageInput } from "../llm/utils";
|
||||
import { modelSupportsImageInput } from "../llm/utils";
|
||||
import { filterAssistants } from "../assistants/utils";
|
||||
|
||||
interface AssistantData {
|
||||
@@ -47,7 +47,7 @@ export async function fetchAssistantData(): Promise<AssistantData> {
|
||||
(provider) =>
|
||||
provider.provider === "openai" ||
|
||||
provider.model_configurations.some((modelConfiguration) =>
|
||||
checkLLMSupportsImageInput(modelConfiguration.name)
|
||||
modelSupportsImageInput(llmProviders, modelConfiguration.name)
|
||||
)
|
||||
);
|
||||
|
||||
|
@@ -25,7 +25,6 @@ import {
|
||||
import { hasCompletedWelcomeFlowSS } from "@/components/initialSetup/welcome/WelcomeModalWrapper";
|
||||
import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS";
|
||||
import { NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN } from "../constants";
|
||||
import { checkLLMSupportsImageInput } from "../llm/utils";
|
||||
|
||||
interface FetchChatDataResult {
|
||||
user?: User | null;
|
||||
@@ -173,8 +172,8 @@ export async function fetchSomeChatData(
|
||||
const hasImageCompatibleModel = result.llmProviders?.some(
|
||||
(provider) =>
|
||||
provider.provider === "openai" ||
|
||||
provider.model_configurations.some((modelConfiguration) =>
|
||||
checkLLMSupportsImageInput(modelConfiguration.name)
|
||||
provider.model_configurations.some(
|
||||
(modelConfiguration) => modelConfiguration.supports_image_input
|
||||
)
|
||||
);
|
||||
|
||||
|
@@ -1,5 +1,8 @@
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
LLMProviderDescriptor,
|
||||
ModelConfiguration,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LlmDescriptor } from "@/lib/hooks";
|
||||
|
||||
export function getFinalLLM(
|
||||
@@ -64,94 +67,6 @@ export function getLLMProviderOverrideForPersona(
|
||||
return null;
|
||||
}
|
||||
|
||||
const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
|
||||
"gpt-4o",
|
||||
"gpt-4.1",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-1106-vision-preview",
|
||||
// standard claude names
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
// custom claude names
|
||||
"claude-3.5-sonnet-v2@20241022",
|
||||
"claude-3-7-sonnet@20250219",
|
||||
// claude names with AWS Bedrock Suffix
|
||||
"claude-3-opus-20240229-v1:0",
|
||||
"claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-haiku-20240307-v1:0",
|
||||
"claude-3-5-sonnet-20240620-v1:0",
|
||||
"claude-3-5-sonnet-20241022-v2:0",
|
||||
// claude names with full AWS Bedrock names
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"claude-3.7-sonnet@202502019",
|
||||
"claude-3-7-sonnet-202502019",
|
||||
// google gemini model names
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-pro-001",
|
||||
"gemini-1.5-flash-001",
|
||||
"gemini-1.5-pro-002",
|
||||
"gemini-1.5-flash-002",
|
||||
"gemini-2.0-flash-exp",
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-pro-exp-02-05",
|
||||
// amazon models
|
||||
"amazon.nova-lite@v1",
|
||||
"amazon.nova-pro@v1",
|
||||
// meta models
|
||||
"llama-3.2-90b-vision-instruct",
|
||||
"llama-3.2-11b-vision-instruct",
|
||||
"Llama-3-2-11B-Vision-Instruct-yb",
|
||||
];
|
||||
|
||||
export function checkLLMSupportsImageInput(model: string) {
|
||||
// Original exact match check
|
||||
const exactMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
|
||||
(modelName) => modelName.toLowerCase() === model.toLowerCase()
|
||||
);
|
||||
|
||||
if (exactMatch) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Additional check for the last part of the model name
|
||||
const modelParts = model.split(/[/.]/);
|
||||
const lastPart = modelParts[modelParts.length - 1]?.toLowerCase();
|
||||
|
||||
// Try matching the last part
|
||||
const lastPartMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
|
||||
const modelNameParts = modelName.split(/[/.]/);
|
||||
const modelNameLastPart = modelNameParts[modelNameParts.length - 1];
|
||||
// lastPart is already lowercased above for tiny performance gain
|
||||
return modelNameLastPart?.toLowerCase() === lastPart;
|
||||
});
|
||||
|
||||
if (lastPartMatch) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If no match found, try getting the text after the first slash
|
||||
if (model.includes("/")) {
|
||||
const afterSlash = model.split("/")[1]?.toLowerCase();
|
||||
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) =>
|
||||
modelName.toLowerCase().includes(afterSlash)
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
export const structureValue = (
|
||||
name: string,
|
||||
provider: string,
|
||||
@@ -180,3 +95,48 @@ export const findProviderForModel = (
|
||||
);
|
||||
return provider ? provider.provider : "";
|
||||
};
|
||||
|
||||
export const findModelInModelConfigurations = (
|
||||
modelConfigurations: ModelConfiguration[],
|
||||
modelName: string
|
||||
): ModelConfiguration | null => {
|
||||
return modelConfigurations.find((m) => m.name === modelName) || null;
|
||||
};
|
||||
|
||||
export const findModelConfiguration = (
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
modelName: string,
|
||||
providerName: string | null = null
|
||||
): ModelConfiguration | null => {
|
||||
if (providerName) {
|
||||
const provider = llmProviders.find((p) => p.name === providerName);
|
||||
return provider
|
||||
? findModelInModelConfigurations(provider.model_configurations, modelName)
|
||||
: null;
|
||||
}
|
||||
|
||||
for (const provider of llmProviders) {
|
||||
const modelConfiguration = findModelInModelConfigurations(
|
||||
provider.model_configurations,
|
||||
modelName
|
||||
);
|
||||
if (modelConfiguration) {
|
||||
return modelConfiguration;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export const modelSupportsImageInput = (
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
modelName: string,
|
||||
providerName: string | null = null
|
||||
): boolean => {
|
||||
const modelConfiguration = findModelConfiguration(
|
||||
llmProviders,
|
||||
modelName,
|
||||
providerName
|
||||
);
|
||||
return modelConfiguration?.supports_image_input || false;
|
||||
};
|
||||
|
Reference in New Issue
Block a user