diff --git a/backend/onyx/connectors/github/connector.py b/backend/onyx/connectors/github/connector.py index 86cf0cd7d..74dc63f4d 100644 --- a/backend/onyx/connectors/github/connector.py +++ b/backend/onyx/connectors/github/connector.py @@ -127,7 +127,6 @@ class SerializedRepository(BaseModel): ) -# Repository.Repository(self.github_client._Github__requester, repos[0]._headers, repos[0].raw_data, completed=True) class GithubConnectorStage(Enum): START = "start" PRS = "prs" @@ -270,6 +269,8 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]): ) checkpoint.stage = GithubConnectorStage.PRS checkpoint.curr_page = 0 + # save checkpoint with repo ids retrieved + return checkpoint assert checkpoint.cached_repo is not None, "No repo saved in checkpoint" repo = checkpoint.cached_repo.to_Repository(self.github_client.requester) @@ -281,54 +282,51 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]): ) doc_batch: list[Document] = [] - while True: - pr_batch = _get_batch_rate_limited( - pull_requests, checkpoint.curr_page, self.github_client - ) - checkpoint.curr_page += 1 - done_with_prs = False - for pr in pr_batch: - # we iterate backwards in time, so at this point we stop processing prs - if ( - start is not None - and pr.updated_at - and pr.updated_at.replace(tzinfo=timezone.utc) < start - ): - yield from doc_batch - done_with_prs = True - break - # Skip PRs updated after the end date - if ( - end is not None - and pr.updated_at - and pr.updated_at.replace(tzinfo=timezone.utc) > end - ): - continue - try: - doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr))) - except Exception as e: - error_msg = f"Error converting PR to document: {e}" - logger.error(error_msg) - yield ConnectorFailure( - failed_document=DocumentFailure( - document_id=str(pr.id), document_link=pr.html_url - ), - failure_message=error_msg, - exception=e, - ) - continue - - # if we went past the start date during the loop or there are no more - # prs to get, we move on to issues - if done_with_prs or len(pr_batch) == 0: - checkpoint.stage = GithubConnectorStage.ISSUES - checkpoint.curr_page = 0 - break - - # if we found any PRs on the page, yield them and return the checkpoint - if len(doc_batch) > 0: + pr_batch = _get_batch_rate_limited( + pull_requests, checkpoint.curr_page, self.github_client + ) + checkpoint.curr_page += 1 + done_with_prs = False + for pr in pr_batch: + # we iterate backwards in time, so at this point we stop processing prs + if ( + start is not None + and pr.updated_at + and pr.updated_at.replace(tzinfo=timezone.utc) < start + ): yield from doc_batch - return checkpoint + done_with_prs = True + break + # Skip PRs updated after the end date + if ( + end is not None + and pr.updated_at + and pr.updated_at.replace(tzinfo=timezone.utc) > end + ): + continue + try: + doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr))) + except Exception as e: + error_msg = f"Error converting PR to document: {e}" + logger.exception(error_msg) + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=str(pr.id), document_link=pr.html_url + ), + failure_message=error_msg, + exception=e, + ) + continue + + # if we found any PRs on the page, yield any associated documents and return the checkpoint + if not done_with_prs and len(pr_batch) > 0: + yield from doc_batch + return checkpoint + + # if we went past the start date during the loop or there are no more + # prs to get, we move on to issues + checkpoint.stage = GithubConnectorStage.ISSUES + checkpoint.curr_page = 0 checkpoint.stage = GithubConnectorStage.ISSUES @@ -339,58 +337,55 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]): ) doc_batch = [] - while True: - issue_batch = _get_batch_rate_limited( - issues, checkpoint.curr_page, self.github_client - ) - checkpoint.curr_page += 1 - done_with_issues = False - for issue in cast(list[Issue], issue_batch): - # we iterate backwards in time, so at this point we stop processing prs - if ( - start is not None - and issue.updated_at.replace(tzinfo=timezone.utc) < start - ): - yield from doc_batch - done_with_issues = True - break - # Skip PRs updated after the end date - if ( - end is not None - and issue.updated_at.replace(tzinfo=timezone.utc) > end - ): - continue - - if issue.pull_request is not None: - # PRs are handled separately - continue - - try: - doc_batch.append(_convert_issue_to_document(issue)) - except Exception as e: - error_msg = f"Error converting issue to document: {e}" - logger.error(error_msg) - yield ConnectorFailure( - failed_document=DocumentFailure( - document_id=str(issue.id), - document_link=issue.html_url, - ), - failure_message=error_msg, - exception=e, - ) - continue - - # if we went past the start date during the loop or there are no more - # issues to get, we move on to the next repo - if done_with_issues or len(issue_batch) == 0: - checkpoint.stage = GithubConnectorStage.PRS - checkpoint.curr_page = 0 - break - - # if we found any issues on the page, yield them and return the checkpoint - if len(doc_batch) > 0: + issue_batch = _get_batch_rate_limited( + issues, checkpoint.curr_page, self.github_client + ) + checkpoint.curr_page += 1 + done_with_issues = False + for issue in cast(list[Issue], issue_batch): + # we iterate backwards in time, so at this point we stop processing prs + if ( + start is not None + and issue.updated_at.replace(tzinfo=timezone.utc) < start + ): yield from doc_batch - return checkpoint + done_with_issues = True + break + # Skip PRs updated after the end date + if ( + end is not None + and issue.updated_at.replace(tzinfo=timezone.utc) > end + ): + continue + + if issue.pull_request is not None: + # PRs are handled separately + continue + + try: + doc_batch.append(_convert_issue_to_document(issue)) + except Exception as e: + error_msg = f"Error converting issue to document: {e}" + logger.exception(error_msg) + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=str(issue.id), + document_link=issue.html_url, + ), + failure_message=error_msg, + exception=e, + ) + continue + + # if we found any issues on the page, yield them and return the checkpoint + if not done_with_issues and len(issue_batch) > 0: + yield from doc_batch + return checkpoint + + # if we went past the start date during the loop or there are no more + # issues to get, we move on to the next repo + checkpoint.stage = GithubConnectorStage.PRS + checkpoint.curr_page = 0 checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1 if checkpoint.cached_repo_ids: diff --git a/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py b/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py index 043813f45..5c797a5dc 100644 --- a/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py +++ b/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py @@ -177,10 +177,14 @@ def test_load_from_checkpoint_happy_path( ) # Check that we got all documents and final has_more=False - assert len(outputs) == 3 + assert len(outputs) == 4 + + repo_batch = outputs[0] + assert len(repo_batch.items) == 0 + assert repo_batch.next_checkpoint.has_more is True # Check first batch (PRs) - first_batch = outputs[0] + first_batch = outputs[1] assert len(first_batch.items) == 2 assert isinstance(first_batch.items[0], Document) assert first_batch.items[0].id == "https://github.com/test-org/test-repo/pull/1" @@ -189,7 +193,7 @@ def test_load_from_checkpoint_happy_path( assert first_batch.next_checkpoint.curr_page == 1 # Check second batch (Issues) - second_batch = outputs[1] + second_batch = outputs[2] assert len(second_batch.items) == 2 assert isinstance(second_batch.items[0], Document) assert ( @@ -202,7 +206,7 @@ def test_load_from_checkpoint_happy_path( assert second_batch.next_checkpoint.has_more # Check third batch (finished checkpoint) - third_batch = outputs[2] + third_batch = outputs[3] assert len(third_batch.items) == 0 assert third_batch.next_checkpoint.has_more is False @@ -249,10 +253,10 @@ def test_load_from_checkpoint_with_rate_limit( assert mock_sleep.call_count == 1 # Check that we got the document after rate limit was handled - assert len(outputs) >= 1 - assert len(outputs[0].items) == 1 - assert isinstance(outputs[0].items[0], Document) - assert outputs[0].items[0].id == "https://github.com/test-org/test-repo/pull/1" + assert len(outputs) >= 2 + assert len(outputs[1].items) == 1 + assert isinstance(outputs[1].items[0], Document) + assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1" assert outputs[-1].next_checkpoint.has_more is False @@ -283,9 +287,9 @@ def test_load_from_checkpoint_with_empty_repo( ) # Check that we got no documents - assert len(outputs) == 1 - assert len(outputs[0].items) == 0 - assert not outputs[0].next_checkpoint.has_more + assert len(outputs) == 2 + assert len(outputs[-1].items) == 0 + assert not outputs[-1].next_checkpoint.has_more def test_load_from_checkpoint_with_prs_only( @@ -324,8 +328,8 @@ def test_load_from_checkpoint_with_prs_only( ) # Check that we only got PRs - assert len(outputs) >= 1 - assert len(outputs[0].items) == 2 + assert len(outputs) >= 2 + assert len(outputs[1].items) == 2 assert all( isinstance(doc, Document) and "pull" in doc.id for doc in outputs[0].items ) # All documents should be PRs @@ -369,12 +373,12 @@ def test_load_from_checkpoint_with_issues_only( ) # Check that we only got issues - assert len(outputs) >= 1 - assert len(outputs[0].items) == 2 + assert len(outputs) >= 2 + assert len(outputs[1].items) == 2 assert all( isinstance(doc, Document) and "issues" in doc.id for doc in outputs[0].items ) # All documents should be issues - assert outputs[0].next_checkpoint.has_more + assert outputs[1].next_checkpoint.has_more assert outputs[-1].next_checkpoint.has_more is False