mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-21 18:43:30 +02:00
Support image indexing customization (#4261)
* working well * k * ready to go * k * minor nits * k * quick fix * k * k
This commit is contained in:
@ -0,0 +1,45 @@
|
|||||||
|
"""add_default_vision_provider_to_llm_provider
|
||||||
|
|
||||||
|
Revision ID: df46c75b714e
|
||||||
|
Revises: 3934b1bc7b62
|
||||||
|
Create Date: 2025-03-11 16:20:19.038945
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "df46c75b714e"
|
||||||
|
down_revision = "3934b1bc7b62"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"llm_provider",
|
||||||
|
sa.Column(
|
||||||
|
"is_default_vision_provider",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=True,
|
||||||
|
server_default=sa.false(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
|
||||||
|
)
|
||||||
|
# Add unique constraint for is_default_vision_provider
|
||||||
|
op.create_unique_constraint(
|
||||||
|
"uq_llm_provider_is_default_vision_provider",
|
||||||
|
"llm_provider",
|
||||||
|
["is_default_vision_provider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_constraint(
|
||||||
|
"uq_llm_provider_is_default_vision_provider", "llm_provider", type_="unique"
|
||||||
|
)
|
||||||
|
op.drop_column("llm_provider", "default_vision_model")
|
||||||
|
op.drop_column("llm_provider", "is_default_vision_provider")
|
@ -8,6 +8,9 @@ from onyx.configs.constants import AuthType
|
|||||||
from onyx.configs.constants import DocumentIndexType
|
from onyx.configs.constants import DocumentIndexType
|
||||||
from onyx.configs.constants import QueryHistoryType
|
from onyx.configs.constants import QueryHistoryType
|
||||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||||
|
from onyx.prompts.image_analysis import DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||||
|
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||||
|
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# App Configs
|
# App Configs
|
||||||
@ -646,3 +649,21 @@ DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
|||||||
|
|
||||||
# Number of pre-provisioned tenants to maintain
|
# Number of pre-provisioned tenants to maintain
|
||||||
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
|
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
|
||||||
|
|
||||||
|
|
||||||
|
# Image summarization configuration
|
||||||
|
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(
|
||||||
|
"IMAGE_SUMMARIZATION_SYSTEM_PROMPT",
|
||||||
|
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The user prompt for image summarization - the image filename will be automatically prepended
|
||||||
|
IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get(
|
||||||
|
"IMAGE_SUMMARIZATION_USER_PROMPT",
|
||||||
|
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
|
||||||
|
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
|
||||||
|
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
@ -30,6 +30,7 @@ class VisionEnabledConnector:
|
|||||||
Sets self.image_analysis_llm to the LLM instance or None if disabled.
|
Sets self.image_analysis_llm to the LLM instance or None if disabled.
|
||||||
"""
|
"""
|
||||||
self.image_analysis_llm: LLM | None = None
|
self.image_analysis_llm: LLM | None = None
|
||||||
|
|
||||||
if get_image_extraction_and_analysis_enabled():
|
if get_image_extraction_and_analysis_enabled():
|
||||||
try:
|
try:
|
||||||
self.image_analysis_llm = get_default_llm_with_vision()
|
self.image_analysis_llm = get_default_llm_with_vision()
|
||||||
|
@ -10,6 +10,7 @@ from langchain_core.messages import SystemMessage
|
|||||||
|
|
||||||
from onyx.chat.models import SectionRelevancePiece
|
from onyx.chat.models import SectionRelevancePiece
|
||||||
from onyx.configs.app_configs import BLURB_SIZE
|
from onyx.configs.app_configs import BLURB_SIZE
|
||||||
|
from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||||
from onyx.configs.constants import RETURN_SEPARATOR
|
from onyx.configs.constants import RETURN_SEPARATOR
|
||||||
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
|
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
|
||||||
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||||
@ -31,7 +32,6 @@ from onyx.file_store.file_store import get_default_file_store
|
|||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.llm.utils import message_to_string
|
from onyx.llm.utils import message_to_string
|
||||||
from onyx.natural_language_processing.search_nlp_models import RerankingModel
|
from onyx.natural_language_processing.search_nlp_models import RerankingModel
|
||||||
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
|
||||||
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||||
|
@ -13,6 +13,7 @@ from onyx.db.models import SearchSettings
|
|||||||
from onyx.db.models import Tool as ToolModel
|
from onyx.db.models import Tool as ToolModel
|
||||||
from onyx.db.models import User
|
from onyx.db.models import User
|
||||||
from onyx.db.models import User__UserGroup
|
from onyx.db.models import User__UserGroup
|
||||||
|
from onyx.llm.utils import model_supports_image_input
|
||||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||||
from onyx.server.manage.llm.models import FullLLMProvider
|
from onyx.server.manage.llm.models import FullLLMProvider
|
||||||
@ -187,6 +188,17 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
|||||||
return FullLLMProvider.from_model(provider_model)
|
return FullLLMProvider.from_model(provider_model)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
|
||||||
|
provider_model = db_session.scalar(
|
||||||
|
select(LLMProviderModel).where(
|
||||||
|
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not provider_model:
|
||||||
|
return None
|
||||||
|
return FullLLMProvider.from_model(provider_model)
|
||||||
|
|
||||||
|
|
||||||
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
||||||
provider_model = db_session.scalar(
|
provider_model = db_session.scalar(
|
||||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||||
@ -246,3 +258,39 @@ def update_default_provider(provider_id: int, db_session: Session) -> None:
|
|||||||
|
|
||||||
new_default.is_default_provider = True
|
new_default.is_default_provider = True
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def update_default_vision_provider(
|
||||||
|
provider_id: int, vision_model: str | None, db_session: Session
|
||||||
|
) -> None:
|
||||||
|
new_default = db_session.scalar(
|
||||||
|
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||||
|
)
|
||||||
|
if not new_default:
|
||||||
|
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
||||||
|
|
||||||
|
# Validate that the specified vision model supports image input
|
||||||
|
model_to_validate = vision_model or new_default.default_model_name
|
||||||
|
if model_to_validate:
|
||||||
|
if not model_supports_image_input(model_to_validate, new_default.provider):
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_default = db_session.scalar(
|
||||||
|
select(LLMProviderModel).where(
|
||||||
|
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if existing_default:
|
||||||
|
existing_default.is_default_vision_provider = None
|
||||||
|
# required to ensure that the below does not cause a unique constraint violation
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
new_default.is_default_vision_provider = True
|
||||||
|
new_default.default_vision_model = vision_model
|
||||||
|
db_session.commit()
|
||||||
|
@ -1489,6 +1489,10 @@ class LLMProvider(Base):
|
|||||||
|
|
||||||
# should only be set for a single provider
|
# should only be set for a single provider
|
||||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||||
|
is_default_vision_provider: Mapped[bool | None] = mapped_column(
|
||||||
|
Boolean, unique=True
|
||||||
|
)
|
||||||
|
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||||
# EE only
|
# EE only
|
||||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
groups: Mapped[list["UserGroup"]] = relationship(
|
groups: Mapped[list["UserGroup"]] = relationship(
|
||||||
|
@ -6,10 +6,10 @@ from langchain_core.messages import HumanMessage
|
|||||||
from langchain_core.messages import SystemMessage
|
from langchain_core.messages import SystemMessage
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||||
|
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT
|
||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.llm.utils import message_to_string
|
from onyx.llm.utils import message_to_string
|
||||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
|
||||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT
|
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -62,7 +62,7 @@ def summarize_image_with_error_handling(
|
|||||||
image_data: The raw image bytes
|
image_data: The raw image bytes
|
||||||
context_name: Name or title of the image for context
|
context_name: Name or title of the image for context
|
||||||
system_prompt: System prompt to use for the LLM
|
system_prompt: System prompt to use for the LLM
|
||||||
user_prompt_template: Template for the user prompt, should contain {title} placeholder
|
user_prompt_template: User prompt to use (without title)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The image summary text, or None if summarization failed or is disabled
|
The image summary text, or None if summarization failed or is disabled
|
||||||
@ -70,7 +70,10 @@ def summarize_image_with_error_handling(
|
|||||||
if llm is None:
|
if llm is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
user_prompt = user_prompt_template.format(title=context_name)
|
# Prepend the image filename to the user prompt
|
||||||
|
user_prompt = (
|
||||||
|
f"The image has the file name '{context_name}'.\n{user_prompt_template}"
|
||||||
|
)
|
||||||
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
|
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,9 @@ from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
|||||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
from onyx.db.engine import get_session_context_manager
|
from onyx.db.engine import get_session_context_manager
|
||||||
|
from onyx.db.engine import get_session_with_current_tenant
|
||||||
from onyx.db.llm import fetch_default_provider
|
from onyx.db.llm import fetch_default_provider
|
||||||
|
from onyx.db.llm import fetch_default_vision_provider
|
||||||
from onyx.db.llm import fetch_existing_llm_providers
|
from onyx.db.llm import fetch_existing_llm_providers
|
||||||
from onyx.db.llm import fetch_provider
|
from onyx.db.llm import fetch_provider
|
||||||
from onyx.db.models import Persona
|
from onyx.db.models import Persona
|
||||||
@ -14,6 +16,7 @@ from onyx.llm.exceptions import GenAIDisabledException
|
|||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.llm.override_models import LLMOverride
|
from onyx.llm.override_models import LLMOverride
|
||||||
from onyx.llm.utils import model_supports_image_input
|
from onyx.llm.utils import model_supports_image_input
|
||||||
|
from onyx.server.manage.llm.models import FullLLMProvider
|
||||||
from onyx.utils.headers import build_llm_extra_headers
|
from onyx.utils.headers import build_llm_extra_headers
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from onyx.utils.long_term_log import LongTermLogger
|
from onyx.utils.long_term_log import LongTermLogger
|
||||||
@ -94,40 +97,61 @@ def get_default_llm_with_vision(
|
|||||||
additional_headers: dict[str, str] | None = None,
|
additional_headers: dict[str, str] | None = None,
|
||||||
long_term_logger: LongTermLogger | None = None,
|
long_term_logger: LongTermLogger | None = None,
|
||||||
) -> LLM | None:
|
) -> LLM | None:
|
||||||
|
"""Get an LLM that supports image input, with the following priority:
|
||||||
|
1. Use the designated default vision provider if it exists and supports image input
|
||||||
|
2. Fall back to the first LLM provider that supports image input
|
||||||
|
|
||||||
|
Returns None if no providers exist or if no provider supports images.
|
||||||
|
"""
|
||||||
if DISABLE_GENERATIVE_AI:
|
if DISABLE_GENERATIVE_AI:
|
||||||
raise GenAIDisabledException()
|
raise GenAIDisabledException()
|
||||||
|
|
||||||
with get_session_context_manager() as db_session:
|
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
|
||||||
llm_providers = fetch_existing_llm_providers(db_session)
|
"""Helper to create an LLM if the provider supports image input."""
|
||||||
|
return get_llm(
|
||||||
if not llm_providers:
|
provider=provider.provider,
|
||||||
return None
|
model=model,
|
||||||
|
deployment_name=provider.deployment_name,
|
||||||
for provider in llm_providers:
|
api_key=provider.api_key,
|
||||||
model_name = provider.default_model_name
|
api_base=provider.api_base,
|
||||||
fast_model_name = (
|
api_version=provider.api_version,
|
||||||
provider.fast_default_model_name or provider.default_model_name
|
custom_config=provider.custom_config,
|
||||||
|
timeout=timeout,
|
||||||
|
temperature=temperature,
|
||||||
|
additional_headers=additional_headers,
|
||||||
|
long_term_logger=long_term_logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not model_name or not fast_model_name:
|
with get_session_with_current_tenant() as db_session:
|
||||||
continue
|
# Try the default vision provider first
|
||||||
|
default_provider = fetch_default_vision_provider(db_session)
|
||||||
if model_supports_image_input(model_name, provider.provider):
|
if (
|
||||||
return get_llm(
|
default_provider
|
||||||
provider=provider.provider,
|
and default_provider.default_vision_model
|
||||||
model=model_name,
|
and model_supports_image_input(
|
||||||
deployment_name=provider.deployment_name,
|
default_provider.default_vision_model, default_provider.provider
|
||||||
api_key=provider.api_key,
|
)
|
||||||
api_base=provider.api_base,
|
):
|
||||||
api_version=provider.api_version,
|
return create_vision_llm(
|
||||||
custom_config=provider.custom_config,
|
default_provider, default_provider.default_vision_model
|
||||||
timeout=timeout,
|
|
||||||
temperature=temperature,
|
|
||||||
additional_headers=additional_headers,
|
|
||||||
long_term_logger=long_term_logger,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raise ValueError("No LLM provider found that supports image input")
|
# Fall back to searching all providers
|
||||||
|
providers = fetch_existing_llm_providers(db_session)
|
||||||
|
|
||||||
|
if not providers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the first provider that supports image input
|
||||||
|
for provider in providers:
|
||||||
|
if provider.default_vision_model and model_supports_image_input(
|
||||||
|
provider.default_vision_model, provider.provider
|
||||||
|
):
|
||||||
|
return create_vision_llm(
|
||||||
|
FullLLMProvider.from_model(provider), provider.default_vision_model
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_default_llms(
|
def get_default_llms(
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Used for creating embeddings of images for vector search
|
# Used for creating embeddings of images for vector search
|
||||||
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
|
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
|
||||||
You are an assistant for summarizing images for retrieval.
|
You are an assistant for summarizing images for retrieval.
|
||||||
Summarize the content of the following image and be as precise as possible.
|
Summarize the content of the following image and be as precise as possible.
|
||||||
The summary will be embedded and used to retrieve the original image.
|
The summary will be embedded and used to retrieve the original image.
|
||||||
@ -7,14 +7,13 @@ Therefore, write a concise summary of the image that is optimized for retrieval.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Prompt for generating image descriptions with filename context
|
# Prompt for generating image descriptions with filename context
|
||||||
IMAGE_SUMMARIZATION_USER_PROMPT = """
|
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT = """
|
||||||
The image has the file name '{title}'.
|
|
||||||
Describe precisely and concisely what the image shows.
|
Describe precisely and concisely what the image shows.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Used for analyzing images in response to user queries at search time
|
# Used for analyzing images in response to user queries at search time
|
||||||
IMAGE_ANALYSIS_SYSTEM_PROMPT = (
|
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT = (
|
||||||
"You are an AI assistant specialized in describing images.\n"
|
"You are an AI assistant specialized in describing images.\n"
|
||||||
"You will receive a user question plus an image URL. Provide a concise textual answer.\n"
|
"You will receive a user question plus an image URL. Provide a concise textual answer.\n"
|
||||||
"Focus on aspects of the image that are relevant to the user's question.\n"
|
"Focus on aspects of the image that are relevant to the user's question.\n"
|
||||||
|
@ -14,6 +14,7 @@ from onyx.db.llm import fetch_existing_llm_providers_for_user
|
|||||||
from onyx.db.llm import fetch_provider
|
from onyx.db.llm import fetch_provider
|
||||||
from onyx.db.llm import remove_llm_provider
|
from onyx.db.llm import remove_llm_provider
|
||||||
from onyx.db.llm import update_default_provider
|
from onyx.db.llm import update_default_provider
|
||||||
|
from onyx.db.llm import update_default_vision_provider
|
||||||
from onyx.db.llm import upsert_llm_provider
|
from onyx.db.llm import upsert_llm_provider
|
||||||
from onyx.db.models import User
|
from onyx.db.models import User
|
||||||
from onyx.llm.factory import get_default_llms
|
from onyx.llm.factory import get_default_llms
|
||||||
@ -21,11 +22,13 @@ from onyx.llm.factory import get_llm
|
|||||||
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
||||||
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||||
|
from onyx.llm.utils import model_supports_image_input
|
||||||
from onyx.llm.utils import test_llm
|
from onyx.llm.utils import test_llm
|
||||||
from onyx.server.manage.llm.models import FullLLMProvider
|
from onyx.server.manage.llm.models import FullLLMProvider
|
||||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||||
from onyx.server.manage.llm.models import TestLLMRequest
|
from onyx.server.manage.llm.models import TestLLMRequest
|
||||||
|
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||||
|
|
||||||
@ -186,6 +189,62 @@ def set_provider_as_default(
|
|||||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||||
|
def set_provider_as_default_vision(
|
||||||
|
provider_id: int,
|
||||||
|
vision_model: str
|
||||||
|
| None = Query(None, description="The default vision model to use"),
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
update_default_vision_provider(
|
||||||
|
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/vision-providers")
|
||||||
|
def get_vision_capable_providers(
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> list[VisionProviderResponse]:
|
||||||
|
"""Return a list of LLM providers and their models that support image input"""
|
||||||
|
|
||||||
|
providers = fetch_existing_llm_providers(db_session)
|
||||||
|
vision_providers = []
|
||||||
|
|
||||||
|
logger.info("Fetching vision-capable providers")
|
||||||
|
|
||||||
|
for provider in providers:
|
||||||
|
vision_models = []
|
||||||
|
|
||||||
|
# Check model names in priority order
|
||||||
|
model_names_to_check = []
|
||||||
|
if provider.model_names:
|
||||||
|
model_names_to_check = provider.model_names
|
||||||
|
elif provider.display_model_names:
|
||||||
|
model_names_to_check = provider.display_model_names
|
||||||
|
elif provider.default_model_name:
|
||||||
|
model_names_to_check = [provider.default_model_name]
|
||||||
|
|
||||||
|
# Check each model for vision capability
|
||||||
|
for model_name in model_names_to_check:
|
||||||
|
if model_supports_image_input(model_name, provider.provider):
|
||||||
|
vision_models.append(model_name)
|
||||||
|
logger.debug(f"Vision model found: {provider.provider}/{model_name}")
|
||||||
|
|
||||||
|
# Only include providers with at least one vision-capable model
|
||||||
|
if vision_models:
|
||||||
|
provider_dict = FullLLMProvider.from_model(provider).model_dump()
|
||||||
|
provider_dict["vision_models"] = vision_models
|
||||||
|
logger.info(
|
||||||
|
f"Vision provider: {provider.provider} with models: {vision_models}"
|
||||||
|
)
|
||||||
|
vision_providers.append(VisionProviderResponse(**provider_dict))
|
||||||
|
|
||||||
|
logger.info(f"Found {len(vision_providers)} vision-capable providers")
|
||||||
|
return vision_providers
|
||||||
|
|
||||||
|
|
||||||
"""Endpoints for all"""
|
"""Endpoints for all"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,6 +34,8 @@ class LLMProviderDescriptor(BaseModel):
|
|||||||
default_model_name: str
|
default_model_name: str
|
||||||
fast_default_model_name: str | None
|
fast_default_model_name: str | None
|
||||||
is_default_provider: bool | None
|
is_default_provider: bool | None
|
||||||
|
is_default_vision_provider: bool | None
|
||||||
|
default_vision_model: str | None
|
||||||
display_model_names: list[str] | None
|
display_model_names: list[str] | None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -46,11 +48,10 @@ class LLMProviderDescriptor(BaseModel):
|
|||||||
default_model_name=llm_provider_model.default_model_name,
|
default_model_name=llm_provider_model.default_model_name,
|
||||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||||
is_default_provider=llm_provider_model.is_default_provider,
|
is_default_provider=llm_provider_model.is_default_provider,
|
||||||
model_names=(
|
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||||
llm_provider_model.model_names
|
default_vision_model=llm_provider_model.default_vision_model,
|
||||||
or fetch_models_for_provider(llm_provider_model.provider)
|
model_names=llm_provider_model.model_names
|
||||||
or [llm_provider_model.default_model_name]
|
or fetch_models_for_provider(llm_provider_model.provider),
|
||||||
),
|
|
||||||
display_model_names=llm_provider_model.display_model_names,
|
display_model_names=llm_provider_model.display_model_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -68,6 +69,7 @@ class LLMProvider(BaseModel):
|
|||||||
groups: list[int] = Field(default_factory=list)
|
groups: list[int] = Field(default_factory=list)
|
||||||
display_model_names: list[str] | None = None
|
display_model_names: list[str] | None = None
|
||||||
deployment_name: str | None = None
|
deployment_name: str | None = None
|
||||||
|
default_vision_model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class LLMProviderUpsertRequest(LLMProvider):
|
class LLMProviderUpsertRequest(LLMProvider):
|
||||||
@ -79,6 +81,7 @@ class LLMProviderUpsertRequest(LLMProvider):
|
|||||||
class FullLLMProvider(LLMProvider):
|
class FullLLMProvider(LLMProvider):
|
||||||
id: int
|
id: int
|
||||||
is_default_provider: bool | None = None
|
is_default_provider: bool | None = None
|
||||||
|
is_default_vision_provider: bool | None = None
|
||||||
model_names: list[str]
|
model_names: list[str]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -94,6 +97,8 @@ class FullLLMProvider(LLMProvider):
|
|||||||
default_model_name=llm_provider_model.default_model_name,
|
default_model_name=llm_provider_model.default_model_name,
|
||||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||||
is_default_provider=llm_provider_model.is_default_provider,
|
is_default_provider=llm_provider_model.is_default_provider,
|
||||||
|
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||||
|
default_vision_model=llm_provider_model.default_vision_model,
|
||||||
display_model_names=llm_provider_model.display_model_names,
|
display_model_names=llm_provider_model.display_model_names,
|
||||||
model_names=(
|
model_names=(
|
||||||
llm_provider_model.model_names
|
llm_provider_model.model_names
|
||||||
@ -104,3 +109,9 @@ class FullLLMProvider(LLMProvider):
|
|||||||
groups=[group.id for group in llm_provider_model.groups],
|
groups=[group.id for group in llm_provider_model.groups],
|
||||||
deployment_name=llm_provider_model.deployment_name,
|
deployment_name=llm_provider_model.deployment_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VisionProviderResponse(FullLLMProvider):
|
||||||
|
"""Response model for vision providers endpoint, including vision-specific fields."""
|
||||||
|
|
||||||
|
vision_models: list[str]
|
||||||
|
@ -49,6 +49,8 @@ export interface LLMProvider {
|
|||||||
groups: number[];
|
groups: number[];
|
||||||
display_model_names: string[] | null;
|
display_model_names: string[] | null;
|
||||||
deployment_name: string | null;
|
deployment_name: string | null;
|
||||||
|
default_vision_model: string | null;
|
||||||
|
is_default_vision_provider: boolean | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface FullLLMProvider extends LLMProvider {
|
export interface FullLLMProvider extends LLMProvider {
|
||||||
@ -58,6 +60,10 @@ export interface FullLLMProvider extends LLMProvider {
|
|||||||
icon?: React.FC<{ size?: number; className?: string }>;
|
icon?: React.FC<{ size?: number; className?: string }>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface VisionProvider extends FullLLMProvider {
|
||||||
|
vision_models: string[];
|
||||||
|
}
|
||||||
|
|
||||||
export interface LLMProviderDescriptor {
|
export interface LLMProviderDescriptor {
|
||||||
name: string;
|
name: string;
|
||||||
provider: string;
|
provider: string;
|
||||||
|
@ -13,6 +13,9 @@ import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidE
|
|||||||
import { Modal } from "@/components/Modal";
|
import { Modal } from "@/components/Modal";
|
||||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||||
import { AnonymousUserPath } from "./AnonymousUserPath";
|
import { AnonymousUserPath } from "./AnonymousUserPath";
|
||||||
|
import { useChatContext } from "@/components/context/ChatContext";
|
||||||
|
import { LLMSelector } from "@/components/llm/LLMSelector";
|
||||||
|
import { useVisionProviders } from "./hooks/useVisionProviders";
|
||||||
|
|
||||||
export function Checkbox({
|
export function Checkbox({
|
||||||
label,
|
label,
|
||||||
@ -111,6 +114,14 @@ export function SettingsForm() {
|
|||||||
const { popup, setPopup } = usePopup();
|
const { popup, setPopup } = usePopup();
|
||||||
const isEnterpriseEnabled = usePaidEnterpriseFeaturesEnabled();
|
const isEnterpriseEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||||
|
|
||||||
|
// Pass setPopup to the hook
|
||||||
|
const {
|
||||||
|
visionProviders,
|
||||||
|
visionLLM,
|
||||||
|
setVisionLLM,
|
||||||
|
updateDefaultVisionProvider,
|
||||||
|
} = useVisionProviders(setPopup);
|
||||||
|
|
||||||
const combinedSettings = useContext(SettingsContext);
|
const combinedSettings = useContext(SettingsContext);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -120,6 +131,7 @@ export function SettingsForm() {
|
|||||||
combinedSettings.settings.maximum_chat_retention_days?.toString() || ""
|
combinedSettings.settings.maximum_chat_retention_days?.toString() || ""
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
// We don't need to fetch vision providers here anymore as the hook handles it
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
if (!settings) {
|
if (!settings) {
|
||||||
@ -354,6 +366,49 @@ export function SettingsForm() {
|
|||||||
id="image-analysis-max-size"
|
id="image-analysis-max-size"
|
||||||
placeholder="Enter maximum size in MB"
|
placeholder="Enter maximum size in MB"
|
||||||
/>
|
/>
|
||||||
|
{/* Default Vision LLM Section */}
|
||||||
|
<div className="mt-4">
|
||||||
|
<Label>Default Vision LLM</Label>
|
||||||
|
<SubLabel>
|
||||||
|
Select the default LLM to use for image analysis. This model will be
|
||||||
|
utilized during image indexing and at query time for search results,
|
||||||
|
if the above settings are enabled.
|
||||||
|
</SubLabel>
|
||||||
|
|
||||||
|
<div className="mt-2 max-w-xs">
|
||||||
|
{!visionProviders || visionProviders.length === 0 ? (
|
||||||
|
<div className="text-sm text-gray-500">
|
||||||
|
No vision providers found. Please add a vision provider.
|
||||||
|
</div>
|
||||||
|
) : visionProviders.length > 0 ? (
|
||||||
|
<>
|
||||||
|
<LLMSelector
|
||||||
|
userSettings={false}
|
||||||
|
llmProviders={visionProviders.map((provider) => ({
|
||||||
|
...provider,
|
||||||
|
model_names: provider.vision_models,
|
||||||
|
display_model_names: provider.vision_models,
|
||||||
|
}))}
|
||||||
|
currentLlm={visionLLM}
|
||||||
|
onSelect={(value) => setVisionLLM(value)}
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
onClick={() => updateDefaultVisionProvider(visionLLM)}
|
||||||
|
className="mt-2"
|
||||||
|
variant="default"
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
Set Default Vision LLM
|
||||||
|
</Button>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<div className="text-sm text-gray-500">
|
||||||
|
No vision-capable LLMs found. Please add an LLM provider that
|
||||||
|
supports image input.
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
123
web/src/app/admin/settings/hooks/useVisionProviders.ts
Normal file
123
web/src/app/admin/settings/hooks/useVisionProviders.ts
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import { useState, useEffect, useCallback } from "react";
|
||||||
|
import { VisionProvider } from "@/app/admin/configuration/llm/interfaces";
|
||||||
|
import {
|
||||||
|
fetchVisionProviders,
|
||||||
|
setDefaultVisionProvider,
|
||||||
|
} from "@/lib/llm/visionLLM";
|
||||||
|
import { destructureValue, structureValue } from "@/lib/llm/utils";
|
||||||
|
|
||||||
|
// Define a type for the popup setter function
|
||||||
|
type SetPopup = (popup: {
|
||||||
|
message: string;
|
||||||
|
type: "success" | "error" | "info";
|
||||||
|
}) => void;
|
||||||
|
|
||||||
|
// Accept the setPopup function as a parameter
|
||||||
|
export function useVisionProviders(setPopup: SetPopup) {
|
||||||
|
const [visionProviders, setVisionProviders] = useState<VisionProvider[]>([]);
|
||||||
|
const [visionLLM, setVisionLLM] = useState<string | null>(null);
|
||||||
|
const [isLoading, setIsLoading] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const loadVisionProviders = useCallback(async () => {
|
||||||
|
setIsLoading(true);
|
||||||
|
setError(null);
|
||||||
|
try {
|
||||||
|
const data = await fetchVisionProviders();
|
||||||
|
setVisionProviders(data);
|
||||||
|
|
||||||
|
// Find the default vision provider and set it
|
||||||
|
const defaultProvider = data.find(
|
||||||
|
(provider) => provider.is_default_vision_provider
|
||||||
|
);
|
||||||
|
|
||||||
|
if (defaultProvider) {
|
||||||
|
const modelToUse =
|
||||||
|
defaultProvider.default_vision_model ||
|
||||||
|
defaultProvider.default_model_name;
|
||||||
|
|
||||||
|
if (modelToUse && defaultProvider.vision_models.includes(modelToUse)) {
|
||||||
|
setVisionLLM(
|
||||||
|
structureValue(
|
||||||
|
defaultProvider.name,
|
||||||
|
defaultProvider.provider,
|
||||||
|
modelToUse
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error fetching vision providers:", error);
|
||||||
|
setError(
|
||||||
|
error instanceof Error ? error.message : "Unknown error occurred"
|
||||||
|
);
|
||||||
|
setPopup({
|
||||||
|
message: `Failed to load vision providers: ${
|
||||||
|
error instanceof Error ? error.message : "Unknown error"
|
||||||
|
}`,
|
||||||
|
type: "error",
|
||||||
|
});
|
||||||
|
} finally {
|
||||||
|
setIsLoading(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const updateDefaultVisionProvider = useCallback(
|
||||||
|
async (llmValue: string | null) => {
|
||||||
|
if (!llmValue) {
|
||||||
|
setPopup({
|
||||||
|
message: "Please select a valid vision model",
|
||||||
|
type: "error",
|
||||||
|
});
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { name, modelName } = destructureValue(llmValue);
|
||||||
|
|
||||||
|
// Find the provider ID
|
||||||
|
const providerObj = visionProviders.find((p) => p.name === name);
|
||||||
|
if (!providerObj) {
|
||||||
|
throw new Error("Provider not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
await setDefaultVisionProvider(providerObj.id, modelName);
|
||||||
|
|
||||||
|
setPopup({
|
||||||
|
message: "Default vision provider updated successfully!",
|
||||||
|
type: "success",
|
||||||
|
});
|
||||||
|
setVisionLLM(llmValue);
|
||||||
|
|
||||||
|
// Refresh the list to reflect the change
|
||||||
|
await loadVisionProviders();
|
||||||
|
return true;
|
||||||
|
} catch (error: unknown) {
|
||||||
|
console.error("Error setting default vision provider:", error);
|
||||||
|
const errorMessage =
|
||||||
|
error instanceof Error ? error.message : "Unknown error occurred";
|
||||||
|
setPopup({
|
||||||
|
message: `Failed to update default vision provider: ${errorMessage}`,
|
||||||
|
type: "error",
|
||||||
|
});
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[visionProviders, setPopup, loadVisionProviders]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Load providers on mount
|
||||||
|
useEffect(() => {
|
||||||
|
loadVisionProviders();
|
||||||
|
}, [loadVisionProviders]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
visionProviders,
|
||||||
|
visionLLM,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
setVisionLLM,
|
||||||
|
updateDefaultVisionProvider,
|
||||||
|
refreshVisionProviders: loadVisionProviders,
|
||||||
|
};
|
||||||
|
}
|
@ -12,11 +12,12 @@ export enum QueryHistoryType {
|
|||||||
|
|
||||||
export interface Settings {
|
export interface Settings {
|
||||||
anonymous_user_enabled: boolean;
|
anonymous_user_enabled: boolean;
|
||||||
maximum_chat_retention_days: number | null;
|
anonymous_user_path?: string;
|
||||||
|
maximum_chat_retention_days?: number | null;
|
||||||
notifications: Notification[];
|
notifications: Notification[];
|
||||||
needs_reindexing: boolean;
|
needs_reindexing: boolean;
|
||||||
gpu_enabled: boolean;
|
gpu_enabled: boolean;
|
||||||
pro_search_enabled: boolean | null;
|
pro_search_enabled?: boolean;
|
||||||
application_status: ApplicationStatus;
|
application_status: ApplicationStatus;
|
||||||
auto_scroll: boolean;
|
auto_scroll: boolean;
|
||||||
temperature_override_enabled: boolean;
|
temperature_override_enabled: boolean;
|
||||||
@ -25,7 +26,7 @@ export interface Settings {
|
|||||||
// Image processing settings
|
// Image processing settings
|
||||||
image_extraction_and_analysis_enabled?: boolean;
|
image_extraction_and_analysis_enabled?: boolean;
|
||||||
search_time_image_analysis_enabled?: boolean;
|
search_time_image_analysis_enabled?: boolean;
|
||||||
image_analysis_max_size_mb?: number;
|
image_analysis_max_size_mb?: number | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum NotificationType {
|
export enum NotificationType {
|
||||||
|
@ -243,6 +243,7 @@ export const AIMessage = ({
|
|||||||
return preprocessLaTeX(content);
|
return preprocessLaTeX(content);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// return content;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
preprocessLaTeX(content) +
|
preprocessLaTeX(content) +
|
||||||
|
@ -103,7 +103,6 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
|
|||||||
</SelectItem>
|
</SelectItem>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return null;
|
|
||||||
})}
|
})}
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
|
37
web/src/lib/llm/visionLLM.ts
Normal file
37
web/src/lib/llm/visionLLM.ts
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import { VisionProvider } from "@/app/admin/configuration/llm/interfaces";
|
||||||
|
|
||||||
|
export async function fetchVisionProviders(): Promise<VisionProvider[]> {
|
||||||
|
const response = await fetch("/api/admin/llm/vision-providers", {
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(
|
||||||
|
`Failed to fetch vision providers: ${await response.text()}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function setDefaultVisionProvider(
|
||||||
|
providerId: number,
|
||||||
|
visionModel: string
|
||||||
|
): Promise<void> {
|
||||||
|
const response = await fetch(
|
||||||
|
`/api/admin/llm/provider/${providerId}/default-vision?vision_model=${encodeURIComponent(
|
||||||
|
visionModel
|
||||||
|
)}`,
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorMsg = await response.text();
|
||||||
|
throw new Error(errorMsg);
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user