From 4c6f4e715816620df3b7b461fd1a06a2a4016488 Mon Sep 17 00:00:00 2001 From: Weves Date: Sat, 15 Mar 2025 13:54:49 -0700 Subject: [PATCH] Checkpointed Jira connector --- .../onyx/connectors/onyx_jira/connector.py | 294 ++++++------ .../daily/connectors/jira/test_jira_basic.py | 17 +- backend/tests/daily/connectors/utils.py | 65 +++ .../jira/test_jira_checkpointing.py | 440 ++++++++++++++++++ .../jira/test_large_ticket_handling.py | 51 +- backend/tests/unit/onyx/connectors/utils.py | 48 ++ 6 files changed, 750 insertions(+), 165 deletions(-) create mode 100644 backend/tests/daily/connectors/utils.py create mode 100644 backend/tests/unit/onyx/connectors/jira/test_jira_checkpointing.py create mode 100644 backend/tests/unit/onyx/connectors/utils.py diff --git a/backend/onyx/connectors/onyx_jira/connector.py b/backend/onyx/connectors/onyx_jira/connector.py index 30caf3ea5..c097e64ee 100644 --- a/backend/onyx/connectors/onyx_jira/connector.py +++ b/backend/onyx/connectors/onyx_jira/connector.py @@ -1,4 +1,5 @@ import os +from collections.abc import Generator from collections.abc import Iterable from datetime import datetime from datetime import timezone @@ -15,14 +16,15 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError -from onyx.connectors.interfaces import GenerateDocumentsOutput -from onyx.connectors.interfaces import GenerateSlimDocumentOutput -from onyx.connectors.interfaces import LoadConnector -from onyx.connectors.interfaces import PollConnector +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document +from onyx.connectors.models import DocumentFailure from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info @@ -42,121 +44,108 @@ _JIRA_SLIM_PAGE_SIZE = 500 _JIRA_FULL_PAGE_SIZE = 50 -def _paginate_jql_search( +def _perform_jql_search( jira_client: JIRA, jql: str, + start: int, max_results: int, fields: str | None = None, ) -> Iterable[Issue]: - start = 0 - while True: - logger.debug( - f"Fetching Jira issues with JQL: {jql}, " - f"starting at {start}, max results: {max_results}" - ) - issues = jira_client.search_issues( - jql_str=jql, - startAt=start, - maxResults=max_results, - fields=fields, - ) + logger.debug( + f"Fetching Jira issues with JQL: {jql}, " + f"starting at {start}, max results: {max_results}" + ) + issues = jira_client.search_issues( + jql_str=jql, + startAt=start, + maxResults=max_results, + fields=fields, + ) - for issue in issues: - if isinstance(issue, Issue): - yield issue - else: - raise Exception(f"Found Jira object not of type Issue: {issue}") - - if len(issues) < max_results: - break - - start += max_results + for issue in issues: + if isinstance(issue, Issue): + yield issue + else: + raise RuntimeError(f"Found Jira object not of type Issue: {issue}") -def fetch_jira_issues_batch( +def process_jira_issue( jira_client: JIRA, - jql: str, - batch_size: int, + issue: Issue, comment_email_blacklist: tuple[str, ...] = (), labels_to_skip: set[str] | None = None, -) -> Iterable[Document]: - for issue in _paginate_jql_search( - jira_client=jira_client, - jql=jql, - max_results=batch_size, - ): - if labels_to_skip: - if any(label in issue.fields.labels for label in labels_to_skip): - logger.info( - f"Skipping {issue.key} because it has a label to skip. Found " - f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}." - ) - continue - - description = ( - issue.fields.description - if JIRA_API_VERSION == "2" - else extract_text_from_adf(issue.raw["fields"]["description"]) - ) - comments = get_comment_strs( - issue=issue, - comment_email_blacklist=comment_email_blacklist, - ) - ticket_content = f"{description}\n" + "\n".join( - [f"Comment: {comment}" for comment in comments if comment] - ) - - # Check ticket size - if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE: +) -> Document | None: + if labels_to_skip: + if any(label in issue.fields.labels for label in labels_to_skip): logger.info( - f"Skipping {issue.key} because it exceeds the maximum size of " - f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes." + f"Skipping {issue.key} because it has a label to skip. Found " + f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}." ) - continue + return None - page_url = f"{jira_client.client_info()}/browse/{issue.key}" + description = ( + issue.fields.description + if JIRA_API_VERSION == "2" + else extract_text_from_adf(issue.raw["fields"]["description"]) + ) + comments = get_comment_strs( + issue=issue, + comment_email_blacklist=comment_email_blacklist, + ) + ticket_content = f"{description}\n" + "\n".join( + [f"Comment: {comment}" for comment in comments if comment] + ) - people = set() - try: - creator = best_effort_get_field_from_issue(issue, "creator") - if basic_expert_info := best_effort_basic_expert_info(creator): - people.add(basic_expert_info) - except Exception: - # Author should exist but if not, doesn't matter - pass - - try: - assignee = best_effort_get_field_from_issue(issue, "assignee") - if basic_expert_info := best_effort_basic_expert_info(assignee): - people.add(basic_expert_info) - except Exception: - # Author should exist but if not, doesn't matter - pass - - metadata_dict = {} - if priority := best_effort_get_field_from_issue(issue, "priority"): - metadata_dict["priority"] = priority.name - if status := best_effort_get_field_from_issue(issue, "status"): - metadata_dict["status"] = status.name - if resolution := best_effort_get_field_from_issue(issue, "resolution"): - metadata_dict["resolution"] = resolution.name - if labels := best_effort_get_field_from_issue(issue, "labels"): - metadata_dict["label"] = labels - - yield Document( - id=page_url, - sections=[TextSection(link=page_url, text=ticket_content)], - source=DocumentSource.JIRA, - semantic_identifier=f"{issue.key}: {issue.fields.summary}", - title=f"{issue.key} {issue.fields.summary}", - doc_updated_at=time_str_to_utc(issue.fields.updated), - primary_owners=list(people) or None, - # TODO add secondary_owners (commenters) if needed - metadata=metadata_dict, + # Check ticket size + if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE: + logger.info( + f"Skipping {issue.key} because it exceeds the maximum size of " + f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes." ) + return None + + page_url = build_jira_url(jira_client, issue.key) + + people = set() + try: + creator = best_effort_get_field_from_issue(issue, "creator") + if basic_expert_info := best_effort_basic_expert_info(creator): + people.add(basic_expert_info) + except Exception: + # Author should exist but if not, doesn't matter + pass + + try: + assignee = best_effort_get_field_from_issue(issue, "assignee") + if basic_expert_info := best_effort_basic_expert_info(assignee): + people.add(basic_expert_info) + except Exception: + # Author should exist but if not, doesn't matter + pass + + metadata_dict = {} + if priority := best_effort_get_field_from_issue(issue, "priority"): + metadata_dict["priority"] = priority.name + if status := best_effort_get_field_from_issue(issue, "status"): + metadata_dict["status"] = status.name + if resolution := best_effort_get_field_from_issue(issue, "resolution"): + metadata_dict["resolution"] = resolution.name + if labels := best_effort_get_field_from_issue(issue, "labels"): + metadata_dict["label"] = labels + + return Document( + id=page_url, + sections=[TextSection(link=page_url, text=ticket_content)], + source=DocumentSource.JIRA, + semantic_identifier=f"{issue.key}: {issue.fields.summary}", + title=f"{issue.key} {issue.fields.summary}", + doc_updated_at=time_str_to_utc(issue.fields.updated), + primary_owners=list(people) or None, + metadata=metadata_dict, + ) -class JiraConnector(LoadConnector, PollConnector, SlimConnector): +class JiraConnector(CheckpointConnector, SlimConnector): def __init__( self, jira_base_url: str, @@ -200,33 +189,10 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector): ) return None - def _get_jql_query(self) -> str: - """Get the JQL query based on whether a specific project is set""" - if self.jira_project: - return f"project = {self.quoted_jira_project}" - return "" # Empty string means all accessible projects - - def load_from_state(self) -> GenerateDocumentsOutput: - jql = self._get_jql_query() - - document_batch = [] - for doc in fetch_jira_issues_batch( - jira_client=self.jira_client, - jql=jql, - batch_size=_JIRA_FULL_PAGE_SIZE, - comment_email_blacklist=self.comment_email_blacklist, - labels_to_skip=self.labels_to_skip, - ): - document_batch.append(doc) - if len(document_batch) >= self.batch_size: - yield document_batch - document_batch = [] - - yield document_batch - - def poll_source( + def _get_jql_query( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> GenerateDocumentsOutput: + ) -> str: + """Get the JQL query based on whether a specific project is set and time range""" start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime( "%Y-%m-%d %H:%M" ) @@ -234,38 +200,74 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector): "%Y-%m-%d %H:%M" ) - base_jql = self._get_jql_query() - jql = ( - f"{base_jql} AND " if base_jql else "" - ) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'" + base_jql = f"project = {self.quoted_jira_project}" if self.jira_project else "" + time_jql = f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'" - document_batch = [] - for doc in fetch_jira_issues_batch( + return f"{base_jql} AND {time_jql}" if base_jql else time_jql + + def load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: ConnectorCheckpoint, + ) -> CheckpointOutput: + jql = self._get_jql_query(start, end) + + # Get the current offset from checkpoint or start at 0 + starting_offset = checkpoint.checkpoint_content.get("offset", 0) + current_offset = starting_offset + + for issue in _perform_jql_search( jira_client=self.jira_client, jql=jql, - batch_size=_JIRA_FULL_PAGE_SIZE, - comment_email_blacklist=self.comment_email_blacklist, - labels_to_skip=self.labels_to_skip, + start=current_offset, + max_results=_JIRA_FULL_PAGE_SIZE, ): - document_batch.append(doc) - if len(document_batch) >= self.batch_size: - yield document_batch - document_batch = [] + issue_key = issue.key + try: + if document := process_jira_issue( + jira_client=self.jira_client, + issue=issue, + comment_email_blacklist=self.comment_email_blacklist, + labels_to_skip=self.labels_to_skip, + ): + yield document - yield document_batch + except Exception as e: + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=issue_key, + document_link=build_jira_url(self.jira_client, issue_key), + ), + failure_message=f"Failed to process Jira issue: {str(e)}", + exception=e, + ) + + current_offset += 1 + + # Update checkpoint + checkpoint = ConnectorCheckpoint( + checkpoint_content={ + "offset": current_offset, + }, + # if we didn't retrieve a full batch, we're done + has_more=current_offset - starting_offset == _JIRA_FULL_PAGE_SIZE, + ) + return checkpoint def retrieve_all_slim_documents( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, - ) -> GenerateSlimDocumentOutput: - jql = self._get_jql_query() + ) -> Generator[list[SlimDocument], None, None]: + jql = self._get_jql_query(start or 0, end or float("inf")) slim_doc_batch = [] - for issue in _paginate_jql_search( + for issue in _perform_jql_search( jira_client=self.jira_client, jql=jql, + start=0, max_results=_JIRA_SLIM_PAGE_SIZE, fields="key", ): @@ -350,5 +352,7 @@ if __name__ == "__main__": "jira_api_token": os.environ["JIRA_API_TOKEN"], } ) - document_batches = connector.load_from_state() + document_batches = connector.load_from_checkpoint( + 0, float("inf"), ConnectorCheckpoint(checkpoint_content={}, has_more=True) + ) print(next(document_batches)) diff --git a/backend/tests/daily/connectors/jira/test_jira_basic.py b/backend/tests/daily/connectors/jira/test_jira_basic.py index cf7d14fbd..885d4f2ca 100644 --- a/backend/tests/daily/connectors/jira/test_jira_basic.py +++ b/backend/tests/daily/connectors/jira/test_jira_basic.py @@ -5,6 +5,7 @@ import pytest from onyx.configs.constants import DocumentSource from onyx.connectors.onyx_jira.connector import JiraConnector +from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector @pytest.fixture @@ -24,15 +25,13 @@ def jira_connector() -> JiraConnector: def test_jira_connector_basic(jira_connector: JiraConnector) -> None: - doc_batch_generator = jira_connector.poll_source(0, time.time()) - - doc_batch = next(doc_batch_generator) - with pytest.raises(StopIteration): - next(doc_batch_generator) - - assert len(doc_batch) == 1 - - doc = doc_batch[0] + docs = load_all_docs_from_checkpoint_connector( + connector=jira_connector, + start=0, + end=time.time(), + ) + assert len(docs) == 1 + doc = docs[0] assert doc.id == "https://danswerai.atlassian.net/browse/AS-2" assert doc.semantic_identifier == "AS-2: test123small" diff --git a/backend/tests/daily/connectors/utils.py b/backend/tests/daily/connectors/utils.py new file mode 100644 index 000000000..1de2c0fad --- /dev/null +++ b/backend/tests/daily/connectors/utils.py @@ -0,0 +1,65 @@ +from onyx.connectors.connector_runner import CheckpointOutputWrapper +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import SecondsSinceUnixEpoch +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document + +_ITERATION_LIMIT = 100_000 + + +def load_all_docs_from_checkpoint_connector( + connector: CheckpointConnector, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, +) -> list[Document]: + num_iterations = 0 + + checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + documents: list[Document] = [] + while checkpoint.has_more: + doc_batch_generator = CheckpointOutputWrapper()( + connector.load_from_checkpoint(start, end, checkpoint) + ) + for document, failure, next_checkpoint in doc_batch_generator: + if failure is not None: + raise RuntimeError(f"Failed to load documents: {failure}") + if document is not None: + documents.append(document) + if next_checkpoint is not None: + checkpoint = next_checkpoint + + num_iterations += 1 + if num_iterations > _ITERATION_LIMIT: + raise RuntimeError("Too many iterations. Infinite loop?") + + return documents + + +def load_everything_from_checkpoint_connector( + connector: CheckpointConnector, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, +) -> list[Document | ConnectorFailure]: + """Like load_all_docs_from_checkpoint_connector but returns both documents and failures""" + num_iterations = 0 + + checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + outputs: list[Document | ConnectorFailure] = [] + while checkpoint.has_more: + doc_batch_generator = CheckpointOutputWrapper()( + connector.load_from_checkpoint(start, end, checkpoint) + ) + for document, failure, next_checkpoint in doc_batch_generator: + if failure is not None: + outputs.append(failure) + if document is not None: + outputs.append(document) + if next_checkpoint is not None: + checkpoint = next_checkpoint + + num_iterations += 1 + if num_iterations > _ITERATION_LIMIT: + raise RuntimeError("Too many iterations. Infinite loop?") + + return outputs diff --git a/backend/tests/unit/onyx/connectors/jira/test_jira_checkpointing.py b/backend/tests/unit/onyx/connectors/jira/test_jira_checkpointing.py new file mode 100644 index 000000000..6211b7909 --- /dev/null +++ b/backend/tests/unit/onyx/connectors/jira/test_jira_checkpointing.py @@ -0,0 +1,440 @@ +import time +from collections.abc import Callable +from collections.abc import Generator +from datetime import datetime +from datetime import timezone +from typing import Any +from typing import cast +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from jira import JIRA +from jira import JIRAError +from jira.resources import Issue + +from onyx.configs.constants import DocumentSource +from onyx.connectors.exceptions import ConnectorValidationError +from onyx.connectors.exceptions import CredentialExpiredError +from onyx.connectors.exceptions import InsufficientPermissionsError +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document +from onyx.connectors.models import SlimDocument +from onyx.connectors.onyx_jira.connector import JiraConnector +from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector + +PAGE_SIZE = 2 + + +@pytest.fixture +def jira_base_url() -> str: + return "https://jira.example.com" + + +@pytest.fixture +def project_key() -> str: + return "TEST" + + +@pytest.fixture +def mock_jira_client() -> MagicMock: + """Create a mock JIRA client with proper typing""" + mock = MagicMock(spec=JIRA) + # Add proper return typing for search_issues method + mock.search_issues = MagicMock() + # Add proper return typing for project method + mock.project = MagicMock() + # Add proper return typing for projects method + mock.projects = MagicMock() + return mock + + +@pytest.fixture +def jira_connector( + jira_base_url: str, project_key: str, mock_jira_client: MagicMock +) -> Generator[JiraConnector, None, None]: + connector = JiraConnector( + jira_base_url=jira_base_url, + project_key=project_key, + comment_email_blacklist=["blacklist@example.com"], + labels_to_skip=["secret", "sensitive"], + ) + connector._jira_client = mock_jira_client + connector._jira_client.client_info.return_value = jira_base_url + with patch("onyx.connectors.onyx_jira.connector._JIRA_FULL_PAGE_SIZE", 2): + yield connector + + +@pytest.fixture +def create_mock_issue() -> Callable[..., MagicMock]: + def _create_mock_issue( + key: str = "TEST-123", + summary: str = "Test Issue", + updated: str = "2023-01-01T12:00:00.000+0000", + description: str = "Test Description", + labels: list[str] | None = None, + ) -> MagicMock: + """Helper to create a mock Issue object""" + mock_issue = MagicMock(spec=Issue) + # Create fields attribute first + mock_issue.fields = MagicMock() + mock_issue.key = key + mock_issue.fields.summary = summary + mock_issue.fields.updated = updated + mock_issue.fields.description = description + mock_issue.fields.labels = labels or [] + + # Set up creator and assignee for testing owner extraction + mock_issue.fields.creator = MagicMock() + mock_issue.fields.creator.displayName = "Test Creator" + mock_issue.fields.creator.emailAddress = "creator@example.com" + + mock_issue.fields.assignee = MagicMock() + mock_issue.fields.assignee.displayName = "Test Assignee" + mock_issue.fields.assignee.emailAddress = "assignee@example.com" + + # Set up priority, status, and resolution + mock_issue.fields.priority = MagicMock() + mock_issue.fields.priority.name = "High" + + mock_issue.fields.status = MagicMock() + mock_issue.fields.status.name = "In Progress" + + mock_issue.fields.resolution = MagicMock() + mock_issue.fields.resolution.name = "Fixed" + + # Add raw field for accessing through API version check + mock_issue.raw = {"fields": {"description": description}} + + return mock_issue + + return _create_mock_issue + + +def test_load_credentials(jira_connector: JiraConnector) -> None: + """Test loading credentials""" + with patch( + "onyx.connectors.onyx_jira.connector.build_jira_client" + ) as mock_build_client: + mock_build_client.return_value = jira_connector._jira_client + credentials = { + "jira_user_email": "user@example.com", + "jira_api_token": "token123", + } + + result = jira_connector.load_credentials(credentials) + + mock_build_client.assert_called_once_with( + credentials=credentials, jira_base=jira_connector.jira_base + ) + assert result is None + assert jira_connector._jira_client == mock_build_client.return_value + + +def test_get_jql_query_with_project(jira_connector: JiraConnector) -> None: + """Test JQL query generation with project specified""" + start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp() + end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp() + + query = jira_connector._get_jql_query(start, end) + + # Check that the project part and time part are both in the query + assert f'project = "{jira_connector.jira_project}"' in query + assert "updated >= '2023-01-01 00:00'" in query + assert "updated <= '2023-01-02 00:00'" in query + assert " AND " in query + + +def test_get_jql_query_without_project(jira_base_url: str) -> None: + """Test JQL query generation without project specified""" + # Create connector without project key + connector = JiraConnector(jira_base_url=jira_base_url) + + start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp() + end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp() + + query = connector._get_jql_query(start, end) + + # Check that only time part is in the query + assert "project =" not in query + assert "updated >= '2023-01-01 00:00'" in query + assert "updated <= '2023-01-02 00:00'" in query + + +def test_load_from_checkpoint_happy_path( + jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock] +) -> None: + """Test loading from checkpoint - happy path""" + # Set up mocked issues + mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1") + mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2") + mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3") + + # Only mock the search_issues method + jira_client = cast(JIRA, jira_connector._jira_client) + search_issues_mock = cast(MagicMock, jira_client.search_issues) + search_issues_mock.side_effect = [ + [mock_issue1, mock_issue2], + [mock_issue3], + [], + ] + + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time) + + # Check that the documents were returned + assert len(outputs) == 2 + + checkpoint_output1 = outputs[0] + assert len(checkpoint_output1.items) == 2 + document1 = checkpoint_output1.items[0] + assert isinstance(document1, Document) + assert document1.id == "https://jira.example.com/browse/TEST-1" + document2 = checkpoint_output1.items[1] + assert isinstance(document2, Document) + assert document2.id == "https://jira.example.com/browse/TEST-2" + assert checkpoint_output1.next_checkpoint == ConnectorCheckpoint( + checkpoint_content={ + "offset": 2, + }, + has_more=True, + ) + + checkpoint_output2 = outputs[1] + assert len(checkpoint_output2.items) == 1 + document3 = checkpoint_output2.items[0] + assert isinstance(document3, Document) + assert document3.id == "https://jira.example.com/browse/TEST-3" + assert checkpoint_output2.next_checkpoint == ConnectorCheckpoint( + checkpoint_content={ + "offset": 3, + }, + has_more=False, + ) + + # Check that search_issues was called with the right parameters + search_issues_mock.call_count == 2 + args, kwargs = search_issues_mock.call_args_list[0] + assert kwargs["startAt"] == 0 + assert kwargs["maxResults"] == PAGE_SIZE + + args, kwargs = search_issues_mock.call_args_list[1] + assert kwargs["startAt"] == 2 + assert kwargs["maxResults"] == PAGE_SIZE + + +def test_load_from_checkpoint_with_issue_processing_error( + jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock] +) -> None: + """Test loading from checkpoint with a mix of successful and failed issue processing across multiple batches""" + # Set up mocked issues for first batch + mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1") + mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2") + # Set up mocked issues for second batch + mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3") + mock_issue4 = create_mock_issue(key="TEST-4", summary="Issue 4") + + # Mock search_issues to return our mock issues in batches + jira_client = cast(JIRA, jira_connector._jira_client) + search_issues_mock = cast(MagicMock, jira_client.search_issues) + search_issues_mock.side_effect = [ + [mock_issue1, mock_issue2], # First batch + [mock_issue3, mock_issue4], # Second batch + [], # Empty batch to indicate end + ] + + # Mock process_jira_issue to succeed for some issues and fail for others + def mock_process_side_effect( + jira_client: JIRA, issue: Issue, *args: Any, **kwargs: Any + ) -> Document | None: + if issue.key in ["TEST-1", "TEST-3"]: + return Document( + id=f"https://jira.example.com/browse/{issue.key}", + sections=[], + source=DocumentSource.JIRA, + semantic_identifier=f"{issue.key}: {issue.fields.summary}", + title=f"{issue.key} {issue.fields.summary}", + metadata={}, + ) + else: + raise Exception(f"Processing error for {issue.key}") + + with patch( + "onyx.connectors.onyx_jira.connector.process_jira_issue" + ) as mock_process: + mock_process.side_effect = mock_process_side_effect + + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time) + + assert len(outputs) == 3 + + # Check first batch + first_batch = outputs[0] + assert len(first_batch.items) == 2 + # First item should be successful + assert isinstance(first_batch.items[0], Document) + assert first_batch.items[0].id == "https://jira.example.com/browse/TEST-1" + # Second item should be a failure + assert isinstance(first_batch.items[1], ConnectorFailure) + assert first_batch.items[1].failed_document is not None + assert first_batch.items[1].failed_document.document_id == "TEST-2" + assert "Failed to process Jira issue" in first_batch.items[1].failure_message + # Check checkpoint indicates more items (full batch) + assert first_batch.next_checkpoint.has_more is True + assert first_batch.next_checkpoint.checkpoint_content == {"offset": 2} + + # Check second batch + second_batch = outputs[1] + assert len(second_batch.items) == 2 + # First item should be successful + assert isinstance(second_batch.items[0], Document) + assert second_batch.items[0].id == "https://jira.example.com/browse/TEST-3" + # Second item should be a failure + assert isinstance(second_batch.items[1], ConnectorFailure) + assert second_batch.items[1].failed_document is not None + assert second_batch.items[1].failed_document.document_id == "TEST-4" + assert "Failed to process Jira issue" in second_batch.items[1].failure_message + # Check checkpoint indicates more items + assert second_batch.next_checkpoint.has_more is True + assert second_batch.next_checkpoint.checkpoint_content == {"offset": 4} + + # Check third, empty batch + third_batch = outputs[2] + assert len(third_batch.items) == 0 + assert third_batch.next_checkpoint.has_more is False + assert third_batch.next_checkpoint.checkpoint_content == {"offset": 4} + + +def test_load_from_checkpoint_with_skipped_issue( + jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock] +) -> None: + """Test loading from checkpoint with an issue that should be skipped due to labels""" + LABEL_TO_SKIP = "secret" + jira_connector.labels_to_skip = {LABEL_TO_SKIP} + + # Set up mocked issue with a label to skip + mock_issue = create_mock_issue( + key="TEST-1", summary="Issue 1", labels=[LABEL_TO_SKIP] + ) + + # Mock search_issues to return our mock issue + jira_client = cast(JIRA, jira_connector._jira_client) + search_issues_mock = cast(MagicMock, jira_client.search_issues) + search_issues_mock.return_value = [mock_issue] + + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time) + + assert len(outputs) == 1 + checkpoint_output = outputs[0] + # Check that no documents were returned + assert len(checkpoint_output.items) == 0 + + +def test_retrieve_all_slim_documents( + jira_connector: JiraConnector, create_mock_issue: Any +) -> None: + """Test retrieving all slim documents""" + # Set up mocked issues + mock_issue1 = create_mock_issue(key="TEST-1") + mock_issue2 = create_mock_issue(key="TEST-2") + + # Mock search_issues to return our mock issues + jira_client = cast(JIRA, jira_connector._jira_client) + search_issues_mock = cast(MagicMock, jira_client.search_issues) + search_issues_mock.return_value = [mock_issue1, mock_issue2] + + # Mock best_effort_get_field_from_issue to return the keys + with patch( + "onyx.connectors.onyx_jira.connector.best_effort_get_field_from_issue" + ) as mock_field: + mock_field.side_effect = ["TEST-1", "TEST-2"] + + # Mock build_jira_url to return URLs + with patch("onyx.connectors.onyx_jira.connector.build_jira_url") as mock_url: + mock_url.side_effect = [ + "https://jira.example.com/browse/TEST-1", + "https://jira.example.com/browse/TEST-2", + ] + + # Call retrieve_all_slim_documents + batches = list(jira_connector.retrieve_all_slim_documents(0, 100)) + + # Check that a batch with 2 documents was returned + assert len(batches) == 1 + assert len(batches[0]) == 2 + assert isinstance(batches[0][0], SlimDocument) + assert batches[0][0].id == "https://jira.example.com/browse/TEST-1" + assert batches[0][1].id == "https://jira.example.com/browse/TEST-2" + + # Check that search_issues was called with the right parameters + search_issues_mock.assert_called_once() + args, kwargs = search_issues_mock.call_args + assert kwargs["fields"] == "key" + + +@pytest.mark.parametrize( + "status_code,expected_exception,expected_message", + [ + ( + 401, + CredentialExpiredError, + "Jira credential appears to be expired or invalid", + ), + ( + 403, + InsufficientPermissionsError, + "Your Jira token does not have sufficient permissions", + ), + (404, ConnectorValidationError, "Jira project not found"), + ( + 429, + ConnectorValidationError, + "Validation failed due to Jira rate-limits being exceeded", + ), + ], +) +def test_validate_connector_settings_errors( + jira_connector: JiraConnector, + status_code: int, + expected_exception: type[Exception], + expected_message: str, +) -> None: + """Test validation with various error scenarios""" + error = JIRAError(status_code=status_code) + + jira_client = cast(JIRA, jira_connector._jira_client) + project_mock = cast(MagicMock, jira_client.project) + project_mock.side_effect = error + + with pytest.raises(expected_exception) as excinfo: + jira_connector.validate_connector_settings() + assert expected_message in str(excinfo.value) + + +def test_validate_connector_settings_with_project_success( + jira_connector: JiraConnector, +) -> None: + """Test successful validation with project specified""" + jira_client = cast(JIRA, jira_connector._jira_client) + project_mock = cast(MagicMock, jira_client.project) + project_mock.return_value = MagicMock() + jira_connector.validate_connector_settings() + project_mock.assert_called_once_with(jira_connector.jira_project) + + +def test_validate_connector_settings_without_project_success( + jira_base_url: str, +) -> None: + """Test successful validation without project specified""" + connector = JiraConnector(jira_base_url=jira_base_url) + connector._jira_client = MagicMock() + connector._jira_client.projects.return_value = [MagicMock()] + + connector.validate_connector_settings() + connector._jira_client.projects.assert_called_once() diff --git a/backend/tests/unit/onyx/connectors/jira/test_large_ticket_handling.py b/backend/tests/unit/onyx/connectors/jira/test_large_ticket_handling.py index c8ae925ce..badb72aef 100644 --- a/backend/tests/unit/onyx/connectors/jira/test_large_ticket_handling.py +++ b/backend/tests/unit/onyx/connectors/jira/test_large_ticket_handling.py @@ -7,7 +7,8 @@ import pytest from jira.resources import Issue from pytest_mock import MockFixture -from onyx.connectors.onyx_jira.connector import fetch_jira_issues_batch +from onyx.connectors.onyx_jira.connector import _perform_jql_search +from onyx.connectors.onyx_jira.connector import process_jira_issue @pytest.fixture @@ -79,14 +80,22 @@ def test_fetch_jira_issues_batch_small_ticket( ) -> None: mock_jira_client.search_issues.return_value = [mock_issue_small] - docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50)) + # First get the issues via pagination + issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50)) + assert len(issues) == 1 + + # Then process each issue + docs = [process_jira_issue(mock_jira_client, issue) for issue in issues] + docs = [doc for doc in docs if doc is not None] # Filter out None values assert len(docs) == 1 - assert docs[0].id.endswith("/SMALL-1") - assert docs[0].sections[0].text is not None - assert "Small description" in docs[0].sections[0].text - assert "Small comment 1" in docs[0].sections[0].text - assert "Small comment 2" in docs[0].sections[0].text + doc = docs[0] + assert doc is not None # Type assertion for mypy + assert doc.id.endswith("/SMALL-1") + assert doc.sections[0].text is not None + assert "Small description" in doc.sections[0].text + assert "Small comment 1" in doc.sections[0].text + assert "Small comment 2" in doc.sections[0].text def test_fetch_jira_issues_batch_large_ticket( @@ -96,7 +105,13 @@ def test_fetch_jira_issues_batch_large_ticket( ) -> None: mock_jira_client.search_issues.return_value = [mock_issue_large] - docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50)) + # First get the issues via pagination + issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50)) + assert len(issues) == 1 + + # Then process each issue + docs = [process_jira_issue(mock_jira_client, issue) for issue in issues] + docs = [doc for doc in docs if doc is not None] # Filter out None values assert len(docs) == 0 # The large ticket should be skipped @@ -109,10 +124,18 @@ def test_fetch_jira_issues_batch_mixed_tickets( ) -> None: mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large] - docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50)) + # First get the issues via pagination + issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50)) + assert len(issues) == 2 + + # Then process each issue + docs = [process_jira_issue(mock_jira_client, issue) for issue in issues] + docs = [doc for doc in docs if doc is not None] # Filter out None values assert len(docs) == 1 # Only the small ticket should be included - assert docs[0].id.endswith("/SMALL-1") + doc = docs[0] + assert doc is not None # Type assertion for mypy + assert doc.id.endswith("/SMALL-1") @patch("onyx.connectors.onyx_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50) @@ -124,6 +147,12 @@ def test_fetch_jira_issues_batch_custom_size_limit( ) -> None: mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large] - docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50)) + # First get the issues via pagination + issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50)) + assert len(issues) == 2 + + # Then process each issue + docs = [process_jira_issue(mock_jira_client, issue) for issue in issues] + docs = [doc for doc in docs if doc is not None] # Filter out None values assert len(docs) == 0 # Both tickets should be skipped due to the low size limit diff --git a/backend/tests/unit/onyx/connectors/utils.py b/backend/tests/unit/onyx/connectors/utils.py new file mode 100644 index 000000000..a4d59e0e8 --- /dev/null +++ b/backend/tests/unit/onyx/connectors/utils.py @@ -0,0 +1,48 @@ +from pydantic import BaseModel + +from onyx.connectors.connector_runner import CheckpointOutputWrapper +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import SecondsSinceUnixEpoch +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document + +_ITERATION_LIMIT = 100_000 + + +class SingleConnectorCallOutput(BaseModel): + items: list[Document | ConnectorFailure] + next_checkpoint: ConnectorCheckpoint + + +def load_everything_from_checkpoint_connector( + connector: CheckpointConnector, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, +) -> list[SingleConnectorCallOutput]: + num_iterations = 0 + + checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + outputs: list[SingleConnectorCallOutput] = [] + while checkpoint.has_more: + items: list[Document | ConnectorFailure] = [] + doc_batch_generator = CheckpointOutputWrapper()( + connector.load_from_checkpoint(start, end, checkpoint) + ) + for document, failure, next_checkpoint in doc_batch_generator: + if failure is not None: + items.append(failure) + if document is not None: + items.append(document) + if next_checkpoint is not None: + checkpoint = next_checkpoint + + outputs.append( + SingleConnectorCallOutput(items=items, next_checkpoint=checkpoint) + ) + + num_iterations += 1 + if num_iterations > _ITERATION_LIMIT: + raise RuntimeError("Too many iterations. Infinite loop?") + + return outputs