github cursor pagination (#4642)

* v1 of cursor pagination

* mypy

* unit tests

* CW comments
This commit is contained in:
Evan Lohn 2025-04-30 19:09:20 -07:00 committed by GitHub
parent a6cc1c84dc
commit 6436b60763
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 393 additions and 39 deletions

View File

@ -1,5 +1,6 @@
import copy
import time
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
@ -39,6 +40,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
ITEMS_PER_PAGE = 100
CURSOR_LOG_FREQUENCY = 100
_MAX_NUM_RATE_LIMIT_RETRIES = 5
@ -52,26 +54,138 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
time.sleep(sleep_time.seconds)
# Cases
# X (from start) standard run, no fallback to cursor-based pagination
# X (from start) standard run errors, fallback to cursor-based pagination
# X error in the middle of a page
# X no errors: run to completion
# X (from checkpoint) standard run, no fallback to cursor-based pagination
# X (from checkpoint) continue from cursor-based pagination
# - retrying
# - no retrying
# things to check:
# checkpoint state on return
# checkpoint progress (no infinite loop)
def _paginate_until_error(
git_objs: Callable[[], PaginatedList[PullRequest | Issue]],
cursor_url: str | None,
prev_num_objs: int,
cursor_url_callback: Callable[[str | None, int], None],
retrying: bool = False,
) -> Generator[PullRequest | Issue, None, None]:
num_objs = prev_num_objs
pag_list = git_objs()
if cursor_url:
pag_list.__nextUrl = cursor_url
elif retrying:
# if we are retrying, we want to skip the objects retrieved
# over previous calls. Unfortunately, this WILL retrieve all
# pages before the one we are resuming from, so we really
# don't want this case to be hit often
logger.warning(
"Retrying from a previous cursor-based pagination call. "
"This will retrieve all pages before the one we are resuming from, "
"which may take a while and consume many API calls."
)
pag_list = pag_list[prev_num_objs:]
num_objs = 0
try:
# this for loop handles cursor-based pagination
for issue_or_pr in pag_list:
num_objs += 1
yield issue_or_pr
# used to store the current cursor url in the checkpoint. This value
# is updated during iteration over pag_list.
cursor_url_callback(pag_list.__nextUrl, num_objs)
if num_objs % CURSOR_LOG_FREQUENCY == 0:
logger.info(
f"Retrieved {num_objs} objects with current cursor url: {pag_list.__nextUrl}"
)
except Exception as e:
logger.exception(f"Error during cursor-based pagination: {e}")
if num_objs - prev_num_objs > 0:
raise
if pag_list.__nextUrl is not None and not retrying:
logger.info(
"Assuming that this error is due to cursor "
"expiration because no objects were retrieved. "
"Retrying from the first page."
)
yield from _paginate_until_error(
git_objs, None, prev_num_objs, cursor_url_callback, retrying=True
)
return
# for no cursor url or if we reach this point after a retry, raise the error
raise
def _get_batch_rate_limited(
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
) -> list[PullRequest | Issue]:
# We pass in a callable because we want git_objs to produce a fresh
# PaginatedList each time it's called to avoid using the same object for cursor-based pagination
# from a partial offset-based pagination call.
git_objs: Callable[[], PaginatedList],
page_num: int,
cursor_url: str | None,
prev_num_objs: int,
cursor_url_callback: Callable[[str | None, int], None],
github_client: Github,
attempt_num: int = 0,
) -> Generator[PullRequest | Issue, None, None]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
)
try:
objs = list(git_objs.get_page(page_num))
if cursor_url:
# when this is set, we are resuming from an earlier
# cursor-based pagination call.
yield from _paginate_until_error(
git_objs, cursor_url, prev_num_objs, cursor_url_callback
)
return
objs = list(git_objs().get_page(page_num))
# fetch all data here to disable lazy loading later
# this is needed to capture the rate limit exception here (if one occurs)
for obj in objs:
if hasattr(obj, "raw_data"):
getattr(obj, "raw_data")
return objs
yield from objs
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return _get_batch_rate_limited(
git_objs, page_num, github_client, attempt_num + 1
yield from _get_batch_rate_limited(
git_objs,
page_num,
cursor_url,
prev_num_objs,
cursor_url_callback,
github_client,
attempt_num + 1,
)
except GithubException as e:
if not (
e.status == 422
and (
"cursor" in (e.message or "")
or "cursor" in (e.data or {}).get("message", "")
)
):
raise
# Fallback to a cursor-based pagination strategy
# This can happen for "large datasets," but there's no documentation
# On the error on the web as far as we can tell.
# Error message:
# "Pagination with the page parameter is not supported for large datasets,
# please use cursor based pagination (after/before)"
yield from _paginate_until_error(
git_objs, cursor_url, prev_num_objs, cursor_url_callback
)
@ -142,6 +256,18 @@ class GithubConnectorCheckpoint(ConnectorCheckpoint):
cached_repo_ids: list[int] | None = None
cached_repo: SerializedRepository | None = None
# Used for the fallback cursor-based pagination strategy
num_retrieved: int
cursor_url: str | None = None
def reset(self) -> None:
"""
Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None)
"""
self.curr_page = 0
self.num_retrieved = 0
self.cursor_url = None
class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
def __init__(
@ -230,6 +356,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
try:
org = github_client.get_organization(self.repo_owner)
return list(org.get_repos())
except GithubException:
# If not an org, try as a user
user = github_client.get_user(self.repo_owner)
@ -266,11 +393,12 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
checkpoint.has_more = False
return checkpoint
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
curr_repo = repos.pop()
checkpoint.cached_repo_ids = [repo.id for repo in repos]
checkpoint.cached_repo = SerializedRepository(
id=checkpoint.cached_repo_ids[0],
headers=repos[0].raw_headers,
raw_data=repos[0].raw_data,
id=curr_repo.id,
headers=curr_repo.raw_headers,
raw_data=curr_repo.raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
@ -299,26 +427,41 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
repo_id = checkpoint.cached_repo.id
repo = self.github_client.get_repo(repo_id)
def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None:
checkpoint.cursor_url = cursor_url
checkpoint.num_retrieved = num_objs
# TODO: all PRs are also issues, so we should be able to _only_ get issues
# and then filter appropriately whenever include_issues is True
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch: list[Document] = []
def pull_requests_func() -> PaginatedList[PullRequest]:
return repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
pr_batch = _get_batch_rate_limited(
pull_requests, checkpoint.curr_page, self.github_client
pull_requests_func,
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
)
checkpoint.curr_page += 1
checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback
done_with_prs = False
num_prs = 0
pr = None
for pr in pr_batch:
num_prs += 1
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
done_with_prs = True
break
# Skip PRs updated after the end date
@ -329,7 +472,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
):
continue
try:
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield _convert_pr_to_document(cast(PullRequest, pr))
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
@ -342,37 +485,57 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
)
continue
# if we found any PRs on the page, yield any associated documents and return the checkpoint
if not done_with_prs and len(pr_batch) > 0:
yield from doc_batch
# If we reach this point with a cursor url in the checkpoint, we were using
# the fallback cursor-based pagination strategy. That strategy tries to get all
# PRs, so having curosr_url set means we are done with prs. However, we need to
# return AFTER the checkpoint reset to avoid infinite loops.
# if we found any PRs on the page and there are more PRs to get, return the checkpoint.
# In offset mode, while indexing without time constraints, the pr batch
# will be empty when we're done.
if num_prs > 0 and not done_with_prs and not checkpoint.cursor_url:
return checkpoint
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.curr_page = 0
checkpoint.reset()
if checkpoint.cursor_url:
# save the checkpoint after changing stage; next run will continue from issues
return checkpoint
checkpoint.stage = GithubConnectorStage.ISSUES
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch = []
issue_batch = _get_batch_rate_limited(
issues, checkpoint.curr_page, self.github_client
def issues_func() -> PaginatedList[Issue]:
return repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
issue_batch = list(
_get_batch_rate_limited(
issues_func,
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
)
)
checkpoint.curr_page += 1
done_with_issues = False
for issue in cast(list[Issue], issue_batch):
num_issues = 0
for issue in issue_batch:
num_issues += 1
issue = cast(Issue, issue)
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
done_with_issues = True
break
# Skip PRs updated after the end date
@ -384,10 +547,11 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
if issue.pull_request is not None:
# PRs are handled separately
# TODO: but they shouldn't always be
continue
try:
doc_batch.append(_convert_issue_to_document(issue))
yield _convert_issue_to_document(issue)
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
@ -401,17 +565,17 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
)
continue
# if we found any issues on the page, yield them and return the checkpoint
if not done_with_issues and len(issue_batch) > 0:
yield from doc_batch
# if we found any issues on the page, and we're not done, return the checkpoint.
# don't return if we're using cursor-based pagination to avoid infinite loops
if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url:
return checkpoint
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
checkpoint.reset()
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0
if checkpoint.cached_repo_ids:
next_id = checkpoint.cached_repo_ids.pop()
next_repo = self.github_client.get_repo(next_id)
@ -553,7 +717,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint(
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0
)

View File

@ -9,8 +9,8 @@ from unittest.mock import patch
import pytest
from github import Github
from github import GithubException
from github import RateLimitExceededException
from github.GithubException import GithubException
from github.Issue import Issue
from github.PullRequest import PullRequest
from github.RateLimit import RateLimit
@ -24,6 +24,9 @@ from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.github.connector import SerializedRepository
from onyx.connectors.models import Document
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
from tests.unit.onyx.connectors.utils import (
load_everything_from_checkpoint_connector_from_checkpoint,
)
@pytest.fixture
@ -439,3 +442,190 @@ def test_validate_connector_settings_success(
github_connector.github_client.get_repo.assert_called_once_with(
f"{github_connector.repo_owner}/{github_connector.repositories}"
)
def test_load_from_checkpoint_with_cursor_fallback(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint with fallback to cursor-based pagination"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked PRs
mock_pr1 = create_mock_pr(number=1, title="PR 1")
mock_pr2 = create_mock_pr(number=2, title="PR 2")
# Create a mock paginated list that will raise the 422 error on get_page
mock_paginated_list = MagicMock()
mock_paginated_list.get_page.side_effect = [
GithubException(
422,
{
"message": "Pagination with the page parameter is not supported for large datasets. Use cursor"
},
{},
),
]
# Create a new mock for cursor-based pagination
mock_cursor_paginated_list = MagicMock()
mock_cursor_paginated_list.__nextUrl = (
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=abc123"
)
mock_cursor_paginated_list.__iter__.return_value = iter([mock_pr1, mock_pr2])
mock_repo.get_pulls.side_effect = [
mock_paginated_list,
mock_cursor_paginated_list,
]
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
# Check that we got the documents via cursor-based pagination
assert len(outputs) >= 2
assert len(outputs[1].items) == 2
assert isinstance(outputs[1].items[0], Document)
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
assert isinstance(outputs[1].items[1], Document)
assert outputs[1].items[1].id == "https://github.com/test-org/test-repo/pull/2"
# Verify cursor URL is not set in checkpoint since pagination succeeded without failures
assert outputs[1].next_checkpoint.cursor_url is None
assert outputs[1].next_checkpoint.num_retrieved == 0
def test_load_from_checkpoint_resume_cursor_pagination(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
) -> None:
"""Test resuming from a checkpoint that was using cursor-based pagination"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked PRs
mock_pr3 = create_mock_pr(number=3, title="PR 3")
mock_pr4 = create_mock_pr(number=4, title="PR 4")
# Create a checkpoint that was using cursor-based pagination
checkpoint = github_connector.build_dummy_checkpoint()
checkpoint.cursor_url = (
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=abc123"
)
checkpoint.num_retrieved = 2
# Mock get_pulls to use cursor-based pagination
mock_paginated_list = MagicMock()
mock_paginated_list.__nextUrl = (
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=def456"
)
mock_paginated_list.__iter__.return_value = iter([mock_pr3, mock_pr4])
mock_repo.get_pulls.return_value = mock_paginated_list
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint with the checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector_from_checkpoint(
github_connector, 0, end_time, checkpoint
)
# Check that we got the documents via cursor-based pagination
assert len(outputs) >= 2
assert len(outputs[1].items) == 2
assert isinstance(outputs[1].items[0], Document)
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/3"
assert isinstance(outputs[1].items[1], Document)
assert outputs[1].items[1].id == "https://github.com/test-org/test-repo/pull/4"
# Verify cursor URL was stored in checkpoint
assert outputs[1].next_checkpoint.cursor_url is None
assert outputs[1].next_checkpoint.num_retrieved == 0
def test_load_from_checkpoint_cursor_expiration(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
) -> None:
"""Test handling of cursor expiration during cursor-based pagination"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked PRs
mock_pr4 = create_mock_pr(number=4, title="PR 4")
# Create a checkpoint with an expired cursor
checkpoint = github_connector.build_dummy_checkpoint()
checkpoint.cursor_url = (
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=expired"
)
checkpoint.num_retrieved = 3 # We've already retrieved 3 items
# Mock get_pulls to simulate cursor expiration by raising an error before any results
mock_paginated_list = MagicMock()
mock_paginated_list.__nextUrl = (
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=expired"
)
mock_paginated_list.__iter__.side_effect = GithubException(
422, {"message": "Cursor expired"}, {}
)
# Create a new mock for successful retrieval after retry
mock_retry_paginated_list = MagicMock()
mock_retry_paginated_list.__nextUrl = None
# Create an iterator that will yield the remaining PR
def retry_iterator() -> Generator[PullRequest, None, None]:
yield mock_pr4
# Create a mock for the _Slice object that will be returned by pag_list[prev_num_objs:]
mock_slice = MagicMock()
mock_slice.__iter__.return_value = retry_iterator()
# Set up the slice behavior for the retry paginated list
mock_retry_paginated_list.__getitem__.return_value = mock_slice
# Set up the side effect for get_pulls to return our mocks
mock_repo.get_pulls.side_effect = [
mock_paginated_list,
mock_retry_paginated_list,
]
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint with the checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector_from_checkpoint(
github_connector, 0, end_time, checkpoint
)
# Check that we got the remaining document after retrying from the beginning
assert len(outputs) >= 2
assert len(outputs[1].items) == 1
assert isinstance(outputs[1].items[0], Document)
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/4"
# Verify cursor URL was cleared in checkpoint
assert outputs[1].next_checkpoint.cursor_url is None
assert outputs[1].next_checkpoint.num_retrieved == 0
# Verify that the slice was called with the correct argument
mock_retry_paginated_list.__getitem__.assert_called_once_with(slice(3, None))