checkpointed confluence (#4473)

* checkpointed confluence

* confluence checkpointing tested

* fixed integration tests

* attempt to fix connector test flakiness

* fix rebase
This commit is contained in:
evan-danswer 2025-04-14 16:59:53 -07:00 committed by GitHub
parent 742041d97a
commit ae9f8c3071
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 580 additions and 126 deletions

View File

@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.db.engine import get_db_current_time
from onyx.db.index_attempt import get_index_attempt
@ -61,7 +61,7 @@ def load_checkpoint(
try:
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
if isinstance(connector, CheckpointConnector):
if isinstance(connector, CheckpointedConnector):
return connector.validate_checkpoint_json(checkpoint_data)
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
except RuntimeError:

View File

@ -1,3 +1,4 @@
import copy
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@ -5,6 +6,7 @@ from typing import Any
from urllib.parse import quote
from requests.exceptions import HTTPError
from typing_extensions import override
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
@ -22,17 +24,19 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
@ -68,9 +72,16 @@ _SLIM_DOC_BATCH_SIZE = 5000
ONE_HOUR = 3600
def _should_propagate_error(e: Exception) -> bool:
return "field 'updated' is invalid" in str(e)
class ConfluenceCheckpoint(ConnectorCheckpoint):
last_updated: SecondsSinceUnixEpoch
class ConfluenceConnector(
LoadConnector,
PollConnector,
CheckpointedConnector[ConfluenceCheckpoint],
SlimConnector,
CredentialsConnector,
):
@ -211,6 +222,8 @@ class ConfluenceConnector(
"%Y-%m-%d %H:%M"
)
page_query += f" and lastmodified <= '{formatted_end_time}'"
page_query += " order by lastmodified asc"
return page_query
def _construct_attachment_query(self, confluence_page_id: str) -> str:
@ -236,11 +249,14 @@ class ConfluenceConnector(
)
return comment_string
def _convert_page_to_document(self, page: dict[str, Any]) -> Document | None:
def _convert_page_to_document(
self, page: dict[str, Any]
) -> Document | ConnectorFailure:
"""
Converts a Confluence page to a Document object.
Includes the page content, comments, and attachments.
"""
page_id = page_url = ""
try:
# Extract basic page information
page_id = page["id"]
@ -336,15 +352,90 @@ class ConfluenceConnector(
)
except Exception as e:
logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
if not self.continue_on_failure:
if _should_propagate_error(e):
raise
return None
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=page_id,
document_link=page_url,
),
failure_message=f"Error converting page {page.get('id', 'unknown')}: {e}",
exception=e,
)
def _fetch_page_attachments(
self, page: dict[str, Any], doc: Document
) -> Document | ConnectorFailure:
attachment_query = self._construct_attachment_query(page["id"])
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment,
):
logger.info(f"Skipping attachment: {attachment['title']}")
continue
logger.info(f"Processing attachment: {attachment['title']}")
# Attempt to get textual content or image summarization:
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
try:
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
allow_images=self.allow_images,
)
if response is None:
continue
content_text, file_storage_name = response
if content_text:
doc.sections.append(
TextSection(
text=content_text,
link=object_url,
)
)
elif file_storage_name:
doc.sections.append(
ImageSection(
link=object_url,
image_file_name=file_storage_name,
)
)
except Exception as e:
logger.error(
f"Failed to extract/summarize attachment {attachment['title']}",
exc_info=e,
)
if not self.continue_on_failure:
if _should_propagate_error(e):
raise
# TODO: should we remove continue_on_failure entirely now that we have checkpointing?
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc.id,
document_link=object_url,
),
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {doc.id}",
exception=e,
)
return doc
def _fetch_document_batches(
self,
checkpoint: ConfluenceCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
) -> CheckpointOutput[ConfluenceCheckpoint]:
"""
Yields batches of Documents. For each page:
- Create a Document with 1 Section for the page text/comments
@ -352,9 +443,12 @@ class ConfluenceConnector(
- Attempt to convert it with convert_attachment_to_content(...)
- If successful, create a new Section with the extracted text or summary.
"""
doc_batch: list[Document] = []
doc_count = 0
page_query = self._construct_page_query(start, end)
checkpoint = copy.deepcopy(checkpoint)
# use "start" when last_updated is 0
page_query = self._construct_page_query(checkpoint.last_updated or start, end)
logger.debug(f"page_query: {page_query}")
for page in self.confluence_client.paginated_cql_retrieval(
@ -363,94 +457,61 @@ class ConfluenceConnector(
limit=self.batch_size,
):
# Build doc from page
doc = self._convert_page_to_document(page)
if not doc:
doc_or_failure = self._convert_page_to_document(page)
if isinstance(doc_or_failure, ConnectorFailure):
yield doc_or_failure
continue
checkpoint.last_updated = datetime_from_string(
page["version"]["when"]
).timestamp()
# Now get attachments for that page:
attachment_query = self._construct_attachment_query(page["id"])
# We'll use the page's XML to provide context if we summarize an image
page.get("body", {}).get("storage", {}).get("value", "")
doc_or_failure = self._fetch_page_attachments(page, doc_or_failure)
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment,
):
logger.info(f"Skipping attachment: {attachment['title']}")
continue
if isinstance(doc_or_failure, ConnectorFailure):
yield doc_or_failure
continue
logger.info(f"Processing attachment: {attachment['title']}")
# yield completed document
doc_count += 1
yield doc_or_failure
# Attempt to get textual content or image summarization:
try:
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
allow_images=self.allow_images,
)
if response is None:
continue
# create checkpoint after enough documents have been processed
if doc_count >= self.batch_size:
return checkpoint
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
if content_text:
doc.sections.append(
TextSection(
text=content_text,
link=object_url,
)
)
elif file_storage_name:
doc.sections.append(
ImageSection(
link=object_url,
image_file_name=file_storage_name,
)
)
except Exception as e:
logger.error(
f"Failed to extract/summarize attachment {attachment['title']}",
exc_info=e,
)
if not self.continue_on_failure:
raise
checkpoint.has_more = False
return checkpoint
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_document_batches()
def poll_source(
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConfluenceCheckpoint,
) -> CheckpointOutput[ConfluenceCheckpoint]:
try:
return self._fetch_document_batches(start, end)
return self._fetch_document_batches(checkpoint, start, end)
except Exception as e:
if "field 'updated' is invalid" in str(e) and start is not None:
if _should_propagate_error(e) and start is not None:
logger.warning(
"Confluence says we provided an invalid 'updated' field. This may indicate"
"a real issue, but can also appear during edge cases like daylight"
f"savings time changes. Retrying with a 1 hour offset. Error: {e}"
)
return self._fetch_document_batches(start - ONE_HOUR, end)
return self._fetch_document_batches(checkpoint, start - ONE_HOUR, end)
raise
@override
def build_dummy_checkpoint(self) -> ConfluenceCheckpoint:
return ConfluenceCheckpoint(last_updated=0, has_more=True)
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> ConfluenceCheckpoint:
return ConfluenceCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,

