Single source of truth for image capability (#4612)

* Single source of truth for image capability

* Update web/src/app/admin/assistants/AssistantEditor.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Fix tests

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
This commit is contained in:
Chris Weaver
2025-04-25 13:37:16 -07:00
committed by GitHub
parent ad76e6ac9e
commit 23c6e0f3bf
11 changed files with 91 additions and 113 deletions

View File

@@ -24,7 +24,7 @@ branch_labels = None
depends_on = None
class ModelConfiguration(BaseModel):
class _SimpleModelConfiguration(BaseModel):
# Configure model to read from attributes
model_config = ConfigDict(from_attributes=True)
@@ -82,7 +82,7 @@ def upgrade() -> None:
)
model_configurations = [
ModelConfiguration.model_validate(model_configuration)
_SimpleModelConfiguration.model_validate(model_configuration)
for model_configuration in connection.execute(
sa.select(
model_configuration_table.c.id,

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel
from pydantic import Field
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import model_supports_image_input
if TYPE_CHECKING:
@@ -152,6 +153,7 @@ class ModelConfigurationView(BaseModel):
name: str
is_visible: bool | None = False
max_input_tokens: int | None = None
supports_image_input: bool
@classmethod
def from_model(
@@ -166,6 +168,10 @@ class ModelConfigurationView(BaseModel):
or get_max_input_tokens(
model_name=model_configuration_model.name, model_provider=provider_name
),
supports_image_input=model_supports_image_input(
model_name=model_configuration_model.name,
model_provider=provider_name,
),
)

View File

@@ -1,19 +1,18 @@
import uuid
from typing import Any
import pytest
import requests
from requests.models import Response
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
"""Utility function to fetch an LLM provider by ID"""
response = requests.get(
@@ -40,10 +39,10 @@ def assert_response_is_equivalent(
assert provider_data["default_model_name"] == default_model_name
def fill_max_input_tokens_if_none(
def fill_max_input_tokens_and_supports_image_input(
req: ModelConfigurationUpsertRequest,
) -> ModelConfigurationUpsertRequest:
return ModelConfigurationUpsertRequest(
) -> dict[str, Any]:
filled_with_max_input_tokens = ModelConfigurationUpsertRequest(
name=req.name,
is_visible=req.is_visible,
max_input_tokens=req.max_input_tokens
@@ -51,13 +50,21 @@ def assert_response_is_equivalent(
model_name=req.name, model_provider=default_model_name
),
)
return {
**filled_with_max_input_tokens.model_dump(),
"supports_image_input": model_supports_image_input(
req.name, created_provider["provider"]
),
}
actual = set(
tuple(model_configuration.items())
for model_configuration in provider_data["model_configurations"]
)
expected = set(
tuple(fill_max_input_tokens_if_none(model_configuration).dict().items())
tuple(
fill_max_input_tokens_and_supports_image_input(model_configuration).items()
)
for model_configuration in model_configurations
)
assert actual == expected
@@ -150,7 +157,7 @@ def test_create_llm_provider(
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": default_model_name,
"model_configurations": [
model_configuration.dict()
model_configuration.model_dump()
for model_configuration in model_configurations
],
"is_public": True,

View File

@@ -25,8 +25,8 @@ import { getDisplayNameForModel, useLabels } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import {
checkLLMSupportsImageInput,
destructureValue,
modelSupportsImageInput,
structureValue,
} from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces";
@@ -139,6 +139,7 @@ export function AssistantEditor({
admin?: boolean;
}) {
const { refreshAssistants, isImageGenerationAvailable } = useAssistants();
const router = useRouter();
const searchParams = useSearchParams();
const isAdminPage = searchParams?.get("admin") === "true";
@@ -643,7 +644,8 @@ export function AssistantEditor({
// model must support image input for image generation
// to work
const currentLLMSupportsImageOutput = checkLLMSupportsImageInput(
const currentLLMSupportsImageOutput = modelSupportsImageInput(
llmProviders,
values.llm_model_version_override || defaultModelName || ""
);

View File

@@ -71,6 +71,7 @@ export interface ModelConfiguration {
name: string;
is_visible: boolean;
max_input_tokens: number | null;
supports_image_input: boolean;
}
export interface VisionProvider extends LLMProviderView {

View File

@@ -90,8 +90,8 @@ import { buildFilters } from "@/lib/search/utils";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import {
checkLLMSupportsImageInput,
getFinalLLM,
modelSupportsImageInput,
structureValue,
} from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar";
@@ -1952,7 +1952,7 @@ export function ChatPage({
liveAssistant,
llmManager.currentLlm
);
const llmAcceptsImages = checkLLMSupportsImageInput(llmModel);
const llmAcceptsImages = modelSupportsImageInput(llmProviders, llmModel);
const imageFiles = acceptedFiles.filter((file) =>
file.type.startsWith("image/")

View File

@@ -6,7 +6,7 @@ import {
} from "@/components/ui/popover";
import { getDisplayNameForModel } from "@/lib/hooks";
import {
checkLLMSupportsImageInput,
modelSupportsImageInput,
destructureValue,
structureValue,
} from "@/lib/llm/utils";
@@ -175,7 +175,10 @@ export default function LLMPopover({
>
<div className="flex-grow max-h-[300px] default-scrollbar overflow-y-auto">
{llmOptions.map(({ name, icon, value }, index) => {
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
if (
!requiresImageGeneration ||
modelSupportsImageInput(llmProviders, name)
) {
return (
<button
key={index}
@@ -206,7 +209,7 @@ export default function LLMPopover({
}
})()}
{llmManager.imageFilesPresent &&
!checkLLMSupportsImageInput(name) && (
!modelSupportsImageInput(llmProviders, name) && (
<TooltipProvider>
<Tooltip delayDuration={0}>
<TooltipTrigger className="my-auto flex items-center ml-auto">

View File

@@ -1,8 +1,8 @@
import React from "react";
import { getDisplayNameForModel } from "@/lib/hooks";
import {
checkLLMSupportsImageInput,
destructureValue,
modelSupportsImageInput,
structureValue,
} from "@/lib/llm/utils";
import {
@@ -96,7 +96,7 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
{llmOptions.map((option) => {
if (
!requiresImageGeneration ||
checkLLMSupportsImageInput(option.name)
modelSupportsImageInput(llmProviders, option.name)
) {
return (
<SelectItem key={option.value} value={option.value}>

View File

@@ -2,7 +2,7 @@ import { fetchSS } from "@/lib/utilsSS";
import { Persona } from "@/app/admin/assistants/interfaces";
import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs";
import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS";
import { checkLLMSupportsImageInput } from "../llm/utils";
import { modelSupportsImageInput } from "../llm/utils";
import { filterAssistants } from "../assistants/utils";
interface AssistantData {
@@ -47,7 +47,7 @@ export async function fetchAssistantData(): Promise<AssistantData> {
(provider) =>
provider.provider === "openai" ||
provider.model_configurations.some((modelConfiguration) =>
checkLLMSupportsImageInput(modelConfiguration.name)
modelSupportsImageInput(llmProviders, modelConfiguration.name)
)
);

View File

@@ -25,7 +25,6 @@ import {
import { hasCompletedWelcomeFlowSS } from "@/components/initialSetup/welcome/WelcomeModalWrapper";
import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS";
import { NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN } from "../constants";
import { checkLLMSupportsImageInput } from "../llm/utils";
interface FetchChatDataResult {
user?: User | null;
@@ -173,8 +172,8 @@ export async function fetchSomeChatData(
const hasImageCompatibleModel = result.llmProviders?.some(
(provider) =>
provider.provider === "openai" ||
provider.model_configurations.some((modelConfiguration) =>
checkLLMSupportsImageInput(modelConfiguration.name)
provider.model_configurations.some(
(modelConfiguration) => modelConfiguration.supports_image_input
)
);

View File

@@ -1,5 +1,8 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import {
LLMProviderDescriptor,
ModelConfiguration,
} from "@/app/admin/configuration/llm/interfaces";
import { LlmDescriptor } from "@/lib/hooks";
export function getFinalLLM(
@@ -64,94 +67,6 @@ export function getLLMProviderOverrideForPersona(
return null;
}
const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"gpt-4o",
"gpt-4.1",
"gpt-4o-mini",
"gpt-4-vision-preview",
"gpt-4-turbo",
"gpt-4-1106-vision-preview",
// standard claude names
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
// custom claude names
"claude-3.5-sonnet-v2@20241022",
"claude-3-7-sonnet@20250219",
// claude names with AWS Bedrock Suffix
"claude-3-opus-20240229-v1:0",
"claude-3-sonnet-20240229-v1:0",
"claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022-v2:0",
// claude names with full AWS Bedrock names
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-7-sonnet-20250219-v1:0",
"claude-3.7-sonnet@202502019",
"claude-3-7-sonnet-202502019",
// google gemini model names
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-1.5-pro-001",
"gemini-1.5-flash-001",
"gemini-1.5-pro-002",
"gemini-1.5-flash-002",
"gemini-2.0-flash-exp",
"gemini-2.0-flash-001",
"gemini-2.0-pro-exp-02-05",
// amazon models
"amazon.nova-lite@v1",
"amazon.nova-pro@v1",
// meta models
"llama-3.2-90b-vision-instruct",
"llama-3.2-11b-vision-instruct",
"Llama-3-2-11B-Vision-Instruct-yb",
];
export function checkLLMSupportsImageInput(model: string) {
// Original exact match check
const exactMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
(modelName) => modelName.toLowerCase() === model.toLowerCase()
);
if (exactMatch) {
return true;
}
// Additional check for the last part of the model name
const modelParts = model.split(/[/.]/);
const lastPart = modelParts[modelParts.length - 1]?.toLowerCase();
// Try matching the last part
const lastPartMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
const modelNameParts = modelName.split(/[/.]/);
const modelNameLastPart = modelNameParts[modelNameParts.length - 1];
// lastPart is already lowercased above for tiny performance gain
return modelNameLastPart?.toLowerCase() === lastPart;
});
if (lastPartMatch) {
return true;
}
// If no match found, try getting the text after the first slash
if (model.includes("/")) {
const afterSlash = model.split("/")[1]?.toLowerCase();
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) =>
modelName.toLowerCase().includes(afterSlash)
);
}
return false;
}
export const structureValue = (
name: string,
provider: string,
@@ -180,3 +95,48 @@ export const findProviderForModel = (
);
return provider ? provider.provider : "";
};
export const findModelInModelConfigurations = (
modelConfigurations: ModelConfiguration[],
modelName: string
): ModelConfiguration | null => {
return modelConfigurations.find((m) => m.name === modelName) || null;
};
export const findModelConfiguration = (
llmProviders: LLMProviderDescriptor[],
modelName: string,
providerName: string | null = null
): ModelConfiguration | null => {
if (providerName) {
const provider = llmProviders.find((p) => p.name === providerName);
return provider
? findModelInModelConfigurations(provider.model_configurations, modelName)
: null;
}
for (const provider of llmProviders) {
const modelConfiguration = findModelInModelConfigurations(
provider.model_configurations,
modelName
);
if (modelConfiguration) {
return modelConfiguration;
}
}
return null;
};
export const modelSupportsImageInput = (
llmProviders: LLMProviderDescriptor[],
modelName: string,
providerName: string | null = null
): boolean => {
const modelConfiguration = findModelConfiguration(
llmProviders,
modelName,
providerName
);
return modelConfiguration?.supports_image_input || false;
};