From 6436b60763c2700ca11748ac63679bbcdd48e800 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Wed, 30 Apr 2025 19:09:20 -0700 Subject: [PATCH] github cursor pagination (#4642) * v1 of cursor pagination * mypy * unit tests * CW comments --- backend/onyx/connectors/github/connector.py | 240 +++++++++++++++--- .../github/test_github_checkpointing.py | 192 +++++++++++++- 2 files changed, 393 insertions(+), 39 deletions(-) diff --git a/backend/onyx/connectors/github/connector.py b/backend/onyx/connectors/github/connector.py index 504f9281f8a..cf83002ae0f 100644 --- a/backend/onyx/connectors/github/connector.py +++ b/backend/onyx/connectors/github/connector.py @@ -1,5 +1,6 @@ import copy import time +from collections.abc import Callable from collections.abc import Generator from datetime import datetime from datetime import timedelta @@ -39,6 +40,7 @@ from onyx.utils.logger import setup_logger logger = setup_logger() ITEMS_PER_PAGE = 100 +CURSOR_LOG_FREQUENCY = 100 _MAX_NUM_RATE_LIMIT_RETRIES = 5 @@ -52,26 +54,138 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None: time.sleep(sleep_time.seconds) +# Cases +# X (from start) standard run, no fallback to cursor-based pagination +# X (from start) standard run errors, fallback to cursor-based pagination +# X error in the middle of a page +# X no errors: run to completion +# X (from checkpoint) standard run, no fallback to cursor-based pagination +# X (from checkpoint) continue from cursor-based pagination +# - retrying +# - no retrying + +# things to check: +# checkpoint state on return +# checkpoint progress (no infinite loop) + + +def _paginate_until_error( + git_objs: Callable[[], PaginatedList[PullRequest | Issue]], + cursor_url: str | None, + prev_num_objs: int, + cursor_url_callback: Callable[[str | None, int], None], + retrying: bool = False, +) -> Generator[PullRequest | Issue, None, None]: + num_objs = prev_num_objs + pag_list = git_objs() + if cursor_url: + pag_list.__nextUrl = cursor_url + elif retrying: + # if we are retrying, we want to skip the objects retrieved + # over previous calls. Unfortunately, this WILL retrieve all + # pages before the one we are resuming from, so we really + # don't want this case to be hit often + logger.warning( + "Retrying from a previous cursor-based pagination call. " + "This will retrieve all pages before the one we are resuming from, " + "which may take a while and consume many API calls." + ) + pag_list = pag_list[prev_num_objs:] + num_objs = 0 + + try: + # this for loop handles cursor-based pagination + for issue_or_pr in pag_list: + num_objs += 1 + yield issue_or_pr + # used to store the current cursor url in the checkpoint. This value + # is updated during iteration over pag_list. + cursor_url_callback(pag_list.__nextUrl, num_objs) + + if num_objs % CURSOR_LOG_FREQUENCY == 0: + logger.info( + f"Retrieved {num_objs} objects with current cursor url: {pag_list.__nextUrl}" + ) + + except Exception as e: + logger.exception(f"Error during cursor-based pagination: {e}") + if num_objs - prev_num_objs > 0: + raise + + if pag_list.__nextUrl is not None and not retrying: + logger.info( + "Assuming that this error is due to cursor " + "expiration because no objects were retrieved. " + "Retrying from the first page." + ) + yield from _paginate_until_error( + git_objs, None, prev_num_objs, cursor_url_callback, retrying=True + ) + return + + # for no cursor url or if we reach this point after a retry, raise the error + raise + + def _get_batch_rate_limited( - git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0 -) -> list[PullRequest | Issue]: + # We pass in a callable because we want git_objs to produce a fresh + # PaginatedList each time it's called to avoid using the same object for cursor-based pagination + # from a partial offset-based pagination call. + git_objs: Callable[[], PaginatedList], + page_num: int, + cursor_url: str | None, + prev_num_objs: int, + cursor_url_callback: Callable[[str | None, int], None], + github_client: Github, + attempt_num: int = 0, +) -> Generator[PullRequest | Issue, None, None]: if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: raise RuntimeError( "Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github" ) - try: - objs = list(git_objs.get_page(page_num)) + if cursor_url: + # when this is set, we are resuming from an earlier + # cursor-based pagination call. + yield from _paginate_until_error( + git_objs, cursor_url, prev_num_objs, cursor_url_callback + ) + return + objs = list(git_objs().get_page(page_num)) # fetch all data here to disable lazy loading later # this is needed to capture the rate limit exception here (if one occurs) for obj in objs: if hasattr(obj, "raw_data"): getattr(obj, "raw_data") - return objs + yield from objs except RateLimitExceededException: _sleep_after_rate_limit_exception(github_client) - return _get_batch_rate_limited( - git_objs, page_num, github_client, attempt_num + 1 + yield from _get_batch_rate_limited( + git_objs, + page_num, + cursor_url, + prev_num_objs, + cursor_url_callback, + github_client, + attempt_num + 1, + ) + except GithubException as e: + if not ( + e.status == 422 + and ( + "cursor" in (e.message or "") + or "cursor" in (e.data or {}).get("message", "") + ) + ): + raise + # Fallback to a cursor-based pagination strategy + # This can happen for "large datasets," but there's no documentation + # On the error on the web as far as we can tell. + # Error message: + # "Pagination with the page parameter is not supported for large datasets, + # please use cursor based pagination (after/before)" + yield from _paginate_until_error( + git_objs, cursor_url, prev_num_objs, cursor_url_callback ) @@ -142,6 +256,18 @@ class GithubConnectorCheckpoint(ConnectorCheckpoint): cached_repo_ids: list[int] | None = None cached_repo: SerializedRepository | None = None + # Used for the fallback cursor-based pagination strategy + num_retrieved: int + cursor_url: str | None = None + + def reset(self) -> None: + """ + Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None) + """ + self.curr_page = 0 + self.num_retrieved = 0 + self.cursor_url = None + class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]): def __init__( @@ -230,6 +356,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]): try: org = github_client.get_organization(self.repo_owner) return list(org.get_repos()) + except GithubException: # If not an org, try as a user user = github_client.get_user(self.repo_owner) @@ -266,11 +393,12 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]): checkpoint.has_more = False return checkpoint - checkpoint.cached_repo_ids = sorted([repo.id for repo in repos]) + curr_repo = repos.pop() + checkpoint.cached_repo_ids = [repo.id for repo in repos] checkpoint.cached_repo = SerializedRepository( - id=checkpoint.cached_repo_ids[0], - headers=repos[0].raw_headers, - raw_data=repos[0].raw_data, + id=curr_repo.id, + headers=curr_repo.raw_headers, + raw_data=curr_repo.raw_data, ) checkpoint.stage = GithubConnectorStage.PRS checkpoint.curr_page = 0 @@ -299,26 +427,41 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]): repo_id = checkpoint.cached_repo.id repo = self.github_client.get_repo(repo_id) + def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None: + checkpoint.cursor_url = cursor_url + checkpoint.num_retrieved = num_objs + + # TODO: all PRs are also issues, so we should be able to _only_ get issues + # and then filter appropriately whenever include_issues is True if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS: logger.info(f"Fetching PRs for repo: {repo.name}") - pull_requests = repo.get_pulls( - state=self.state_filter, sort="updated", direction="desc" - ) - doc_batch: list[Document] = [] + def pull_requests_func() -> PaginatedList[PullRequest]: + return repo.get_pulls( + state=self.state_filter, sort="updated", direction="desc" + ) + pr_batch = _get_batch_rate_limited( - pull_requests, checkpoint.curr_page, self.github_client + pull_requests_func, + checkpoint.curr_page, + checkpoint.cursor_url, + checkpoint.num_retrieved, + cursor_url_callback, + self.github_client, ) - checkpoint.curr_page += 1 + checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback done_with_prs = False + num_prs = 0 + pr = None for pr in pr_batch: + num_prs += 1 + # we iterate backwards in time, so at this point we stop processing prs if ( start is not None and pr.updated_at and pr.updated_at.replace(tzinfo=timezone.utc) < start ): - yield from doc_batch done_with_prs = True break # Skip PRs updated after the end date @@ -329,7 +472,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]): ): continue try: - doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr))) + yield _convert_pr_to_document(cast(PullRequest, pr)) except Exception as e: error_msg = f"Error converting PR to document: {e}" logger.exception(error_msg) @@ -342,37 +485,57 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]): ) continue - # if we found any PRs on the page, yield any associated documents and return the checkpoint - if not done_with_prs and len(pr_batch) > 0: - yield from doc_batch + # If we reach this point with a cursor url in the checkpoint, we were using + # the fallback cursor-based pagination strategy. That strategy tries to get all + # PRs, so having curosr_url set means we are done with prs. However, we need to + # return AFTER the checkpoint reset to avoid infinite loops. + + # if we found any PRs on the page and there are more PRs to get, return the checkpoint. + # In offset mode, while indexing without time constraints, the pr batch + # will be empty when we're done. + if num_prs > 0 and not done_with_prs and not checkpoint.cursor_url: return checkpoint # if we went past the start date during the loop or there are no more # prs to get, we move on to issues checkpoint.stage = GithubConnectorStage.ISSUES - checkpoint.curr_page = 0 + checkpoint.reset() + + if checkpoint.cursor_url: + # save the checkpoint after changing stage; next run will continue from issues + return checkpoint checkpoint.stage = GithubConnectorStage.ISSUES if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES: logger.info(f"Fetching issues for repo: {repo.name}") - issues = repo.get_issues( - state=self.state_filter, sort="updated", direction="desc" - ) - doc_batch = [] - issue_batch = _get_batch_rate_limited( - issues, checkpoint.curr_page, self.github_client + def issues_func() -> PaginatedList[Issue]: + return repo.get_issues( + state=self.state_filter, sort="updated", direction="desc" + ) + + issue_batch = list( + _get_batch_rate_limited( + issues_func, + checkpoint.curr_page, + checkpoint.cursor_url, + checkpoint.num_retrieved, + cursor_url_callback, + self.github_client, + ) ) checkpoint.curr_page += 1 done_with_issues = False - for issue in cast(list[Issue], issue_batch): + num_issues = 0 + for issue in issue_batch: + num_issues += 1 + issue = cast(Issue, issue) # we iterate backwards in time, so at this point we stop processing prs if ( start is not None and issue.updated_at.replace(tzinfo=timezone.utc) < start ): - yield from doc_batch done_with_issues = True break # Skip PRs updated after the end date @@ -384,10 +547,11 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]): if issue.pull_request is not None: # PRs are handled separately + # TODO: but they shouldn't always be continue try: - doc_batch.append(_convert_issue_to_document(issue)) + yield _convert_issue_to_document(issue) except Exception as e: error_msg = f"Error converting issue to document: {e}" logger.exception(error_msg) @@ -401,17 +565,17 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]): ) continue - # if we found any issues on the page, yield them and return the checkpoint - if not done_with_issues and len(issue_batch) > 0: - yield from doc_batch + # if we found any issues on the page, and we're not done, return the checkpoint. + # don't return if we're using cursor-based pagination to avoid infinite loops + if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url: return checkpoint # if we went past the start date during the loop or there are no more # issues to get, we move on to the next repo checkpoint.stage = GithubConnectorStage.PRS - checkpoint.curr_page = 0 + checkpoint.reset() - checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1 + checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0 if checkpoint.cached_repo_ids: next_id = checkpoint.cached_repo_ids.pop() next_repo = self.github_client.get_repo(next_id) @@ -553,7 +717,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]): def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint: return GithubConnectorCheckpoint( - stage=GithubConnectorStage.PRS, curr_page=0, has_more=True + stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0 ) diff --git a/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py b/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py index 5c797a5dc33..8863d58ee20 100644 --- a/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py +++ b/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py @@ -9,8 +9,8 @@ from unittest.mock import patch import pytest from github import Github -from github import GithubException from github import RateLimitExceededException +from github.GithubException import GithubException from github.Issue import Issue from github.PullRequest import PullRequest from github.RateLimit import RateLimit @@ -24,6 +24,9 @@ from onyx.connectors.github.connector import GithubConnector from onyx.connectors.github.connector import SerializedRepository from onyx.connectors.models import Document from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector +from tests.unit.onyx.connectors.utils import ( + load_everything_from_checkpoint_connector_from_checkpoint, +) @pytest.fixture @@ -439,3 +442,190 @@ def test_validate_connector_settings_success( github_connector.github_client.get_repo.assert_called_once_with( f"{github_connector.repo_owner}/{github_connector.repositories}" ) + + +def test_load_from_checkpoint_with_cursor_fallback( + github_connector: GithubConnector, + mock_github_client: MagicMock, + create_mock_repo: Callable[..., MagicMock], + create_mock_pr: Callable[..., MagicMock], +) -> None: + """Test loading from checkpoint with fallback to cursor-based pagination""" + # Set up mocked repo + mock_repo = create_mock_repo() + github_connector.github_client = mock_github_client + mock_github_client.get_repo.return_value = mock_repo + + # Set up mocked PRs + mock_pr1 = create_mock_pr(number=1, title="PR 1") + mock_pr2 = create_mock_pr(number=2, title="PR 2") + + # Create a mock paginated list that will raise the 422 error on get_page + mock_paginated_list = MagicMock() + mock_paginated_list.get_page.side_effect = [ + GithubException( + 422, + { + "message": "Pagination with the page parameter is not supported for large datasets. Use cursor" + }, + {}, + ), + ] + + # Create a new mock for cursor-based pagination + mock_cursor_paginated_list = MagicMock() + mock_cursor_paginated_list.__nextUrl = ( + "https://api.github.com/repos/test-org/test-repo/pulls?cursor=abc123" + ) + mock_cursor_paginated_list.__iter__.return_value = iter([mock_pr1, mock_pr2]) + + mock_repo.get_pulls.side_effect = [ + mock_paginated_list, + mock_cursor_paginated_list, + ] + + # Mock SerializedRepository.to_Repository to return our mock repo + with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector( + github_connector, 0, end_time + ) + + # Check that we got the documents via cursor-based pagination + assert len(outputs) >= 2 + assert len(outputs[1].items) == 2 + assert isinstance(outputs[1].items[0], Document) + assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1" + assert isinstance(outputs[1].items[1], Document) + assert outputs[1].items[1].id == "https://github.com/test-org/test-repo/pull/2" + + # Verify cursor URL is not set in checkpoint since pagination succeeded without failures + assert outputs[1].next_checkpoint.cursor_url is None + assert outputs[1].next_checkpoint.num_retrieved == 0 + + +def test_load_from_checkpoint_resume_cursor_pagination( + github_connector: GithubConnector, + mock_github_client: MagicMock, + create_mock_repo: Callable[..., MagicMock], + create_mock_pr: Callable[..., MagicMock], +) -> None: + """Test resuming from a checkpoint that was using cursor-based pagination""" + # Set up mocked repo + mock_repo = create_mock_repo() + github_connector.github_client = mock_github_client + mock_github_client.get_repo.return_value = mock_repo + + # Set up mocked PRs + mock_pr3 = create_mock_pr(number=3, title="PR 3") + mock_pr4 = create_mock_pr(number=4, title="PR 4") + + # Create a checkpoint that was using cursor-based pagination + checkpoint = github_connector.build_dummy_checkpoint() + checkpoint.cursor_url = ( + "https://api.github.com/repos/test-org/test-repo/pulls?cursor=abc123" + ) + checkpoint.num_retrieved = 2 + + # Mock get_pulls to use cursor-based pagination + mock_paginated_list = MagicMock() + mock_paginated_list.__nextUrl = ( + "https://api.github.com/repos/test-org/test-repo/pulls?cursor=def456" + ) + mock_paginated_list.__iter__.return_value = iter([mock_pr3, mock_pr4]) + mock_repo.get_pulls.return_value = mock_paginated_list + + # Mock SerializedRepository.to_Repository to return our mock repo + with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): + # Call load_from_checkpoint with the checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector_from_checkpoint( + github_connector, 0, end_time, checkpoint + ) + + # Check that we got the documents via cursor-based pagination + assert len(outputs) >= 2 + assert len(outputs[1].items) == 2 + assert isinstance(outputs[1].items[0], Document) + assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/3" + assert isinstance(outputs[1].items[1], Document) + assert outputs[1].items[1].id == "https://github.com/test-org/test-repo/pull/4" + + # Verify cursor URL was stored in checkpoint + assert outputs[1].next_checkpoint.cursor_url is None + assert outputs[1].next_checkpoint.num_retrieved == 0 + + +def test_load_from_checkpoint_cursor_expiration( + github_connector: GithubConnector, + mock_github_client: MagicMock, + create_mock_repo: Callable[..., MagicMock], + create_mock_pr: Callable[..., MagicMock], +) -> None: + """Test handling of cursor expiration during cursor-based pagination""" + # Set up mocked repo + mock_repo = create_mock_repo() + github_connector.github_client = mock_github_client + mock_github_client.get_repo.return_value = mock_repo + + # Set up mocked PRs + mock_pr4 = create_mock_pr(number=4, title="PR 4") + + # Create a checkpoint with an expired cursor + checkpoint = github_connector.build_dummy_checkpoint() + checkpoint.cursor_url = ( + "https://api.github.com/repos/test-org/test-repo/pulls?cursor=expired" + ) + checkpoint.num_retrieved = 3 # We've already retrieved 3 items + + # Mock get_pulls to simulate cursor expiration by raising an error before any results + mock_paginated_list = MagicMock() + mock_paginated_list.__nextUrl = ( + "https://api.github.com/repos/test-org/test-repo/pulls?cursor=expired" + ) + mock_paginated_list.__iter__.side_effect = GithubException( + 422, {"message": "Cursor expired"}, {} + ) + + # Create a new mock for successful retrieval after retry + mock_retry_paginated_list = MagicMock() + mock_retry_paginated_list.__nextUrl = None + + # Create an iterator that will yield the remaining PR + def retry_iterator() -> Generator[PullRequest, None, None]: + yield mock_pr4 + + # Create a mock for the _Slice object that will be returned by pag_list[prev_num_objs:] + mock_slice = MagicMock() + mock_slice.__iter__.return_value = retry_iterator() + + # Set up the slice behavior for the retry paginated list + mock_retry_paginated_list.__getitem__.return_value = mock_slice + + # Set up the side effect for get_pulls to return our mocks + mock_repo.get_pulls.side_effect = [ + mock_paginated_list, + mock_retry_paginated_list, + ] + + # Mock SerializedRepository.to_Repository to return our mock repo + with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): + # Call load_from_checkpoint with the checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector_from_checkpoint( + github_connector, 0, end_time, checkpoint + ) + + # Check that we got the remaining document after retrying from the beginning + assert len(outputs) >= 2 + assert len(outputs[1].items) == 1 + assert isinstance(outputs[1].items[0], Document) + assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/4" + + # Verify cursor URL was cleared in checkpoint + assert outputs[1].next_checkpoint.cursor_url is None + assert outputs[1].next_checkpoint.num_retrieved == 0 + + # Verify that the slice was called with the correct argument + mock_retry_paginated_list.__getitem__.assert_called_once_with(slice(3, None))