mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 13:22:42 +01:00
Checkpointed Jira connector
This commit is contained in:
parent
64ff5df083
commit
4c6f4e7158
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@ -15,14 +16,15 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
|
||||
@ -42,121 +44,108 @@ _JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
|
||||
def _paginate_jql_search(
|
||||
def _perform_jql_search(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
start: int,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
) -> Iterable[Issue]:
|
||||
start = 0
|
||||
while True:
|
||||
logger.debug(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
logger.debug(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
yield issue
|
||||
else:
|
||||
raise Exception(f"Found Jira object not of type Issue: {issue}")
|
||||
|
||||
if len(issues) < max_results:
|
||||
break
|
||||
|
||||
start += max_results
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
yield issue
|
||||
else:
|
||||
raise RuntimeError(f"Found Jira object not of type Issue: {issue}")
|
||||
|
||||
|
||||
def fetch_jira_issues_batch(
|
||||
def process_jira_issue(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
batch_size: int,
|
||||
issue: Issue,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
) -> Iterable[Document]:
|
||||
for issue in _paginate_jql_search(
|
||||
jira_client=jira_client,
|
||||
jql=jql,
|
||||
max_results=batch_size,
|
||||
):
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it has a label to skip. Found "
|
||||
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
issue.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(issue.raw["fields"]["description"])
|
||||
)
|
||||
comments = get_comment_strs(
|
||||
issue=issue,
|
||||
comment_email_blacklist=comment_email_blacklist,
|
||||
)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
) -> Document | None:
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
f"Skipping {issue.key} because it has a label to skip. Found "
|
||||
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
return None
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
|
||||
description = (
|
||||
issue.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(issue.raw["fields"]["description"])
|
||||
)
|
||||
comments = get_comment_strs(
|
||||
issue=issue,
|
||||
comment_email_blacklist=comment_email_blacklist,
|
||||
)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
people = set()
|
||||
try:
|
||||
creator = best_effort_get_field_from_issue(issue, "creator")
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
try:
|
||||
assignee = best_effort_get_field_from_issue(issue, "assignee")
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if priority := best_effort_get_field_from_issue(issue, "priority"):
|
||||
metadata_dict["priority"] = priority.name
|
||||
if status := best_effort_get_field_from_issue(issue, "status"):
|
||||
metadata_dict["status"] = status.name
|
||||
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
if labels := best_effort_get_field_from_issue(issue, "labels"):
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
yield Document(
|
||||
id=page_url,
|
||||
sections=[TextSection(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
metadata=metadata_dict,
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
return None
|
||||
|
||||
page_url = build_jira_url(jira_client, issue.key)
|
||||
|
||||
people = set()
|
||||
try:
|
||||
creator = best_effort_get_field_from_issue(issue, "creator")
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
try:
|
||||
assignee = best_effort_get_field_from_issue(issue, "assignee")
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if priority := best_effort_get_field_from_issue(issue, "priority"):
|
||||
metadata_dict["priority"] = priority.name
|
||||
if status := best_effort_get_field_from_issue(issue, "status"):
|
||||
metadata_dict["status"] = status.name
|
||||
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
if labels := best_effort_get_field_from_issue(issue, "labels"):
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
return Document(
|
||||
id=page_url,
|
||||
sections=[TextSection(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
|
||||
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class JiraConnector(CheckpointConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
@ -200,33 +189,10 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_jql_query(self) -> str:
|
||||
"""Get the JQL query based on whether a specific project is set"""
|
||||
if self.jira_project:
|
||||
return f"project = {self.quoted_jira_project}"
|
||||
return "" # Empty string means all accessible projects
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
jql = self._get_jql_query()
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
|
||||
yield document_batch
|
||||
|
||||
def poll_source(
|
||||
def _get_jql_query(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
) -> str:
|
||||
"""Get the JQL query based on whether a specific project is set and time range"""
|
||||
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
@ -234,38 +200,74 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
base_jql = self._get_jql_query()
|
||||
jql = (
|
||||
f"{base_jql} AND " if base_jql else ""
|
||||
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
|
||||
base_jql = f"project = {self.quoted_jira_project}" if self.jira_project else ""
|
||||
time_jql = f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
return f"{base_jql} AND {time_jql}" if base_jql else time_jql
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
jql = self._get_jql_query(start, end)
|
||||
|
||||
# Get the current offset from checkpoint or start at 0
|
||||
starting_offset = checkpoint.checkpoint_content.get("offset", 0)
|
||||
current_offset = starting_offset
|
||||
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_FULL_PAGE_SIZE,
|
||||
):
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
issue_key = issue.key
|
||||
try:
|
||||
if document := process_jira_issue(
|
||||
jira_client=self.jira_client,
|
||||
issue=issue,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
yield document
|
||||
|
||||
yield document_batch
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=issue_key,
|
||||
document_link=build_jira_url(self.jira_client, issue_key),
|
||||
),
|
||||
failure_message=f"Failed to process Jira issue: {str(e)}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
current_offset += 1
|
||||
|
||||
# Update checkpoint
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content={
|
||||
"offset": current_offset,
|
||||
},
|
||||
# if we didn't retrieve a full batch, we're done
|
||||
has_more=current_offset - starting_offset == _JIRA_FULL_PAGE_SIZE,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = self._get_jql_query()
|
||||
) -> Generator[list[SlimDocument], None, None]:
|
||||
jql = self._get_jql_query(start or 0, end or float("inf"))
|
||||
|
||||
slim_doc_batch = []
|
||||
for issue in _paginate_jql_search(
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=0,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
fields="key",
|
||||
):
|
||||
@ -350,5 +352,7 @@ if __name__ == "__main__":
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
0, float("inf"), ConnectorCheckpoint(checkpoint_content={}, has_more=True)
|
||||
)
|
||||
print(next(document_batches))
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.onyx_jira.connector import JiraConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -24,15 +25,13 @@ def jira_connector() -> JiraConnector:
|
||||
|
||||
|
||||
def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
|
||||
doc_batch_generator = jira_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 1
|
||||
|
||||
doc = doc_batch[0]
|
||||
docs = load_all_docs_from_checkpoint_connector(
|
||||
connector=jira_connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
assert len(docs) == 1
|
||||
doc = docs[0]
|
||||
|
||||
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
|
||||
assert doc.semantic_identifier == "AS-2: test123small"
|
||||
|
65
backend/tests/daily/connectors/utils.py
Normal file
65
backend/tests/daily/connectors/utils.py
Normal file
@ -0,0 +1,65 @@
|
||||
from onyx.connectors.connector_runner import CheckpointOutputWrapper
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
_ITERATION_LIMIT = 100_000
|
||||
|
||||
|
||||
def load_all_docs_from_checkpoint_connector(
|
||||
connector: CheckpointConnector,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> list[Document]:
|
||||
num_iterations = 0
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
documents: list[Document] = []
|
||||
while checkpoint.has_more:
|
||||
doc_batch_generator = CheckpointOutputWrapper()(
|
||||
connector.load_from_checkpoint(start, end, checkpoint)
|
||||
)
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
raise RuntimeError(f"Failed to load documents: {failure}")
|
||||
if document is not None:
|
||||
documents.append(document)
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
num_iterations += 1
|
||||
if num_iterations > _ITERATION_LIMIT:
|
||||
raise RuntimeError("Too many iterations. Infinite loop?")
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
def load_everything_from_checkpoint_connector(
|
||||
connector: CheckpointConnector,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> list[Document | ConnectorFailure]:
|
||||
"""Like load_all_docs_from_checkpoint_connector but returns both documents and failures"""
|
||||
num_iterations = 0
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
outputs: list[Document | ConnectorFailure] = []
|
||||
while checkpoint.has_more:
|
||||
doc_batch_generator = CheckpointOutputWrapper()(
|
||||
connector.load_from_checkpoint(start, end, checkpoint)
|
||||
)
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
outputs.append(failure)
|
||||
if document is not None:
|
||||
outputs.append(document)
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
num_iterations += 1
|
||||
if num_iterations > _ITERATION_LIMIT:
|
||||
raise RuntimeError("Too many iterations. Infinite loop?")
|
||||
|
||||
return outputs
|
@ -0,0 +1,440 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from jira import JIRA
|
||||
from jira import JIRAError
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.onyx_jira.connector import JiraConnector
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
PAGE_SIZE = 2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_base_url() -> str:
|
||||
return "https://jira.example.com"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_key() -> str:
|
||||
return "TEST"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_client() -> MagicMock:
|
||||
"""Create a mock JIRA client with proper typing"""
|
||||
mock = MagicMock(spec=JIRA)
|
||||
# Add proper return typing for search_issues method
|
||||
mock.search_issues = MagicMock()
|
||||
# Add proper return typing for project method
|
||||
mock.project = MagicMock()
|
||||
# Add proper return typing for projects method
|
||||
mock.projects = MagicMock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_connector(
|
||||
jira_base_url: str, project_key: str, mock_jira_client: MagicMock
|
||||
) -> Generator[JiraConnector, None, None]:
|
||||
connector = JiraConnector(
|
||||
jira_base_url=jira_base_url,
|
||||
project_key=project_key,
|
||||
comment_email_blacklist=["blacklist@example.com"],
|
||||
labels_to_skip=["secret", "sensitive"],
|
||||
)
|
||||
connector._jira_client = mock_jira_client
|
||||
connector._jira_client.client_info.return_value = jira_base_url
|
||||
with patch("onyx.connectors.onyx_jira.connector._JIRA_FULL_PAGE_SIZE", 2):
|
||||
yield connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_issue() -> Callable[..., MagicMock]:
|
||||
def _create_mock_issue(
|
||||
key: str = "TEST-123",
|
||||
summary: str = "Test Issue",
|
||||
updated: str = "2023-01-01T12:00:00.000+0000",
|
||||
description: str = "Test Description",
|
||||
labels: list[str] | None = None,
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock Issue object"""
|
||||
mock_issue = MagicMock(spec=Issue)
|
||||
# Create fields attribute first
|
||||
mock_issue.fields = MagicMock()
|
||||
mock_issue.key = key
|
||||
mock_issue.fields.summary = summary
|
||||
mock_issue.fields.updated = updated
|
||||
mock_issue.fields.description = description
|
||||
mock_issue.fields.labels = labels or []
|
||||
|
||||
# Set up creator and assignee for testing owner extraction
|
||||
mock_issue.fields.creator = MagicMock()
|
||||
mock_issue.fields.creator.displayName = "Test Creator"
|
||||
mock_issue.fields.creator.emailAddress = "creator@example.com"
|
||||
|
||||
mock_issue.fields.assignee = MagicMock()
|
||||
mock_issue.fields.assignee.displayName = "Test Assignee"
|
||||
mock_issue.fields.assignee.emailAddress = "assignee@example.com"
|
||||
|
||||
# Set up priority, status, and resolution
|
||||
mock_issue.fields.priority = MagicMock()
|
||||
mock_issue.fields.priority.name = "High"
|
||||
|
||||
mock_issue.fields.status = MagicMock()
|
||||
mock_issue.fields.status.name = "In Progress"
|
||||
|
||||
mock_issue.fields.resolution = MagicMock()
|
||||
mock_issue.fields.resolution.name = "Fixed"
|
||||
|
||||
# Add raw field for accessing through API version check
|
||||
mock_issue.raw = {"fields": {"description": description}}
|
||||
|
||||
return mock_issue
|
||||
|
||||
return _create_mock_issue
|
||||
|
||||
|
||||
def test_load_credentials(jira_connector: JiraConnector) -> None:
|
||||
"""Test loading credentials"""
|
||||
with patch(
|
||||
"onyx.connectors.onyx_jira.connector.build_jira_client"
|
||||
) as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
credentials = {
|
||||
"jira_user_email": "user@example.com",
|
||||
"jira_api_token": "token123",
|
||||
}
|
||||
|
||||
result = jira_connector.load_credentials(credentials)
|
||||
|
||||
mock_build_client.assert_called_once_with(
|
||||
credentials=credentials, jira_base=jira_connector.jira_base
|
||||
)
|
||||
assert result is None
|
||||
assert jira_connector._jira_client == mock_build_client.return_value
|
||||
|
||||
|
||||
def test_get_jql_query_with_project(jira_connector: JiraConnector) -> None:
|
||||
"""Test JQL query generation with project specified"""
|
||||
start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
query = jira_connector._get_jql_query(start, end)
|
||||
|
||||
# Check that the project part and time part are both in the query
|
||||
assert f'project = "{jira_connector.jira_project}"' in query
|
||||
assert "updated >= '2023-01-01 00:00'" in query
|
||||
assert "updated <= '2023-01-02 00:00'" in query
|
||||
assert " AND " in query
|
||||
|
||||
|
||||
def test_get_jql_query_without_project(jira_base_url: str) -> None:
|
||||
"""Test JQL query generation without project specified"""
|
||||
# Create connector without project key
|
||||
connector = JiraConnector(jira_base_url=jira_base_url)
|
||||
|
||||
start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
query = connector._get_jql_query(start, end)
|
||||
|
||||
# Check that only time part is in the query
|
||||
assert "project =" not in query
|
||||
assert "updated >= '2023-01-01 00:00'" in query
|
||||
assert "updated <= '2023-01-02 00:00'" in query
|
||||
|
||||
|
||||
def test_load_from_checkpoint_happy_path(
|
||||
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
|
||||
) -> None:
|
||||
"""Test loading from checkpoint - happy path"""
|
||||
# Set up mocked issues
|
||||
mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1")
|
||||
mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2")
|
||||
mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3")
|
||||
|
||||
# Only mock the search_issues method
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
search_issues_mock = cast(MagicMock, jira_client.search_issues)
|
||||
search_issues_mock.side_effect = [
|
||||
[mock_issue1, mock_issue2],
|
||||
[mock_issue3],
|
||||
[],
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
|
||||
|
||||
# Check that the documents were returned
|
||||
assert len(outputs) == 2
|
||||
|
||||
checkpoint_output1 = outputs[0]
|
||||
assert len(checkpoint_output1.items) == 2
|
||||
document1 = checkpoint_output1.items[0]
|
||||
assert isinstance(document1, Document)
|
||||
assert document1.id == "https://jira.example.com/browse/TEST-1"
|
||||
document2 = checkpoint_output1.items[1]
|
||||
assert isinstance(document2, Document)
|
||||
assert document2.id == "https://jira.example.com/browse/TEST-2"
|
||||
assert checkpoint_output1.next_checkpoint == ConnectorCheckpoint(
|
||||
checkpoint_content={
|
||||
"offset": 2,
|
||||
},
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
checkpoint_output2 = outputs[1]
|
||||
assert len(checkpoint_output2.items) == 1
|
||||
document3 = checkpoint_output2.items[0]
|
||||
assert isinstance(document3, Document)
|
||||
assert document3.id == "https://jira.example.com/browse/TEST-3"
|
||||
assert checkpoint_output2.next_checkpoint == ConnectorCheckpoint(
|
||||
checkpoint_content={
|
||||
"offset": 3,
|
||||
},
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
# Check that search_issues was called with the right parameters
|
||||
search_issues_mock.call_count == 2
|
||||
args, kwargs = search_issues_mock.call_args_list[0]
|
||||
assert kwargs["startAt"] == 0
|
||||
assert kwargs["maxResults"] == PAGE_SIZE
|
||||
|
||||
args, kwargs = search_issues_mock.call_args_list[1]
|
||||
assert kwargs["startAt"] == 2
|
||||
assert kwargs["maxResults"] == PAGE_SIZE
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_issue_processing_error(
|
||||
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with a mix of successful and failed issue processing across multiple batches"""
|
||||
# Set up mocked issues for first batch
|
||||
mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1")
|
||||
mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2")
|
||||
# Set up mocked issues for second batch
|
||||
mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3")
|
||||
mock_issue4 = create_mock_issue(key="TEST-4", summary="Issue 4")
|
||||
|
||||
# Mock search_issues to return our mock issues in batches
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
search_issues_mock = cast(MagicMock, jira_client.search_issues)
|
||||
search_issues_mock.side_effect = [
|
||||
[mock_issue1, mock_issue2], # First batch
|
||||
[mock_issue3, mock_issue4], # Second batch
|
||||
[], # Empty batch to indicate end
|
||||
]
|
||||
|
||||
# Mock process_jira_issue to succeed for some issues and fail for others
|
||||
def mock_process_side_effect(
|
||||
jira_client: JIRA, issue: Issue, *args: Any, **kwargs: Any
|
||||
) -> Document | None:
|
||||
if issue.key in ["TEST-1", "TEST-3"]:
|
||||
return Document(
|
||||
id=f"https://jira.example.com/browse/{issue.key}",
|
||||
sections=[],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
metadata={},
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Processing error for {issue.key}")
|
||||
|
||||
with patch(
|
||||
"onyx.connectors.onyx_jira.connector.process_jira_issue"
|
||||
) as mock_process:
|
||||
mock_process.side_effect = mock_process_side_effect
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
|
||||
|
||||
assert len(outputs) == 3
|
||||
|
||||
# Check first batch
|
||||
first_batch = outputs[0]
|
||||
assert len(first_batch.items) == 2
|
||||
# First item should be successful
|
||||
assert isinstance(first_batch.items[0], Document)
|
||||
assert first_batch.items[0].id == "https://jira.example.com/browse/TEST-1"
|
||||
# Second item should be a failure
|
||||
assert isinstance(first_batch.items[1], ConnectorFailure)
|
||||
assert first_batch.items[1].failed_document is not None
|
||||
assert first_batch.items[1].failed_document.document_id == "TEST-2"
|
||||
assert "Failed to process Jira issue" in first_batch.items[1].failure_message
|
||||
# Check checkpoint indicates more items (full batch)
|
||||
assert first_batch.next_checkpoint.has_more is True
|
||||
assert first_batch.next_checkpoint.checkpoint_content == {"offset": 2}
|
||||
|
||||
# Check second batch
|
||||
second_batch = outputs[1]
|
||||
assert len(second_batch.items) == 2
|
||||
# First item should be successful
|
||||
assert isinstance(second_batch.items[0], Document)
|
||||
assert second_batch.items[0].id == "https://jira.example.com/browse/TEST-3"
|
||||
# Second item should be a failure
|
||||
assert isinstance(second_batch.items[1], ConnectorFailure)
|
||||
assert second_batch.items[1].failed_document is not None
|
||||
assert second_batch.items[1].failed_document.document_id == "TEST-4"
|
||||
assert "Failed to process Jira issue" in second_batch.items[1].failure_message
|
||||
# Check checkpoint indicates more items
|
||||
assert second_batch.next_checkpoint.has_more is True
|
||||
assert second_batch.next_checkpoint.checkpoint_content == {"offset": 4}
|
||||
|
||||
# Check third, empty batch
|
||||
third_batch = outputs[2]
|
||||
assert len(third_batch.items) == 0
|
||||
assert third_batch.next_checkpoint.has_more is False
|
||||
assert third_batch.next_checkpoint.checkpoint_content == {"offset": 4}
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_skipped_issue(
|
||||
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with an issue that should be skipped due to labels"""
|
||||
LABEL_TO_SKIP = "secret"
|
||||
jira_connector.labels_to_skip = {LABEL_TO_SKIP}
|
||||
|
||||
# Set up mocked issue with a label to skip
|
||||
mock_issue = create_mock_issue(
|
||||
key="TEST-1", summary="Issue 1", labels=[LABEL_TO_SKIP]
|
||||
)
|
||||
|
||||
# Mock search_issues to return our mock issue
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
search_issues_mock = cast(MagicMock, jira_client.search_issues)
|
||||
search_issues_mock.return_value = [mock_issue]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
|
||||
|
||||
assert len(outputs) == 1
|
||||
checkpoint_output = outputs[0]
|
||||
# Check that no documents were returned
|
||||
assert len(checkpoint_output.items) == 0
|
||||
|
||||
|
||||
def test_retrieve_all_slim_documents(
|
||||
jira_connector: JiraConnector, create_mock_issue: Any
|
||||
) -> None:
|
||||
"""Test retrieving all slim documents"""
|
||||
# Set up mocked issues
|
||||
mock_issue1 = create_mock_issue(key="TEST-1")
|
||||
mock_issue2 = create_mock_issue(key="TEST-2")
|
||||
|
||||
# Mock search_issues to return our mock issues
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
search_issues_mock = cast(MagicMock, jira_client.search_issues)
|
||||
search_issues_mock.return_value = [mock_issue1, mock_issue2]
|
||||
|
||||
# Mock best_effort_get_field_from_issue to return the keys
|
||||
with patch(
|
||||
"onyx.connectors.onyx_jira.connector.best_effort_get_field_from_issue"
|
||||
) as mock_field:
|
||||
mock_field.side_effect = ["TEST-1", "TEST-2"]
|
||||
|
||||
# Mock build_jira_url to return URLs
|
||||
with patch("onyx.connectors.onyx_jira.connector.build_jira_url") as mock_url:
|
||||
mock_url.side_effect = [
|
||||
"https://jira.example.com/browse/TEST-1",
|
||||
"https://jira.example.com/browse/TEST-2",
|
||||
]
|
||||
|
||||
# Call retrieve_all_slim_documents
|
||||
batches = list(jira_connector.retrieve_all_slim_documents(0, 100))
|
||||
|
||||
# Check that a batch with 2 documents was returned
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 2
|
||||
assert isinstance(batches[0][0], SlimDocument)
|
||||
assert batches[0][0].id == "https://jira.example.com/browse/TEST-1"
|
||||
assert batches[0][1].id == "https://jira.example.com/browse/TEST-2"
|
||||
|
||||
# Check that search_issues was called with the right parameters
|
||||
search_issues_mock.assert_called_once()
|
||||
args, kwargs = search_issues_mock.call_args
|
||||
assert kwargs["fields"] == "key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_exception,expected_message",
|
||||
[
|
||||
(
|
||||
401,
|
||||
CredentialExpiredError,
|
||||
"Jira credential appears to be expired or invalid",
|
||||
),
|
||||
(
|
||||
403,
|
||||
InsufficientPermissionsError,
|
||||
"Your Jira token does not have sufficient permissions",
|
||||
),
|
||||
(404, ConnectorValidationError, "Jira project not found"),
|
||||
(
|
||||
429,
|
||||
ConnectorValidationError,
|
||||
"Validation failed due to Jira rate-limits being exceeded",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_connector_settings_errors(
|
||||
jira_connector: JiraConnector,
|
||||
status_code: int,
|
||||
expected_exception: type[Exception],
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
"""Test validation with various error scenarios"""
|
||||
error = JIRAError(status_code=status_code)
|
||||
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
project_mock = cast(MagicMock, jira_client.project)
|
||||
project_mock.side_effect = error
|
||||
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
jira_connector.validate_connector_settings()
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_validate_connector_settings_with_project_success(
|
||||
jira_connector: JiraConnector,
|
||||
) -> None:
|
||||
"""Test successful validation with project specified"""
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
project_mock = cast(MagicMock, jira_client.project)
|
||||
project_mock.return_value = MagicMock()
|
||||
jira_connector.validate_connector_settings()
|
||||
project_mock.assert_called_once_with(jira_connector.jira_project)
|
||||
|
||||
|
||||
def test_validate_connector_settings_without_project_success(
|
||||
jira_base_url: str,
|
||||
) -> None:
|
||||
"""Test successful validation without project specified"""
|
||||
connector = JiraConnector(jira_base_url=jira_base_url)
|
||||
connector._jira_client = MagicMock()
|
||||
connector._jira_client.projects.return_value = [MagicMock()]
|
||||
|
||||
connector.validate_connector_settings()
|
||||
connector._jira_client.projects.assert_called_once()
|
@ -7,7 +7,8 @@ import pytest
|
||||
from jira.resources import Issue
|
||||
from pytest_mock import MockFixture
|
||||
|
||||
from onyx.connectors.onyx_jira.connector import fetch_jira_issues_batch
|
||||
from onyx.connectors.onyx_jira.connector import _perform_jql_search
|
||||
from onyx.connectors.onyx_jira.connector import process_jira_issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -79,14 +80,22 @@ def test_fetch_jira_issues_batch_small_ticket(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small]
|
||||
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
# First get the issues via pagination
|
||||
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
|
||||
assert len(issues) == 1
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
assert docs[0].sections[0].text is not None
|
||||
assert "Small description" in docs[0].sections[0].text
|
||||
assert "Small comment 1" in docs[0].sections[0].text
|
||||
assert "Small comment 2" in docs[0].sections[0].text
|
||||
doc = docs[0]
|
||||
assert doc is not None # Type assertion for mypy
|
||||
assert doc.id.endswith("/SMALL-1")
|
||||
assert doc.sections[0].text is not None
|
||||
assert "Small description" in doc.sections[0].text
|
||||
assert "Small comment 1" in doc.sections[0].text
|
||||
assert "Small comment 2" in doc.sections[0].text
|
||||
|
||||
|
||||
def test_fetch_jira_issues_batch_large_ticket(
|
||||
@ -96,7 +105,13 @@ def test_fetch_jira_issues_batch_large_ticket(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_large]
|
||||
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
# First get the issues via pagination
|
||||
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
|
||||
assert len(issues) == 1
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
|
||||
assert len(docs) == 0 # The large ticket should be skipped
|
||||
|
||||
@ -109,10 +124,18 @@ def test_fetch_jira_issues_batch_mixed_tickets(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
# First get the issues via pagination
|
||||
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
|
||||
assert len(issues) == 2
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
|
||||
assert len(docs) == 1 # Only the small ticket should be included
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
doc = docs[0]
|
||||
assert doc is not None # Type assertion for mypy
|
||||
assert doc.id.endswith("/SMALL-1")
|
||||
|
||||
|
||||
@patch("onyx.connectors.onyx_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50)
|
||||
@ -124,6 +147,12 @@ def test_fetch_jira_issues_batch_custom_size_limit(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
# First get the issues via pagination
|
||||
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
|
||||
assert len(issues) == 2
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
|
||||
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit
|
||||
|
48
backend/tests/unit/onyx/connectors/utils.py
Normal file
48
backend/tests/unit/onyx/connectors/utils.py
Normal file
@ -0,0 +1,48 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.connector_runner import CheckpointOutputWrapper
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
_ITERATION_LIMIT = 100_000
|
||||
|
||||
|
||||
class SingleConnectorCallOutput(BaseModel):
|
||||
items: list[Document | ConnectorFailure]
|
||||
next_checkpoint: ConnectorCheckpoint
|
||||
|
||||
|
||||
def load_everything_from_checkpoint_connector(
|
||||
connector: CheckpointConnector,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> list[SingleConnectorCallOutput]:
|
||||
num_iterations = 0
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
outputs: list[SingleConnectorCallOutput] = []
|
||||
while checkpoint.has_more:
|
||||
items: list[Document | ConnectorFailure] = []
|
||||
doc_batch_generator = CheckpointOutputWrapper()(
|
||||
connector.load_from_checkpoint(start, end, checkpoint)
|
||||
)
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
items.append(failure)
|
||||
if document is not None:
|
||||
items.append(document)
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
outputs.append(
|
||||
SingleConnectorCallOutput(items=items, next_checkpoint=checkpoint)
|
||||
)
|
||||
|
||||
num_iterations += 1
|
||||
if num_iterations > _ITERATION_LIMIT:
|
||||
raise RuntimeError("Too many iterations. Infinite loop?")
|
||||
|
||||
return outputs
|
Loading…
x
Reference in New Issue
Block a user