Added Slim connector for Jira (#3181)

* Added Slim connector for Jira

* fixed testing

* more cleanup of Jira connector

* cleanup
This commit is contained in:
hagen-danswer 2024-11-21 09:00:20 -08:00 committed by GitHub
parent 70207b4b39
commit 100b4a0d16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 307 additions and 231 deletions

View File

@ -1,8 +1,8 @@
import os import os
from collections.abc import Iterable
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
from typing import Any from typing import Any
from urllib.parse import urlparse
from jira import JIRA from jira import JIRA
from jira.resources import Issue from jira.resources import Issue
@ -12,129 +12,93 @@ from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.danswer_jira.utils import best_effort_basic_expert_info
from danswer.connectors.danswer_jira.utils import best_effort_get_field_from_issue
from danswer.connectors.danswer_jira.utils import build_jira_client
from danswer.connectors.danswer_jira.utils import build_jira_url
from danswer.connectors.danswer_jira.utils import extract_jira_project
from danswer.connectors.danswer_jira.utils import extract_text_from_adf
from danswer.connectors.danswer_jira.utils import get_comment_strs
from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import ConnectorMissingCredentialError from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2" JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
_JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
def extract_jira_project(url: str) -> tuple[str, str]: def _paginate_jql_search(
parsed_url = urlparse(url) jira_client: JIRA,
jira_base = parsed_url.scheme + "://" + parsed_url.netloc jql: str,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
start = 0
while True:
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
# Split the path by '/' and find the position of 'projects' to get the project name for issue in issues:
split_path = parsed_url.path.split("/") if isinstance(issue, Issue):
if PROJECT_URL_PAT in split_path: yield issue
project_pos = split_path.index(PROJECT_URL_PAT) else:
if len(split_path) > project_pos + 1: raise Exception(f"Found Jira object not of type Issue: {issue}")
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
return jira_base, jira_project if len(issues) < max_results:
break
start += max_results
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in jira.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
def fetch_jira_issues_batch( def fetch_jira_issues_batch(
jql: str,
start_index: int,
jira_client: JIRA, jira_client: JIRA,
batch_size: int = INDEX_BATCH_SIZE, jql: str,
batch_size: int,
comment_email_blacklist: tuple[str, ...] = (), comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None, labels_to_skip: set[str] | None = None,
) -> tuple[list[Document], int]: ) -> Iterable[Document]:
doc_batch = [] for issue in _paginate_jql_search(
jira_client=jira_client,
batch = jira_client.search_issues( jql=jql,
jql, max_results=batch_size,
startAt=start_index, ):
maxResults=batch_size, if labels_to_skip:
) if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
for jira in batch: f"Skipping {issue.key} because it has a label to skip. Found "
if type(jira) != Issue: f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
logger.warning(f"Found Jira object not of type Issue {jira}") )
continue continue
if labels_to_skip and any(
label in jira.fields.labels for label in labels_to_skip
):
logger.info(
f"Skipping {jira.key} because it has a label to skip. Found "
f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
description = ( description = (
jira.fields.description issue.fields.description
if JIRA_API_VERSION == "2" if JIRA_API_VERSION == "2"
else extract_text_from_adf(jira.raw["fields"]["description"]) else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
) )
comments = _get_comment_strs(jira, comment_email_blacklist)
ticket_content = f"{description}\n" + "\n".join( ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment] [f"Comment: {comment}" for comment in comments if comment]
) )
@ -142,66 +106,53 @@ def fetch_jira_issues_batch(
# Check ticket size # Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE: if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info( logger.info(
f"Skipping {jira.key} because it exceeds the maximum size of " f"Skipping {issue.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes." f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
) )
continue continue
page_url = f"{jira_client.client_info()}/browse/{jira.key}" page_url = f"{jira_client.client_info()}/browse/{issue.key}"
people = set() people = set()
try: try:
people.add( creator = best_effort_get_field_from_issue(issue, "creator")
BasicExpertInfo( if basic_expert_info := best_effort_basic_expert_info(creator):
display_name=jira.fields.creator.displayName, people.add(basic_expert_info)
email=jira.fields.creator.emailAddress,
)
)
except Exception: except Exception:
# Author should exist but if not, doesn't matter # Author should exist but if not, doesn't matter
pass pass
try: try:
people.add( assignee = best_effort_get_field_from_issue(issue, "assignee")
BasicExpertInfo( if basic_expert_info := best_effort_basic_expert_info(assignee):
display_name=jira.fields.assignee.displayName, # type: ignore people.add(basic_expert_info)
email=jira.fields.assignee.emailAddress, # type: ignore
)
)
except Exception: except Exception:
# Author should exist but if not, doesn't matter # Author should exist but if not, doesn't matter
pass pass
metadata_dict = {} metadata_dict = {}
priority = best_effort_get_field_from_issue(jira, "priority") if priority := best_effort_get_field_from_issue(issue, "priority"):
if priority:
metadata_dict["priority"] = priority.name metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status") if status := best_effort_get_field_from_issue(issue, "status"):
if status:
metadata_dict["status"] = status.name metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution") if resolution := best_effort_get_field_from_issue(issue, "resolution"):
if resolution:
metadata_dict["resolution"] = resolution.name metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels") if labels := best_effort_get_field_from_issue(issue, "labels"):
if labels:
metadata_dict["label"] = labels metadata_dict["label"] = labels
doc_batch.append( yield Document(
Document( id=page_url,
id=page_url, sections=[Section(link=page_url, text=ticket_content)],
sections=[Section(link=page_url, text=ticket_content)], source=DocumentSource.JIRA,
source=DocumentSource.JIRA, semantic_identifier=issue.fields.summary,
semantic_identifier=jira.fields.summary, doc_updated_at=time_str_to_utc(issue.fields.updated),
doc_updated_at=time_str_to_utc(jira.fields.updated), primary_owners=list(people) or None,
primary_owners=list(people) or None, # TODO add secondary_owners (commenters) if needed
# TODO add secondary_owners (commenters) if needed metadata=metadata_dict,
metadata=metadata_dict,
)
) )
return doc_batch, len(batch)
class JiraConnector(LoadConnector, PollConnector): class JiraConnector(LoadConnector, PollConnector, SlimConnector):
def __init__( def __init__(
self, self,
jira_project_url: str, jira_project_url: str,
@ -213,8 +164,8 @@ class JiraConnector(LoadConnector, PollConnector):
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP, labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
) -> None: ) -> None:
self.batch_size = batch_size self.batch_size = batch_size
self.jira_base, self.jira_project = extract_jira_project(jira_project_url) self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
self.jira_client: JIRA | None = None self._jira_client: JIRA | None = None
self._comment_email_blacklist = comment_email_blacklist or [] self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip) self.labels_to_skip = set(labels_to_skip)
@ -223,54 +174,45 @@ class JiraConnector(LoadConnector, PollConnector):
def comment_email_blacklist(self) -> tuple: def comment_email_blacklist(self) -> tuple:
return tuple(email.strip() for email in self._comment_email_blacklist) return tuple(email.strip() for email in self._comment_email_blacklist)
@property
def jira_client(self) -> JIRA:
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
return self._jira_client
@property
def quoted_jira_project(self) -> str:
# Quote the project name to handle reserved words
return f'"{self._jira_project}"'
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
api_token = credentials["jira_api_token"] self._jira_client = build_jira_client(
# if user provide an email we assume it's cloud credentials=credentials,
if "jira_user_email" in credentials: jira_base=self.jira_base,
email = credentials["jira_user_email"] )
self.jira_client = JIRA(
basic_auth=(email, api_token),
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
self.jira_client = JIRA(
token_auth=api_token,
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
return None return None
def load_from_state(self) -> GenerateDocumentsOutput: def load_from_state(self) -> GenerateDocumentsOutput:
if self.jira_client is None: jql = f"project = {self.quoted_jira_project}"
raise ConnectorMissingCredentialError("Jira")
# Quote the project name to handle reserved words document_batch = []
quoted_project = f'"{self.jira_project}"' for doc in fetch_jira_issues_batch(
start_ind = 0 jira_client=self.jira_client,
while True: jql=jql,
doc_batch, fetched_batch_size = fetch_jira_issues_batch( batch_size=_JIRA_FULL_PAGE_SIZE,
jql=f"project = {quoted_project}", comment_email_blacklist=self.comment_email_blacklist,
start_index=start_ind, labels_to_skip=self.labels_to_skip,
jira_client=self.jira_client, ):
batch_size=self.batch_size, document_batch.append(doc)
comment_email_blacklist=self.comment_email_blacklist, if len(document_batch) >= self.batch_size:
labels_to_skip=self.labels_to_skip, yield document_batch
) document_batch = []
if doc_batch: yield document_batch
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
def poll_source( def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput: ) -> GenerateDocumentsOutput:
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime( start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M" "%Y-%m-%d %H:%M"
) )
@ -278,31 +220,54 @@ class JiraConnector(LoadConnector, PollConnector):
"%Y-%m-%d %H:%M" "%Y-%m-%d %H:%M"
) )
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = ( jql = (
f"project = {quoted_project} AND " f"project = {self.quoted_jira_project} AND "
f"updated >= '{start_date_str}' AND " f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'" f"updated <= '{end_date_str}'"
) )
start_ind = 0 document_batch = []
while True: for doc in fetch_jira_issues_batch(
doc_batch, fetched_batch_size = fetch_jira_issues_batch( jira_client=self.jira_client,
jql=jql, jql=jql,
start_index=start_ind, batch_size=_JIRA_FULL_PAGE_SIZE,
jira_client=self.jira_client, comment_email_blacklist=self.comment_email_blacklist,
batch_size=self.batch_size, labels_to_skip=self.labels_to_skip,
comment_email_blacklist=self.comment_email_blacklist, ):
labels_to_skip=self.labels_to_skip, document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
yield document_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}"
slim_doc_batch = []
for issue in _paginate_jql_search(
jira_client=self.jira_client,
jql=jql,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields="key",
):
issue_key = best_effort_get_field_from_issue(issue, "key")
id = build_jira_url(self.jira_client, issue_key)
slim_doc_batch.append(
SlimDocument(
id=id,
perm_sync_data=None,
)
) )
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
yield slim_doc_batch
slim_doc_batch = []
if doc_batch: yield slim_doc_batch
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,17 +1,136 @@
"""Module with custom fields processing functions""" """Module with custom fields processing functions"""
import os
from typing import Any from typing import Any
from typing import List from typing import List
from urllib.parse import urlparse
from jira import JIRA from jira import JIRA
from jira.resources import CustomFieldOption from jira.resources import CustomFieldOption
from jira.resources import Issue from jira.resources import Issue
from jira.resources import User from jira.resources import User
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None:
display_name = None
email = None
if hasattr(obj, "display_name"):
display_name = obj.display_name
else:
display_name = obj.get("displayName")
if hasattr(obj, "emailAddress"):
email = obj.emailAddress
else:
email = obj.get("emailAddress")
if not email and not display_name:
return None
return BasicExpertInfo(display_name=display_name, email=email)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
return f"{jira_client.client_info()}/browse/{issue_key}"
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:
email = credentials["jira_user_email"]
return JIRA(
basic_auth=(email, api_token),
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
return JIRA(
token_auth=api_token,
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
def extract_jira_project(url: str) -> tuple[str, str]:
parsed_url = urlparse(url)
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
# Split the path by '/' and find the position of 'projects' to get the project name
split_path = parsed_url.path.split("/")
if PROJECT_URL_PAT in split_path:
project_pos = split_path.index(PROJECT_URL_PAT)
if len(split_path) > project_pos + 1:
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
return jira_base, jira_project
def get_comment_strs(
issue: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in issue.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
class CustomFieldExtractor: class CustomFieldExtractor:
@staticmethod @staticmethod
def _process_custom_field_value(value: Any) -> str: def _process_custom_field_value(value: Any) -> str:

View File

@ -1,4 +1,3 @@
from collections.abc import Callable
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -18,49 +17,48 @@ def mock_jira_client() -> MagicMock:
@pytest.fixture @pytest.fixture
def mock_issue_small() -> MagicMock: def mock_issue_small() -> MagicMock:
issue = MagicMock() issue = MagicMock(spec=Issue)
issue.key = "SMALL-1" fields = MagicMock()
issue.fields.description = "Small description" fields.description = "Small description"
issue.fields.comment.comments = [ fields.comment = MagicMock()
fields.comment.comments = [
MagicMock(body="Small comment 1"), MagicMock(body="Small comment 1"),
MagicMock(body="Small comment 2"), MagicMock(body="Small comment 2"),
] ]
issue.fields.creator.displayName = "John Doe" fields.creator = MagicMock()
issue.fields.creator.emailAddress = "john@example.com" fields.creator.displayName = "John Doe"
issue.fields.summary = "Small Issue" fields.creator.emailAddress = "john@example.com"
issue.fields.updated = "2023-01-01T00:00:00+0000" fields.summary = "Small Issue"
issue.fields.labels = [] fields.updated = "2023-01-01T00:00:00+0000"
fields.labels = []
issue.fields = fields
issue.key = "SMALL-1"
return issue return issue
@pytest.fixture @pytest.fixture
def mock_issue_large() -> MagicMock: def mock_issue_large() -> MagicMock:
# This will be larger than 100KB issue = MagicMock(spec=Issue)
issue = MagicMock() fields = MagicMock()
issue.key = "LARGE-1" fields.description = "a" * 99_000
issue.fields.description = "a" * 99_000 fields.comment = MagicMock()
issue.fields.comment.comments = [ fields.comment.comments = [
MagicMock(body="Large comment " * 1000), MagicMock(body="Large comment " * 1000),
MagicMock(body="Another large comment " * 1000), MagicMock(body="Another large comment " * 1000),
] ]
issue.fields.creator.displayName = "Jane Doe" fields.creator = MagicMock()
issue.fields.creator.emailAddress = "jane@example.com" fields.creator.displayName = "Jane Doe"
issue.fields.summary = "Large Issue" fields.creator.emailAddress = "jane@example.com"
issue.fields.updated = "2023-01-02T00:00:00+0000" fields.summary = "Large Issue"
issue.fields.labels = [] fields.updated = "2023-01-02T00:00:00+0000"
fields.labels = []
issue.fields = fields
issue.key = "LARGE-1"
return issue return issue
@pytest.fixture
def patched_type() -> Callable[[Any], type]:
def _patched_type(obj: Any) -> type:
if isinstance(obj, MagicMock):
return Issue
return type(obj)
return _patched_type
@pytest.fixture @pytest.fixture
def mock_jira_api_version() -> Generator[Any, Any, Any]: def mock_jira_api_version() -> Generator[Any, Any, Any]:
with patch("danswer.connectors.danswer_jira.connector.JIRA_API_VERSION", "2"): with patch("danswer.connectors.danswer_jira.connector.JIRA_API_VERSION", "2"):
@ -69,11 +67,9 @@ def mock_jira_api_version() -> Generator[Any, Any, Any]:
@pytest.fixture @pytest.fixture
def patched_environment( def patched_environment(
patched_type: type,
mock_jira_api_version: MockFixture, mock_jira_api_version: MockFixture,
) -> Generator[Any, Any, Any]: ) -> Generator[Any, Any, Any]:
with patch("danswer.connectors.danswer_jira.connector.type", patched_type): yield
yield
def test_fetch_jira_issues_batch_small_ticket( def test_fetch_jira_issues_batch_small_ticket(
@ -83,9 +79,8 @@ def test_fetch_jira_issues_batch_small_ticket(
) -> None: ) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small] mock_jira_client.search_issues.return_value = [mock_issue_small]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client) docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
assert count == 1
assert len(docs) == 1 assert len(docs) == 1
assert docs[0].id.endswith("/SMALL-1") assert docs[0].id.endswith("/SMALL-1")
assert "Small description" in docs[0].sections[0].text assert "Small description" in docs[0].sections[0].text
@ -100,9 +95,8 @@ def test_fetch_jira_issues_batch_large_ticket(
) -> None: ) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_large] mock_jira_client.search_issues.return_value = [mock_issue_large]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client) docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
assert count == 1
assert len(docs) == 0 # The large ticket should be skipped assert len(docs) == 0 # The large ticket should be skipped
@ -114,9 +108,8 @@ def test_fetch_jira_issues_batch_mixed_tickets(
) -> None: ) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large] mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client) docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
assert count == 2
assert len(docs) == 1 # Only the small ticket should be included assert len(docs) == 1 # Only the small ticket should be included
assert docs[0].id.endswith("/SMALL-1") assert docs[0].id.endswith("/SMALL-1")
@ -130,7 +123,6 @@ def test_fetch_jira_issues_batch_custom_size_limit(
) -> None: ) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large] mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client) docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
assert count == 2
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit assert len(docs) == 0 # Both tickets should be skipped due to the low size limit