From 20f2b9b2bb1204a0bf1fa21ee1cfd3562b383eb4 Mon Sep 17 00:00:00 2001 From: pablonyx Date: Wed, 5 Mar 2025 09:44:18 -0800 Subject: [PATCH] Add image support for search (#4090) * add support for image search * quick fix up * k * k * k * k * nit * quick fix for connector tests --- backend/onyx/chat/process_message.py | 1 + backend/onyx/configs/app_configs.py | 13 + .../connectors/airtable/airtable_connector.py | 1 - .../onyx/connectors/confluence/connector.py | 315 +++++++++------ backend/onyx/connectors/confluence/utils.py | 306 +++++++++++++-- backend/onyx/connectors/file/connector.py | 256 +++++++++---- .../onyx/connectors/google_drive/connector.py | 124 +++--- .../connectors/google_drive/doc_conversion.py | 217 +++++++---- backend/onyx/connectors/models.py | 3 +- .../connectors/vision_enabled_connector.py | 45 +++ backend/onyx/connectors/web/connector.py | 2 +- .../search/postprocessing/postprocessing.py | 138 +++++++ backend/onyx/db/pg_file_store.py | 25 ++ .../vespa/app_config/schemas/danswer_chunk.sd | 3 + .../document_index/vespa/chunk_retrieval.py | 3 + .../document_index/vespa/indexing_utils.py | 3 +- .../onyx/document_index/vespa_constants.py | 2 + .../onyx/file_processing/extract_file_text.py | 360 +++++++++++------- .../onyx/file_processing/file_validation.py | 46 +++ .../file_processing/image_summarization.py | 129 +++++++ backend/onyx/file_processing/image_utils.py | 70 ++++ backend/onyx/indexing/chunker.py | 249 +++++++----- backend/onyx/indexing/models.py | 1 + backend/onyx/llm/factory.py | 44 +++ backend/onyx/prompts/image_analysis.py | 22 ++ backend/onyx/seeding/load_docs.py | 7 +- backend/onyx/utils/error_handling.py | 23 ++ backend/scripts/debugging/onyx_vespa.py | 11 +- .../query_time_check/seed_dummy_docs.py | 1 + .../salesforce/test_postprocessing.py | 1 + backend/tests/unit/onyx/chat/conftest.py | 2 + .../test_quotes_processing.py | 2 + .../unit/onyx/chat/test_prune_and_merge.py | 1 + .../tests/unit/onyx/indexing/test_embedder.py | 1 + web/src/app/chat/ChatPage.tsx | 1 - web/src/components/chat/TextView.tsx | 18 +- 36 files changed, 1857 insertions(+), 589 deletions(-) create mode 100644 backend/onyx/connectors/vision_enabled_connector.py create mode 100644 backend/onyx/file_processing/file_validation.py create mode 100644 backend/onyx/file_processing/image_summarization.py create mode 100644 backend/onyx/file_processing/image_utils.py create mode 100644 backend/onyx/prompts/image_analysis.py create mode 100644 backend/onyx/utils/error_handling.py diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 2bc43e368a..7788b95d6c 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -756,6 +756,7 @@ def stream_chat_message_objects( ) # LLM prompt building, response capturing, etc. + answer = Answer( prompt_builder=prompt_builder, is_connected=is_connected, diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 3cd33fe42b..ceb38ca820 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -640,3 +640,16 @@ TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true" MOCK_LLM_RESPONSE = ( os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None ) + + +# Image processing configurations +ENABLE_IMAGE_EXTRACTION = ( + os.environ.get("ENABLE_IMAGE_EXTRACTION", "true").lower() == "true" +) +ENABLE_INDEXING_TIME_IMAGE_ANALYSIS = not ( + os.environ.get("DISABLE_INDEXING_TIME_IMAGE_ANALYSIS", "false").lower() == "true" +) +ENABLE_SEARCH_TIME_IMAGE_ANALYSIS = not ( + os.environ.get("DISABLE_SEARCH_TIME_IMAGE_ANALYSIS", "false").lower() == "true" +) +IMAGE_ANALYSIS_MAX_SIZE_MB = int(os.environ.get("IMAGE_ANALYSIS_MAX_SIZE_MB", "20")) diff --git a/backend/onyx/connectors/airtable/airtable_connector.py b/backend/onyx/connectors/airtable/airtable_connector.py index bb2990d022..689c7b45b0 100644 --- a/backend/onyx/connectors/airtable/airtable_connector.py +++ b/backend/onyx/connectors/airtable/airtable_connector.py @@ -200,7 +200,6 @@ class AirtableConnector(LoadConnector): return attachment_response.content logger.error(f"Failed to refresh attachment for {filename}") - raise attachment_content = get_attachment_with_retry(url, record_id) diff --git a/backend/onyx/connectors/confluence/connector.py b/backend/onyx/connectors/confluence/connector.py index dccab042d6..3116bfb35d 100644 --- a/backend/onyx/connectors/confluence/connector.py +++ b/backend/onyx/connectors/confluence/connector.py @@ -11,13 +11,12 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource -from onyx.connectors.confluence.onyx_confluence import attachment_to_content -from onyx.connectors.confluence.onyx_confluence import ( - extract_text_from_confluence_html, -) +from onyx.connectors.confluence.onyx_confluence import extract_text_from_confluence_html from onyx.connectors.confluence.onyx_confluence import OnyxConfluence from onyx.connectors.confluence.utils import build_confluence_document_id +from onyx.connectors.confluence.utils import convert_attachment_to_content from onyx.connectors.confluence.utils import datetime_from_string +from onyx.connectors.confluence.utils import process_attachment from onyx.connectors.confluence.utils import validate_attachment_filetype from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError @@ -36,28 +35,26 @@ from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import Section from onyx.connectors.models import SlimDocument +from onyx.connectors.vision_enabled_connector import VisionEnabledConnector from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() - # Potential Improvements -# 1. Include attachments, etc -# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost - +# 1. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost _COMMENT_EXPANSION_FIELDS = ["body.storage.value"] _PAGE_EXPANSION_FIELDS = [ "body.storage.value", "version", "space", "metadata.labels", + "history.lastUpdated", ] _ATTACHMENT_EXPANSION_FIELDS = [ "version", "space", "metadata.labels", ] - _RESTRICTIONS_EXPANSION_FIELDS = [ "space", "restrictions.read.restrictions.user", @@ -87,7 +84,11 @@ _FULL_EXTENSION_FILTER_STRING = "".join( class ConfluenceConnector( - LoadConnector, PollConnector, SlimConnector, CredentialsConnector + LoadConnector, + PollConnector, + SlimConnector, + CredentialsConnector, + VisionEnabledConnector, ): def __init__( self, @@ -105,13 +106,24 @@ class ConfluenceConnector( labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP, timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET, ) -> None: + self.wiki_base = wiki_base + self.is_cloud = is_cloud + self.space = space + self.page_id = page_id + self.index_recursively = index_recursively + self.cql_query = cql_query self.batch_size = batch_size self.continue_on_failure = continue_on_failure - self.is_cloud = is_cloud + self.labels_to_skip = labels_to_skip + self.timezone_offset = timezone_offset + self._confluence_client: OnyxConfluence | None = None + self._fetched_titles: set[str] = set() + + # Initialize vision LLM using the mixin + self.initialize_vision_llm() # Remove trailing slash from wiki_base if present self.wiki_base = wiki_base.rstrip("/") - """ If nothing is provided, we default to fetching all pages Only one or none of the following options should be specified so @@ -153,8 +165,6 @@ class ConfluenceConnector( "max_backoff_seconds": 60, } - self._confluence_client: OnyxConfluence | None = None - @property def confluence_client(self) -> OnyxConfluence: if self._confluence_client is None: @@ -184,7 +194,6 @@ class ConfluenceConnector( end: SecondsSinceUnixEpoch | None = None, ) -> str: page_query = self.base_cql_page_query + self.cql_label_filter - # Add time filters if start: formatted_start_time = datetime.fromtimestamp( @@ -196,7 +205,6 @@ class ConfluenceConnector( "%Y-%m-%d %H:%M" ) page_query += f" and lastmodified <= '{formatted_end_time}'" - return page_query def _construct_attachment_query(self, confluence_page_id: str) -> str: @@ -207,11 +215,10 @@ class ConfluenceConnector( def _get_comment_string_for_page_id(self, page_id: str) -> str: comment_string = "" - comment_cql = f"type=comment and container='{page_id}'" comment_cql += self.cql_label_filter - expand = ",".join(_COMMENT_EXPANSION_FIELDS) + for comment in self.confluence_client.paginated_cql_retrieval( cql=comment_cql, expand=expand, @@ -222,123 +229,177 @@ class ConfluenceConnector( confluence_object=comment, fetched_titles=set(), ) - return comment_string - def _convert_object_to_document( - self, - confluence_object: dict[str, Any], - parent_content_id: str | None = None, - ) -> Document | None: + def _convert_page_to_document(self, page: dict[str, Any]) -> Document | None: """ - Takes in a confluence object, extracts all metadata, and converts it into a document. - If its a page, it extracts the text, adds the comments for the document text. - If its an attachment, it just downloads the attachment and converts that into a document. - - parent_content_id: if the object is an attachment, specifies the content id that - the attachment is attached to + Converts a Confluence page to a Document object. + Includes the page content, comments, and attachments. """ - # The url and the id are the same - object_url = build_confluence_document_id( - self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud - ) + try: + # Extract basic page information + page_id = page["id"] + page_title = page["title"] + page_url = f"{self.wiki_base}/wiki{page['_links']['webui']}" - object_text = None - # Extract text from page - if confluence_object["type"] == "page": - object_text = extract_text_from_confluence_html( - confluence_client=self.confluence_client, - confluence_object=confluence_object, - fetched_titles={confluence_object.get("title", "")}, - ) - # Add comments to text - object_text += self._get_comment_string_for_page_id(confluence_object["id"]) - elif confluence_object["type"] == "attachment": - object_text = attachment_to_content( - confluence_client=self.confluence_client, - attachment=confluence_object, - parent_content_id=parent_content_id, + # Get the page content + page_content = extract_text_from_confluence_html( + self.confluence_client, page, self._fetched_titles ) - if object_text is None: - # This only happens for attachments that are not parseable + # Create the main section for the page content + sections = [Section(text=page_content, link=page_url)] + + # Process comments if available + comment_text = self._get_comment_string_for_page_id(page_id) + if comment_text: + sections.append(Section(text=comment_text, link=f"{page_url}#comments")) + + # Process attachments + if "children" in page and "attachment" in page["children"]: + attachments = self.confluence_client.get_attachments_for_page( + page_id, expand="metadata" + ) + + for attachment in attachments.get("results", []): + # Process each attachment + result = process_attachment( + self.confluence_client, + attachment, + page_title, + self.image_analysis_llm, + ) + + if result.text: + # Create a section for the attachment text + attachment_section = Section( + text=result.text, + link=f"{page_url}#attachment-{attachment['id']}", + image_file_name=result.file_name, + ) + sections.append(attachment_section) + elif result.error: + logger.warning( + f"Error processing attachment '{attachment.get('title')}': {result.error}" + ) + + # Extract metadata + metadata = {} + if "space" in page: + metadata["space"] = page["space"].get("name", "") + + # Extract labels + labels = [] + if "metadata" in page and "labels" in page["metadata"]: + for label in page["metadata"]["labels"].get("results", []): + labels.append(label.get("name", "")) + if labels: + metadata["labels"] = labels + + # Extract owners + primary_owners = [] + if "version" in page and "by" in page["version"]: + author = page["version"]["by"] + display_name = author.get("displayName", "Unknown") + primary_owners.append(BasicExpertInfo(display_name=display_name)) + + # Create the document + return Document( + id=build_confluence_document_id(self.wiki_base, page_id, self.is_cloud), + sections=sections, + source=DocumentSource.CONFLUENCE, + semantic_identifier=page_title, + metadata=metadata, + doc_updated_at=datetime_from_string(page["version"]["when"]), + primary_owners=primary_owners if primary_owners else None, + ) + except Exception as e: + logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}") + if not self.continue_on_failure: + raise return None - # Get space name - doc_metadata: dict[str, str | list[str]] = { - "Wiki Space Name": confluence_object["space"]["name"] - } - - # Get labels - label_dicts = ( - confluence_object.get("metadata", {}).get("labels", {}).get("results", []) - ) - page_labels = [label.get("name") for label in label_dicts if label.get("name")] - if page_labels: - doc_metadata["labels"] = page_labels - - # Get last modified and author email - version_dict = confluence_object.get("version", {}) - last_modified = ( - datetime_from_string(version_dict.get("when")) - if version_dict.get("when") - else None - ) - author_email = version_dict.get("by", {}).get("email") - - title = confluence_object.get("title", "Untitled Document") - - return Document( - id=object_url, - sections=[Section(link=object_url, text=object_text)], - source=DocumentSource.CONFLUENCE, - semantic_identifier=title, - doc_updated_at=last_modified, - primary_owners=( - [BasicExpertInfo(email=author_email)] if author_email else None - ), - metadata=doc_metadata, - ) - def _fetch_document_batches( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateDocumentsOutput: + """ + Yields batches of Documents. For each page: + - Create a Document with 1 Section for the page text/comments + - Then fetch attachments. For each attachment: + - Attempt to convert it with convert_attachment_to_content(...) + - If successful, create a new Section with the extracted text or summary. + """ doc_batch: list[Document] = [] - confluence_page_ids: list[str] = [] page_query = self._construct_page_query(start, end) logger.debug(f"page_query: {page_query}") - # Fetch pages as Documents + for page in self.confluence_client.paginated_cql_retrieval( cql=page_query, expand=",".join(_PAGE_EXPANSION_FIELDS), limit=self.batch_size, ): - logger.debug(f"_fetch_document_batches: {page['id']}") - confluence_page_ids.append(page["id"]) - doc = self._convert_object_to_document(page) - if doc is not None: - doc_batch.append(doc) - if len(doc_batch) >= self.batch_size: - yield doc_batch - doc_batch = [] + # Build doc from page + doc = self._convert_page_to_document(page) + if not doc: + continue + + # Now get attachments for that page: + attachment_query = self._construct_attachment_query(page["id"]) + # We'll use the page's XML to provide context if we summarize an image + confluence_xml = page.get("body", {}).get("storage", {}).get("value", "") - # Fetch attachments as Documents - for confluence_page_id in confluence_page_ids: - attachment_query = self._construct_attachment_query(confluence_page_id) - # TODO: maybe should add time filter as well? for attachment in self.confluence_client.paginated_cql_retrieval( cql=attachment_query, expand=",".join(_ATTACHMENT_EXPANSION_FIELDS), ): - doc = self._convert_object_to_document(attachment, confluence_page_id) - if doc is not None: - doc_batch.append(doc) - if len(doc_batch) >= self.batch_size: - yield doc_batch - doc_batch = [] + attachment["metadata"].get("mediaType", "") + if not validate_attachment_filetype( + attachment, self.image_analysis_llm + ): + continue + + # Attempt to get textual content or image summarization: + try: + logger.info(f"Processing attachment: {attachment['title']}") + response = convert_attachment_to_content( + confluence_client=self.confluence_client, + attachment=attachment, + page_context=confluence_xml, + llm=self.image_analysis_llm, + ) + if response is None: + continue + + content_text, file_storage_name = response + + object_url = build_confluence_document_id( + self.wiki_base, page["_links"]["webui"], self.is_cloud + ) + + if content_text: + doc.sections.append( + Section( + text=content_text, + link=object_url, + image_file_name=file_storage_name, + ) + ) + except Exception as e: + logger.error( + f"Failed to extract/summarize attachment {attachment['title']}", + exc_info=e, + ) + if not self.continue_on_failure: + raise + + doc_batch.append(doc) + + if len(doc_batch) >= self.batch_size: + yield doc_batch + doc_batch = [] if doc_batch: yield doc_batch @@ -359,55 +420,63 @@ class ConfluenceConnector( end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: + """ + Return 'slim' docs (IDs + minimal permission data). + Does not fetch actual text. Used primarily for incremental permission sync. + """ doc_metadata_list: list[SlimDocument] = [] - restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS) + # Query pages page_query = self.base_cql_page_query + self.cql_label_filter for page in self.confluence_client.cql_paginate_all_expansions( cql=page_query, expand=restrictions_expand, limit=_SLIM_DOC_BATCH_SIZE, ): - # If the page has restrictions, add them to the perm_sync_data - # These will be used by doc_sync.py to sync permissions page_restrictions = page.get("restrictions") page_space_key = page.get("space", {}).get("key") page_ancestors = page.get("ancestors", []) + page_perm_sync_data = { "restrictions": page_restrictions or {}, "space_key": page_space_key, - "ancestors": page_ancestors or [], + "ancestors": page_ancestors, } doc_metadata_list.append( SlimDocument( id=build_confluence_document_id( - self.wiki_base, - page["_links"]["webui"], - self.is_cloud, + self.wiki_base, page["_links"]["webui"], self.is_cloud ), perm_sync_data=page_perm_sync_data, ) ) + + # Query attachments for each page attachment_query = self._construct_attachment_query(page["id"]) for attachment in self.confluence_client.cql_paginate_all_expansions( cql=attachment_query, expand=restrictions_expand, limit=_SLIM_DOC_BATCH_SIZE, ): - if not validate_attachment_filetype(attachment): + # If you skip images, you'll skip them in the permission sync + attachment["metadata"].get("mediaType", "") + if not validate_attachment_filetype( + attachment, self.image_analysis_llm + ): continue - attachment_restrictions = attachment.get("restrictions") + + attachment_restrictions = attachment.get("restrictions", {}) if not attachment_restrictions: - attachment_restrictions = page_restrictions + attachment_restrictions = page_restrictions or {} attachment_space_key = attachment.get("space", {}).get("key") if not attachment_space_key: attachment_space_key = page_space_key attachment_perm_sync_data = { - "restrictions": attachment_restrictions or {}, + "restrictions": attachment_restrictions, "space_key": attachment_space_key, } @@ -421,16 +490,16 @@ class ConfluenceConnector( perm_sync_data=attachment_perm_sync_data, ) ) + if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE: yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE] doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:] + if callback and callback.should_stop(): + raise RuntimeError( + "retrieve_all_slim_documents: Stop signal detected" + ) if callback: - if callback.should_stop(): - raise RuntimeError( - "retrieve_all_slim_documents: Stop signal detected" - ) - callback.progress("retrieve_all_slim_documents", 1) yield doc_metadata_list diff --git a/backend/onyx/connectors/confluence/utils.py b/backend/onyx/connectors/confluence/utils.py index 801e24d4af..9bf1c82d0a 100644 --- a/backend/onyx/connectors/confluence/utils.py +++ b/backend/onyx/connectors/confluence/utils.py @@ -1,9 +1,12 @@ +import io import math import time from collections.abc import Callable from datetime import datetime from datetime import timedelta from datetime import timezone +from io import BytesIO +from pathlib import Path from typing import Any from typing import cast from typing import TYPE_CHECKING @@ -12,14 +15,28 @@ from urllib.parse import parse_qs from urllib.parse import quote from urllib.parse import urlparse -import bs4 import requests from pydantic import BaseModel +from sqlalchemy.orm import Session -from onyx.utils.logger import setup_logger +from onyx.configs.app_configs import ( + CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD, +) +from onyx.configs.constants import FileOrigin if TYPE_CHECKING: - pass + from onyx.connectors.confluence.onyx_confluence import OnyxConfluence + +from onyx.db.engine import get_session_with_current_tenant +from onyx.db.models import PGFileStore +from onyx.db.pg_file_store import create_populate_lobj +from onyx.db.pg_file_store import save_bytes_to_pgfilestore +from onyx.db.pg_file_store import upsert_pgfilestore +from onyx.file_processing.extract_file_text import extract_file_text +from onyx.file_processing.file_validation import is_valid_image_type +from onyx.file_processing.image_utils import store_image_and_create_section +from onyx.llm.interfaces import LLM +from onyx.utils.logger import setup_logger logger = setup_logger() @@ -35,15 +52,229 @@ class TokenResponse(BaseModel): scope: str -def validate_attachment_filetype(attachment: dict[str, Any]) -> bool: - return attachment["metadata"]["mediaType"] not in [ - "image/jpeg", - "image/png", - "image/gif", - "image/svg+xml", - "video/mp4", - "video/quicktime", - ] +def validate_attachment_filetype( + attachment: dict[str, Any], llm: LLM | None = None +) -> bool: + """ + Validates if the attachment is a supported file type. + If LLM is provided, also checks if it's an image that can be processed. + """ + attachment.get("metadata", {}) + media_type = attachment.get("metadata", {}).get("mediaType", "") + + if media_type.startswith("image/"): + return llm is not None and is_valid_image_type(media_type) + + # For non-image files, check if we support the extension + title = attachment.get("title", "") + extension = Path(title).suffix.lstrip(".").lower() if "." in title else "" + return extension in ["pdf", "doc", "docx", "txt", "md", "rtf"] + + +class AttachmentProcessingResult(BaseModel): + """ + A container for results after processing a Confluence attachment. + 'text' is the textual content of the attachment. + 'file_name' is the final file name used in PGFileStore to store the content. + 'error' holds an exception or string if something failed. + """ + + text: str | None + file_name: str | None + error: str | None = None + + +def _download_attachment( + confluence_client: "OnyxConfluence", attachment: dict[str, Any] +) -> bytes | None: + """ + Retrieves the raw bytes of an attachment from Confluence. Returns None on error. + """ + download_link = confluence_client.url + attachment["_links"]["download"] + resp = confluence_client._session.get(download_link) + if resp.status_code != 200: + logger.warning( + f"Failed to fetch {download_link} with status code {resp.status_code}" + ) + return None + return resp.content + + +def process_attachment( + confluence_client: "OnyxConfluence", + attachment: dict[str, Any], + page_context: str, + llm: LLM | None, +) -> AttachmentProcessingResult: + """ + Processes a Confluence attachment. If it's a document, extracts text, + or if it's an image and an LLM is available, summarizes it. Returns a structured result. + """ + try: + # Get the media type from the attachment metadata + media_type = attachment.get("metadata", {}).get("mediaType", "") + + # Validate the attachment type + if not validate_attachment_filetype(attachment, llm): + return AttachmentProcessingResult( + text=None, + file_name=None, + error=f"Unsupported file type: {media_type}", + ) + + # Download the attachment + raw_bytes = _download_attachment(confluence_client, attachment) + if raw_bytes is None: + return AttachmentProcessingResult( + text=None, file_name=None, error="Failed to download attachment" + ) + + # Process image attachments with LLM if available + if media_type.startswith("image/") and llm: + return _process_image_attachment( + confluence_client, attachment, page_context, llm, raw_bytes, media_type + ) + + # Process document attachments + try: + text = extract_file_text( + file=BytesIO(raw_bytes), + file_name=attachment["title"], + ) + + # Skip if the text is too long + if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD: + return AttachmentProcessingResult( + text=None, + file_name=None, + error=f"Attachment text too long: {len(text)} chars", + ) + + return AttachmentProcessingResult(text=text, file_name=None, error=None) + except Exception as e: + return AttachmentProcessingResult( + text=None, file_name=None, error=f"Failed to extract text: {e}" + ) + + except Exception as e: + return AttachmentProcessingResult( + text=None, file_name=None, error=f"Failed to process attachment: {e}" + ) + + +def _process_image_attachment( + confluence_client: "OnyxConfluence", + attachment: dict[str, Any], + page_context: str, + llm: LLM, + raw_bytes: bytes, + media_type: str, +) -> AttachmentProcessingResult: + """Process an image attachment by saving it and generating a summary.""" + try: + # Use the standardized image storage and section creation + with get_session_with_current_tenant() as db_session: + section, file_name = store_image_and_create_section( + db_session=db_session, + image_data=raw_bytes, + file_name=Path(attachment["id"]).name, + display_name=attachment["title"], + media_type=media_type, + llm=llm, + file_origin=FileOrigin.CONNECTOR, + ) + + return AttachmentProcessingResult( + text=section.text, file_name=file_name, error=None + ) + except Exception as e: + msg = f"Image summarization failed for {attachment['title']}: {e}" + logger.error(msg, exc_info=e) + return AttachmentProcessingResult(text=None, file_name=None, error=msg) + + +def _process_text_attachment( + attachment: dict[str, Any], + raw_bytes: bytes, + media_type: str, +) -> AttachmentProcessingResult: + """Process a text-based attachment by extracting its content.""" + try: + extracted_text = extract_file_text( + io.BytesIO(raw_bytes), + file_name=attachment["title"], + break_on_unprocessable=False, + ) + except Exception as e: + msg = f"Failed to extract text for '{attachment['title']}': {e}" + logger.error(msg, exc_info=e) + return AttachmentProcessingResult(text=None, file_name=None, error=msg) + + # Check length constraints + if extracted_text is None or len(extracted_text) == 0: + msg = f"No text extracted for {attachment['title']}" + logger.warning(msg) + return AttachmentProcessingResult(text=None, file_name=None, error=msg) + + if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD: + msg = ( + f"Skipping attachment {attachment['title']} due to char count " + f"({len(extracted_text)} > {CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD})" + ) + logger.warning(msg) + return AttachmentProcessingResult(text=None, file_name=None, error=msg) + + # Save the attachment + try: + with get_session_with_current_tenant() as db_session: + saved_record = save_bytes_to_pgfilestore( + db_session=db_session, + raw_bytes=raw_bytes, + media_type=media_type, + identifier=attachment["id"], + display_name=attachment["title"], + ) + except Exception as e: + msg = f"Failed to save attachment '{attachment['title']}' to PG: {e}" + logger.error(msg, exc_info=e) + return AttachmentProcessingResult( + text=extracted_text, file_name=None, error=msg + ) + + return AttachmentProcessingResult( + text=extracted_text, file_name=saved_record.file_name, error=None + ) + + +def convert_attachment_to_content( + confluence_client: "OnyxConfluence", + attachment: dict[str, Any], + page_context: str, + llm: LLM | None, +) -> tuple[str | None, str | None] | None: + """ + Facade function which: + 1. Validates attachment type + 2. Extracts or summarizes content + 3. Returns (content_text, stored_file_name) or None if we should skip it + """ + media_type = attachment["metadata"]["mediaType"] + # Quick check for unsupported types: + if media_type.startswith("video/") or media_type == "application/gliffy+json": + logger.warning( + f"Skipping unsupported attachment type: '{media_type}' for {attachment['title']}" + ) + return None + + result = process_attachment(confluence_client, attachment, page_context, llm) + if result.error is not None: + logger.warning( + f"Attachment {attachment['title']} encountered error: {result.error}" + ) + return None + + # Return the text and the file name + return result.text, result.file_name def build_confluence_document_id( @@ -64,23 +295,6 @@ def build_confluence_document_id( return f"{base_url}{content_url}" -def _extract_referenced_attachment_names(page_text: str) -> list[str]: - """Parse a Confluence html page to generate a list of current - attachments in use - - Args: - text (str): The page content - - Returns: - list[str]: List of filenames currently in use by the page text - """ - referenced_attachment_filenames = [] - soup = bs4.BeautifulSoup(page_text, "html.parser") - for attachment in soup.findAll("ri:attachment"): - referenced_attachment_filenames.append(attachment.attrs["ri:filename"]) - return referenced_attachment_filenames - - def datetime_from_string(datetime_string: str) -> datetime: datetime_object = datetime.fromisoformat(datetime_string) @@ -252,3 +466,37 @@ def update_param_in_path(path: str, param: str, value: str) -> str: + "?" + "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items()) ) + + +def attachment_to_file_record( + confluence_client: "OnyxConfluence", + attachment: dict[str, Any], + db_session: Session, +) -> tuple[PGFileStore, bytes]: + """Save an attachment to the file store and return the file record.""" + download_link = _attachment_to_download_link(confluence_client, attachment) + image_data = confluence_client.get( + download_link, absolute=True, not_json_response=True + ) + + # Save image to file store + file_name = f"confluence_attachment_{attachment['id']}" + lobj_oid = create_populate_lobj(BytesIO(image_data), db_session) + pgfilestore = upsert_pgfilestore( + file_name=file_name, + display_name=attachment["title"], + file_origin=FileOrigin.OTHER, + file_type=attachment["metadata"]["mediaType"], + lobj_oid=lobj_oid, + db_session=db_session, + commit=True, + ) + + return pgfilestore, image_data + + +def _attachment_to_download_link( + confluence_client: "OnyxConfluence", attachment: dict[str, Any] +) -> str: + """Extracts the download link to images.""" + return confluence_client.url + attachment["_links"]["download"] diff --git a/backend/onyx/connectors/file/connector.py b/backend/onyx/connectors/file/connector.py index 7514950566..83c0426f4e 100644 --- a/backend/onyx/connectors/file/connector.py +++ b/backend/onyx/connectors/file/connector.py @@ -10,22 +10,23 @@ from sqlalchemy.orm import Session from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource +from onyx.configs.constants import FileOrigin from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import Document from onyx.connectors.models import Section +from onyx.connectors.vision_enabled_connector import VisionEnabledConnector from onyx.db.engine import get_session_with_current_tenant -from onyx.file_processing.extract_file_text import detect_encoding -from onyx.file_processing.extract_file_text import extract_file_text +from onyx.db.pg_file_store import get_pgfilestore_by_file_name +from onyx.file_processing.extract_file_text import extract_text_and_images from onyx.file_processing.extract_file_text import get_file_ext -from onyx.file_processing.extract_file_text import is_text_file_extension from onyx.file_processing.extract_file_text import is_valid_file_ext from onyx.file_processing.extract_file_text import load_files_from_zip -from onyx.file_processing.extract_file_text import read_pdf_file -from onyx.file_processing.extract_file_text import read_text_file +from onyx.file_processing.image_utils import store_image_and_create_section from onyx.file_store.file_store import get_default_file_store +from onyx.llm.interfaces import LLM from onyx.utils.logger import setup_logger logger = setup_logger() @@ -35,81 +36,115 @@ def _read_files_and_metadata( file_name: str, db_session: Session, ) -> Iterator[tuple[str, IO, dict[str, Any]]]: - """Reads the file into IO, in the case of a zip file, yields each individual - file contained within, also includes the metadata dict if packaged in the zip""" + """ + Reads the file from Postgres. If the file is a .zip, yields subfiles. + """ extension = get_file_ext(file_name) metadata: dict[str, Any] = {} directory_path = os.path.dirname(file_name) + # Read file from Postgres store file_content = get_default_file_store(db_session).read_file(file_name, mode="b") + # If it's a zip, expand it if extension == ".zip": - for file_info, file, metadata in load_files_from_zip( + for file_info, subfile, metadata in load_files_from_zip( file_content, ignore_dirs=True ): - yield os.path.join(directory_path, file_info.filename), file, metadata + yield os.path.join(directory_path, file_info.filename), subfile, metadata elif is_valid_file_ext(extension): yield file_name, file_content, metadata else: logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") +def _create_image_section( + llm: LLM | None, + image_data: bytes, + db_session: Session, + parent_file_name: str, + display_name: str, + idx: int = 0, +) -> tuple[Section, str | None]: + """ + Create a Section object for a single image and store the image in PGFileStore. + If summarization is enabled and we have an LLM, summarize the image. + + Returns: + tuple: (Section object, file_name in PGFileStore or None if storage failed) + """ + # Create a unique file name for the embedded image + file_name = f"{parent_file_name}_embedded_{idx}" + + # Use the standardized utility to store the image and create a section + return store_image_and_create_section( + db_session=db_session, + image_data=image_data, + file_name=file_name, + display_name=display_name, + llm=llm, + file_origin=FileOrigin.OTHER, + ) + + def _process_file( file_name: str, file: IO[Any], - metadata: dict[str, Any] | None = None, - pdf_pass: str | None = None, + metadata: dict[str, Any] | None, + pdf_pass: str | None, + db_session: Session, + llm: LLM | None, ) -> list[Document]: + """ + Processes a single file, returning a list of Documents (typically one). + Also handles embedded images if 'EMBEDDED_IMAGE_EXTRACTION_ENABLED' is true. + """ extension = get_file_ext(file_name) - if not is_valid_file_ext(extension): - logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") + + # Fetch the DB record so we know the ID for internal URL + pg_record = get_pgfilestore_by_file_name(file_name=file_name, db_session=db_session) + if not pg_record: + logger.warning(f"No file record found for '{file_name}' in PG; skipping.") return [] - file_metadata: dict[str, Any] = {} - - if is_text_file_extension(file_name): - encoding = detect_encoding(file) - file_content_raw, file_metadata = read_text_file( - file, encoding=encoding, ignore_onyx_metadata=False + if not is_valid_file_ext(extension): + logger.warning( + f"Skipping file '{file_name}' with unrecognized extension '{extension}'" ) + return [] - # Using the PDF reader function directly to pass in password cleanly - elif extension == ".pdf" and pdf_pass is not None: - file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass) + # Prepare doc metadata + if metadata is None: + metadata = {} + file_display_name = metadata.get("file_display_name") or os.path.basename(file_name) - else: - file_content_raw = extract_file_text( - file=file, - file_name=file_name, - break_on_unprocessable=True, - ) - - all_metadata = {**metadata, **file_metadata} if metadata else file_metadata - - # add a prefix to avoid conflicts with other connectors - doc_id = f"FILE_CONNECTOR__{file_name}" - if metadata: - doc_id = metadata.get("document_id") or doc_id - - # If this is set, we will show this in the UI as the "name" of the file - file_display_name = all_metadata.get("file_display_name") or os.path.basename( - file_name - ) - title = ( - all_metadata["title"] or "" if "title" in all_metadata else file_display_name - ) - - time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc)) + # Timestamps + current_datetime = datetime.now(timezone.utc) + time_updated = metadata.get("time_updated", current_datetime) if isinstance(time_updated, str): time_updated = time_str_to_utc(time_updated) - dt_str = all_metadata.get("doc_updated_at") + dt_str = metadata.get("doc_updated_at") final_time_updated = time_str_to_utc(dt_str) if dt_str else time_updated - # Metadata tags separate from the Onyx specific fields + # Collect owners + p_owner_names = metadata.get("primary_owners") + s_owner_names = metadata.get("secondary_owners") + p_owners = ( + [BasicExpertInfo(display_name=name) for name in p_owner_names] + if p_owner_names + else None + ) + s_owners = ( + [BasicExpertInfo(display_name=name) for name in s_owner_names] + if s_owner_names + else None + ) + + # Additional tags we store as doc metadata metadata_tags = { k: v - for k, v in all_metadata.items() + for k, v in metadata.items() if k not in [ "document_id", @@ -122,77 +157,142 @@ def _process_file( "file_display_name", "title", "connector_type", + "pdf_password", ] } - source_type_str = all_metadata.get("connector_type") - source_type = DocumentSource(source_type_str) if source_type_str else None - - p_owner_names = all_metadata.get("primary_owners") - s_owner_names = all_metadata.get("secondary_owners") - p_owners = ( - [BasicExpertInfo(display_name=name) for name in p_owner_names] - if p_owner_names - else None - ) - s_owners = ( - [BasicExpertInfo(display_name=name) for name in s_owner_names] - if s_owner_names - else None + source_type_str = metadata.get("connector_type") + source_type = ( + DocumentSource(source_type_str) if source_type_str else DocumentSource.FILE ) + doc_id = metadata.get("document_id") or f"FILE_CONNECTOR__{file_name}" + title = metadata.get("title") or file_display_name + + # 1) If the file itself is an image, handle that scenario quickly + IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"} + if extension in IMAGE_EXTENSIONS: + # Summarize or produce empty doc + image_data = file.read() + image_section, _ = _create_image_section( + llm, image_data, db_session, pg_record.file_name, title + ) + return [ + Document( + id=doc_id, + sections=[image_section], + source=source_type, + semantic_identifier=file_display_name, + title=title, + doc_updated_at=final_time_updated, + primary_owners=p_owners, + secondary_owners=s_owners, + metadata=metadata_tags, + ) + ] + + # 2) Otherwise: text-based approach. Possibly with embedded images if enabled. + # (For example .docx with inline images). + file.seek(0) + text_content = "" + embedded_images: list[tuple[bytes, str]] = [] + + text_content, embedded_images = extract_text_and_images( + file=file, + file_name=file_name, + pdf_pass=pdf_pass, + ) + + # Build sections: first the text as a single Section + sections = [] + link_in_meta = metadata.get("link") + if text_content.strip(): + sections.append(Section(link=link_in_meta, text=text_content.strip())) + + # Then any extracted images from docx, etc. + for idx, (img_data, img_name) in enumerate(embedded_images, start=1): + # Store each embedded image as a separate file in PGFileStore + # and create a section with the image summary + image_section, _ = _create_image_section( + llm, + img_data, + db_session, + pg_record.file_name, + f"{title} - image {idx}", + idx, + ) + sections.append(image_section) return [ Document( id=doc_id, - sections=[ - Section(link=all_metadata.get("link"), text=file_content_raw.strip()) - ], - source=source_type or DocumentSource.FILE, + sections=sections, + source=source_type, semantic_identifier=file_display_name, title=title, doc_updated_at=final_time_updated, primary_owners=p_owners, secondary_owners=s_owners, - # currently metadata just houses tags, other stuff like owners / updated at have dedicated fields metadata=metadata_tags, ) ] -class LocalFileConnector(LoadConnector): +class LocalFileConnector(LoadConnector, VisionEnabledConnector): + """ + Connector that reads files from Postgres and yields Documents, including + optional embedded image extraction. + """ + def __init__( self, file_locations: list[Path | str], batch_size: int = INDEX_BATCH_SIZE, ) -> None: - self.file_locations = [Path(file_location) for file_location in file_locations] + self.file_locations = [str(loc) for loc in file_locations] self.batch_size = batch_size self.pdf_pass: str | None = None + # Initialize vision LLM using the mixin + self.initialize_vision_llm() + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.pdf_pass = credentials.get("pdf_password") + return None def load_from_state(self) -> GenerateDocumentsOutput: + """ + Iterates over each file path, fetches from Postgres, tries to parse text + or images, and yields Document batches. + """ documents: list[Document] = [] with get_session_with_current_tenant() as db_session: for file_path in self.file_locations: current_datetime = datetime.now(timezone.utc) - files = _read_files_and_metadata( - file_name=str(file_path), db_session=db_session + + files_iter = _read_files_and_metadata( + file_name=file_path, + db_session=db_session, ) - for file_name, file, metadata in files: + for actual_file_name, file, metadata in files_iter: metadata["time_updated"] = metadata.get( "time_updated", current_datetime ) - documents.extend( - _process_file(file_name, file, metadata, self.pdf_pass) + new_docs = _process_file( + file_name=actual_file_name, + file=file, + metadata=metadata, + pdf_pass=self.pdf_pass, + db_session=db_session, + llm=self.image_analysis_llm, ) + documents.extend(new_docs) if len(documents) >= self.batch_size: yield documents + documents = [] if documents: @@ -201,7 +301,7 @@ class LocalFileConnector(LoadConnector): if __name__ == "__main__": connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]]) - connector.load_credentials({"pdf_password": os.environ["PDF_PASSWORD"]}) - - document_batches = connector.load_from_state() - print(next(document_batches)) + connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")}) + doc_batches = connector.load_from_state() + for batch in doc_batches: + print("BATCH:", batch) diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index 542190b998..82d445a9e7 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -4,14 +4,12 @@ from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import Any -from typing import cast from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from googleapiclient.errors import HttpError # type: ignore from onyx.configs.app_configs import INDEX_BATCH_SIZE -from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError @@ -36,7 +34,6 @@ from onyx.connectors.google_utils.shared_constants import ( ) from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS -from onyx.connectors.google_utils.shared_constants import SCOPE_DOC_URL from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE from onyx.connectors.google_utils.shared_constants import USER_FIELDS from onyx.connectors.interfaces import GenerateDocumentsOutput @@ -46,7 +43,9 @@ from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector from onyx.connectors.models import ConnectorMissingCredentialError +from onyx.connectors.vision_enabled_connector import VisionEnabledConnector from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface +from onyx.llm.interfaces import LLM from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder @@ -66,7 +65,10 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]: def _convert_single_file( - creds: Any, primary_admin_email: str, file: dict[str, Any] + creds: Any, + primary_admin_email: str, + file: dict[str, Any], + image_analysis_llm: LLM | None, ) -> Any: user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email user_drive_service = get_drive_service(creds, user_email=user_email) @@ -75,11 +77,14 @@ def _convert_single_file( file=file, drive_service=user_drive_service, docs_service=docs_service, + image_analysis_llm=image_analysis_llm, # pass the LLM so doc_conversion can summarize images ) def _process_files_batch( - files: list[GoogleDriveFileType], convert_func: Callable, batch_size: int + files: list[GoogleDriveFileType], + convert_func: Callable[[GoogleDriveFileType], Any], + batch_size: int, ) -> GenerateDocumentsOutput: doc_batch = [] with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor: @@ -111,7 +116,9 @@ def _clean_requested_drive_ids( return valid_requested_drive_ids, filtered_folder_ids -class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): +class GoogleDriveConnector( + LoadConnector, PollConnector, SlimConnector, VisionEnabledConnector +): def __init__( self, include_shared_drives: bool = False, @@ -129,23 +136,23 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): continue_on_failure: bool | None = None, ) -> None: # Check for old input parameters - if ( - folder_paths is not None - or include_shared is not None - or follow_shortcuts is not None - or only_org_public is not None - or continue_on_failure is not None - ): - logger.exception( - "Google Drive connector received old input parameters. " - "Please visit the docs for help with the new setup: " - f"{SCOPE_DOC_URL}" + if folder_paths is not None: + logger.warning( + "The 'folder_paths' parameter is deprecated. Use 'shared_folder_urls' instead." ) - raise ConnectorValidationError( - "Google Drive connector received old input parameters. " - "Please visit the docs for help with the new setup: " - f"{SCOPE_DOC_URL}" + if include_shared is not None: + logger.warning( + "The 'include_shared' parameter is deprecated. Use 'include_files_shared_with_me' instead." ) + if follow_shortcuts is not None: + logger.warning("The 'follow_shortcuts' parameter is deprecated.") + if only_org_public is not None: + logger.warning("The 'only_org_public' parameter is deprecated.") + if continue_on_failure is not None: + logger.warning("The 'continue_on_failure' parameter is deprecated.") + + # Initialize vision LLM using the mixin + self.initialize_vision_llm() if ( not include_shared_drives @@ -237,6 +244,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): credentials=credentials, source=DocumentSource.GOOGLE_DRIVE, ) + return new_creds_dict def _update_traversed_parent_ids(self, folder_id: str) -> None: @@ -523,37 +531,53 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): end: SecondsSinceUnixEpoch | None = None, ) -> GenerateDocumentsOutput: # Create a larger process pool for file conversion - convert_func = partial( - _convert_single_file, self.creds, self.primary_admin_email - ) - - # Process files in larger batches - LARGE_BATCH_SIZE = self.batch_size * 4 - files_to_process = [] - # Gather the files into batches to be processed in parallel - for file in self._fetch_drive_items(is_slim=False, start=start, end=end): - if ( - file.get("size") - and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES - ): - logger.warning( - f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes" - ) - continue - - files_to_process.append(file) - if len(files_to_process) >= LARGE_BATCH_SIZE: - yield from _process_files_batch( - files_to_process, convert_func, self.batch_size - ) - files_to_process = [] - - # Process any remaining files - if files_to_process: - yield from _process_files_batch( - files_to_process, convert_func, self.batch_size + with ThreadPoolExecutor(max_workers=8) as executor: + # Prepare a partial function with the credentials and admin email + convert_func = partial( + _convert_single_file, + self.creds, + self.primary_admin_email, + image_analysis_llm=self.image_analysis_llm, # Use the mixin's LLM ) + # Fetch files in batches + files_batch: list[GoogleDriveFileType] = [] + for file in self._fetch_drive_items(is_slim=False, start=start, end=end): + files_batch.append(file) + + if len(files_batch) >= self.batch_size: + # Process the batch + futures = [ + executor.submit(convert_func, file) for file in files_batch + ] + documents = [] + for future in as_completed(futures): + try: + doc = future.result() + if doc is not None: + documents.append(doc) + except Exception as e: + logger.error(f"Error converting file: {e}") + + if documents: + yield documents + files_batch = [] + + # Process any remaining files + if files_batch: + futures = [executor.submit(convert_func, file) for file in files_batch] + documents = [] + for future in as_completed(futures): + try: + doc = future.result() + if doc is not None: + documents.append(doc) + except Exception as e: + logger.error(f"Error converting file: {e}") + + if documents: + yield documents + def load_from_state(self) -> GenerateDocumentsOutput: try: yield from self._extract_docs_from_google_drive() diff --git a/backend/onyx/connectors/google_drive/doc_conversion.py b/backend/onyx/connectors/google_drive/doc_conversion.py index a7cbb709df..6a447aa4fd 100644 --- a/backend/onyx/connectors/google_drive/doc_conversion.py +++ b/backend/onyx/connectors/google_drive/doc_conversion.py @@ -9,7 +9,7 @@ from googleapiclient.errors import HttpError # type: ignore from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from onyx.configs.constants import DocumentSource -from onyx.configs.constants import IGNORE_FOR_QA +from onyx.configs.constants import FileOrigin from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE from onyx.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT @@ -21,32 +21,88 @@ from onyx.connectors.google_utils.resources import GoogleDriveService from onyx.connectors.models import Document from onyx.connectors.models import Section from onyx.connectors.models import SlimDocument -from onyx.file_processing.extract_file_text import docx_to_text +from onyx.db.engine import get_session_with_current_tenant +from onyx.file_processing.extract_file_text import docx_to_text_and_images from onyx.file_processing.extract_file_text import pptx_to_text from onyx.file_processing.extract_file_text import read_pdf_file +from onyx.file_processing.file_validation import is_valid_image_type +from onyx.file_processing.image_summarization import summarize_image_with_error_handling +from onyx.file_processing.image_utils import store_image_and_create_section from onyx.file_processing.unstructured import get_unstructured_api_key from onyx.file_processing.unstructured import unstructured_to_text +from onyx.llm.interfaces import LLM from onyx.utils.logger import setup_logger logger = setup_logger() -# these errors don't represent a failure in the connector, but simply files -# that can't / shouldn't be indexed -ERRORS_TO_CONTINUE_ON = [ - "cannotExportFile", - "exportSizeLimitExceeded", - "cannotDownloadFile", -] +def _summarize_drive_image( + image_data: bytes, image_name: str, image_analysis_llm: LLM | None +) -> str: + """ + Summarize the given image using the provided LLM. + """ + if not image_analysis_llm: + return "" + + return ( + summarize_image_with_error_handling( + llm=image_analysis_llm, + image_data=image_data, + context_name=image_name, + ) + or "" + ) + + +def is_gdrive_image_mime_type(mime_type: str) -> bool: + """ + Return True if the mime_type is a common image type in GDrive. + (e.g. 'image/png', 'image/jpeg') + """ + return is_valid_image_type(mime_type) def _extract_sections_basic( - file: dict[str, str], service: GoogleDriveService + file: dict[str, str], + service: GoogleDriveService, + image_analysis_llm: LLM | None = None, ) -> list[Section]: + """ + Extends the existing logic to handle either a docx with embedded images + or standalone images (PNG, JPG, etc). + """ mime_type = file["mimeType"] link = file["webViewLink"] + file_name = file.get("name", file["id"]) supported_file_types = set(item.value for item in GDriveMimeType) + # 1) If the file is an image, retrieve the raw bytes, optionally summarize + if is_gdrive_image_mime_type(mime_type): + try: + response = service.files().get_media(fileId=file["id"]).execute() + + with get_session_with_current_tenant() as db_session: + section, _ = store_image_and_create_section( + db_session=db_session, + image_data=response, + file_name=file["id"], + display_name=file_name, + media_type=mime_type, + llm=image_analysis_llm, + file_origin=FileOrigin.CONNECTOR, + ) + return [section] + except Exception as e: + logger.warning(f"Failed to fetch or summarize image: {e}") + return [ + Section( + link=link, + text="", + image_file_name=link, + ) + ] + if mime_type not in supported_file_types: # Unsupported file types can still have a title, finding this way is still useful return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)] @@ -185,45 +241,63 @@ def _extract_sections_basic( GDriveMimeType.PLAIN_TEXT.value, GDriveMimeType.MARKDOWN.value, ]: - return [ - Section( - link=link, - text=service.files() - .get_media(fileId=file["id"]) - .execute() - .decode("utf-8"), - ) - ] + text_data = ( + service.files().get_media(fileId=file["id"]).execute().decode("utf-8") + ) + return [Section(link=link, text=text_data)] + # --------------------------- # Word, PowerPoint, PDF files - if mime_type in [ + elif mime_type in [ GDriveMimeType.WORD_DOC.value, GDriveMimeType.POWERPOINT.value, GDriveMimeType.PDF.value, ]: - response = service.files().get_media(fileId=file["id"]).execute() + response_bytes = service.files().get_media(fileId=file["id"]).execute() + + # Optionally use Unstructured if get_unstructured_api_key(): - return [ - Section( - link=link, - text=unstructured_to_text( - file=io.BytesIO(response), - file_name=file.get("name", file["id"]), - ), - ) - ] + text = unstructured_to_text( + file=io.BytesIO(response_bytes), + file_name=file_name, + ) + return [Section(link=link, text=text)] if mime_type == GDriveMimeType.WORD_DOC.value: - return [ - Section(link=link, text=docx_to_text(file=io.BytesIO(response))) - ] + # Use docx_to_text_and_images to get text plus embedded images + text, embedded_images = docx_to_text_and_images( + file=io.BytesIO(response_bytes), + ) + sections = [] + if text.strip(): + sections.append(Section(link=link, text=text.strip())) + + # Process each embedded image using the standardized function + with get_session_with_current_tenant() as db_session: + for idx, (img_data, img_name) in enumerate( + embedded_images, start=1 + ): + # Create a unique identifier for the embedded image + embedded_id = f"{file['id']}_embedded_{idx}" + + section, _ = store_image_and_create_section( + db_session=db_session, + image_data=img_data, + file_name=embedded_id, + display_name=img_name or f"{file_name} - image {idx}", + llm=image_analysis_llm, + file_origin=FileOrigin.CONNECTOR, + ) + sections.append(section) + return sections + elif mime_type == GDriveMimeType.PDF.value: - text, _ = read_pdf_file(file=io.BytesIO(response)) + text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_bytes)) return [Section(link=link, text=text)] + elif mime_type == GDriveMimeType.POWERPOINT.value: - return [ - Section(link=link, text=pptx_to_text(file=io.BytesIO(response))) - ] + text_data = pptx_to_text(io.BytesIO(response_bytes)) + return [Section(link=link, text=text_data)] # Catch-all case, should not happen since there should be specific handling # for each of the supported file types @@ -231,7 +305,8 @@ def _extract_sections_basic( logger.error(error_message) raise ValueError(error_message) - except Exception: + except Exception as e: + logger.exception(f"Error extracting sections from file: {e}") return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)] @@ -239,74 +314,62 @@ def convert_drive_item_to_document( file: GoogleDriveFileType, drive_service: GoogleDriveService, docs_service: GoogleDocsService, + image_analysis_llm: LLM | None, ) -> Document | None: + """ + Main entry point for converting a Google Drive file => Document object. + Now we accept an optional `llm` to pass to `_extract_sections_basic`. + """ try: - # Skip files that are shortcuts - if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: - logger.info("Ignoring Drive Shortcut Filetype") - return None - # Skip files that are folders - if file.get("mimeType") == DRIVE_FOLDER_TYPE: - logger.info("Ignoring Drive Folder Filetype") + # skip shortcuts or folders + if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]: + logger.info("Skipping shortcut/folder.") return None + # If it's a Google Doc, we might do advanced parsing sections: list[Section] = [] - - # Special handling for Google Docs to preserve structure, link - # to headers if file.get("mimeType") == GDriveMimeType.DOC.value: try: + # get_document_sections is the advanced approach for Google Docs sections = get_document_sections(docs_service, file["id"]) except Exception as e: logger.warning( - f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'." - " Falling back to basic extraction." + f"Failed to pull google doc sections from '{file['name']}': {e}. " + "Falling back to basic extraction." ) - # NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc + + # If not a doc, or if we failed above, do our 'basic' approach if not sections: - try: - # For all other file types just extract the text - sections = _extract_sections_basic(file, drive_service) + sections = _extract_sections_basic(file, drive_service, image_analysis_llm) - except HttpError as e: - reason = e.error_details[0]["reason"] if e.error_details else e.reason - message = e.error_details[0]["message"] if e.error_details else e.reason - if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON: - logger.warning( - f"Could not export file '{file['name']}' due to '{message}', skipping..." - ) - return None - - raise if not sections: return None + doc_id = file["webViewLink"] + updated_time = datetime.fromisoformat(file["modifiedTime"]).astimezone( + timezone.utc + ) + return Document( - id=file["webViewLink"], + id=doc_id, sections=sections, source=DocumentSource.GOOGLE_DRIVE, semantic_identifier=file["name"], - doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone( - timezone.utc - ), - metadata={} - if any(section.text for section in sections) - else {IGNORE_FOR_QA: "True"}, + doc_updated_at=updated_time, + metadata={}, # or any metadata from 'file' additional_info=file.get("id"), ) - except Exception as e: - if not CONTINUE_ON_CONNECTOR_FAILURE: - raise e - logger.exception("Ran into exception when pulling a file from Google Drive") + except Exception as e: + logger.exception(f"Error converting file '{file.get('name')}' to Document: {e}") + if not CONTINUE_ON_CONNECTOR_FAILURE: + raise return None def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None: - # Skip files that are folders or shortcuts if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]: return None - return SlimDocument( id=file["webViewLink"], perm_sync_data={ diff --git a/backend/onyx/connectors/models.py b/backend/onyx/connectors/models.py index 1fd53f0ace..6211970f71 100644 --- a/backend/onyx/connectors/models.py +++ b/backend/onyx/connectors/models.py @@ -28,7 +28,8 @@ class ConnectorMissingCredentialError(PermissionError): class Section(BaseModel): text: str - link: str | None + link: str | None = None + image_file_name: str | None = None class BasicExpertInfo(BaseModel): diff --git a/backend/onyx/connectors/vision_enabled_connector.py b/backend/onyx/connectors/vision_enabled_connector.py new file mode 100644 index 0000000000..385c703568 --- /dev/null +++ b/backend/onyx/connectors/vision_enabled_connector.py @@ -0,0 +1,45 @@ +""" +Mixin for connectors that need vision capabilities. +""" +from onyx.configs.app_configs import ENABLE_INDEXING_TIME_IMAGE_ANALYSIS +from onyx.llm.factory import get_default_llm_with_vision +from onyx.llm.interfaces import LLM +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +class VisionEnabledConnector: + """ + Mixin for connectors that need vision capabilities. + + This mixin provides a standard way to initialize a vision-capable LLM + for image analysis during indexing. + + Usage: + class MyConnector(LoadConnector, VisionEnabledConnector): + def __init__(self, ...): + super().__init__(...) + self.initialize_vision_llm() + """ + + def initialize_vision_llm(self) -> None: + """ + Initialize a vision-capable LLM if enabled by configuration. + + Sets self.image_analysis_llm to the LLM instance or None if disabled. + """ + self.image_analysis_llm: LLM | None = None + if ENABLE_INDEXING_TIME_IMAGE_ANALYSIS: + try: + self.image_analysis_llm = get_default_llm_with_vision() + if self.image_analysis_llm is None: + logger.warning( + "No LLM with vision found; image summarization will be disabled" + ) + except Exception as e: + logger.warning( + f"Failed to initialize vision LLM due to an error: {str(e)}. " + "Image summarization will be disabled." + ) + self.image_analysis_llm = None diff --git a/backend/onyx/connectors/web/connector.py b/backend/onyx/connectors/web/connector.py index 115ff6fdf4..f92ed3710f 100644 --- a/backend/onyx/connectors/web/connector.py +++ b/backend/onyx/connectors/web/connector.py @@ -333,7 +333,7 @@ class WebConnector(LoadConnector): if initial_url.split(".")[-1] == "pdf": # PDF files are not checked for links response = requests.get(initial_url) - page_text, metadata = read_pdf_file( + page_text, metadata, images = read_pdf_file( file=io.BytesIO(response.content) ) last_modified = response.headers.get("Last-Modified") diff --git a/backend/onyx/context/search/postprocessing/postprocessing.py b/backend/onyx/context/search/postprocessing/postprocessing.py index 34f8a18d92..9d3a07d933 100644 --- a/backend/onyx/context/search/postprocessing/postprocessing.py +++ b/backend/onyx/context/search/postprocessing/postprocessing.py @@ -1,11 +1,16 @@ +import base64 from collections.abc import Callable from collections.abc import Iterator from typing import cast import numpy +from langchain_core.messages import BaseMessage +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage from onyx.chat.models import SectionRelevancePiece from onyx.configs.app_configs import BLURB_SIZE +from onyx.configs.app_configs import ENABLE_SEARCH_TIME_IMAGE_ANALYSIS from onyx.configs.constants import RETURN_SEPARATOR from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MIN @@ -18,11 +23,15 @@ from onyx.context.search.models import MAX_METRICS_CONTENT from onyx.context.search.models import RerankingDetails from onyx.context.search.models import RerankMetricsContainer from onyx.context.search.models import SearchQuery +from onyx.db.engine import get_session_with_current_tenant from onyx.document_index.document_index_utils import ( translate_boost_count_to_multiplier, ) +from onyx.file_store.file_store import get_default_file_store from onyx.llm.interfaces import LLM +from onyx.llm.utils import message_to_string from onyx.natural_language_processing.search_nlp_models import RerankingModel +from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import FunctionCall @@ -30,6 +39,124 @@ from onyx.utils.threadpool_concurrency import run_functions_in_parallel from onyx.utils.timing import log_function_time +def update_image_sections_with_query( + sections: list[InferenceSection], + query: str, + llm: LLM, +) -> None: + """ + For each chunk in each section that has an image URL, call an LLM to produce + a new 'content' string that directly addresses the user's query about that image. + This implementation uses parallel processing for efficiency. + """ + logger = setup_logger() + logger.debug(f"Starting image section update with query: {query}") + + chunks_with_images = [] + for section in sections: + for chunk in section.chunks: + if chunk.image_file_name: + chunks_with_images.append(chunk) + + if not chunks_with_images: + logger.debug("No images to process in the sections") + return # No images to process + + logger.info(f"Found {len(chunks_with_images)} chunks with images to process") + + def process_image_chunk(chunk: InferenceChunk) -> tuple[str, str]: + try: + logger.debug( + f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_name}" + ) + with get_session_with_current_tenant() as db_session: + file_record = get_default_file_store(db_session).read_file( + cast(str, chunk.image_file_name), mode="b" + ) + if not file_record: + logger.error(f"Image file not found: {chunk.image_file_name}") + raise Exception("File not found") + file_content = file_record.read() + image_base64 = base64.b64encode(file_content).decode() + logger.debug( + f"Successfully loaded image data for {chunk.image_file_name}" + ) + + messages: list[BaseMessage] = [ + SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT), + HumanMessage( + content=[ + { + "type": "text", + "text": ( + f"The user's question is: '{query}'. " + "Please analyze the following image in that context:\n" + ), + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}", + }, + }, + ] + ), + ] + + raw_response = llm.invoke(messages) + + answer_text = message_to_string(raw_response).strip() + return ( + chunk.unique_id, + answer_text if answer_text else "No relevant info found.", + ) + + except Exception: + logger.exception( + f"Error updating image section with query source image url: {chunk.image_file_name}" + ) + return chunk.unique_id, "Error analyzing image." + + image_processing_tasks = [ + FunctionCall(process_image_chunk, (chunk,)) for chunk in chunks_with_images + ] + + logger.info( + f"Starting parallel processing of {len(image_processing_tasks)} image tasks" + ) + image_processing_results = run_functions_in_parallel(image_processing_tasks) + logger.info( + f"Completed parallel processing with {len(image_processing_results)} results" + ) + + # Create a mapping of chunk IDs to their processed content + chunk_id_to_content = {} + success_count = 0 + for task_id, result in image_processing_results.items(): + if result: + chunk_id, content = result + chunk_id_to_content[chunk_id] = content + success_count += 1 + else: + logger.error(f"Task {task_id} failed to return a valid result") + + logger.info( + f"Successfully processed {success_count}/{len(image_processing_results)} images" + ) + + # Update the chunks with the processed content + updated_count = 0 + for section in sections: + for chunk in section.chunks: + if chunk.unique_id in chunk_id_to_content: + chunk.content = chunk_id_to_content[chunk.unique_id] + updated_count += 1 + + logger.info( + f"Updated content for {updated_count} chunks with image analysis results" + ) + + logger = setup_logger() @@ -286,6 +413,10 @@ def search_postprocessing( # NOTE: if we don't rerank, we can return the chunks immediately # since we know this is the final order. # This way the user experience isn't delayed by the LLM step + if ENABLE_SEARCH_TIME_IMAGE_ANALYSIS: + update_image_sections_with_query( + retrieved_sections, search_query.query, llm + ) _log_top_section_links(search_query.search_type.value, retrieved_sections) yield retrieved_sections sections_yielded = True @@ -323,6 +454,13 @@ def search_postprocessing( ) else: _log_top_section_links(search_query.search_type.value, reranked_sections) + + # Add the image processing step here + if ENABLE_SEARCH_TIME_IMAGE_ANALYSIS: + update_image_sections_with_query( + reranked_sections, search_query.query, llm + ) + yield reranked_sections llm_selected_section_ids = ( diff --git a/backend/onyx/db/pg_file_store.py b/backend/onyx/db/pg_file_store.py index 469379cbe2..5d31358b49 100644 --- a/backend/onyx/db/pg_file_store.py +++ b/backend/onyx/db/pg_file_store.py @@ -148,3 +148,28 @@ def upsert_pgfilestore( db_session.commit() return pgfilestore + + +def save_bytes_to_pgfilestore( + db_session: Session, + raw_bytes: bytes, + media_type: str, + identifier: str, + display_name: str, + file_origin: FileOrigin = FileOrigin.OTHER, +) -> PGFileStore: + """ + Saves raw bytes to PGFileStore and returns the resulting record. + """ + file_name = f"{file_origin.name.lower()}_{identifier}" + lobj_oid = create_populate_lobj(BytesIO(raw_bytes), db_session) + pgfilestore = upsert_pgfilestore( + file_name=file_name, + display_name=display_name, + file_origin=file_origin, + file_type=media_type, + lobj_oid=lobj_oid, + db_session=db_session, + commit=True, + ) + return pgfilestore diff --git a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd index f846c32fca..d5d5220f85 100644 --- a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd @@ -55,6 +55,9 @@ schema DANSWER_CHUNK_NAME { field blurb type string { indexing: summary | attribute } + field image_file_name type string { + indexing: summary | attribute + } # https://docs.vespa.ai/en/attributes.html potential enum store for speed, but probably not worth it field source_type type string { indexing: summary | attribute diff --git a/backend/onyx/document_index/vespa/chunk_retrieval.py b/backend/onyx/document_index/vespa/chunk_retrieval.py index 5f3dff5c8e..8c7ee69637 100644 --- a/backend/onyx/document_index/vespa/chunk_retrieval.py +++ b/backend/onyx/document_index/vespa/chunk_retrieval.py @@ -31,6 +31,7 @@ from onyx.document_index.vespa_constants import DOC_UPDATED_AT from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.document_index.vespa_constants import HIDDEN +from onyx.document_index.vespa_constants import IMAGE_FILE_NAME from onyx.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS from onyx.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE from onyx.document_index.vespa_constants import MAX_OR_CONDITIONS @@ -130,6 +131,7 @@ def _vespa_hit_to_inference_chunk( section_continuation=fields[SECTION_CONTINUATION], document_id=fields[DOCUMENT_ID], source_type=fields[SOURCE_TYPE], + image_file_name=fields.get(IMAGE_FILE_NAME), title=fields.get(TITLE), semantic_identifier=fields[SEMANTIC_IDENTIFIER], boost=fields.get(BOOST, 1), @@ -211,6 +213,7 @@ def _get_chunks_via_visit_api( # Check if the response contains any documents response_data = response.json() + if "documents" in response_data: for document in response_data["documents"]: if filters.access_control_list: diff --git a/backend/onyx/document_index/vespa/indexing_utils.py b/backend/onyx/document_index/vespa/indexing_utils.py index 2802b59f58..81fc2a0d4c 100644 --- a/backend/onyx/document_index/vespa/indexing_utils.py +++ b/backend/onyx/document_index/vespa/indexing_utils.py @@ -32,6 +32,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.document_index.vespa_constants import DOCUMENT_SETS from onyx.document_index.vespa_constants import EMBEDDINGS +from onyx.document_index.vespa_constants import IMAGE_FILE_NAME from onyx.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS from onyx.document_index.vespa_constants import METADATA from onyx.document_index.vespa_constants import METADATA_LIST @@ -198,13 +199,13 @@ def _index_vespa_chunk( # which only calls VespaIndex.update ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()}, DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets}, + IMAGE_FILE_NAME: chunk.image_file_name, BOOST: chunk.boost, } if multitenant: if chunk.tenant_id: vespa_document_fields[TENANT_ID] = chunk.tenant_id - vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}" logger.debug(f'Indexing to URL "{vespa_url}"') res = http_client.post( diff --git a/backend/onyx/document_index/vespa_constants.py b/backend/onyx/document_index/vespa_constants.py index 82bb591983..15f889f3cd 100644 --- a/backend/onyx/document_index/vespa_constants.py +++ b/backend/onyx/document_index/vespa_constants.py @@ -77,6 +77,7 @@ PRIMARY_OWNERS = "primary_owners" SECONDARY_OWNERS = "secondary_owners" RECENCY_BIAS = "recency_bias" HIDDEN = "hidden" +IMAGE_FILE_NAME = "image_file_name" # Specific to Vespa, needed for highlighting matching keywords / section CONTENT_SUMMARY = "content_summary" @@ -94,6 +95,7 @@ YQL_BASE = ( f"{SEMANTIC_IDENTIFIER}, " f"{TITLE}, " f"{SECTION_CONTINUATION}, " + f"{IMAGE_FILE_NAME}, " f"{BOOST}, " f"{HIDDEN}, " f"{DOC_UPDATED_AT}, " diff --git a/backend/onyx/file_processing/extract_file_text.py b/backend/onyx/file_processing/extract_file_text.py index 349073a686..235cbcec62 100644 --- a/backend/onyx/file_processing/extract_file_text.py +++ b/backend/onyx/file_processing/extract_file_text.py @@ -9,15 +9,17 @@ from email.parser import Parser as EmailParser from io import BytesIO from pathlib import Path from typing import Any -from typing import Dict from typing import IO +from typing import List +from typing import Tuple import chardet import docx # type: ignore import openpyxl # type: ignore import pptx # type: ignore -from docx import Document +from docx import Document as DocxDocument from fastapi import UploadFile +from PIL import Image from pypdf import PdfReader from pypdf.errors import PdfStreamError @@ -31,10 +33,8 @@ from onyx.utils.logger import setup_logger logger = setup_logger() - TEXT_SECTION_SEPARATOR = "\n\n" - PLAIN_TEXT_FILE_EXTENSIONS = [ ".txt", ".md", @@ -49,7 +49,6 @@ PLAIN_TEXT_FILE_EXTENSIONS = [ ".yaml", ] - VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [ ".pdf", ".docx", @@ -58,6 +57,16 @@ VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [ ".eml", ".epub", ".html", + ".png", + ".jpg", + ".jpeg", + ".webp", +] + +IMAGE_MEDIA_TYPES = [ + "image/png", + "image/jpeg", + "image/webp", ] @@ -67,11 +76,13 @@ def is_text_file_extension(file_name: str) -> bool: def get_file_ext(file_path_or_name: str | Path) -> str: _, extension = os.path.splitext(file_path_or_name) - # standardize all extensions to be lowercase so that checks against - # VALID_FILE_EXTENSIONS and similar will work as intended return extension.lower() +def is_valid_media_type(media_type: str) -> bool: + return media_type in IMAGE_MEDIA_TYPES + + def is_valid_file_ext(ext: str) -> bool: return ext in VALID_FILE_EXTENSIONS @@ -79,17 +90,18 @@ def is_valid_file_ext(ext: str) -> bool: def is_text_file(file: IO[bytes]) -> bool: """ checks if the first 1024 bytes only contain printable or whitespace characters - if it does, then we say its a plaintext file + if it does, then we say it's a plaintext file """ raw_data = file.read(1024) + file.seek(0) text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F}) return all(c in text_chars for c in raw_data) def detect_encoding(file: IO[bytes]) -> str: raw_data = file.read(50000) - encoding = chardet.detect(raw_data)["encoding"] or "utf-8" file.seek(0) + encoding = chardet.detect(raw_data)["encoding"] or "utf-8" return encoding @@ -99,14 +111,14 @@ def is_macos_resource_fork_file(file_name: str) -> bool: ) -# To include additional metadata in the search index, add a .onyx_metadata.json file -# to the zip file. This file should contain a list of objects with the following format: -# [{ "filename": "file1.txt", "link": "https://example.com/file1.txt" }] def load_files_from_zip( zip_file_io: IO, ignore_macos_resource_fork_files: bool = True, ignore_dirs: bool = True, ) -> Iterator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]]]: + """ + If there's a .onyx_metadata.json in the zip, attach those metadata to each subfile. + """ with zipfile.ZipFile(zip_file_io, "r") as zip_file: zip_metadata = {} try: @@ -118,24 +130,31 @@ def load_files_from_zip( # convert list of dicts to dict of dicts zip_metadata = {d["filename"]: d for d in zip_metadata} except json.JSONDecodeError: - logger.warn(f"Unable to load {DANSWER_METADATA_FILENAME}") + logger.warning(f"Unable to load {DANSWER_METADATA_FILENAME}") except KeyError: logger.info(f"No {DANSWER_METADATA_FILENAME} file") for file_info in zip_file.infolist(): - with zip_file.open(file_info.filename, "r") as file: - if ignore_dirs and file_info.is_dir(): - continue + if ignore_dirs and file_info.is_dir(): + continue - if ( - ignore_macos_resource_fork_files - and is_macos_resource_fork_file(file_info.filename) - ) or file_info.filename == DANSWER_METADATA_FILENAME: - continue - yield file_info, file, zip_metadata.get(file_info.filename, {}) + if ( + ignore_macos_resource_fork_files + and is_macos_resource_fork_file(file_info.filename) + ) or file_info.filename == DANSWER_METADATA_FILENAME: + continue + + with zip_file.open(file_info.filename, "r") as subfile: + yield file_info, subfile, zip_metadata.get(file_info.filename, {}) def _extract_onyx_metadata(line: str) -> dict | None: + """ + Example: first line has: + + or + #DANSWER_METADATA={"title":"..."} + """ html_comment_pattern = r"" hashtag_pattern = r"#DANSWER_METADATA=\{(.*?)\}" @@ -161,9 +180,13 @@ def read_text_file( errors: str = "replace", ignore_onyx_metadata: bool = True, ) -> tuple[str, dict]: + """ + For plain text files. Optionally extracts Onyx metadata from the first line. + """ metadata = {} file_content_raw = "" for ind, line in enumerate(file): + # decode try: line = line.decode(encoding) if isinstance(line, bytes) else line except UnicodeDecodeError: @@ -173,131 +196,132 @@ def read_text_file( else line ) - if ind == 0: - metadata_or_none = ( - None if ignore_onyx_metadata else _extract_onyx_metadata(line) - ) - if metadata_or_none is not None: - metadata = metadata_or_none - else: - file_content_raw += line - else: - file_content_raw += line + # optionally parse metadata in the first line + if ind == 0 and not ignore_onyx_metadata: + potential_meta = _extract_onyx_metadata(line) + if potential_meta is not None: + metadata = potential_meta + continue + + file_content_raw += line return file_content_raw, metadata def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str: - """Extract text from a PDF file.""" - # Return only the extracted text from read_pdf_file - text, _ = read_pdf_file(file, pdf_pass) + """ + Extract text from a PDF. For embedded images, a more complex approach is needed. + This is a minimal approach returning text only. + """ + text, _, _ = read_pdf_file(file, pdf_pass) return text def read_pdf_file( - file: IO[Any], - pdf_pass: str | None = None, -) -> tuple[str, dict]: - metadata: Dict[str, Any] = {} + file: IO[Any], pdf_pass: str | None = None, extract_images: bool = False +) -> tuple[str, dict, list[tuple[bytes, str]]]: + """ + Returns the text, basic PDF metadata, and optionally extracted images. + """ + metadata: dict[str, Any] = {} + extracted_images: list[tuple[bytes, str]] = [] try: pdf_reader = PdfReader(file) - # If marked as encrypted and a password is provided, try to decrypt if pdf_reader.is_encrypted and pdf_pass is not None: decrypt_success = False - if pdf_pass is not None: - try: - decrypt_success = pdf_reader.decrypt(pdf_pass) != 0 - except Exception: - logger.error("Unable to decrypt pdf") + try: + decrypt_success = pdf_reader.decrypt(pdf_pass) != 0 + except Exception: + logger.error("Unable to decrypt pdf") if not decrypt_success: - # By user request, keep files that are unreadable just so they - # can be discoverable by title. - return "", metadata + return "", metadata, [] elif pdf_reader.is_encrypted: - logger.warning("No Password available to decrypt pdf, returning empty") - return "", metadata + logger.warning("No Password for an encrypted PDF, returning empty text.") + return "", metadata, [] - # Extract metadata from the PDF, removing leading '/' from keys if present - # This standardizes the metadata keys for consistency - metadata = {} + # Basic PDF metadata if pdf_reader.metadata is not None: for key, value in pdf_reader.metadata.items(): clean_key = key.lstrip("/") if isinstance(value, str) and value.strip(): metadata[clean_key] = value - elif isinstance(value, list) and all( isinstance(item, str) for item in value ): metadata[clean_key] = ", ".join(value) - return ( - TEXT_SECTION_SEPARATOR.join( - page.extract_text() for page in pdf_reader.pages - ), - metadata, + text = TEXT_SECTION_SEPARATOR.join( + page.extract_text() for page in pdf_reader.pages ) + + if extract_images: + for page_num, page in enumerate(pdf_reader.pages): + for image_file_object in page.images: + image = Image.open(io.BytesIO(image_file_object.data)) + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format=image.format) + img_bytes = img_byte_arr.getvalue() + + image_name = ( + f"page_{page_num + 1}_image_{image_file_object.name}." + f"{image.format.lower() if image.format else 'png'}" + ) + extracted_images.append((img_bytes, image_name)) + + return text, metadata, extracted_images + except PdfStreamError: - logger.exception("PDF file is not a valid PDF") + logger.exception("Invalid PDF file") except Exception: logger.exception("Failed to read PDF") - # File is still discoverable by title - # but the contents are not included as they cannot be parsed - return "", metadata + return "", metadata, [] -def docx_to_text(file: IO[Any]) -> str: - def is_simple_table(table: docx.table.Table) -> bool: - for row in table.rows: - # No omitted cells - if row.grid_cols_before > 0 or row.grid_cols_after > 0: - return False - - # No nested tables - if any(cell.tables for cell in row.cells): - return False - - return True - - def extract_cell_text(cell: docx.table._Cell) -> str: - cell_paragraphs = [para.text.strip() for para in cell.paragraphs] - return " ".join(p for p in cell_paragraphs if p) or "N/A" - +def docx_to_text_and_images( + file: IO[Any], +) -> Tuple[str, List[Tuple[bytes, str]]]: + """ + Extract text from a docx. If embed_images=True, also extract inline images. + Return (text_content, list_of_images). + """ paragraphs = [] + embedded_images: List[Tuple[bytes, str]] = [] + doc = docx.Document(file) - for item in doc.iter_inner_content(): - if isinstance(item, docx.text.paragraph.Paragraph): - paragraphs.append(item.text) - elif isinstance(item, docx.table.Table): - if not item.rows or not is_simple_table(item): - continue + # Grab text from paragraphs + for paragraph in doc.paragraphs: + paragraphs.append(paragraph.text) - # Every row is a new line, joined with a single newline - table_content = "\n".join( - [ - ",\t".join(extract_cell_text(cell) for cell in row.cells) - for row in item.rows - ] - ) - paragraphs.append(table_content) + # Reset position so we can re-load the doc (python-docx has read the stream) + # Note: if python-docx has fully consumed the stream, you may need to open it again from memory. + # For large docs, a more robust approach is needed. + # This is a simplified example. - # Docx already has good spacing between paragraphs - return "\n".join(paragraphs) + for rel_id, rel in doc.part.rels.items(): + if "image" in rel.reltype: + # image is typically in rel.target_part.blob + image_bytes = rel.target_part.blob + image_name = rel.target_part.partname + # store + embedded_images.append((image_bytes, os.path.basename(str(image_name)))) + + text_content = "\n".join(paragraphs) + return text_content, embedded_images def pptx_to_text(file: IO[Any]) -> str: presentation = pptx.Presentation(file) text_content = [] for slide_number, slide in enumerate(presentation.slides, start=1): - extracted_text = f"\nSlide {slide_number}:\n" + slide_text = f"\nSlide {slide_number}:\n" for shape in slide.shapes: if hasattr(shape, "text"): - extracted_text += shape.text + "\n" - text_content.append(extracted_text) + slide_text += shape.text + "\n" + text_content.append(slide_text) return TEXT_SECTION_SEPARATOR.join(text_content) @@ -305,18 +329,21 @@ def xlsx_to_text(file: IO[Any]) -> str: workbook = openpyxl.load_workbook(file, read_only=True) text_content = [] for sheet in workbook.worksheets: - sheet_string = "\n".join( - ",".join(map(str, row)) - for row in sheet.iter_rows(min_row=1, values_only=True) - ) - text_content.append(sheet_string) + rows = [] + for row in sheet.iter_rows(min_row=1, values_only=True): + row_str = ",".join(str(cell) if cell is not None else "" for cell in row) + rows.append(row_str) + sheet_str = "\n".join(rows) + text_content.append(sheet_str) return TEXT_SECTION_SEPARATOR.join(text_content) def eml_to_text(file: IO[Any]) -> str: - text_file = io.TextIOWrapper(file, encoding=detect_encoding(file)) + encoding = detect_encoding(file) + text_file = io.TextIOWrapper(file, encoding=encoding) parser = EmailParser() message = parser.parse(text_file) + text_content = [] for part in message.walk(): if part.get_content_type().startswith("text/plain"): @@ -342,8 +369,8 @@ def epub_to_text(file: IO[Any]) -> str: def file_io_to_text(file: IO[Any]) -> str: encoding = detect_encoding(file) - file_content_raw, _ = read_text_file(file, encoding=encoding) - return file_content_raw + file_content, _ = read_text_file(file, encoding=encoding) + return file_content def extract_file_text( @@ -352,9 +379,13 @@ def extract_file_text( break_on_unprocessable: bool = True, extension: str | None = None, ) -> str: + """ + Legacy function that returns *only text*, ignoring embedded images. + For backward-compatibility in code that only wants text. + """ extension_to_function: dict[str, Callable[[IO[Any]], str]] = { ".pdf": pdf_to_text, - ".docx": docx_to_text, + ".docx": lambda f: docx_to_text_and_images(f)[0], # no images ".pptx": pptx_to_text, ".xlsx": xlsx_to_text, ".eml": eml_to_text, @@ -368,24 +399,23 @@ def extract_file_text( return unstructured_to_text(file, file_name) except Exception as unstructured_error: logger.error( - f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing." + f"Failed to process with Unstructured: {str(unstructured_error)}. " + "Falling back to normal processing." ) - # Fall through to normal processing - final_extension: str - if file_name or extension: - if extension is not None: - final_extension = extension - elif file_name is not None: - final_extension = get_file_ext(file_name) + if extension is None: + extension = get_file_ext(file_name) - if is_valid_file_ext(final_extension): - return extension_to_function.get(final_extension, file_io_to_text)(file) + if is_valid_file_ext(extension): + func = extension_to_function.get(extension, file_io_to_text) + file.seek(0) + return func(file) - # Either the file somehow has no name or the extension is not one that we recognize + # If unknown extension, maybe it's a text file + file.seek(0) if is_text_file(file): return file_io_to_text(file) - raise ValueError("Unknown file extension and unknown text encoding") + raise ValueError("Unknown file extension or not recognized as text data") except Exception as e: if break_on_unprocessable: @@ -396,20 +426,93 @@ def extract_file_text( return "" +def extract_text_and_images( + file: IO[Any], + file_name: str, + pdf_pass: str | None = None, +) -> Tuple[str, List[Tuple[bytes, str]]]: + """ + Primary new function for the updated connector. + Returns (text_content, [(embedded_img_bytes, embedded_img_name), ...]). + """ + + try: + # Attempt unstructured if env var is set + if get_unstructured_api_key(): + # If the user doesn't want embedded images, unstructured is fine + file.seek(0) + text_content = unstructured_to_text(file, file_name) + return (text_content, []) + + extension = get_file_ext(file_name) + + # docx example for embedded images + if extension == ".docx": + file.seek(0) + text_content, images = docx_to_text_and_images(file) + return (text_content, images) + + # PDF example: we do not show complicated PDF image extraction here + # so we simply extract text for now and skip images. + if extension == ".pdf": + file.seek(0) + text_content, _, images = read_pdf_file(file, pdf_pass, extract_images=True) + return (text_content, images) + + # For PPTX, XLSX, EML, etc., we do not show embedded image logic here. + # You can do something similar to docx if needed. + if extension == ".pptx": + file.seek(0) + return (pptx_to_text(file), []) + + if extension == ".xlsx": + file.seek(0) + return (xlsx_to_text(file), []) + + if extension == ".eml": + file.seek(0) + return (eml_to_text(file), []) + + if extension == ".epub": + file.seek(0) + return (epub_to_text(file), []) + + if extension == ".html": + file.seek(0) + return (parse_html_page_basic(file), []) + + # If we reach here and it's a recognized text extension + if is_text_file_extension(file_name): + file.seek(0) + encoding = detect_encoding(file) + text_content_raw, _ = read_text_file( + file, encoding=encoding, ignore_onyx_metadata=False + ) + return (text_content_raw, []) + + # If it's an image file or something else, we do not parse embedded images from them + # just return empty text + file.seek(0) + return ("", []) + + except Exception as e: + logger.exception(f"Failed to extract text/images from {file_name}: {e}") + return ("", []) + + def convert_docx_to_txt( file: UploadFile, file_store: FileStore, file_path: str ) -> None: + """ + Helper to convert docx to a .txt file in the same filestore. + """ file.file.seek(0) docx_content = file.file.read() - doc = Document(BytesIO(docx_content)) + doc = DocxDocument(BytesIO(docx_content)) # Extract text from the document - full_text = [] - for para in doc.paragraphs: - full_text.append(para.text) - - # Join the extracted text - text_content = "\n".join(full_text) + all_paras = [p.text for p in doc.paragraphs] + text_content = "\n".join(all_paras) txt_file_path = docx_to_txt_filename(file_path) file_store.save_file( @@ -422,7 +525,4 @@ def convert_docx_to_txt( def docx_to_txt_filename(file_path: str) -> str: - """ - Convert a .docx file path to its corresponding .txt file path. - """ return file_path.rsplit(".", 1)[0] + ".txt" diff --git a/backend/onyx/file_processing/file_validation.py b/backend/onyx/file_processing/file_validation.py new file mode 100644 index 0000000000..fa4df5a429 --- /dev/null +++ b/backend/onyx/file_processing/file_validation.py @@ -0,0 +1,46 @@ +""" +Centralized file type validation utilities. +""" +# Standard image MIME types supported by most vision LLMs +IMAGE_MIME_TYPES = [ + "image/png", + "image/jpeg", + "image/jpg", + "image/webp", +] + +# Image types that should be excluded from processing +EXCLUDED_IMAGE_TYPES = [ + "image/bmp", + "image/tiff", + "image/gif", + "image/svg+xml", +] + + +def is_valid_image_type(mime_type: str) -> bool: + """ + Check if mime_type is a valid image type. + + Args: + mime_type: The MIME type to check + + Returns: + True if the MIME type is a valid image type, False otherwise + """ + if not mime_type: + return False + return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES + + +def is_supported_by_vision_llm(mime_type: str) -> bool: + """ + Check if this image type can be processed by vision LLMs. + + Args: + mime_type: The MIME type to check + + Returns: + True if the MIME type is supported by vision LLMs, False otherwise + """ + return mime_type in IMAGE_MIME_TYPES diff --git a/backend/onyx/file_processing/image_summarization.py b/backend/onyx/file_processing/image_summarization.py new file mode 100644 index 0000000000..b81da25ecf --- /dev/null +++ b/backend/onyx/file_processing/image_summarization.py @@ -0,0 +1,129 @@ +import base64 +from io import BytesIO + +from langchain_core.messages import BaseMessage +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage +from PIL import Image + +from onyx.llm.interfaces import LLM +from onyx.llm.utils import message_to_string +from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT +from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def prepare_image_bytes(image_data: bytes) -> str: + """Prepare image bytes for summarization. + Resizes image if it's larger than 20MB. Encodes image as a base64 string.""" + image_data = _resize_image_if_needed(image_data) + + # encode image (base64) + encoded_image = _encode_image_for_llm_prompt(image_data) + + return encoded_image + + +def summarize_image_pipeline( + llm: LLM, + image_data: bytes, + query: str | None = None, + system_prompt: str | None = None, +) -> str: + """Pipeline to generate a summary of an image. + Resizes images if it is bigger than 20MB. Encodes image as a base64 string. + And finally uses the Default LLM to generate a textual summary of the image.""" + # resize image if it's bigger than 20MB + encoded_image = prepare_image_bytes(image_data) + + summary = _summarize_image( + encoded_image, + llm, + query, + system_prompt, + ) + + return summary + + +def summarize_image_with_error_handling( + llm: LLM | None, + image_data: bytes, + context_name: str, + system_prompt: str = IMAGE_SUMMARIZATION_SYSTEM_PROMPT, + user_prompt_template: str = IMAGE_SUMMARIZATION_USER_PROMPT, +) -> str | None: + """Wrapper function that handles error cases and configuration consistently. + + Args: + llm: The LLM with vision capabilities to use for summarization + image_data: The raw image bytes + context_name: Name or title of the image for context + system_prompt: System prompt to use for the LLM + user_prompt_template: Template for the user prompt, should contain {title} placeholder + + Returns: + The image summary text, or None if summarization failed or is disabled + """ + if llm is None: + return None + + user_prompt = user_prompt_template.format(title=context_name) + return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt) + + +def _summarize_image( + encoded_image: str, + llm: LLM, + query: str | None = None, + system_prompt: str | None = None, +) -> str: + """Use default LLM (if it is multimodal) to generate a summary of an image.""" + + messages: list[BaseMessage] = [] + + if system_prompt: + messages.append(SystemMessage(content=system_prompt)) + + messages.append( + HumanMessage( + content=[ + {"type": "text", "text": query}, + {"type": "image_url", "image_url": {"url": encoded_image}}, + ], + ), + ) + + try: + return message_to_string(llm.invoke(messages)) + + except Exception as e: + raise ValueError(f"Summarization failed. Messages: {messages}") from e + + +def _encode_image_for_llm_prompt(image_data: bytes) -> str: + """Getting the base64 string.""" + base64_encoded_data = base64.b64encode(image_data).decode("utf-8") + + return f"data:image/jpeg;base64,{base64_encoded_data}" + + +def _resize_image_if_needed(image_data: bytes, max_size_mb: int = 20) -> bytes: + """Resize image if it's larger than the specified max size in MB.""" + max_size_bytes = max_size_mb * 1024 * 1024 + + if len(image_data) > max_size_bytes: + with Image.open(BytesIO(image_data)) as img: + # Reduce dimensions for better size reduction + img.thumbnail((1024, 1024), Image.Resampling.LANCZOS) + output = BytesIO() + + # Save with lower quality for compression + img.save(output, format="JPEG", quality=85) + resized_data = output.getvalue() + + return resized_data + + return image_data diff --git a/backend/onyx/file_processing/image_utils.py b/backend/onyx/file_processing/image_utils.py new file mode 100644 index 0000000000..3ae48df891 --- /dev/null +++ b/backend/onyx/file_processing/image_utils.py @@ -0,0 +1,70 @@ +from typing import Tuple + +from sqlalchemy.orm import Session + +from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE +from onyx.configs.constants import FileOrigin +from onyx.connectors.models import Section +from onyx.db.pg_file_store import save_bytes_to_pgfilestore +from onyx.file_processing.image_summarization import summarize_image_with_error_handling +from onyx.llm.interfaces import LLM +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def store_image_and_create_section( + db_session: Session, + image_data: bytes, + file_name: str, + display_name: str, + media_type: str = "image/unknown", + llm: LLM | None = None, + file_origin: FileOrigin = FileOrigin.OTHER, +) -> Tuple[Section, str | None]: + """ + Stores an image in PGFileStore and creates a Section object with optional summarization. + + Args: + db_session: Database session + image_data: Raw image bytes + file_name: Base identifier for the file + display_name: Human-readable name for the image + media_type: MIME type of the image + llm: Optional LLM with vision capabilities for summarization + file_origin: Origin of the file (e.g., CONFLUENCE, GOOGLE_DRIVE, etc.) + + Returns: + Tuple containing: + - Section object with image reference and optional summary text + - The file_name in PGFileStore or None if storage failed + """ + # Storage logic + stored_file_name = None + try: + pgfilestore = save_bytes_to_pgfilestore( + db_session=db_session, + raw_bytes=image_data, + media_type=media_type, + identifier=file_name, + display_name=display_name, + file_origin=file_origin, + ) + stored_file_name = pgfilestore.file_name + except Exception as e: + logger.error(f"Failed to store image: {e}") + if not CONTINUE_ON_CONNECTOR_FAILURE: + raise + return Section(text=""), None + + # Summarization logic + summary_text = "" + if llm: + summary_text = ( + summarize_image_with_error_handling(llm, image_data, display_name) or "" + ) + + return ( + Section(text=summary_text, image_file_name=stored_file_name), + stored_file_name, + ) diff --git a/backend/onyx/indexing/chunker.py b/backend/onyx/indexing/chunker.py index e42dafb293..91c8505920 100644 --- a/backend/onyx/indexing/chunker.py +++ b/backend/onyx/indexing/chunker.py @@ -23,12 +23,9 @@ from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT CHUNK_OVERLAP = 0 # Fairly arbitrary numbers but the general concept is we don't want the title/metadata to # overwhelm the actual contents of the chunk -# For example in a rare case, this could be 128 tokens for the 512 chunk and title prefix -# could be another 128 tokens leaving 256 for the actual contents MAX_METADATA_PERCENTAGE = 0.25 CHUNK_MIN_CONTENT = 256 - logger = setup_logger() @@ -36,16 +33,8 @@ def _get_metadata_suffix_for_document_index( metadata: dict[str, str | list[str]], include_separator: bool = False ) -> tuple[str, str]: """ - Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding - and a string of all of the values for the keyword search - - For example, if we have the following metadata: - { - "author": "John Doe", - "space": "Engineering" - } - The vector embedding string should include the relation between the key and value wheres as for keyword we only want John Doe - and Engineering. The keys are repeat and much more noisy. + Returns the metadata as a natural language string representation with all of the keys and values + for the vector embedding and a string of all of the values for the keyword search. """ if not metadata: return "", "" @@ -74,12 +63,17 @@ def _get_metadata_suffix_for_document_index( def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwareChunk: + """ + Combines multiple DocAwareChunks into one large chunk (for “multipass” mode), + appending the content and adjusting source_links accordingly. + """ merged_chunk = DocAwareChunk( source_document=chunks[0].source_document, chunk_id=chunks[0].chunk_id, blurb=chunks[0].blurb, content=chunks[0].content, source_links=chunks[0].source_links or {}, + image_file_name=None, section_continuation=(chunks[0].chunk_id > 0), title_prefix=chunks[0].title_prefix, metadata_suffix_semantic=chunks[0].metadata_suffix_semantic, @@ -103,6 +97,9 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar def generate_large_chunks(chunks: list[DocAwareChunk]) -> list[DocAwareChunk]: + """ + Generates larger “grouped” chunks by combining sets of smaller chunks. + """ large_chunks = [] for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO)): chunk_group = chunks[i : i + LARGE_CHUNK_RATIO] @@ -172,23 +169,60 @@ class Chunker: while start < total_tokens: end = min(start + content_token_limit, total_tokens) token_chunk = tokens[start:end] - # Join the tokens to reconstruct the text chunk_text = " ".join(token_chunk) chunks.append(chunk_text) start = end return chunks def _extract_blurb(self, text: str) -> str: + """ + Extract a short blurb from the text (first chunk of size `blurb_size`). + """ texts = self.blurb_splitter.split_text(text) if not texts: return "" return texts[0] def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None: + """ + For “multipass” mode: additional sub-chunks (mini-chunks) for use in certain embeddings. + """ if self.mini_chunk_splitter and chunk_text.strip(): return self.mini_chunk_splitter.split_text(chunk_text) return None + # ADDED: extra param image_url to store in the chunk + def _create_chunk( + self, + document: Document, + chunks_list: list[DocAwareChunk], + text: str, + links: dict[int, str], + is_continuation: bool = False, + title_prefix: str = "", + metadata_suffix_semantic: str = "", + metadata_suffix_keyword: str = "", + image_file_name: str | None = None, + ) -> None: + """ + Helper to create a new DocAwareChunk, append it to chunks_list. + """ + new_chunk = DocAwareChunk( + source_document=document, + chunk_id=len(chunks_list), + blurb=self._extract_blurb(text), + content=text, + source_links=links or {0: ""}, + image_file_name=image_file_name, + section_continuation=is_continuation, + title_prefix=title_prefix, + metadata_suffix_semantic=metadata_suffix_semantic, + metadata_suffix_keyword=metadata_suffix_keyword, + mini_chunk_texts=self._get_mini_chunk_texts(text), + large_chunk_id=None, + ) + chunks_list.append(new_chunk) + def _chunk_document( self, document: Document, @@ -198,122 +232,156 @@ class Chunker: content_token_limit: int, ) -> list[DocAwareChunk]: """ - Loops through sections of the document, adds metadata and converts them into chunks. + Loops through sections of the document, converting them into one or more chunks. + If a section has an image_link, we treat it as a dedicated chunk. """ + chunks: list[DocAwareChunk] = [] link_offsets: dict[int, str] = {} chunk_text = "" - def _create_chunk( - text: str, - links: dict[int, str], - is_continuation: bool = False, - ) -> DocAwareChunk: - return DocAwareChunk( - source_document=document, - chunk_id=len(chunks), - blurb=self._extract_blurb(text), - content=text, - source_links=links or {0: ""}, - section_continuation=is_continuation, - title_prefix=title_prefix, - metadata_suffix_semantic=metadata_suffix_semantic, - metadata_suffix_keyword=metadata_suffix_keyword, - mini_chunk_texts=self._get_mini_chunk_texts(text), - large_chunk_id=None, - ) - - section_link_text: str - for section_idx, section in enumerate(document.sections): section_text = clean_text(section.text) section_link_text = section.link or "" - # If there is no useful content, not even the title, just drop it + # ADDED: if the Section has an image link + image_url = section.image_file_name + + # If there is no useful content, skip if not section_text and (not document.title or section_idx > 0): - # If a section is empty and the document has no title, we can just drop it. We return a list of - # DocAwareChunks where each one contains the necessary information needed down the line for indexing. - # There is no concern about dropping whole documents from this list, it should not cause any indexing failures. logger.warning( - f"Skipping section {section.text} from document " - f"{document.semantic_identifier} due to empty text after cleaning " - f"with link {section_link_text}" + f"Skipping empty or irrelevant section in doc " + f"{document.semantic_identifier}, link={section_link_text}" ) continue + # CASE 1: If this is an image section, force a separate chunk + if image_url: + # First, if we have any partially built text chunk, finalize it + if chunk_text.strip(): + self._create_chunk( + document, + chunks, + chunk_text, + link_offsets, + is_continuation=False, + title_prefix=title_prefix, + metadata_suffix_semantic=metadata_suffix_semantic, + metadata_suffix_keyword=metadata_suffix_keyword, + ) + chunk_text = "" + link_offsets = {} + + # Create a chunk specifically for this image + # (If the section has text describing the image, use that as content) + self._create_chunk( + document, + chunks, + section_text, + links={0: section_link_text} + if section_link_text + else {}, # No text offsets needed for images + image_file_name=image_url, + title_prefix=title_prefix, + metadata_suffix_semantic=metadata_suffix_semantic, + metadata_suffix_keyword=metadata_suffix_keyword, + ) + # Continue to next section + continue + + # CASE 2: Normal text section section_token_count = len(self.tokenizer.tokenize(section_text)) - # Large sections are considered self-contained/unique - # Therefore, they start a new chunk and are not concatenated - # at the end by other sections + # If the section is large on its own, split it separately if section_token_count > content_token_limit: - if chunk_text: - chunks.append(_create_chunk(chunk_text, link_offsets)) - link_offsets = {} + if chunk_text.strip(): + self._create_chunk( + document, + chunks, + chunk_text, + link_offsets, + False, + title_prefix, + metadata_suffix_semantic, + metadata_suffix_keyword, + ) chunk_text = "" + link_offsets = {} split_texts = self.chunk_splitter.split_text(section_text) - for i, split_text in enumerate(split_texts): + # If even the split_text is bigger than strict limit, further split if ( STRICT_CHUNK_TOKEN_LIMIT - and - # Tokenizer only runs if STRICT_CHUNK_TOKEN_LIMIT is true - len(self.tokenizer.tokenize(split_text)) > content_token_limit + and len(self.tokenizer.tokenize(split_text)) + > content_token_limit ): - # If STRICT_CHUNK_TOKEN_LIMIT is true, manually check - # the token count of each split text to ensure it is - # not larger than the content_token_limit smaller_chunks = self._split_oversized_chunk( split_text, content_token_limit ) - for i, small_chunk in enumerate(smaller_chunks): - chunks.append( - _create_chunk( - text=small_chunk, - links={0: section_link_text}, - is_continuation=(i != 0), - ) + for j, small_chunk in enumerate(smaller_chunks): + self._create_chunk( + document, + chunks, + small_chunk, + {0: section_link_text}, + is_continuation=(j != 0), + title_prefix=title_prefix, + metadata_suffix_semantic=metadata_suffix_semantic, + metadata_suffix_keyword=metadata_suffix_keyword, ) else: - chunks.append( - _create_chunk( - text=split_text, - links={0: section_link_text}, - is_continuation=(i != 0), - ) + self._create_chunk( + document, + chunks, + split_text, + {0: section_link_text}, + is_continuation=(i != 0), + title_prefix=title_prefix, + metadata_suffix_semantic=metadata_suffix_semantic, + metadata_suffix_keyword=metadata_suffix_keyword, ) - continue + # If we can still fit this section into the current chunk, do so current_token_count = len(self.tokenizer.tokenize(chunk_text)) current_offset = len(shared_precompare_cleanup(chunk_text)) - # In the case where the whole section is shorter than a chunk, either add - # to chunk or start a new one next_section_tokens = ( len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count ) + if next_section_tokens + current_token_count <= content_token_limit: if chunk_text: chunk_text += SECTION_SEPARATOR chunk_text += section_text link_offsets[current_offset] = section_link_text else: - chunks.append(_create_chunk(chunk_text, link_offsets)) + # finalize the existing chunk + self._create_chunk( + document, + chunks, + chunk_text, + link_offsets, + False, + title_prefix, + metadata_suffix_semantic, + metadata_suffix_keyword, + ) + # start a new chunk link_offsets = {0: section_link_text} chunk_text = section_text - # Once we hit the end, if we're still in the process of building a chunk, add what we have. - # If there is only whitespace left then don't include it. If there are no chunks at all - # from the doc, we can just create a single chunk with the title. + # finalize any leftover text chunk if chunk_text.strip() or not chunks: - chunks.append( - _create_chunk( - chunk_text, - link_offsets or {0: section_link_text}, - ) + self._create_chunk( + document, + chunks, + chunk_text, + link_offsets or {0: ""}, # safe default + False, + title_prefix, + metadata_suffix_semantic, + metadata_suffix_keyword, ) - - # If the chunk does not have any useable content, it will not be indexed return chunks def _handle_single_document(self, document: Document) -> list[DocAwareChunk]: @@ -321,10 +389,12 @@ class Chunker: if document.source == DocumentSource.GMAIL: logger.debug(f"Chunking {document.semantic_identifier}") + # Title prep title = self._extract_blurb(document.get_title_for_document_index() or "") title_prefix = title + RETURN_SEPARATOR if title else "" title_tokens = len(self.tokenizer.tokenize(title_prefix)) + # Metadata prep metadata_suffix_semantic = "" metadata_suffix_keyword = "" metadata_tokens = 0 @@ -337,19 +407,20 @@ class Chunker: ) metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic)) + # If metadata is too large, skip it in the semantic content if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE: - # Note: we can keep the keyword suffix even if the semantic suffix is too long to fit in the model - # context, there is no limit for the keyword component metadata_suffix_semantic = "" metadata_tokens = 0 + # Adjust content token limit to accommodate title + metadata content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens - # If there is not enough context remaining then just index the chunk with no prefix/suffix if content_token_limit <= CHUNK_MIN_CONTENT: + # Not enough space left, so revert to full chunk without the prefix content_token_limit = self.chunk_token_limit title_prefix = "" metadata_suffix_semantic = "" + # Chunk the document normal_chunks = self._chunk_document( document, title_prefix, @@ -358,6 +429,7 @@ class Chunker: content_token_limit, ) + # Optional “multipass” large chunk creation if self.enable_multipass and self.enable_large_chunks: large_chunks = generate_large_chunks(normal_chunks) normal_chunks.extend(large_chunks) @@ -371,9 +443,8 @@ class Chunker: """ final_chunks: list[DocAwareChunk] = [] for document in documents: - if self.callback: - if self.callback.should_stop(): - raise RuntimeError("Chunker.chunk: Stop signal detected") + if self.callback and self.callback.should_stop(): + raise RuntimeError("Chunker.chunk: Stop signal detected") chunks = self._handle_single_document(document) final_chunks.extend(chunks) diff --git a/backend/onyx/indexing/models.py b/backend/onyx/indexing/models.py index cffbdaa9bb..5dffe1b089 100644 --- a/backend/onyx/indexing/models.py +++ b/backend/onyx/indexing/models.py @@ -29,6 +29,7 @@ class BaseChunk(BaseModel): content: str # Holds the link and the offsets into the raw Chunk text source_links: dict[int, str] | None + image_file_name: str | None # True if this Chunk's start is not at the start of a Section section_continuation: bool diff --git a/backend/onyx/llm/factory.py b/backend/onyx/llm/factory.py index dff83e07f3..4c8a5f0936 100644 --- a/backend/onyx/llm/factory.py +++ b/backend/onyx/llm/factory.py @@ -6,12 +6,14 @@ from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_TEMPERATURE from onyx.db.engine import get_session_context_manager from onyx.db.llm import fetch_default_provider +from onyx.db.llm import fetch_existing_llm_providers from onyx.db.llm import fetch_provider from onyx.db.models import Persona from onyx.llm.chat_llm import DefaultMultiLLM from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.interfaces import LLM from onyx.llm.override_models import LLMOverride +from onyx.llm.utils import model_supports_image_input from onyx.utils.headers import build_llm_extra_headers from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger @@ -86,6 +88,48 @@ def get_llms_for_persona( return _create_llm(model), _create_llm(fast_model) +def get_default_llm_with_vision( + timeout: int | None = None, + temperature: float | None = None, + additional_headers: dict[str, str] | None = None, + long_term_logger: LongTermLogger | None = None, +) -> LLM | None: + if DISABLE_GENERATIVE_AI: + raise GenAIDisabledException() + + with get_session_context_manager() as db_session: + llm_providers = fetch_existing_llm_providers(db_session) + + if not llm_providers: + return None + + for provider in llm_providers: + model_name = provider.default_model_name + fast_model_name = ( + provider.fast_default_model_name or provider.default_model_name + ) + + if not model_name or not fast_model_name: + continue + + if model_supports_image_input(model_name, provider.provider): + return get_llm( + provider=provider.provider, + model=model_name, + deployment_name=provider.deployment_name, + api_key=provider.api_key, + api_base=provider.api_base, + api_version=provider.api_version, + custom_config=provider.custom_config, + timeout=timeout, + temperature=temperature, + additional_headers=additional_headers, + long_term_logger=long_term_logger, + ) + + raise ValueError("No LLM provider found that supports image input") + + def get_default_llms( timeout: int | None = None, temperature: float | None = None, diff --git a/backend/onyx/prompts/image_analysis.py b/backend/onyx/prompts/image_analysis.py new file mode 100644 index 0000000000..290f80526b --- /dev/null +++ b/backend/onyx/prompts/image_analysis.py @@ -0,0 +1,22 @@ +# Used for creating embeddings of images for vector search +IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """ +You are an assistant for summarizing images for retrieval. +Summarize the content of the following image and be as precise as possible. +The summary will be embedded and used to retrieve the original image. +Therefore, write a concise summary of the image that is optimized for retrieval. +""" + +# Prompt for generating image descriptions with filename context +IMAGE_SUMMARIZATION_USER_PROMPT = """ +The image has the file name '{title}'. +Describe precisely and concisely what the image shows. +""" + + +# Used for analyzing images in response to user queries at search time +IMAGE_ANALYSIS_SYSTEM_PROMPT = ( + "You are an AI assistant specialized in describing images.\n" + "You will receive a user question plus an image URL. Provide a concise textual answer.\n" + "Focus on aspects of the image that are relevant to the user's question.\n" + "Be specific and detailed about visual elements that directly address the query.\n" +) diff --git a/backend/onyx/seeding/load_docs.py b/backend/onyx/seeding/load_docs.py index 38ad523451..b415c2920d 100644 --- a/backend/onyx/seeding/load_docs.py +++ b/backend/onyx/seeding/load_docs.py @@ -55,7 +55,11 @@ def _create_indexable_chunks( # The section is not really used past this point since we have already done the other processing # for the chunking and embedding. sections=[ - Section(text=preprocessed_doc["content"], link=preprocessed_doc["url"]) + Section( + text=preprocessed_doc["content"], + link=preprocessed_doc["url"], + image_file_name=None, + ) ], source=DocumentSource.WEB, semantic_identifier=preprocessed_doc["title"], @@ -93,6 +97,7 @@ def _create_indexable_chunks( document_sets=set(), boost=DEFAULT_BOOST, large_chunk_id=None, + image_file_name=None, ) chunks.append(chunk) diff --git a/backend/onyx/utils/error_handling.py b/backend/onyx/utils/error_handling.py new file mode 100644 index 0000000000..c39709bd55 --- /dev/null +++ b/backend/onyx/utils/error_handling.py @@ -0,0 +1,23 @@ +""" +Standardized error handling utilities. +""" +from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def handle_connector_error(e: Exception, context: str) -> None: + """ + Standard error handling for connectors. + + Args: + e: The exception that was raised + context: A description of where the error occurred + + Raises: + The original exception if CONTINUE_ON_CONNECTOR_FAILURE is False + """ + logger.error(f"Error in {context}: {e}", exc_info=e) + if not CONTINUE_ON_CONNECTOR_FAILURE: + raise diff --git a/backend/scripts/debugging/onyx_vespa.py b/backend/scripts/debugging/onyx_vespa.py index 954072feb3..f1c7eab4a8 100644 --- a/backend/scripts/debugging/onyx_vespa.py +++ b/backend/scripts/debugging/onyx_vespa.py @@ -207,7 +207,7 @@ def query_vespa( yql: str, tenant_id: Optional[str] = None, limit: int = 10 ) -> List[Dict[str, Any]]: # Perform a Vespa query using YQL syntax. - filters = IndexFilters(tenant_id=tenant_id, access_control_list=[]) + filters = IndexFilters(tenant_id=None, access_control_list=[]) filter_string = build_vespa_filters(filters, remove_trailing_and=True) full_yql = yql.strip() if filter_string: @@ -472,9 +472,7 @@ def get_document_acls( print("-" * 80) -def get_current_chunk_count( - document_id: str, index_name: str, tenant_id: str -) -> int | None: +def get_current_chunk_count(document_id: str) -> int | None: with get_session_with_current_tenant() as session: return ( session.query(Document.chunk_count) @@ -486,7 +484,7 @@ def get_current_chunk_count( def get_number_of_chunks_we_think_exist( document_id: str, index_name: str, tenant_id: str ) -> int: - current_chunk_count = get_current_chunk_count(document_id, index_name, tenant_id) + current_chunk_count = get_current_chunk_count(document_id) print(f"Current chunk count: {current_chunk_count}") doc_info = VespaIndex.enrich_basic_chunk_info( @@ -636,6 +634,7 @@ def delete_where( Removes visited documents in `cluster` where the given selection is true, using Vespa's 'delete where' endpoint. + :param index_name: Typically / from your schema :param selection: The selection string, e.g., "true" or "foo contains 'bar'" :param cluster: The name of the cluster where documents reside @@ -799,7 +798,7 @@ def main() -> None: args = parser.parse_args() vespa_debug = VespaDebugging(args.tenant_id) - CURRENT_TENANT_ID_CONTEXTVAR.set(args.tenant_id) + CURRENT_TENANT_ID_CONTEXTVAR.set(args.tenant_id or "public") if args.action == "delete-all-documents": if not args.tenant_id: parser.error("--tenant-id is required for delete-all-documents action") diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py index 4353fc4f4a..b94b413e28 100644 --- a/backend/scripts/query_time_check/seed_dummy_docs.py +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -71,6 +71,7 @@ def generate_dummy_chunk( title_embedding=generate_random_embedding(embedding_dim), large_chunk_id=None, large_chunk_reference_ids=[], + image_file_name=None, ) document_set_names = [] diff --git a/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py b/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py index 8b7a668210..162db23c5b 100644 --- a/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py +++ b/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py @@ -31,6 +31,7 @@ def create_test_chunk( metadata={}, match_highlights=[], updated_at=datetime.now(), + image_file_name=None, ) diff --git a/backend/tests/unit/onyx/chat/conftest.py b/backend/tests/unit/onyx/chat/conftest.py index 69b835c564..8837b1eb2e 100644 --- a/backend/tests/unit/onyx/chat/conftest.py +++ b/backend/tests/unit/onyx/chat/conftest.py @@ -80,6 +80,7 @@ def mock_inference_sections() -> list[InferenceSection]: updated_at=datetime(2023, 1, 1), source_links={0: "https://example.com/doc1"}, match_highlights=[], + image_file_name=None, ), chunks=MagicMock(), ), @@ -102,6 +103,7 @@ def mock_inference_sections() -> list[InferenceSection]: updated_at=datetime(2023, 1, 2), source_links={0: "https://example.com/doc2"}, match_highlights=[], + image_file_name=None, ), chunks=MagicMock(), ), diff --git a/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py b/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py index b0a2cb6921..1bc7cc12e0 100644 --- a/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py +++ b/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py @@ -150,6 +150,7 @@ def test_fuzzy_match_quotes_to_docs() -> None: metadata={}, match_highlights=[], updated_at=None, + image_file_name=None, ) test_chunk_1 = InferenceChunk( document_id="test doc 1", @@ -168,6 +169,7 @@ def test_fuzzy_match_quotes_to_docs() -> None: metadata={}, match_highlights=[], updated_at=None, + image_file_name=None, ) test_quotes = [ diff --git a/backend/tests/unit/onyx/chat/test_prune_and_merge.py b/backend/tests/unit/onyx/chat/test_prune_and_merge.py index bcc471748b..a4037b3d0a 100644 --- a/backend/tests/unit/onyx/chat/test_prune_and_merge.py +++ b/backend/tests/unit/onyx/chat/test_prune_and_merge.py @@ -37,6 +37,7 @@ def create_inference_chunk( metadata={}, match_highlights=[], updated_at=None, + image_file_name=None, ) diff --git a/backend/tests/unit/onyx/indexing/test_embedder.py b/backend/tests/unit/onyx/indexing/test_embedder.py index 0c7d6b43f5..7585f7a1d1 100644 --- a/backend/tests/unit/onyx/indexing/test_embedder.py +++ b/backend/tests/unit/onyx/indexing/test_embedder.py @@ -62,6 +62,7 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N mini_chunk_texts=None, large_chunk_reference_ids=[], large_chunk_id=None, + image_file_name=None, ) ] diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 6056a26467..3071af72a1 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -204,7 +204,6 @@ export function ChatPage({ const [documentSidebarVisible, setDocumentSidebarVisible] = useState(false); const [proSearchEnabled, setProSearchEnabled] = useState(proSearchToggled); - const [streamingAllowed, setStreamingAllowed] = useState(false); const toggleProSearch = () => { Cookies.set( PRO_SEARCH_TOGGLED_COOKIE_NAME, diff --git a/web/src/components/chat/TextView.tsx b/web/src/components/chat/TextView.tsx index a70da405f4..4cd3c04206 100644 --- a/web/src/components/chat/TextView.tsx +++ b/web/src/components/chat/TextView.tsx @@ -41,6 +41,15 @@ export default function TextView({ return markdownFormats.some((format) => mimeType.startsWith(format)); }; + const isImageFormat = (mimeType: string) => { + const imageFormats = [ + "image/png", + "image/jpeg", + "image/gif", + "image/svg+xml", + ]; + return imageFormats.some((format) => mimeType.startsWith(format)); + }; // Detect if a given MIME type can be rendered in an