mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 13:22:42 +01:00
Support image indexing customization (#4261)
* working well * k * ready to go * k * minor nits * k * quick fix * k * k
This commit is contained in:
parent
0153ff6b51
commit
5883336d5e
@ -0,0 +1,45 @@
|
||||
"""add_default_vision_provider_to_llm_provider
|
||||
|
||||
Revision ID: df46c75b714e
|
||||
Revises: 3934b1bc7b62
|
||||
Create Date: 2025-03-11 16:20:19.038945
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "df46c75b714e"
|
||||
down_revision = "3934b1bc7b62"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"is_default_vision_provider",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
|
||||
)
|
||||
# Add unique constraint for is_default_vision_provider
|
||||
op.create_unique_constraint(
|
||||
"uq_llm_provider_is_default_vision_provider",
|
||||
"llm_provider",
|
||||
["is_default_vision_provider"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"uq_llm_provider_is_default_vision_provider", "llm_provider", type_="unique"
|
||||
)
|
||||
op.drop_column("llm_provider", "default_vision_model")
|
||||
op.drop_column("llm_provider", "is_default_vision_provider")
|
@ -8,6 +8,9 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
from onyx.prompts.image_analysis import DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
|
||||
#####
|
||||
# App Configs
|
||||
@ -646,3 +649,21 @@ DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
||||
|
||||
# Number of pre-provisioned tenants to maintain
|
||||
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
|
||||
|
||||
|
||||
# Image summarization configuration
|
||||
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(
|
||||
"IMAGE_SUMMARIZATION_SYSTEM_PROMPT",
|
||||
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
# The user prompt for image summarization - the image filename will be automatically prepended
|
||||
IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get(
|
||||
"IMAGE_SUMMARIZATION_USER_PROMPT",
|
||||
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT,
|
||||
)
|
||||
|
||||
IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
|
||||
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
|
||||
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ class VisionEnabledConnector:
|
||||
Sets self.image_analysis_llm to the LLM instance or None if disabled.
|
||||
"""
|
||||
self.image_analysis_llm: LLM | None = None
|
||||
|
||||
if get_image_extraction_and_analysis_enabled():
|
||||
try:
|
||||
self.image_analysis_llm = get_default_llm_with_vision()
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
|
||||
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
@ -31,7 +32,6 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.natural_language_processing.search_nlp_models import RerankingModel
|
||||
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
|
@ -13,6 +13,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import Tool as ToolModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
@ -187,6 +188,17 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
@ -246,3 +258,39 @@ def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
|
||||
new_default.is_default_provider = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_default_vision_provider(
|
||||
provider_id: int, vision_model: str | None, db_session: Session
|
||||
) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
if not new_default:
|
||||
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
||||
|
||||
# Validate that the specified vision model supports image input
|
||||
model_to_validate = vision_model or new_default.default_model_name
|
||||
if model_to_validate:
|
||||
if not model_supports_image_input(model_to_validate, new_default.provider):
|
||||
raise ValueError(
|
||||
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
|
||||
)
|
||||
|
||||
existing_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if existing_default:
|
||||
existing_default.is_default_vision_provider = None
|
||||
# required to ensure that the below does not cause a unique constraint violation
|
||||
db_session.flush()
|
||||
|
||||
new_default.is_default_vision_provider = True
|
||||
new_default.default_vision_model = vision_model
|
||||
db_session.commit()
|
||||
|
@ -1489,6 +1489,10 @@ class LLMProvider(Base):
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(
|
||||
Boolean, unique=True
|
||||
)
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
|
@ -6,10 +6,10 @@ from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -62,7 +62,7 @@ def summarize_image_with_error_handling(
|
||||
image_data: The raw image bytes
|
||||
context_name: Name or title of the image for context
|
||||
system_prompt: System prompt to use for the LLM
|
||||
user_prompt_template: Template for the user prompt, should contain {title} placeholder
|
||||
user_prompt_template: User prompt to use (without title)
|
||||
|
||||
Returns:
|
||||
The image summary text, or None if summarization failed or is disabled
|
||||
@ -70,7 +70,10 @@ def summarize_image_with_error_handling(
|
||||
if llm is None:
|
||||
return None
|
||||
|
||||
user_prompt = user_prompt_template.format(title=context_name)
|
||||
# Prepend the image filename to the user prompt
|
||||
user_prompt = (
|
||||
f"The image has the file name '{context_name}'.\n{user_prompt_template}"
|
||||
)
|
||||
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
|
||||
|
||||
|
||||
|
@ -5,7 +5,9 @@ from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_vision_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.models import Persona
|
||||
@ -14,6 +16,7 @@ from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
@ -94,40 +97,61 @@ def get_default_llm_with_vision(
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM | None:
|
||||
"""Get an LLM that supports image input, with the following priority:
|
||||
1. Use the designated default vision provider if it exists and supports image input
|
||||
2. Fall back to the first LLM provider that supports image input
|
||||
|
||||
Returns None if no providers exist or if no provider supports images.
|
||||
"""
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
|
||||
if not llm_providers:
|
||||
return None
|
||||
|
||||
for provider in llm_providers:
|
||||
model_name = provider.default_model_name
|
||||
fast_model_name = (
|
||||
provider.fast_default_model_name or provider.default_model_name
|
||||
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
|
||||
"""Helper to create an LLM if the provider supports image input."""
|
||||
return get_llm(
|
||||
provider=provider.provider,
|
||||
model=model,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
if not model_name or not fast_model_name:
|
||||
continue
|
||||
|
||||
if model_supports_image_input(model_name, provider.provider):
|
||||
return get_llm(
|
||||
provider=provider.provider,
|
||||
model=model_name,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Try the default vision provider first
|
||||
default_provider = fetch_default_vision_provider(db_session)
|
||||
if (
|
||||
default_provider
|
||||
and default_provider.default_vision_model
|
||||
and model_supports_image_input(
|
||||
default_provider.default_vision_model, default_provider.provider
|
||||
)
|
||||
):
|
||||
return create_vision_llm(
|
||||
default_provider, default_provider.default_vision_model
|
||||
)
|
||||
|
||||
raise ValueError("No LLM provider found that supports image input")
|
||||
# Fall back to searching all providers
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
|
||||
if not providers:
|
||||
return None
|
||||
|
||||
# Find the first provider that supports image input
|
||||
for provider in providers:
|
||||
if provider.default_vision_model and model_supports_image_input(
|
||||
provider.default_vision_model, provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
FullLLMProvider.from_model(provider), provider.default_vision_model
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_default_llms(
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Used for creating embeddings of images for vector search
|
||||
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
|
||||
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
|
||||
You are an assistant for summarizing images for retrieval.
|
||||
Summarize the content of the following image and be as precise as possible.
|
||||
The summary will be embedded and used to retrieve the original image.
|
||||
@ -7,14 +7,13 @@ Therefore, write a concise summary of the image that is optimized for retrieval.
|
||||
"""
|
||||
|
||||
# Prompt for generating image descriptions with filename context
|
||||
IMAGE_SUMMARIZATION_USER_PROMPT = """
|
||||
The image has the file name '{title}'.
|
||||
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT = """
|
||||
Describe precisely and concisely what the image shows.
|
||||
"""
|
||||
|
||||
|
||||
# Used for analyzing images in response to user queries at search time
|
||||
IMAGE_ANALYSIS_SYSTEM_PROMPT = (
|
||||
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT = (
|
||||
"You are an AI assistant specialized in describing images.\n"
|
||||
"You will receive a user question plus an image URL. Provide a concise textual answer.\n"
|
||||
"Focus on aspects of the image that are relevant to the user's question.\n"
|
||||
|
@ -14,6 +14,7 @@ from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llms
|
||||
@ -21,11 +22,13 @@ from onyx.llm.factory import get_llm
|
||||
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@ -186,6 +189,62 @@ def set_provider_as_default(
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
provider_id: int,
|
||||
vision_model: str
|
||||
| None = Query(None, description="The default vision model to use"),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_vision_provider(
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/vision-providers")
|
||||
def get_vision_capable_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
vision_providers = []
|
||||
|
||||
logger.info("Fetching vision-capable providers")
|
||||
|
||||
for provider in providers:
|
||||
vision_models = []
|
||||
|
||||
# Check model names in priority order
|
||||
model_names_to_check = []
|
||||
if provider.model_names:
|
||||
model_names_to_check = provider.model_names
|
||||
elif provider.display_model_names:
|
||||
model_names_to_check = provider.display_model_names
|
||||
elif provider.default_model_name:
|
||||
model_names_to_check = [provider.default_model_name]
|
||||
|
||||
# Check each model for vision capability
|
||||
for model_name in model_names_to_check:
|
||||
if model_supports_image_input(model_name, provider.provider):
|
||||
vision_models.append(model_name)
|
||||
logger.debug(f"Vision model found: {provider.provider}/{model_name}")
|
||||
|
||||
# Only include providers with at least one vision-capable model
|
||||
if vision_models:
|
||||
provider_dict = FullLLMProvider.from_model(provider).model_dump()
|
||||
provider_dict["vision_models"] = vision_models
|
||||
logger.info(
|
||||
f"Vision provider: {provider.provider} with models: {vision_models}"
|
||||
)
|
||||
vision_providers.append(VisionProviderResponse(**provider_dict))
|
||||
|
||||
logger.info(f"Found {len(vision_providers)} vision-capable providers")
|
||||
return vision_providers
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
|
@ -34,6 +34,8 @@ class LLMProviderDescriptor(BaseModel):
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
display_model_names: list[str] | None
|
||||
|
||||
@classmethod
|
||||
@ -46,11 +48,10 @@ class LLMProviderDescriptor(BaseModel):
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
model_names=llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.provider),
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
)
|
||||
|
||||
@ -68,6 +69,7 @@ class LLMProvider(BaseModel):
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
display_model_names: list[str] | None = None
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@ -79,6 +81,7 @@ class LLMProviderUpsertRequest(LLMProvider):
|
||||
class FullLLMProvider(LLMProvider):
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_names: list[str]
|
||||
|
||||
@classmethod
|
||||
@ -94,6 +97,8 @@ class FullLLMProvider(LLMProvider):
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
@ -104,3 +109,9 @@ class FullLLMProvider(LLMProvider):
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
deployment_name=llm_provider_model.deployment_name,
|
||||
)
|
||||
|
||||
|
||||
class VisionProviderResponse(FullLLMProvider):
|
||||
"""Response model for vision providers endpoint, including vision-specific fields."""
|
||||
|
||||
vision_models: list[str]
|
||||
|
@ -49,6 +49,8 @@ export interface LLMProvider {
|
||||
groups: number[];
|
||||
display_model_names: string[] | null;
|
||||
deployment_name: string | null;
|
||||
default_vision_model: string | null;
|
||||
is_default_vision_provider: boolean | null;
|
||||
}
|
||||
|
||||
export interface FullLLMProvider extends LLMProvider {
|
||||
@ -58,6 +60,10 @@ export interface FullLLMProvider extends LLMProvider {
|
||||
icon?: React.FC<{ size?: number; className?: string }>;
|
||||
}
|
||||
|
||||
export interface VisionProvider extends FullLLMProvider {
|
||||
vision_models: string[];
|
||||
}
|
||||
|
||||
export interface LLMProviderDescriptor {
|
||||
name: string;
|
||||
provider: string;
|
||||
|
@ -13,6 +13,9 @@ import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidE
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { AnonymousUserPath } from "./AnonymousUserPath";
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
import { LLMSelector } from "@/components/llm/LLMSelector";
|
||||
import { useVisionProviders } from "./hooks/useVisionProviders";
|
||||
|
||||
export function Checkbox({
|
||||
label,
|
||||
@ -111,6 +114,14 @@ export function SettingsForm() {
|
||||
const { popup, setPopup } = usePopup();
|
||||
const isEnterpriseEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
// Pass setPopup to the hook
|
||||
const {
|
||||
visionProviders,
|
||||
visionLLM,
|
||||
setVisionLLM,
|
||||
updateDefaultVisionProvider,
|
||||
} = useVisionProviders(setPopup);
|
||||
|
||||
const combinedSettings = useContext(SettingsContext);
|
||||
|
||||
useEffect(() => {
|
||||
@ -120,6 +131,7 @@ export function SettingsForm() {
|
||||
combinedSettings.settings.maximum_chat_retention_days?.toString() || ""
|
||||
);
|
||||
}
|
||||
// We don't need to fetch vision providers here anymore as the hook handles it
|
||||
}, []);
|
||||
|
||||
if (!settings) {
|
||||
@ -354,6 +366,49 @@ export function SettingsForm() {
|
||||
id="image-analysis-max-size"
|
||||
placeholder="Enter maximum size in MB"
|
||||
/>
|
||||
{/* Default Vision LLM Section */}
|
||||
<div className="mt-4">
|
||||
<Label>Default Vision LLM</Label>
|
||||
<SubLabel>
|
||||
Select the default LLM to use for image analysis. This model will be
|
||||
utilized during image indexing and at query time for search results,
|
||||
if the above settings are enabled.
|
||||
</SubLabel>
|
||||
|
||||
<div className="mt-2 max-w-xs">
|
||||
{!visionProviders || visionProviders.length === 0 ? (
|
||||
<div className="text-sm text-gray-500">
|
||||
No vision providers found. Please add a vision provider.
|
||||
</div>
|
||||
) : visionProviders.length > 0 ? (
|
||||
<>
|
||||
<LLMSelector
|
||||
userSettings={false}
|
||||
llmProviders={visionProviders.map((provider) => ({
|
||||
...provider,
|
||||
model_names: provider.vision_models,
|
||||
display_model_names: provider.vision_models,
|
||||
}))}
|
||||
currentLlm={visionLLM}
|
||||
onSelect={(value) => setVisionLLM(value)}
|
||||
/>
|
||||
<Button
|
||||
onClick={() => updateDefaultVisionProvider(visionLLM)}
|
||||
className="mt-2"
|
||||
variant="default"
|
||||
size="sm"
|
||||
>
|
||||
Set Default Vision LLM
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="text-sm text-gray-500">
|
||||
No vision-capable LLMs found. Please add an LLM provider that
|
||||
supports image input.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
123
web/src/app/admin/settings/hooks/useVisionProviders.ts
Normal file
123
web/src/app/admin/settings/hooks/useVisionProviders.ts
Normal file
@ -0,0 +1,123 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { VisionProvider } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
fetchVisionProviders,
|
||||
setDefaultVisionProvider,
|
||||
} from "@/lib/llm/visionLLM";
|
||||
import { destructureValue, structureValue } from "@/lib/llm/utils";
|
||||
|
||||
// Define a type for the popup setter function
|
||||
type SetPopup = (popup: {
|
||||
message: string;
|
||||
type: "success" | "error" | "info";
|
||||
}) => void;
|
||||
|
||||
// Accept the setPopup function as a parameter
|
||||
export function useVisionProviders(setPopup: SetPopup) {
|
||||
const [visionProviders, setVisionProviders] = useState<VisionProvider[]>([]);
|
||||
const [visionLLM, setVisionLLM] = useState<string | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const loadVisionProviders = useCallback(async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const data = await fetchVisionProviders();
|
||||
setVisionProviders(data);
|
||||
|
||||
// Find the default vision provider and set it
|
||||
const defaultProvider = data.find(
|
||||
(provider) => provider.is_default_vision_provider
|
||||
);
|
||||
|
||||
if (defaultProvider) {
|
||||
const modelToUse =
|
||||
defaultProvider.default_vision_model ||
|
||||
defaultProvider.default_model_name;
|
||||
|
||||
if (modelToUse && defaultProvider.vision_models.includes(modelToUse)) {
|
||||
setVisionLLM(
|
||||
structureValue(
|
||||
defaultProvider.name,
|
||||
defaultProvider.provider,
|
||||
modelToUse
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching vision providers:", error);
|
||||
setError(
|
||||
error instanceof Error ? error.message : "Unknown error occurred"
|
||||
);
|
||||
setPopup({
|
||||
message: `Failed to load vision providers: ${
|
||||
error instanceof Error ? error.message : "Unknown error"
|
||||
}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const updateDefaultVisionProvider = useCallback(
|
||||
async (llmValue: string | null) => {
|
||||
if (!llmValue) {
|
||||
setPopup({
|
||||
message: "Please select a valid vision model",
|
||||
type: "error",
|
||||
});
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
const { name, modelName } = destructureValue(llmValue);
|
||||
|
||||
// Find the provider ID
|
||||
const providerObj = visionProviders.find((p) => p.name === name);
|
||||
if (!providerObj) {
|
||||
throw new Error("Provider not found");
|
||||
}
|
||||
|
||||
await setDefaultVisionProvider(providerObj.id, modelName);
|
||||
|
||||
setPopup({
|
||||
message: "Default vision provider updated successfully!",
|
||||
type: "success",
|
||||
});
|
||||
setVisionLLM(llmValue);
|
||||
|
||||
// Refresh the list to reflect the change
|
||||
await loadVisionProviders();
|
||||
return true;
|
||||
} catch (error: unknown) {
|
||||
console.error("Error setting default vision provider:", error);
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error occurred";
|
||||
setPopup({
|
||||
message: `Failed to update default vision provider: ${errorMessage}`,
|
||||
type: "error",
|
||||
});
|
||||
return false;
|
||||
}
|
||||
},
|
||||
[visionProviders, setPopup, loadVisionProviders]
|
||||
);
|
||||
|
||||
// Load providers on mount
|
||||
useEffect(() => {
|
||||
loadVisionProviders();
|
||||
}, [loadVisionProviders]);
|
||||
|
||||
return {
|
||||
visionProviders,
|
||||
visionLLM,
|
||||
isLoading,
|
||||
error,
|
||||
setVisionLLM,
|
||||
updateDefaultVisionProvider,
|
||||
refreshVisionProviders: loadVisionProviders,
|
||||
};
|
||||
}
|
@ -12,11 +12,12 @@ export enum QueryHistoryType {
|
||||
|
||||
export interface Settings {
|
||||
anonymous_user_enabled: boolean;
|
||||
maximum_chat_retention_days: number | null;
|
||||
anonymous_user_path?: string;
|
||||
maximum_chat_retention_days?: number | null;
|
||||
notifications: Notification[];
|
||||
needs_reindexing: boolean;
|
||||
gpu_enabled: boolean;
|
||||
pro_search_enabled: boolean | null;
|
||||
pro_search_enabled?: boolean;
|
||||
application_status: ApplicationStatus;
|
||||
auto_scroll: boolean;
|
||||
temperature_override_enabled: boolean;
|
||||
@ -25,7 +26,7 @@ export interface Settings {
|
||||
// Image processing settings
|
||||
image_extraction_and_analysis_enabled?: boolean;
|
||||
search_time_image_analysis_enabled?: boolean;
|
||||
image_analysis_max_size_mb?: number;
|
||||
image_analysis_max_size_mb?: number | null;
|
||||
}
|
||||
|
||||
export enum NotificationType {
|
||||
|
@ -243,6 +243,7 @@ export const AIMessage = ({
|
||||
return preprocessLaTeX(content);
|
||||
}
|
||||
}
|
||||
// return content;
|
||||
|
||||
return (
|
||||
preprocessLaTeX(content) +
|
||||
|
@ -103,7 +103,6 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
|
||||
</SelectItem>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
})}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
37
web/src/lib/llm/visionLLM.ts
Normal file
37
web/src/lib/llm/visionLLM.ts
Normal file
@ -0,0 +1,37 @@
|
||||
import { VisionProvider } from "@/app/admin/configuration/llm/interfaces";
|
||||
|
||||
export async function fetchVisionProviders(): Promise<VisionProvider[]> {
|
||||
const response = await fetch("/api/admin/llm/vision-providers", {
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch vision providers: ${await response.text()}`
|
||||
);
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
|
||||
export async function setDefaultVisionProvider(
|
||||
providerId: number,
|
||||
visionModel: string
|
||||
): Promise<void> {
|
||||
const response = await fetch(
|
||||
`/api/admin/llm/provider/${providerId}/default-vision?vision_model=${encodeURIComponent(
|
||||
visionModel
|
||||
)}`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
throw new Error(errorMsg);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user