Checkpointed Jira connector

This commit is contained in:
Weves 2025-03-15 13:54:49 -07:00
parent 64ff5df083
commit 4c6f4e7158
6 changed files with 750 additions and 165 deletions

View File

@ -1,4 +1,5 @@
import os
from collections.abc import Generator
from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
@ -15,14 +16,15 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
@ -42,121 +44,108 @@ _JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
def _paginate_jql_search(
def _perform_jql_search(
jira_client: JIRA,
jql: str,
start: int,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
start = 0
while True:
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise Exception(f"Found Jira object not of type Issue: {issue}")
if len(issues) < max_results:
break
start += max_results
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise RuntimeError(f"Found Jira object not of type Issue: {issue}")
def fetch_jira_issues_batch(
def process_jira_issue(
jira_client: JIRA,
jql: str,
batch_size: int,
issue: Issue,
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
) -> Iterable[Document]:
for issue in _paginate_jql_search(
jira_client=jira_client,
jql=jql,
max_results=batch_size,
):
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
description = (
issue.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
) -> Document | None:
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
return None
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
description = (
issue.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
people = set()
try:
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
if priority := best_effort_get_field_from_issue(issue, "priority"):
metadata_dict["priority"] = priority.name
if status := best_effort_get_field_from_issue(issue, "status"):
metadata_dict["status"] = status.name
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
metadata_dict["resolution"] = resolution.name
if labels := best_effort_get_field_from_issue(issue, "labels"):
metadata_dict["label"] = labels
yield Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {issue.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
return None
page_url = build_jira_url(jira_client, issue.key)
people = set()
try:
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
if priority := best_effort_get_field_from_issue(issue, "priority"):
metadata_dict["priority"] = priority.name
if status := best_effort_get_field_from_issue(issue, "status"):
metadata_dict["status"] = status.name
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
metadata_dict["resolution"] = resolution.name
if labels := best_effort_get_field_from_issue(issue, "labels"):
metadata_dict["label"] = labels
return Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
metadata=metadata_dict,
)
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
class JiraConnector(CheckpointConnector, SlimConnector):
def __init__(
self,
jira_base_url: str,
@ -200,33 +189,10 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _get_jql_query(self) -> str:
"""Get the JQL query based on whether a specific project is set"""
if self.jira_project:
return f"project = {self.quoted_jira_project}"
return "" # Empty string means all accessible projects
def load_from_state(self) -> GenerateDocumentsOutput:
jql = self._get_jql_query()
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
yield document_batch
def poll_source(
def _get_jql_query(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
) -> str:
"""Get the JQL query based on whether a specific project is set and time range"""
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M"
)
@ -234,38 +200,74 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
base_jql = self._get_jql_query()
jql = (
f"{base_jql} AND " if base_jql else ""
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
base_jql = f"project = {self.quoted_jira_project}" if self.jira_project else ""
time_jql = f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
document_batch = []
for doc in fetch_jira_issues_batch(
return f"{base_jql} AND {time_jql}" if base_jql else time_jql
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
jql = self._get_jql_query(start, end)
# Get the current offset from checkpoint or start at 0
starting_offset = checkpoint.checkpoint_content.get("offset", 0)
current_offset = starting_offset
for issue in _perform_jql_search(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
start=current_offset,
max_results=_JIRA_FULL_PAGE_SIZE,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
issue_key = issue.key
try:
if document := process_jira_issue(
jira_client=self.jira_client,
issue=issue,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
yield document
yield document_batch
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=issue_key,
document_link=build_jira_url(self.jira_client, issue_key),
),
failure_message=f"Failed to process Jira issue: {str(e)}",
exception=e,
)
current_offset += 1
# Update checkpoint
checkpoint = ConnectorCheckpoint(
checkpoint_content={
"offset": current_offset,
},
# if we didn't retrieve a full batch, we're done
has_more=current_offset - starting_offset == _JIRA_FULL_PAGE_SIZE,
)
return checkpoint
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
jql = self._get_jql_query()
) -> Generator[list[SlimDocument], None, None]:
jql = self._get_jql_query(start or 0, end or float("inf"))
slim_doc_batch = []
for issue in _paginate_jql_search(
for issue in _perform_jql_search(
jira_client=self.jira_client,
jql=jql,
start=0,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields="key",
):
@ -350,5 +352,7 @@ if __name__ == "__main__":
"jira_api_token": os.environ["JIRA_API_TOKEN"],
}
)
document_batches = connector.load_from_state()
document_batches = connector.load_from_checkpoint(
0, float("inf"), ConnectorCheckpoint(checkpoint_content={}, has_more=True)
)
print(next(document_batches))

View File

@ -5,6 +5,7 @@ import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.onyx_jira.connector import JiraConnector
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
@pytest.fixture
@ -24,15 +25,13 @@ def jira_connector() -> JiraConnector:
def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
doc_batch_generator = jira_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 1
doc = doc_batch[0]
docs = load_all_docs_from_checkpoint_connector(
connector=jira_connector,
start=0,
end=time.time(),
)
assert len(docs) == 1
doc = docs[0]
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
assert doc.semantic_identifier == "AS-2: test123small"

View File

@ -0,0 +1,65 @@
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
_ITERATION_LIMIT = 100_000
def load_all_docs_from_checkpoint_connector(
connector: CheckpointConnector,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[Document]:
num_iterations = 0
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
documents: list[Document] = []
while checkpoint.has_more:
doc_batch_generator = CheckpointOutputWrapper()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
raise RuntimeError(f"Failed to load documents: {failure}")
if document is not None:
documents.append(document)
if next_checkpoint is not None:
checkpoint = next_checkpoint
num_iterations += 1
if num_iterations > _ITERATION_LIMIT:
raise RuntimeError("Too many iterations. Infinite loop?")
return documents
def load_everything_from_checkpoint_connector(
connector: CheckpointConnector,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[Document | ConnectorFailure]:
"""Like load_all_docs_from_checkpoint_connector but returns both documents and failures"""
num_iterations = 0
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
outputs: list[Document | ConnectorFailure] = []
while checkpoint.has_more:
doc_batch_generator = CheckpointOutputWrapper()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
outputs.append(failure)
if document is not None:
outputs.append(document)
if next_checkpoint is not None:
checkpoint = next_checkpoint
num_iterations += 1
if num_iterations > _ITERATION_LIMIT:
raise RuntimeError("Too many iterations. Infinite loop?")
return outputs

View File

@ -0,0 +1,440 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from jira import JIRA
from jira import JIRAError
from jira.resources import Issue
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.onyx_jira.connector import JiraConnector
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
PAGE_SIZE = 2
@pytest.fixture
def jira_base_url() -> str:
return "https://jira.example.com"
@pytest.fixture
def project_key() -> str:
return "TEST"
@pytest.fixture
def mock_jira_client() -> MagicMock:
"""Create a mock JIRA client with proper typing"""
mock = MagicMock(spec=JIRA)
# Add proper return typing for search_issues method
mock.search_issues = MagicMock()
# Add proper return typing for project method
mock.project = MagicMock()
# Add proper return typing for projects method
mock.projects = MagicMock()
return mock
@pytest.fixture
def jira_connector(
jira_base_url: str, project_key: str, mock_jira_client: MagicMock
) -> Generator[JiraConnector, None, None]:
connector = JiraConnector(
jira_base_url=jira_base_url,
project_key=project_key,
comment_email_blacklist=["blacklist@example.com"],
labels_to_skip=["secret", "sensitive"],
)
connector._jira_client = mock_jira_client
connector._jira_client.client_info.return_value = jira_base_url
with patch("onyx.connectors.onyx_jira.connector._JIRA_FULL_PAGE_SIZE", 2):
yield connector
@pytest.fixture
def create_mock_issue() -> Callable[..., MagicMock]:
def _create_mock_issue(
key: str = "TEST-123",
summary: str = "Test Issue",
updated: str = "2023-01-01T12:00:00.000+0000",
description: str = "Test Description",
labels: list[str] | None = None,
) -> MagicMock:
"""Helper to create a mock Issue object"""
mock_issue = MagicMock(spec=Issue)
# Create fields attribute first
mock_issue.fields = MagicMock()
mock_issue.key = key
mock_issue.fields.summary = summary
mock_issue.fields.updated = updated
mock_issue.fields.description = description
mock_issue.fields.labels = labels or []
# Set up creator and assignee for testing owner extraction
mock_issue.fields.creator = MagicMock()
mock_issue.fields.creator.displayName = "Test Creator"
mock_issue.fields.creator.emailAddress = "creator@example.com"
mock_issue.fields.assignee = MagicMock()
mock_issue.fields.assignee.displayName = "Test Assignee"
mock_issue.fields.assignee.emailAddress = "assignee@example.com"
# Set up priority, status, and resolution
mock_issue.fields.priority = MagicMock()
mock_issue.fields.priority.name = "High"
mock_issue.fields.status = MagicMock()
mock_issue.fields.status.name = "In Progress"
mock_issue.fields.resolution = MagicMock()
mock_issue.fields.resolution.name = "Fixed"
# Add raw field for accessing through API version check
mock_issue.raw = {"fields": {"description": description}}
return mock_issue
return _create_mock_issue
def test_load_credentials(jira_connector: JiraConnector) -> None:
"""Test loading credentials"""
with patch(
"onyx.connectors.onyx_jira.connector.build_jira_client"
) as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
credentials = {
"jira_user_email": "user@example.com",
"jira_api_token": "token123",
}
result = jira_connector.load_credentials(credentials)
mock_build_client.assert_called_once_with(
credentials=credentials, jira_base=jira_connector.jira_base
)
assert result is None
assert jira_connector._jira_client == mock_build_client.return_value
def test_get_jql_query_with_project(jira_connector: JiraConnector) -> None:
"""Test JQL query generation with project specified"""
start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp()
query = jira_connector._get_jql_query(start, end)
# Check that the project part and time part are both in the query
assert f'project = "{jira_connector.jira_project}"' in query
assert "updated >= '2023-01-01 00:00'" in query
assert "updated <= '2023-01-02 00:00'" in query
assert " AND " in query
def test_get_jql_query_without_project(jira_base_url: str) -> None:
"""Test JQL query generation without project specified"""
# Create connector without project key
connector = JiraConnector(jira_base_url=jira_base_url)
start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp()
query = connector._get_jql_query(start, end)
# Check that only time part is in the query
assert "project =" not in query
assert "updated >= '2023-01-01 00:00'" in query
assert "updated <= '2023-01-02 00:00'" in query
def test_load_from_checkpoint_happy_path(
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
) -> None:
"""Test loading from checkpoint - happy path"""
# Set up mocked issues
mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1")
mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2")
mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3")
# Only mock the search_issues method
jira_client = cast(JIRA, jira_connector._jira_client)
search_issues_mock = cast(MagicMock, jira_client.search_issues)
search_issues_mock.side_effect = [
[mock_issue1, mock_issue2],
[mock_issue3],
[],
]
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
# Check that the documents were returned
assert len(outputs) == 2
checkpoint_output1 = outputs[0]
assert len(checkpoint_output1.items) == 2
document1 = checkpoint_output1.items[0]
assert isinstance(document1, Document)
assert document1.id == "https://jira.example.com/browse/TEST-1"
document2 = checkpoint_output1.items[1]
assert isinstance(document2, Document)
assert document2.id == "https://jira.example.com/browse/TEST-2"
assert checkpoint_output1.next_checkpoint == ConnectorCheckpoint(
checkpoint_content={
"offset": 2,
},
has_more=True,
)
checkpoint_output2 = outputs[1]
assert len(checkpoint_output2.items) == 1
document3 = checkpoint_output2.items[0]
assert isinstance(document3, Document)
assert document3.id == "https://jira.example.com/browse/TEST-3"
assert checkpoint_output2.next_checkpoint == ConnectorCheckpoint(
checkpoint_content={
"offset": 3,
},
has_more=False,
)
# Check that search_issues was called with the right parameters
search_issues_mock.call_count == 2
args, kwargs = search_issues_mock.call_args_list[0]
assert kwargs["startAt"] == 0
assert kwargs["maxResults"] == PAGE_SIZE
args, kwargs = search_issues_mock.call_args_list[1]
assert kwargs["startAt"] == 2
assert kwargs["maxResults"] == PAGE_SIZE
def test_load_from_checkpoint_with_issue_processing_error(
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
) -> None:
"""Test loading from checkpoint with a mix of successful and failed issue processing across multiple batches"""
# Set up mocked issues for first batch
mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1")
mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2")
# Set up mocked issues for second batch
mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3")
mock_issue4 = create_mock_issue(key="TEST-4", summary="Issue 4")
# Mock search_issues to return our mock issues in batches
jira_client = cast(JIRA, jira_connector._jira_client)
search_issues_mock = cast(MagicMock, jira_client.search_issues)
search_issues_mock.side_effect = [
[mock_issue1, mock_issue2], # First batch
[mock_issue3, mock_issue4], # Second batch
[], # Empty batch to indicate end
]
# Mock process_jira_issue to succeed for some issues and fail for others
def mock_process_side_effect(
jira_client: JIRA, issue: Issue, *args: Any, **kwargs: Any
) -> Document | None:
if issue.key in ["TEST-1", "TEST-3"]:
return Document(
id=f"https://jira.example.com/browse/{issue.key}",
sections=[],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
metadata={},
)
else:
raise Exception(f"Processing error for {issue.key}")
with patch(
"onyx.connectors.onyx_jira.connector.process_jira_issue"
) as mock_process:
mock_process.side_effect = mock_process_side_effect
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
assert len(outputs) == 3
# Check first batch
first_batch = outputs[0]
assert len(first_batch.items) == 2
# First item should be successful
assert isinstance(first_batch.items[0], Document)
assert first_batch.items[0].id == "https://jira.example.com/browse/TEST-1"
# Second item should be a failure
assert isinstance(first_batch.items[1], ConnectorFailure)
assert first_batch.items[1].failed_document is not None
assert first_batch.items[1].failed_document.document_id == "TEST-2"
assert "Failed to process Jira issue" in first_batch.items[1].failure_message
# Check checkpoint indicates more items (full batch)
assert first_batch.next_checkpoint.has_more is True
assert first_batch.next_checkpoint.checkpoint_content == {"offset": 2}
# Check second batch
second_batch = outputs[1]
assert len(second_batch.items) == 2
# First item should be successful
assert isinstance(second_batch.items[0], Document)
assert second_batch.items[0].id == "https://jira.example.com/browse/TEST-3"
# Second item should be a failure
assert isinstance(second_batch.items[1], ConnectorFailure)
assert second_batch.items[1].failed_document is not None
assert second_batch.items[1].failed_document.document_id == "TEST-4"
assert "Failed to process Jira issue" in second_batch.items[1].failure_message
# Check checkpoint indicates more items
assert second_batch.next_checkpoint.has_more is True
assert second_batch.next_checkpoint.checkpoint_content == {"offset": 4}
# Check third, empty batch
third_batch = outputs[2]
assert len(third_batch.items) == 0
assert third_batch.next_checkpoint.has_more is False
assert third_batch.next_checkpoint.checkpoint_content == {"offset": 4}
def test_load_from_checkpoint_with_skipped_issue(
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
) -> None:
"""Test loading from checkpoint with an issue that should be skipped due to labels"""
LABEL_TO_SKIP = "secret"
jira_connector.labels_to_skip = {LABEL_TO_SKIP}
# Set up mocked issue with a label to skip
mock_issue = create_mock_issue(
key="TEST-1", summary="Issue 1", labels=[LABEL_TO_SKIP]
)
# Mock search_issues to return our mock issue
jira_client = cast(JIRA, jira_connector._jira_client)
search_issues_mock = cast(MagicMock, jira_client.search_issues)
search_issues_mock.return_value = [mock_issue]
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
assert len(outputs) == 1
checkpoint_output = outputs[0]
# Check that no documents were returned
assert len(checkpoint_output.items) == 0
def test_retrieve_all_slim_documents(
jira_connector: JiraConnector, create_mock_issue: Any
) -> None:
"""Test retrieving all slim documents"""
# Set up mocked issues
mock_issue1 = create_mock_issue(key="TEST-1")
mock_issue2 = create_mock_issue(key="TEST-2")
# Mock search_issues to return our mock issues
jira_client = cast(JIRA, jira_connector._jira_client)
search_issues_mock = cast(MagicMock, jira_client.search_issues)
search_issues_mock.return_value = [mock_issue1, mock_issue2]
# Mock best_effort_get_field_from_issue to return the keys
with patch(
"onyx.connectors.onyx_jira.connector.best_effort_get_field_from_issue"
) as mock_field:
mock_field.side_effect = ["TEST-1", "TEST-2"]
# Mock build_jira_url to return URLs
with patch("onyx.connectors.onyx_jira.connector.build_jira_url") as mock_url:
mock_url.side_effect = [
"https://jira.example.com/browse/TEST-1",
"https://jira.example.com/browse/TEST-2",
]
# Call retrieve_all_slim_documents
batches = list(jira_connector.retrieve_all_slim_documents(0, 100))
# Check that a batch with 2 documents was returned
assert len(batches) == 1
assert len(batches[0]) == 2
assert isinstance(batches[0][0], SlimDocument)
assert batches[0][0].id == "https://jira.example.com/browse/TEST-1"
assert batches[0][1].id == "https://jira.example.com/browse/TEST-2"
# Check that search_issues was called with the right parameters
search_issues_mock.assert_called_once()
args, kwargs = search_issues_mock.call_args
assert kwargs["fields"] == "key"
@pytest.mark.parametrize(
"status_code,expected_exception,expected_message",
[
(
401,
CredentialExpiredError,
"Jira credential appears to be expired or invalid",
),
(
403,
InsufficientPermissionsError,
"Your Jira token does not have sufficient permissions",
),
(404, ConnectorValidationError, "Jira project not found"),
(
429,
ConnectorValidationError,
"Validation failed due to Jira rate-limits being exceeded",
),
],
)
def test_validate_connector_settings_errors(
jira_connector: JiraConnector,
status_code: int,
expected_exception: type[Exception],
expected_message: str,
) -> None:
"""Test validation with various error scenarios"""
error = JIRAError(status_code=status_code)
jira_client = cast(JIRA, jira_connector._jira_client)
project_mock = cast(MagicMock, jira_client.project)
project_mock.side_effect = error
with pytest.raises(expected_exception) as excinfo:
jira_connector.validate_connector_settings()
assert expected_message in str(excinfo.value)
def test_validate_connector_settings_with_project_success(
jira_connector: JiraConnector,
) -> None:
"""Test successful validation with project specified"""
jira_client = cast(JIRA, jira_connector._jira_client)
project_mock = cast(MagicMock, jira_client.project)
project_mock.return_value = MagicMock()
jira_connector.validate_connector_settings()
project_mock.assert_called_once_with(jira_connector.jira_project)
def test_validate_connector_settings_without_project_success(
jira_base_url: str,
) -> None:
"""Test successful validation without project specified"""
connector = JiraConnector(jira_base_url=jira_base_url)
connector._jira_client = MagicMock()
connector._jira_client.projects.return_value = [MagicMock()]
connector.validate_connector_settings()
connector._jira_client.projects.assert_called_once()

View File

@ -7,7 +7,8 @@ import pytest
from jira.resources import Issue
from pytest_mock import MockFixture
from onyx.connectors.onyx_jira.connector import fetch_jira_issues_batch
from onyx.connectors.onyx_jira.connector import _perform_jql_search
from onyx.connectors.onyx_jira.connector import process_jira_issue
@pytest.fixture
@ -79,14 +80,22 @@ def test_fetch_jira_issues_batch_small_ticket(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
# First get the issues via pagination
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
assert len(issues) == 1
# Then process each issue
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
docs = [doc for doc in docs if doc is not None] # Filter out None values
assert len(docs) == 1
assert docs[0].id.endswith("/SMALL-1")
assert docs[0].sections[0].text is not None
assert "Small description" in docs[0].sections[0].text
assert "Small comment 1" in docs[0].sections[0].text
assert "Small comment 2" in docs[0].sections[0].text
doc = docs[0]
assert doc is not None # Type assertion for mypy
assert doc.id.endswith("/SMALL-1")
assert doc.sections[0].text is not None
assert "Small description" in doc.sections[0].text
assert "Small comment 1" in doc.sections[0].text
assert "Small comment 2" in doc.sections[0].text
def test_fetch_jira_issues_batch_large_ticket(
@ -96,7 +105,13 @@ def test_fetch_jira_issues_batch_large_ticket(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_large]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
# First get the issues via pagination
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
assert len(issues) == 1
# Then process each issue
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
docs = [doc for doc in docs if doc is not None] # Filter out None values
assert len(docs) == 0 # The large ticket should be skipped
@ -109,10 +124,18 @@ def test_fetch_jira_issues_batch_mixed_tickets(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
# First get the issues via pagination
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
assert len(issues) == 2
# Then process each issue
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
docs = [doc for doc in docs if doc is not None] # Filter out None values
assert len(docs) == 1 # Only the small ticket should be included
assert docs[0].id.endswith("/SMALL-1")
doc = docs[0]
assert doc is not None # Type assertion for mypy
assert doc.id.endswith("/SMALL-1")
@patch("onyx.connectors.onyx_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50)
@ -124,6 +147,12 @@ def test_fetch_jira_issues_batch_custom_size_limit(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
# First get the issues via pagination
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
assert len(issues) == 2
# Then process each issue
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
docs = [doc for doc in docs if doc is not None] # Filter out None values
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit

View File

@ -0,0 +1,48 @@
from pydantic import BaseModel
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
_ITERATION_LIMIT = 100_000
class SingleConnectorCallOutput(BaseModel):
items: list[Document | ConnectorFailure]
next_checkpoint: ConnectorCheckpoint
def load_everything_from_checkpoint_connector(
connector: CheckpointConnector,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[SingleConnectorCallOutput]:
num_iterations = 0
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
outputs: list[SingleConnectorCallOutput] = []
while checkpoint.has_more:
items: list[Document | ConnectorFailure] = []
doc_batch_generator = CheckpointOutputWrapper()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
items.append(failure)
if document is not None:
items.append(document)
if next_checkpoint is not None:
checkpoint = next_checkpoint
outputs.append(
SingleConnectorCallOutput(items=items, next_checkpoint=checkpoint)
)
num_iterations += 1
if num_iterations > _ITERATION_LIMIT:
raise RuntimeError("Too many iterations. Infinite loop?")
return outputs