diff --git a/backend/alembic/versions/df46c75b714e_add_default_vision_provider_to_llm_.py b/backend/alembic/versions/df46c75b714e_add_default_vision_provider_to_llm_.py new file mode 100644 index 000000000..15f0760ed --- /dev/null +++ b/backend/alembic/versions/df46c75b714e_add_default_vision_provider_to_llm_.py @@ -0,0 +1,45 @@ +"""add_default_vision_provider_to_llm_provider + +Revision ID: df46c75b714e +Revises: 3934b1bc7b62 +Create Date: 2025-03-11 16:20:19.038945 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "df46c75b714e" +down_revision = "3934b1bc7b62" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "llm_provider", + sa.Column( + "is_default_vision_provider", + sa.Boolean(), + nullable=True, + server_default=sa.false(), + ), + ) + op.add_column( + "llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True) + ) + # Add unique constraint for is_default_vision_provider + op.create_unique_constraint( + "uq_llm_provider_is_default_vision_provider", + "llm_provider", + ["is_default_vision_provider"], + ) + + +def downgrade() -> None: + op.drop_constraint( + "uq_llm_provider_is_default_vision_provider", "llm_provider", type_="unique" + ) + op.drop_column("llm_provider", "default_vision_model") + op.drop_column("llm_provider", "is_default_vision_provider") diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index de720430e..74be29c6d 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -8,6 +8,9 @@ from onyx.configs.constants import AuthType from onyx.configs.constants import DocumentIndexType from onyx.configs.constants import QueryHistoryType from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy +from onyx.prompts.image_analysis import DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT +from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT +from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT ##### # App Configs @@ -646,3 +649,21 @@ DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20 # Number of pre-provisioned tenants to maintain TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5")) + + +# Image summarization configuration +IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get( + "IMAGE_SUMMARIZATION_SYSTEM_PROMPT", + DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT, +) + +# The user prompt for image summarization - the image filename will be automatically prepended +IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get( + "IMAGE_SUMMARIZATION_USER_PROMPT", + DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT, +) + +IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get( + "IMAGE_ANALYSIS_SYSTEM_PROMPT", + DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT, +) diff --git a/backend/onyx/connectors/vision_enabled_connector.py b/backend/onyx/connectors/vision_enabled_connector.py index 021fb759b..03b523f26 100644 --- a/backend/onyx/connectors/vision_enabled_connector.py +++ b/backend/onyx/connectors/vision_enabled_connector.py @@ -30,6 +30,7 @@ class VisionEnabledConnector: Sets self.image_analysis_llm to the LLM instance or None if disabled. """ self.image_analysis_llm: LLM | None = None + if get_image_extraction_and_analysis_enabled(): try: self.image_analysis_llm = get_default_llm_with_vision() diff --git a/backend/onyx/context/search/postprocessing/postprocessing.py b/backend/onyx/context/search/postprocessing/postprocessing.py index 7f88ec00e..41243eec7 100644 --- a/backend/onyx/context/search/postprocessing/postprocessing.py +++ b/backend/onyx/context/search/postprocessing/postprocessing.py @@ -10,6 +10,7 @@ from langchain_core.messages import SystemMessage from onyx.chat.models import SectionRelevancePiece from onyx.configs.app_configs import BLURB_SIZE +from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT from onyx.configs.constants import RETURN_SEPARATOR from onyx.configs.llm_configs import get_search_time_image_analysis_enabled from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX @@ -31,7 +32,6 @@ from onyx.file_store.file_store import get_default_file_store from onyx.llm.interfaces import LLM from onyx.llm.utils import message_to_string from onyx.natural_language_processing.search_nlp_models import RerankingModel -from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import FunctionCall diff --git a/backend/onyx/db/llm.py b/backend/onyx/db/llm.py index eff919295..e5b1602b7 100644 --- a/backend/onyx/db/llm.py +++ b/backend/onyx/db/llm.py @@ -13,6 +13,7 @@ from onyx.db.models import SearchSettings from onyx.db.models import Tool as ToolModel from onyx.db.models import User from onyx.db.models import User__UserGroup +from onyx.llm.utils import model_supports_image_input from onyx.server.manage.embedding.models import CloudEmbeddingProvider from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from onyx.server.manage.llm.models import FullLLMProvider @@ -187,6 +188,17 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: return FullLLMProvider.from_model(provider_model) +def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None: + provider_model = db_session.scalar( + select(LLMProviderModel).where( + LLMProviderModel.is_default_vision_provider == True # noqa: E712 + ) + ) + if not provider_model: + return None + return FullLLMProvider.from_model(provider_model) + + def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None: provider_model = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.name == provider_name) @@ -246,3 +258,39 @@ def update_default_provider(provider_id: int, db_session: Session) -> None: new_default.is_default_provider = True db_session.commit() + + +def update_default_vision_provider( + provider_id: int, vision_model: str | None, db_session: Session +) -> None: + new_default = db_session.scalar( + select(LLMProviderModel).where(LLMProviderModel.id == provider_id) + ) + if not new_default: + raise ValueError(f"LLM Provider with id {provider_id} does not exist") + + # Validate that the specified vision model supports image input + model_to_validate = vision_model or new_default.default_model_name + if model_to_validate: + if not model_supports_image_input(model_to_validate, new_default.provider): + raise ValueError( + f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input" + ) + else: + raise ValueError( + f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'" + ) + + existing_default = db_session.scalar( + select(LLMProviderModel).where( + LLMProviderModel.is_default_vision_provider == True # noqa: E712 + ) + ) + if existing_default: + existing_default.is_default_vision_provider = None + # required to ensure that the below does not cause a unique constraint violation + db_session.flush() + + new_default.is_default_vision_provider = True + new_default.default_vision_model = vision_model + db_session.commit() diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index c7f614370..32aa3f189 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -1489,6 +1489,10 @@ class LLMProvider(Base): # should only be set for a single provider is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True) + is_default_vision_provider: Mapped[bool | None] = mapped_column( + Boolean, unique=True + ) + default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True) # EE only is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) groups: Mapped[list["UserGroup"]] = relationship( diff --git a/backend/onyx/file_processing/image_summarization.py b/backend/onyx/file_processing/image_summarization.py index b81da25ec..69d4c3999 100644 --- a/backend/onyx/file_processing/image_summarization.py +++ b/backend/onyx/file_processing/image_summarization.py @@ -6,10 +6,10 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage from PIL import Image +from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT +from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT from onyx.llm.interfaces import LLM from onyx.llm.utils import message_to_string -from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT -from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT from onyx.utils.logger import setup_logger logger = setup_logger() @@ -62,7 +62,7 @@ def summarize_image_with_error_handling( image_data: The raw image bytes context_name: Name or title of the image for context system_prompt: System prompt to use for the LLM - user_prompt_template: Template for the user prompt, should contain {title} placeholder + user_prompt_template: User prompt to use (without title) Returns: The image summary text, or None if summarization failed or is disabled @@ -70,7 +70,10 @@ def summarize_image_with_error_handling( if llm is None: return None - user_prompt = user_prompt_template.format(title=context_name) + # Prepend the image filename to the user prompt + user_prompt = ( + f"The image has the file name '{context_name}'.\n{user_prompt_template}" + ) return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt) diff --git a/backend/onyx/llm/factory.py b/backend/onyx/llm/factory.py index 4c8a5f093..3d0bb6b3b 100644 --- a/backend/onyx/llm/factory.py +++ b/backend/onyx/llm/factory.py @@ -5,7 +5,9 @@ from onyx.configs.app_configs import DISABLE_GENERATIVE_AI from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_TEMPERATURE from onyx.db.engine import get_session_context_manager +from onyx.db.engine import get_session_with_current_tenant from onyx.db.llm import fetch_default_provider +from onyx.db.llm import fetch_default_vision_provider from onyx.db.llm import fetch_existing_llm_providers from onyx.db.llm import fetch_provider from onyx.db.models import Persona @@ -14,6 +16,7 @@ from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.interfaces import LLM from onyx.llm.override_models import LLMOverride from onyx.llm.utils import model_supports_image_input +from onyx.server.manage.llm.models import FullLLMProvider from onyx.utils.headers import build_llm_extra_headers from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger @@ -94,40 +97,61 @@ def get_default_llm_with_vision( additional_headers: dict[str, str] | None = None, long_term_logger: LongTermLogger | None = None, ) -> LLM | None: + """Get an LLM that supports image input, with the following priority: + 1. Use the designated default vision provider if it exists and supports image input + 2. Fall back to the first LLM provider that supports image input + + Returns None if no providers exist or if no provider supports images. + """ if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() - with get_session_context_manager() as db_session: - llm_providers = fetch_existing_llm_providers(db_session) - - if not llm_providers: - return None - - for provider in llm_providers: - model_name = provider.default_model_name - fast_model_name = ( - provider.fast_default_model_name or provider.default_model_name + def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM: + """Helper to create an LLM if the provider supports image input.""" + return get_llm( + provider=provider.provider, + model=model, + deployment_name=provider.deployment_name, + api_key=provider.api_key, + api_base=provider.api_base, + api_version=provider.api_version, + custom_config=provider.custom_config, + timeout=timeout, + temperature=temperature, + additional_headers=additional_headers, + long_term_logger=long_term_logger, ) - if not model_name or not fast_model_name: - continue - - if model_supports_image_input(model_name, provider.provider): - return get_llm( - provider=provider.provider, - model=model_name, - deployment_name=provider.deployment_name, - api_key=provider.api_key, - api_base=provider.api_base, - api_version=provider.api_version, - custom_config=provider.custom_config, - timeout=timeout, - temperature=temperature, - additional_headers=additional_headers, - long_term_logger=long_term_logger, + with get_session_with_current_tenant() as db_session: + # Try the default vision provider first + default_provider = fetch_default_vision_provider(db_session) + if ( + default_provider + and default_provider.default_vision_model + and model_supports_image_input( + default_provider.default_vision_model, default_provider.provider + ) + ): + return create_vision_llm( + default_provider, default_provider.default_vision_model ) - raise ValueError("No LLM provider found that supports image input") + # Fall back to searching all providers + providers = fetch_existing_llm_providers(db_session) + + if not providers: + return None + + # Find the first provider that supports image input + for provider in providers: + if provider.default_vision_model and model_supports_image_input( + provider.default_vision_model, provider.provider + ): + return create_vision_llm( + FullLLMProvider.from_model(provider), provider.default_vision_model + ) + + return None def get_default_llms( diff --git a/backend/onyx/prompts/image_analysis.py b/backend/onyx/prompts/image_analysis.py index 290f80526..1bc105c1d 100644 --- a/backend/onyx/prompts/image_analysis.py +++ b/backend/onyx/prompts/image_analysis.py @@ -1,5 +1,5 @@ # Used for creating embeddings of images for vector search -IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """ +DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """ You are an assistant for summarizing images for retrieval. Summarize the content of the following image and be as precise as possible. The summary will be embedded and used to retrieve the original image. @@ -7,14 +7,13 @@ Therefore, write a concise summary of the image that is optimized for retrieval. """ # Prompt for generating image descriptions with filename context -IMAGE_SUMMARIZATION_USER_PROMPT = """ -The image has the file name '{title}'. +DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT = """ Describe precisely and concisely what the image shows. """ # Used for analyzing images in response to user queries at search time -IMAGE_ANALYSIS_SYSTEM_PROMPT = ( +DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT = ( "You are an AI assistant specialized in describing images.\n" "You will receive a user question plus an image URL. Provide a concise textual answer.\n" "Focus on aspects of the image that are relevant to the user's question.\n" diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index 7a76ed196..ceafca2e3 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -14,6 +14,7 @@ from onyx.db.llm import fetch_existing_llm_providers_for_user from onyx.db.llm import fetch_provider from onyx.db.llm import remove_llm_provider from onyx.db.llm import update_default_provider +from onyx.db.llm import update_default_vision_provider from onyx.db.llm import upsert_llm_provider from onyx.db.models import User from onyx.llm.factory import get_default_llms @@ -21,11 +22,13 @@ from onyx.llm.factory import get_llm from onyx.llm.llm_provider_options import fetch_available_well_known_llms from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor from onyx.llm.utils import litellm_exception_to_error_msg +from onyx.llm.utils import model_supports_image_input from onyx.llm.utils import test_llm from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderDescriptor from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import TestLLMRequest +from onyx.server.manage.llm.models import VisionProviderResponse from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel @@ -186,6 +189,62 @@ def set_provider_as_default( update_default_provider(provider_id=provider_id, db_session=db_session) +@admin_router.post("/provider/{provider_id}/default-vision") +def set_provider_as_default_vision( + provider_id: int, + vision_model: str + | None = Query(None, description="The default vision model to use"), + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + update_default_vision_provider( + provider_id=provider_id, vision_model=vision_model, db_session=db_session + ) + + +@admin_router.get("/vision-providers") +def get_vision_capable_providers( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[VisionProviderResponse]: + """Return a list of LLM providers and their models that support image input""" + + providers = fetch_existing_llm_providers(db_session) + vision_providers = [] + + logger.info("Fetching vision-capable providers") + + for provider in providers: + vision_models = [] + + # Check model names in priority order + model_names_to_check = [] + if provider.model_names: + model_names_to_check = provider.model_names + elif provider.display_model_names: + model_names_to_check = provider.display_model_names + elif provider.default_model_name: + model_names_to_check = [provider.default_model_name] + + # Check each model for vision capability + for model_name in model_names_to_check: + if model_supports_image_input(model_name, provider.provider): + vision_models.append(model_name) + logger.debug(f"Vision model found: {provider.provider}/{model_name}") + + # Only include providers with at least one vision-capable model + if vision_models: + provider_dict = FullLLMProvider.from_model(provider).model_dump() + provider_dict["vision_models"] = vision_models + logger.info( + f"Vision provider: {provider.provider} with models: {vision_models}" + ) + vision_providers.append(VisionProviderResponse(**provider_dict)) + + logger.info(f"Found {len(vision_providers)} vision-capable providers") + return vision_providers + + """Endpoints for all""" diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py index 91c59fb15..3172f5adf 100644 --- a/backend/onyx/server/manage/llm/models.py +++ b/backend/onyx/server/manage/llm/models.py @@ -34,6 +34,8 @@ class LLMProviderDescriptor(BaseModel): default_model_name: str fast_default_model_name: str | None is_default_provider: bool | None + is_default_vision_provider: bool | None + default_vision_model: str | None display_model_names: list[str] | None @classmethod @@ -46,11 +48,10 @@ class LLMProviderDescriptor(BaseModel): default_model_name=llm_provider_model.default_model_name, fast_default_model_name=llm_provider_model.fast_default_model_name, is_default_provider=llm_provider_model.is_default_provider, - model_names=( - llm_provider_model.model_names - or fetch_models_for_provider(llm_provider_model.provider) - or [llm_provider_model.default_model_name] - ), + is_default_vision_provider=llm_provider_model.is_default_vision_provider, + default_vision_model=llm_provider_model.default_vision_model, + model_names=llm_provider_model.model_names + or fetch_models_for_provider(llm_provider_model.provider), display_model_names=llm_provider_model.display_model_names, ) @@ -68,6 +69,7 @@ class LLMProvider(BaseModel): groups: list[int] = Field(default_factory=list) display_model_names: list[str] | None = None deployment_name: str | None = None + default_vision_model: str | None = None class LLMProviderUpsertRequest(LLMProvider): @@ -79,6 +81,7 @@ class LLMProviderUpsertRequest(LLMProvider): class FullLLMProvider(LLMProvider): id: int is_default_provider: bool | None = None + is_default_vision_provider: bool | None = None model_names: list[str] @classmethod @@ -94,6 +97,8 @@ class FullLLMProvider(LLMProvider): default_model_name=llm_provider_model.default_model_name, fast_default_model_name=llm_provider_model.fast_default_model_name, is_default_provider=llm_provider_model.is_default_provider, + is_default_vision_provider=llm_provider_model.is_default_vision_provider, + default_vision_model=llm_provider_model.default_vision_model, display_model_names=llm_provider_model.display_model_names, model_names=( llm_provider_model.model_names @@ -104,3 +109,9 @@ class FullLLMProvider(LLMProvider): groups=[group.id for group in llm_provider_model.groups], deployment_name=llm_provider_model.deployment_name, ) + + +class VisionProviderResponse(FullLLMProvider): + """Response model for vision providers endpoint, including vision-specific fields.""" + + vision_models: list[str] diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts index 6f40e2078..641a372d2 100644 --- a/web/src/app/admin/configuration/llm/interfaces.ts +++ b/web/src/app/admin/configuration/llm/interfaces.ts @@ -49,6 +49,8 @@ export interface LLMProvider { groups: number[]; display_model_names: string[] | null; deployment_name: string | null; + default_vision_model: string | null; + is_default_vision_provider: boolean | null; } export interface FullLLMProvider extends LLMProvider { @@ -58,6 +60,10 @@ export interface FullLLMProvider extends LLMProvider { icon?: React.FC<{ size?: number; className?: string }>; } +export interface VisionProvider extends FullLLMProvider { + vision_models: string[]; +} + export interface LLMProviderDescriptor { name: string; provider: string; diff --git a/web/src/app/admin/settings/SettingsForm.tsx b/web/src/app/admin/settings/SettingsForm.tsx index f6c09bfc7..a62667caf 100644 --- a/web/src/app/admin/settings/SettingsForm.tsx +++ b/web/src/app/admin/settings/SettingsForm.tsx @@ -13,6 +13,9 @@ import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidE import { Modal } from "@/components/Modal"; import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants"; import { AnonymousUserPath } from "./AnonymousUserPath"; +import { useChatContext } from "@/components/context/ChatContext"; +import { LLMSelector } from "@/components/llm/LLMSelector"; +import { useVisionProviders } from "./hooks/useVisionProviders"; export function Checkbox({ label, @@ -111,6 +114,14 @@ export function SettingsForm() { const { popup, setPopup } = usePopup(); const isEnterpriseEnabled = usePaidEnterpriseFeaturesEnabled(); + // Pass setPopup to the hook + const { + visionProviders, + visionLLM, + setVisionLLM, + updateDefaultVisionProvider, + } = useVisionProviders(setPopup); + const combinedSettings = useContext(SettingsContext); useEffect(() => { @@ -120,6 +131,7 @@ export function SettingsForm() { combinedSettings.settings.maximum_chat_retention_days?.toString() || "" ); } + // We don't need to fetch vision providers here anymore as the hook handles it }, []); if (!settings) { @@ -354,6 +366,49 @@ export function SettingsForm() { id="image-analysis-max-size" placeholder="Enter maximum size in MB" /> + {/* Default Vision LLM Section */} +
+ + + Select the default LLM to use for image analysis. This model will be + utilized during image indexing and at query time for search results, + if the above settings are enabled. + + +
+ {!visionProviders || visionProviders.length === 0 ? ( +
+ No vision providers found. Please add a vision provider. +
+ ) : visionProviders.length > 0 ? ( + <> + ({ + ...provider, + model_names: provider.vision_models, + display_model_names: provider.vision_models, + }))} + currentLlm={visionLLM} + onSelect={(value) => setVisionLLM(value)} + /> + + + ) : ( +
+ No vision-capable LLMs found. Please add an LLM provider that + supports image input. +
+ )} +
+
); diff --git a/web/src/app/admin/settings/hooks/useVisionProviders.ts b/web/src/app/admin/settings/hooks/useVisionProviders.ts new file mode 100644 index 000000000..2583109e7 --- /dev/null +++ b/web/src/app/admin/settings/hooks/useVisionProviders.ts @@ -0,0 +1,123 @@ +import { useState, useEffect, useCallback } from "react"; +import { VisionProvider } from "@/app/admin/configuration/llm/interfaces"; +import { + fetchVisionProviders, + setDefaultVisionProvider, +} from "@/lib/llm/visionLLM"; +import { destructureValue, structureValue } from "@/lib/llm/utils"; + +// Define a type for the popup setter function +type SetPopup = (popup: { + message: string; + type: "success" | "error" | "info"; +}) => void; + +// Accept the setPopup function as a parameter +export function useVisionProviders(setPopup: SetPopup) { + const [visionProviders, setVisionProviders] = useState([]); + const [visionLLM, setVisionLLM] = useState(null); + const [isLoading, setIsLoading] = useState(false); + const [error, setError] = useState(null); + + const loadVisionProviders = useCallback(async () => { + setIsLoading(true); + setError(null); + try { + const data = await fetchVisionProviders(); + setVisionProviders(data); + + // Find the default vision provider and set it + const defaultProvider = data.find( + (provider) => provider.is_default_vision_provider + ); + + if (defaultProvider) { + const modelToUse = + defaultProvider.default_vision_model || + defaultProvider.default_model_name; + + if (modelToUse && defaultProvider.vision_models.includes(modelToUse)) { + setVisionLLM( + structureValue( + defaultProvider.name, + defaultProvider.provider, + modelToUse + ) + ); + } + } + } catch (error) { + console.error("Error fetching vision providers:", error); + setError( + error instanceof Error ? error.message : "Unknown error occurred" + ); + setPopup({ + message: `Failed to load vision providers: ${ + error instanceof Error ? error.message : "Unknown error" + }`, + type: "error", + }); + } finally { + setIsLoading(false); + } + }, []); + + const updateDefaultVisionProvider = useCallback( + async (llmValue: string | null) => { + if (!llmValue) { + setPopup({ + message: "Please select a valid vision model", + type: "error", + }); + return false; + } + + try { + const { name, modelName } = destructureValue(llmValue); + + // Find the provider ID + const providerObj = visionProviders.find((p) => p.name === name); + if (!providerObj) { + throw new Error("Provider not found"); + } + + await setDefaultVisionProvider(providerObj.id, modelName); + + setPopup({ + message: "Default vision provider updated successfully!", + type: "success", + }); + setVisionLLM(llmValue); + + // Refresh the list to reflect the change + await loadVisionProviders(); + return true; + } catch (error: unknown) { + console.error("Error setting default vision provider:", error); + const errorMessage = + error instanceof Error ? error.message : "Unknown error occurred"; + setPopup({ + message: `Failed to update default vision provider: ${errorMessage}`, + type: "error", + }); + return false; + } + }, + [visionProviders, setPopup, loadVisionProviders] + ); + + // Load providers on mount + useEffect(() => { + loadVisionProviders(); + }, [loadVisionProviders]); + + return { + visionProviders, + visionLLM, + isLoading, + error, + setVisionLLM, + updateDefaultVisionProvider, + refreshVisionProviders: loadVisionProviders, + }; +} diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts index 5ad2793f7..8715a922d 100644 --- a/web/src/app/admin/settings/interfaces.ts +++ b/web/src/app/admin/settings/interfaces.ts @@ -12,11 +12,12 @@ export enum QueryHistoryType { export interface Settings { anonymous_user_enabled: boolean; - maximum_chat_retention_days: number | null; + anonymous_user_path?: string; + maximum_chat_retention_days?: number | null; notifications: Notification[]; needs_reindexing: boolean; gpu_enabled: boolean; - pro_search_enabled: boolean | null; + pro_search_enabled?: boolean; application_status: ApplicationStatus; auto_scroll: boolean; temperature_override_enabled: boolean; @@ -25,7 +26,7 @@ export interface Settings { // Image processing settings image_extraction_and_analysis_enabled?: boolean; search_time_image_analysis_enabled?: boolean; - image_analysis_max_size_mb?: number; + image_analysis_max_size_mb?: number | null; } export enum NotificationType { diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 06d3d1550..fbe55a91b 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -243,6 +243,7 @@ export const AIMessage = ({ return preprocessLaTeX(content); } } + // return content; return ( preprocessLaTeX(content) + diff --git a/web/src/components/llm/LLMSelector.tsx b/web/src/components/llm/LLMSelector.tsx index 4434824ce..13372ccc9 100644 --- a/web/src/components/llm/LLMSelector.tsx +++ b/web/src/components/llm/LLMSelector.tsx @@ -103,7 +103,6 @@ export const LLMSelector: React.FC = ({ ); } - return null; })} diff --git a/web/src/lib/llm/visionLLM.ts b/web/src/lib/llm/visionLLM.ts new file mode 100644 index 000000000..e60aa207c --- /dev/null +++ b/web/src/lib/llm/visionLLM.ts @@ -0,0 +1,37 @@ +import { VisionProvider } from "@/app/admin/configuration/llm/interfaces"; + +export async function fetchVisionProviders(): Promise { + const response = await fetch("/api/admin/llm/vision-providers", { + headers: { + "Content-Type": "application/json", + }, + }); + if (!response.ok) { + throw new Error( + `Failed to fetch vision providers: ${await response.text()}` + ); + } + return response.json(); +} + +export async function setDefaultVisionProvider( + providerId: number, + visionModel: string +): Promise { + const response = await fetch( + `/api/admin/llm/provider/${providerId}/default-vision?vision_model=${encodeURIComponent( + visionModel + )}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + } + ); + + if (!response.ok) { + const errorMsg = await response.text(); + throw new Error(errorMsg); + } +}