fixed tests

This commit is contained in:
Evan Lohn 2025-03-21 17:43:00 -07:00
parent 10810d4a20
commit 7dedaac090
2 changed files with 114 additions and 115 deletions

View File

@ -127,7 +127,6 @@ class SerializedRepository(BaseModel):
)
# Repository.Repository(self.github_client._Github__requester, repos[0]._headers, repos[0].raw_data, completed=True)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
@ -270,6 +269,8 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
# save checkpoint with repo ids retrieved
return checkpoint
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
@ -281,54 +282,51 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
)
doc_batch: list[Document] = []
while True:
pr_batch = _get_batch_rate_limited(
pull_requests, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_prs = False
for pr in pr_batch:
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.error(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
if done_with_prs or len(pr_batch) == 0:
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.curr_page = 0
break
# if we found any PRs on the page, yield them and return the checkpoint
if len(doc_batch) > 0:
pr_batch = _get_batch_rate_limited(
pull_requests, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_prs = False
for pr in pr_batch:
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
return checkpoint
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
# if we found any PRs on the page, yield any associated documents and return the checkpoint
if not done_with_prs and len(pr_batch) > 0:
yield from doc_batch
return checkpoint
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.curr_page = 0
checkpoint.stage = GithubConnectorStage.ISSUES
@ -339,58 +337,55 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
)
doc_batch = []
while True:
issue_batch = _get_batch_rate_limited(
issues, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_issues = False
for issue in cast(list[Issue], issue_batch):
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
doc_batch.append(_convert_issue_to_document(issue))
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.error(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
if done_with_issues or len(issue_batch) == 0:
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
break
# if we found any issues on the page, yield them and return the checkpoint
if len(doc_batch) > 0:
issue_batch = _get_batch_rate_limited(
issues, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_issues = False
for issue in cast(list[Issue], issue_batch):
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
return checkpoint
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
doc_batch.append(_convert_issue_to_document(issue))
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
# if we found any issues on the page, yield them and return the checkpoint
if not done_with_issues and len(issue_batch) > 0:
yield from doc_batch
return checkpoint
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
if checkpoint.cached_repo_ids:

View File

@ -177,10 +177,14 @@ def test_load_from_checkpoint_happy_path(
)
# Check that we got all documents and final has_more=False
assert len(outputs) == 3
assert len(outputs) == 4
repo_batch = outputs[0]
assert len(repo_batch.items) == 0
assert repo_batch.next_checkpoint.has_more is True
# Check first batch (PRs)
first_batch = outputs[0]
first_batch = outputs[1]
assert len(first_batch.items) == 2
assert isinstance(first_batch.items[0], Document)
assert first_batch.items[0].id == "https://github.com/test-org/test-repo/pull/1"
@ -189,7 +193,7 @@ def test_load_from_checkpoint_happy_path(
assert first_batch.next_checkpoint.curr_page == 1
# Check second batch (Issues)
second_batch = outputs[1]
second_batch = outputs[2]
assert len(second_batch.items) == 2
assert isinstance(second_batch.items[0], Document)
assert (
@ -202,7 +206,7 @@ def test_load_from_checkpoint_happy_path(
assert second_batch.next_checkpoint.has_more
# Check third batch (finished checkpoint)
third_batch = outputs[2]
third_batch = outputs[3]
assert len(third_batch.items) == 0
assert third_batch.next_checkpoint.has_more is False
@ -249,10 +253,10 @@ def test_load_from_checkpoint_with_rate_limit(
assert mock_sleep.call_count == 1
# Check that we got the document after rate limit was handled
assert len(outputs) >= 1
assert len(outputs[0].items) == 1
assert isinstance(outputs[0].items[0], Document)
assert outputs[0].items[0].id == "https://github.com/test-org/test-repo/pull/1"
assert len(outputs) >= 2
assert len(outputs[1].items) == 1
assert isinstance(outputs[1].items[0], Document)
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
assert outputs[-1].next_checkpoint.has_more is False
@ -283,9 +287,9 @@ def test_load_from_checkpoint_with_empty_repo(
)
# Check that we got no documents
assert len(outputs) == 1
assert len(outputs[0].items) == 0
assert not outputs[0].next_checkpoint.has_more
assert len(outputs) == 2
assert len(outputs[-1].items) == 0
assert not outputs[-1].next_checkpoint.has_more
def test_load_from_checkpoint_with_prs_only(
@ -324,8 +328,8 @@ def test_load_from_checkpoint_with_prs_only(
)
# Check that we only got PRs
assert len(outputs) >= 1
assert len(outputs[0].items) == 2
assert len(outputs) >= 2
assert len(outputs[1].items) == 2
assert all(
isinstance(doc, Document) and "pull" in doc.id for doc in outputs[0].items
) # All documents should be PRs
@ -369,12 +373,12 @@ def test_load_from_checkpoint_with_issues_only(
)
# Check that we only got issues
assert len(outputs) >= 1
assert len(outputs[0].items) == 2
assert len(outputs) >= 2
assert len(outputs[1].items) == 2
assert all(
isinstance(doc, Document) and "issues" in doc.id for doc in outputs[0].items
) # All documents should be issues
assert outputs[0].next_checkpoint.has_more
assert outputs[1].next_checkpoint.has_more
assert outputs[-1].next_checkpoint.has_more is False