wire off image downloading for confluence and gdrive if not enabled i… (#4305)

* wire off image downloading for confluence and gdrive if not enabled in settings

* fix partial func

* fix confluence basic test

* add test for skipping/allowing images

* review comments

* skip allow images test

* mock function using the db

* mock at the proper level

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-03-20 16:10:28 -07:00 committed by GitHub
parent 0292ca2445
commit 6d330131fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 117 additions and 25 deletions

View File

@ -114,6 +114,7 @@ class ConfluenceConnector(
self.timezone_offset = timezone_offset
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
self.allow_images = False
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
@ -158,6 +159,9 @@ class ConfluenceConnector(
"max_backoff_seconds": 60,
}
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
@ -233,7 +237,9 @@ class ConfluenceConnector(
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = f"{self.wiki_base}{page['_links']['webui']}"
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
# Get the page content
page_content = extract_text_from_confluence_html(
@ -264,6 +270,7 @@ class ConfluenceConnector(
self.confluence_client,
attachment,
page_id,
self.allow_images,
)
if result and result.text:
@ -304,13 +311,14 @@ class ConfluenceConnector(
if "version" in page and "by" in page["version"]:
author = page["version"]["by"]
display_name = author.get("displayName", "Unknown")
primary_owners.append(BasicExpertInfo(display_name=display_name))
email = author.get("email", "unknown@domain.invalid")
primary_owners.append(
BasicExpertInfo(display_name=display_name, email=email)
)
# Create the document
return Document(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
id=page_url,
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
@ -373,6 +381,7 @@ class ConfluenceConnector(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
allow_images=self.allow_images,
)
if response is None:
continue

View File

@ -112,6 +112,7 @@ def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None,
allow_images: bool,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
@ -119,7 +120,7 @@ def process_attachment(
"""
try:
# Get the media type from the attachment metadata
media_type = attachment.get("metadata", {}).get("mediaType", "")
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult(
@ -138,7 +139,14 @@ def process_attachment(
attachment_size = attachment["extensions"]["fileSize"]
if not media_type.startswith("image/"):
if media_type.startswith("image/"):
if not allow_images:
return AttachmentProcessingResult(
text=None,
file_name=None,
error="Image downloading is not enabled",
)
else:
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
@ -294,6 +302,7 @@ def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_id: str,
allow_images: bool,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
@ -309,7 +318,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(confluence_client, attachment, page_id)
result = process_attachment(confluence_client, attachment, page_id, allow_images)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"

View File

@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
@ -184,6 +185,8 @@ def instantiate_connector(
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
return connector

View File

@ -86,6 +86,7 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
@ -101,6 +102,7 @@ def _convert_single_file(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
allow_images=allow_images,
)
@ -234,6 +236,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._retrieved_ids: set[str] = set()
self.allow_images = False
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@property
def primary_admin_email(self) -> str:
@ -900,6 +906,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
)
# Fetch files in batches

View File

@ -79,6 +79,7 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
allow_images: bool,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
@ -87,6 +88,10 @@ def _extract_sections_basic(
link = file.get("webViewLink", "")
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
@ -207,6 +212,7 @@ def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: Callable[[], GoogleDriveService],
docs_service: Callable[[], GoogleDocsService],
allow_images: bool,
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
@ -236,7 +242,7 @@ def convert_drive_item_to_document(
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _extract_sections_basic(file, drive_service())
sections = _extract_sections_basic(file, drive_service(), allow_images)
# If we still don't have any sections, skip this file
if not sections:

View File

@ -60,6 +60,10 @@ class BaseConnector(abc.ABC, Generic[CT]):
Default is a no-op (always successful).
"""
def set_allow_images(self, value: bool) -> None:
"""Implement if the underlying connector wants to skip/allow image downloading
based on the application level image analysis setting."""
def build_dummy_checkpoint(self) -> CT:
# TODO: find a way to make this work without type: ignore
return ConnectorCheckpoint(has_more=True) # type: ignore

View File

@ -1,5 +1,6 @@
import os
import time
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
@ -7,15 +8,16 @@ import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.utils import AttachmentProcessingResult
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from onyx.connectors.models import Document
@pytest.fixture
def confluence_connector() -> ConfluenceConnector:
def confluence_connector(space: str) -> ConfluenceConnector:
connector = ConfluenceConnector(
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
space=os.environ["CONFLUENCE_TEST_SPACE"],
space=space,
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
)
@ -32,14 +34,15 @@ def confluence_connector() -> ConfluenceConnector:
return connector
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_basic(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
@ -50,15 +53,14 @@ def test_confluence_connector_basic(
page_within_a_page_doc: Document | None = None
page_doc: Document | None = None
txt_doc: Document | None = None
for doc in doc_batch:
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
page_doc = doc
elif ".txt" in doc.semantic_identifier:
txt_doc = doc
elif doc.semantic_identifier == "Page Within A Page":
page_within_a_page_doc = doc
else:
pass
assert page_within_a_page_doc is not None
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
@ -79,7 +81,7 @@ def test_confluence_connector_basic(
assert page_doc.metadata["labels"] == ["testlabel"]
assert page_doc.primary_owners
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
assert len(page_doc.sections) == 1
assert len(page_doc.sections) == 2 # page text + attachment text
page_section = page_doc.sections[0]
assert page_section.text == "test123 " + page_within_a_page_text
@ -88,13 +90,65 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
)
assert txt_doc is not None
assert txt_doc.semantic_identifier == "small-file.txt"
assert len(txt_doc.sections) == 1
assert txt_doc.sections[0].text == "small"
assert txt_doc.primary_owners
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
assert (
txt_doc.sections[0].link
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
text_attachment_section = page_doc.sections[1]
assert text_attachment_section.text == "small"
assert text_attachment_section.link
assert text_attachment_section.link.endswith("small-file.txt")
@pytest.mark.parametrize("space", ["MI"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_confluence_connector_skip_images(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 8
def mock_process_image_attachment(
*args: Any, **kwargs: Any
) -> AttachmentProcessingResult:
"""We need this mock to bypass DB access happening in the connector. Which shouldn't
be done as a rule to begin with, but life is not perfect. Fix it later"""
return AttachmentProcessingResult(
text="Hi_text",
file_name="Hi_filename",
error=None,
)
@pytest.mark.parametrize("space", ["MI"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@patch(
"onyx.connectors.confluence.utils._process_image_attachment",
side_effect=mock_process_image_attachment,
)
def test_confluence_connector_allow_images(
mock_get_api_key: MagicMock,
mock_process_image_attachment: MagicMock,
confluence_connector: ConfluenceConnector,
) -> None:
confluence_connector.set_allow_images(True)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 12