diff --git a/backend/alembic/versions/47a07e1a38f1_fix_invalid_model_configurations_state.py b/backend/alembic/versions/47a07e1a38f1_fix_invalid_model_configurations_state.py index afa60b19536d..341a5a5389c4 100644 --- a/backend/alembic/versions/47a07e1a38f1_fix_invalid_model_configurations_state.py +++ b/backend/alembic/versions/47a07e1a38f1_fix_invalid_model_configurations_state.py @@ -24,7 +24,7 @@ branch_labels = None depends_on = None -class ModelConfiguration(BaseModel): +class _SimpleModelConfiguration(BaseModel): # Configure model to read from attributes model_config = ConfigDict(from_attributes=True) @@ -82,7 +82,7 @@ def upgrade() -> None: ) model_configurations = [ - ModelConfiguration.model_validate(model_configuration) + _SimpleModelConfiguration.model_validate(model_configuration) for model_configuration in connection.execute( sa.select( model_configuration_table.c.id, diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py index e6446961b0ad..3f559ff0e66b 100644 --- a/backend/onyx/server/manage/llm/models.py +++ b/backend/onyx/server/manage/llm/models.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from pydantic import Field from onyx.llm.utils import get_max_input_tokens +from onyx.llm.utils import model_supports_image_input if TYPE_CHECKING: @@ -152,6 +153,7 @@ class ModelConfigurationView(BaseModel): name: str is_visible: bool | None = False max_input_tokens: int | None = None + supports_image_input: bool @classmethod def from_model( @@ -166,6 +168,10 @@ class ModelConfigurationView(BaseModel): or get_max_input_tokens( model_name=model_configuration_model.name, model_provider=provider_name ), + supports_image_input=model_supports_image_input( + model_name=model_configuration_model.name, + model_provider=provider_name, + ), ) diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider.py b/backend/tests/integration/tests/llm_provider/test_llm_provider.py index 073c28eba715..85d46ef2fe39 100644 --- a/backend/tests/integration/tests/llm_provider/test_llm_provider.py +++ b/backend/tests/integration/tests/llm_provider/test_llm_provider.py @@ -1,19 +1,18 @@ import uuid +from typing import Any import pytest import requests from requests.models import Response from onyx.llm.utils import get_max_input_tokens +from onyx.llm.utils import model_supports_image_input from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser -_DEFAULT_MODELS = ["gpt-4", "gpt-4o"] - - def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None: """Utility function to fetch an LLM provider by ID""" response = requests.get( @@ -40,10 +39,10 @@ def assert_response_is_equivalent( assert provider_data["default_model_name"] == default_model_name - def fill_max_input_tokens_if_none( + def fill_max_input_tokens_and_supports_image_input( req: ModelConfigurationUpsertRequest, - ) -> ModelConfigurationUpsertRequest: - return ModelConfigurationUpsertRequest( + ) -> dict[str, Any]: + filled_with_max_input_tokens = ModelConfigurationUpsertRequest( name=req.name, is_visible=req.is_visible, max_input_tokens=req.max_input_tokens @@ -51,13 +50,21 @@ def assert_response_is_equivalent( model_name=req.name, model_provider=default_model_name ), ) + return { + **filled_with_max_input_tokens.model_dump(), + "supports_image_input": model_supports_image_input( + req.name, created_provider["provider"] + ), + } actual = set( tuple(model_configuration.items()) for model_configuration in provider_data["model_configurations"] ) expected = set( - tuple(fill_max_input_tokens_if_none(model_configuration).dict().items()) + tuple( + fill_max_input_tokens_and_supports_image_input(model_configuration).items() + ) for model_configuration in model_configurations ) assert actual == expected @@ -150,7 +157,7 @@ def test_create_llm_provider( "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": default_model_name, "model_configurations": [ - model_configuration.dict() + model_configuration.model_dump() for model_configuration in model_configurations ], "is_public": True, diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index b5b389722298..c525e79fdeb9 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -25,8 +25,8 @@ import { getDisplayNameForModel, useLabels } from "@/lib/hooks"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences"; import { - checkLLMSupportsImageInput, destructureValue, + modelSupportsImageInput, structureValue, } from "@/lib/llm/utils"; import { ToolSnapshot } from "@/lib/tools/interfaces"; @@ -139,6 +139,7 @@ export function AssistantEditor({ admin?: boolean; }) { const { refreshAssistants, isImageGenerationAvailable } = useAssistants(); + const router = useRouter(); const searchParams = useSearchParams(); const isAdminPage = searchParams?.get("admin") === "true"; @@ -643,7 +644,8 @@ export function AssistantEditor({ // model must support image input for image generation // to work - const currentLLMSupportsImageOutput = checkLLMSupportsImageInput( + const currentLLMSupportsImageOutput = modelSupportsImageInput( + llmProviders, values.llm_model_version_override || defaultModelName || "" ); diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts index 17c8e43838b3..518967c1bcd0 100644 --- a/web/src/app/admin/configuration/llm/interfaces.ts +++ b/web/src/app/admin/configuration/llm/interfaces.ts @@ -71,6 +71,7 @@ export interface ModelConfiguration { name: string; is_visible: boolean; max_input_tokens: number | null; + supports_image_input: boolean; } export interface VisionProvider extends LLMProviderView { diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 01c09f6274a2..7371991e5eaf 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -90,8 +90,8 @@ import { buildFilters } from "@/lib/search/utils"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; import { - checkLLMSupportsImageInput, getFinalLLM, + modelSupportsImageInput, structureValue, } from "@/lib/llm/utils"; import { ChatInputBar } from "./input/ChatInputBar"; @@ -1952,7 +1952,7 @@ export function ChatPage({ liveAssistant, llmManager.currentLlm ); - const llmAcceptsImages = checkLLMSupportsImageInput(llmModel); + const llmAcceptsImages = modelSupportsImageInput(llmProviders, llmModel); const imageFiles = acceptedFiles.filter((file) => file.type.startsWith("image/") diff --git a/web/src/app/chat/input/LLMPopover.tsx b/web/src/app/chat/input/LLMPopover.tsx index 476d4e611574..88dead2732a8 100644 --- a/web/src/app/chat/input/LLMPopover.tsx +++ b/web/src/app/chat/input/LLMPopover.tsx @@ -6,7 +6,7 @@ import { } from "@/components/ui/popover"; import { getDisplayNameForModel } from "@/lib/hooks"; import { - checkLLMSupportsImageInput, + modelSupportsImageInput, destructureValue, structureValue, } from "@/lib/llm/utils"; @@ -175,7 +175,10 @@ export default function LLMPopover({ >
{llmOptions.map(({ name, icon, value }, index) => { - if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) { + if ( + !requiresImageGeneration || + modelSupportsImageInput(llmProviders, name) + ) { return (