View File

@ -6,7 +6,7 @@ from typing import Generic
from typing import TypeVar
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@ -95,9 +95,9 @@ class ConnectorRunner(Generic[CT]):
]:
"""Adds additional exception logging to the connector."""
try:
if isinstance(self.connector, CheckpointConnector):
if isinstance(self.connector, CheckpointedConnector):
if self.time_range is None:
raise ValueError("time_range is required for CheckpointConnector")
raise ValueError("time_range is required for CheckpointedConnector")
start = time.monotonic()
checkpoint_connector_generator = self.connector.load_from_checkpoint(

View File

@ -34,7 +34,7 @@ from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.highspot.connector import HighspotConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
@ -148,7 +148,7 @@ def identify_connector_class(
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointConnector)
and not issubclass(connector, CheckpointedConnector)
)
),
(

View File

@ -25,7 +25,7 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
@ -143,7 +143,7 @@ class GithubConnectorCheckpoint(ConnectorCheckpoint):
cached_repo: SerializedRepository | None = None
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,

View File

@ -52,7 +52,7 @@ from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_S
from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from onyx.connectors.google_utils.shared_constants import USER_FIELDS
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@ -165,7 +165,7 @@ class DriveIdStatus(Enum):
FINISHED = "finished"
class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpoint]):
class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheckpoint]):
def __init__(
self,
include_shared_drives: bool = False,

View File

@ -201,7 +201,7 @@ class EventConnector(BaseConnector):
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
class CheckpointConnector(BaseConnector[CT]):
class CheckpointedConnector(BaseConnector[CT]):
@abc.abstractmethod
def load_from_checkpoint(
self,

View File

@ -4,7 +4,7 @@ import httpx
from pydantic import BaseModel
from typing_extensions import override
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
@ -27,7 +27,7 @@ class SingleConnectorYield(BaseModel):
unhandled_exception: str | None = None
class MockConnector(CheckpointConnector[MockConnectorCheckpoint]):
class MockConnector(CheckpointedConnector[MockConnectorCheckpoint]):
def __init__(
self,
mock_server_host: str,

View File

@ -16,7 +16,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@ -150,7 +150,7 @@ class JiraConnectorCheckpoint(ConnectorCheckpoint):
offset: int | None = None
class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector):
class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnector):
def __init__(
self,
jira_base_url: str,

View File

@ -26,7 +26,7 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
@ -501,7 +501,7 @@ def _process_message(
class SlackConnector(
SlimConnector, CredentialsConnector, CheckpointConnector[SlackCheckpoint]
SlimConnector, CredentialsConnector, CheckpointedConnector[SlackCheckpoint]
):
FAST_TIMEOUT = 1

View File

@ -401,6 +401,7 @@ class WebConnector(LoadConnector):
mintlify_cleanup: bool = True, # Mostly ok to apply to other websites as well
batch_size: int = INDEX_BATCH_SIZE,
scroll_before_scraping: bool = False,
add_randomness: bool = True,
**kwargs: Any,
) -> None:
self.mintlify_cleanup = mintlify_cleanup
@ -408,7 +409,7 @@ class WebConnector(LoadConnector):
self.recursive = False
self.scroll_before_scraping = scroll_before_scraping
self.web_connector_type = web_connector_type
self.add_randomness = add_randomness
if web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value:
self.recursive = True
self.to_visit_list = [_ensure_valid_url(base_url)]
@ -540,8 +541,11 @@ class WebConnector(LoadConnector):
page = context.new_page()
# Add random mouse movements and scrolling to mimic human behavior
page.mouse.move(random.randint(100, 700), random.randint(100, 500))
if self.add_randomness:
# Add random mouse movements and scrolling to mimic human behavior
page.mouse.move(
random.randint(100, 700), random.randint(100, 500)
)
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
page_response = page.goto(

View File

@ -17,7 +17,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
@ -353,7 +353,9 @@ class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
cached_content_tags: dict[str, str] | None
class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckpoint]):
class ZendeskConnector(
SlimConnector, CheckpointedConnector[ZendeskConnectorCheckpoint]
):
def __init__(
self,
content_type: str = "articles",

View File

@ -11,6 +11,7 @@ from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.utils import AttachmentProcessingResult
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from onyx.connectors.models import Document
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
@pytest.fixture
@ -43,11 +44,9 @@ def test_confluence_connector_basic(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
doc_batch = load_all_docs_from_checkpoint_connector(
confluence_connector, 0, time.time()
)
assert len(doc_batch) == 2
@ -105,11 +104,9 @@ def test_confluence_connector_skip_images(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
doc_batch = load_all_docs_from_checkpoint_connector(
confluence_connector, 0, time.time()
)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 8
@ -144,11 +141,9 @@ def test_confluence_connector_allow_images(
) -> None:
confluence_connector.set_allow_images(True)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
doc_batch = load_all_docs_from_checkpoint_connector(
confluence_connector, 0, time.time()
)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 12

View File

@ -1,10 +1,12 @@
import os
import time
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
@pytest.fixture
@ -34,8 +36,10 @@ def test_confluence_connector_permissions(
) -> None:
# Get all doc IDs from the full connector
all_full_doc_ids = set()
for doc_batch in confluence_connector.load_from_state():
all_full_doc_ids.update([doc.id for doc in doc_batch])
doc_batch = load_all_docs_from_checkpoint_connector(
confluence_connector, 0, time.time()
)
all_full_doc_ids.update([doc.id for doc in doc_batch])
# Get all doc IDs from the slim connector
all_slim_doc_ids = set()

View File

@ -2,7 +2,7 @@ from typing import cast
from typing import TypeVar
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@ -14,7 +14,7 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
def load_all_docs_from_checkpoint_connector(
connector: CheckpointConnector[CT],
connector: CheckpointedConnector[CT],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[Document]:
@ -42,7 +42,7 @@ def load_all_docs_from_checkpoint_connector(
def load_everything_from_checkpoint_connector(
connector: CheckpointConnector[CT],
connector: CheckpointedConnector[CT],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[Document | ConnectorFailure]:

View File

@ -18,6 +18,7 @@ def web_connector(request: pytest.FixtureRequest) -> WebConnector:
base_url="https://quotes.toscrape.com/scroll",
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
scroll_before_scraping=scroll_before_scraping,
add_randomness=False,
)
return connector

View File

@ -2,6 +2,7 @@ import os
from datetime import datetime
from datetime import timezone
from onyx.connectors.models import InputType
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.user import UserManager
@ -56,6 +57,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
connector_specific_config=config,
credential_json=credential,
user_performing_action=admin_user,
input_type=InputType.POLL,
)
CCPairManager.wait_for_indexing_completion(
@ -69,6 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
connector_specific_config=config,
credential_json=credential,
user_performing_action=admin_user,
input_type=InputType.POLL,
)
CCPairManager.wait_for_indexing_completion(
@ -115,6 +118,7 @@ def test_connector_pause_while_indexing(reset: None) -> None:
connector_specific_config=config,
credential_json=credential,
user_performing_action=admin_user,
input_type=InputType.POLL,
)
CCPairManager.wait_for_indexing_in_progress(

View File

@ -0,0 +1,383 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from requests.exceptions import HTTPError
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceCheckpoint
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import SlimDocument
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
PAGE_SIZE = 2
@pytest.fixture
def confluence_base_url() -> str:
return "https://example.atlassian.net/wiki"
@pytest.fixture
def space_key() -> str:
return "TEST"
@pytest.fixture
def mock_confluence_client() -> MagicMock:
"""Create a mock Confluence client with proper typing"""
mock = MagicMock(spec=OnyxConfluence)
# Initialize with empty results for common methods
mock.paginated_cql_retrieval.return_value = []
mock.get_all_spaces = MagicMock()
mock.get_all_spaces.return_value = {"results": []}
return mock
@pytest.fixture
def confluence_connector(
confluence_base_url: str, space_key: str, mock_confluence_client: MagicMock
) -> Generator[ConfluenceConnector, None, None]:
"""Create a Confluence connector with a mock client"""
connector = ConfluenceConnector(
wiki_base=confluence_base_url,
space=space_key,
is_cloud=True,
labels_to_skip=["secret", "sensitive"],
timezone_offset=0.0,
batch_size=2,
)
# Initialize the client directly
connector._confluence_client = mock_confluence_client
with patch("onyx.connectors.confluence.connector._SLIM_DOC_BATCH_SIZE", 2):
yield connector
@pytest.fixture
def create_mock_page() -> Callable[..., dict[str, Any]]:
def _create_mock_page(
id: str = "123",
title: str = "Test Page",
updated: str = "2023-01-01T12:00:00.000+0000",
content: str = "Test Content",
labels: list[str] | None = None,
) -> dict[str, Any]:
"""Helper to create a mock Confluence page object"""
return {
"id": id,
"title": title,
"version": {"when": updated},
"body": {"storage": {"value": content}},
"metadata": {
"labels": {"results": [{"name": label} for label in (labels or [])]}
},
"space": {"key": "TEST"},
"_links": {"webui": f"/spaces/TEST/pages/{id}"},
}
return _create_mock_page
def test_get_cql_query_with_space(confluence_connector: ConfluenceConnector) -> None:
"""Test CQL query generation with space specified"""
start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp()
query = confluence_connector._construct_page_query(start, end)
# Check that the space part and time part are both in the query
assert f"space='{confluence_connector.space}'" in query
assert "lastmodified >= '2023-01-01 00:00'" in query
assert "lastmodified <= '2023-01-02 00:00'" in query
assert " and " in query.lower()
def test_get_cql_query_without_space(confluence_base_url: str) -> None:
"""Test CQL query generation without space specified"""
# Create connector without space key
connector = ConfluenceConnector(wiki_base=confluence_base_url, is_cloud=True)
start = datetime(2023, 1, 1, tzinfo=connector.timezone).timestamp()
end = datetime(2023, 1, 2, tzinfo=connector.timezone).timestamp()
query = connector._construct_page_query(start, end)
# Check that only time part is in the query
assert "space=" not in query
assert "lastmodified >= '2023-01-01 00:00'" in query
assert "lastmodified <= '2023-01-02 00:00'" in query
def test_load_from_checkpoint_happy_path(
confluence_connector: ConfluenceConnector,
create_mock_page: Callable[..., dict[str, Any]],
) -> None:
"""Test loading from checkpoint - happy path"""
# Set up mocked pages
first_updated = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc)
last_updated = datetime(2023, 1, 3, 12, 0, tzinfo=timezone.utc)
mock_page1 = create_mock_page(
id="1", title="Page 1", updated=first_updated.isoformat()
)
mock_page2 = create_mock_page(
id="2", title="Page 2", updated=first_updated.isoformat()
)
mock_page3 = create_mock_page(
id="3", title="Page 3", updated=last_updated.isoformat()
)
# Mock paginated_cql_retrieval to return our mock pages
confluence_client = confluence_connector._confluence_client
assert confluence_client is not None, "bad test setup"
paginated_cql_mock = cast(MagicMock, confluence_client.paginated_cql_retrieval)
paginated_cql_mock.side_effect = [
[mock_page1, mock_page2],
[], # comments
[], # attachments
[], # comments
[], # attachments
[mock_page3],
[], # comments
[], # attachments
]
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
confluence_connector, 0, end_time
)
# Check that the documents were returned
assert len(outputs) == 2
checkpoint_output1 = outputs[0]
assert len(checkpoint_output1.items) == 2
document1 = checkpoint_output1.items[0]
assert isinstance(document1, Document)
assert document1.id == f"{confluence_connector.wiki_base}/spaces/TEST/pages/1"
document2 = checkpoint_output1.items[1]
assert isinstance(document2, Document)
assert document2.id == f"{confluence_connector.wiki_base}/spaces/TEST/pages/2"
assert checkpoint_output1.next_checkpoint == ConfluenceCheckpoint(
last_updated=first_updated.timestamp(),
has_more=True,
)
checkpoint_output2 = outputs[1]
assert len(checkpoint_output2.items) == 1
document3 = checkpoint_output2.items[0]
assert isinstance(document3, Document)
assert document3.id == f"{confluence_connector.wiki_base}/spaces/TEST/pages/3"
assert checkpoint_output2.next_checkpoint == ConfluenceCheckpoint(
last_updated=last_updated.timestamp(),
has_more=False,
)
def test_load_from_checkpoint_with_page_processing_error(
confluence_connector: ConfluenceConnector,
create_mock_page: Callable[..., dict[str, Any]],
) -> None:
"""Test loading from checkpoint with a mix of successful and failed page processing"""
# Set up mocked pages
mock_page1 = create_mock_page(id="1", title="Page 1")
mock_page2 = create_mock_page(id="2", title="Page 2")
# Mock paginated_cql_retrieval to return our mock pages
confluence_client = confluence_connector._confluence_client
assert confluence_client is not None, "bad test setup"
paginated_cql_mock = cast(MagicMock, confluence_client.paginated_cql_retrieval)
paginated_cql_mock.return_value = [mock_page1, mock_page2]
# Mock _convert_page_to_document to fail for the second page
def mock_convert_side_effect(page: dict[str, Any]) -> Document | ConnectorFailure:
if page["id"] == "1":
return Document(
id=f"{confluence_connector.wiki_base}/spaces/TEST/pages/1",
sections=[],
source=DocumentSource.CONFLUENCE,
semantic_identifier="Page 1",
metadata={},
)
else:
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=page["id"],
document_link=f"{confluence_connector.wiki_base}/spaces/TEST/pages/{page['id']}",
),
failure_message="Failed to process Confluence page",
exception=Exception("Test error"),
)
with patch(
"onyx.connectors.confluence.connector.ConfluenceConnector._convert_page_to_document",
side_effect=mock_convert_side_effect,
):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
confluence_connector, 0, end_time
)
assert len(outputs) == 1
checkpoint_output = outputs[0]
assert len(checkpoint_output.items) == 2
# First item should be successful
assert isinstance(checkpoint_output.items[0], Document)
assert (
checkpoint_output.items[0].id
== f"{confluence_connector.wiki_base}/spaces/TEST/pages/1"
)
# Second item should be a failure
assert isinstance(checkpoint_output.items[1], ConnectorFailure)
assert (
"Failed to process Confluence page"
in checkpoint_output.items[1].failure_message
)
def test_retrieve_all_slim_documents(
confluence_connector: ConfluenceConnector,
create_mock_page: Callable[..., dict[str, Any]],
) -> None:
"""Test retrieving all slim documents"""
# Set up mocked pages
mock_page1 = create_mock_page(id="1")
mock_page2 = create_mock_page(id="2")
# Mock paginated_cql_retrieval to return our mock pages
confluence_client = confluence_connector._confluence_client
assert confluence_client is not None, "bad test setup"
paginated_cql_mock = cast(MagicMock, confluence_client.cql_paginate_all_expansions)
paginated_cql_mock.side_effect = [[mock_page1, mock_page2], [], []]
# Call retrieve_all_slim_documents
batches = list(confluence_connector.retrieve_all_slim_documents(0, 100))
assert paginated_cql_mock.call_count == 3
# Check that a batch with 2 documents was returned
assert len(batches) == 1
assert len(batches[0]) == 2
assert isinstance(batches[0][0], SlimDocument)
assert batches[0][0].id == f"{confluence_connector.wiki_base}/spaces/TEST/pages/1"
assert batches[0][1].id == f"{confluence_connector.wiki_base}/spaces/TEST/pages/2"
@pytest.mark.parametrize(
"status_code,expected_exception,expected_message",
[
(
401,
CredentialExpiredError,
"Invalid or expired Confluence credentials",
),
(
403,
InsufficientPermissionsError,
"Insufficient permissions to access Confluence resources",
),
(404, UnexpectedValidationError, "Unexpected Confluence error"),
],
)
def test_validate_connector_settings_errors(
confluence_connector: ConfluenceConnector,
status_code: int,
expected_exception: type[Exception],
expected_message: str,
) -> None:
"""Test validation with various error scenarios"""
error = HTTPError(response=MagicMock(status_code=status_code))
confluence_client = MagicMock()
confluence_connector._low_timeout_confluence_client = confluence_client
get_all_spaces_mock = cast(MagicMock, confluence_client.get_all_spaces)
get_all_spaces_mock.side_effect = error
with pytest.raises(expected_exception) as excinfo:
confluence_connector.validate_connector_settings()
assert expected_message in str(excinfo.value)
def test_validate_connector_settings_success(
confluence_connector: ConfluenceConnector,
) -> None:
"""Test successful validation"""
confluence_client = MagicMock()
confluence_connector._low_timeout_confluence_client = confluence_client
get_all_spaces_mock = cast(MagicMock, confluence_client.get_all_spaces)
get_all_spaces_mock.return_value = {"results": [{"key": "TEST"}]}
confluence_connector.validate_connector_settings()
get_all_spaces_mock.assert_called_once()
def test_checkpoint_progress(
confluence_connector: ConfluenceConnector,
create_mock_page: Callable[..., dict[str, Any]],
) -> None:
"""Test that the checkpoint's last_updated field is properly updated after processing pages"""
# Set up mocked pages with different timestamps
earlier_timestamp = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc)
later_timestamp = datetime(2023, 1, 2, 12, 0, tzinfo=timezone.utc)
mock_page1 = create_mock_page(
id="1", title="Page 1", updated=earlier_timestamp.isoformat()
)
mock_page2 = create_mock_page(
id="2", title="Page 2", updated=later_timestamp.isoformat()
)
# Mock paginated_cql_retrieval to return our mock pages
confluence_client = confluence_connector._confluence_client
assert confluence_client is not None, "bad test setup"
paginated_cql_mock = cast(MagicMock, confluence_client.paginated_cql_retrieval)
paginated_cql_mock.side_effect = [
[mock_page1, mock_page2], # Return both pages
[], # No comments for page 1
[], # No attachments for page 1
[], # No comments for page 2
[], # No attachments for page 2
[], # No more pages
]
# Call load_from_checkpoint
end_time = datetime(2023, 1, 3, tzinfo=timezone.utc).timestamp()
outputs = load_everything_from_checkpoint_connector(
confluence_connector, 0, end_time
)
last_checkpoint = outputs[-1].next_checkpoint
assert last_checkpoint == ConfluenceCheckpoint(
last_updated=later_timestamp.timestamp(),
has_more=False,
)
# Convert the expected timestamp to epoch seconds
expected_timestamp = datetime(2023, 1, 2, 12, 0, tzinfo=timezone.utc).timestamp()
# The checkpoint's last_updated should be set to the latest page's timestamp
assert last_checkpoint.last_updated == expected_timestamp
assert not last_checkpoint.has_more # No more pages to process
assert len(outputs) == 2
# Verify we got both documents
assert len(outputs[0].items) == 2
assert isinstance(outputs[0].items[0], Document)
assert isinstance(outputs[0].items[1], Document)

View File

@ -5,7 +5,7 @@ from typing import TypeVar
from pydantic import BaseModel
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@ -23,7 +23,7 @@ class SingleConnectorCallOutput(BaseModel, Generic[CT]):
def load_everything_from_checkpoint_connector(
connector: CheckpointConnector[CT],
connector: CheckpointedConnector[CT],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[SingleConnectorCallOutput[CT]]: