mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-11 00:20:55 +02:00
github cursor pagination (#4642)
* v1 of cursor pagination * mypy * unit tests * CW comments
This commit is contained in:
parent
a6cc1c84dc
commit
6436b60763
@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
@ -39,6 +40,7 @@ from onyx.utils.logger import setup_logger
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
ITEMS_PER_PAGE = 100
|
ITEMS_PER_PAGE = 100
|
||||||
|
CURSOR_LOG_FREQUENCY = 100
|
||||||
|
|
||||||
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
||||||
|
|
||||||
@ -52,26 +54,138 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
|||||||
time.sleep(sleep_time.seconds)
|
time.sleep(sleep_time.seconds)
|
||||||
|
|
||||||
|
|
||||||
|
# Cases
|
||||||
|
# X (from start) standard run, no fallback to cursor-based pagination
|
||||||
|
# X (from start) standard run errors, fallback to cursor-based pagination
|
||||||
|
# X error in the middle of a page
|
||||||
|
# X no errors: run to completion
|
||||||
|
# X (from checkpoint) standard run, no fallback to cursor-based pagination
|
||||||
|
# X (from checkpoint) continue from cursor-based pagination
|
||||||
|
# - retrying
|
||||||
|
# - no retrying
|
||||||
|
|
||||||
|
# things to check:
|
||||||
|
# checkpoint state on return
|
||||||
|
# checkpoint progress (no infinite loop)
|
||||||
|
|
||||||
|
|
||||||
|
def _paginate_until_error(
|
||||||
|
git_objs: Callable[[], PaginatedList[PullRequest | Issue]],
|
||||||
|
cursor_url: str | None,
|
||||||
|
prev_num_objs: int,
|
||||||
|
cursor_url_callback: Callable[[str | None, int], None],
|
||||||
|
retrying: bool = False,
|
||||||
|
) -> Generator[PullRequest | Issue, None, None]:
|
||||||
|
num_objs = prev_num_objs
|
||||||
|
pag_list = git_objs()
|
||||||
|
if cursor_url:
|
||||||
|
pag_list.__nextUrl = cursor_url
|
||||||
|
elif retrying:
|
||||||
|
# if we are retrying, we want to skip the objects retrieved
|
||||||
|
# over previous calls. Unfortunately, this WILL retrieve all
|
||||||
|
# pages before the one we are resuming from, so we really
|
||||||
|
# don't want this case to be hit often
|
||||||
|
logger.warning(
|
||||||
|
"Retrying from a previous cursor-based pagination call. "
|
||||||
|
"This will retrieve all pages before the one we are resuming from, "
|
||||||
|
"which may take a while and consume many API calls."
|
||||||
|
)
|
||||||
|
pag_list = pag_list[prev_num_objs:]
|
||||||
|
num_objs = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# this for loop handles cursor-based pagination
|
||||||
|
for issue_or_pr in pag_list:
|
||||||
|
num_objs += 1
|
||||||
|
yield issue_or_pr
|
||||||
|
# used to store the current cursor url in the checkpoint. This value
|
||||||
|
# is updated during iteration over pag_list.
|
||||||
|
cursor_url_callback(pag_list.__nextUrl, num_objs)
|
||||||
|
|
||||||
|
if num_objs % CURSOR_LOG_FREQUENCY == 0:
|
||||||
|
logger.info(
|
||||||
|
f"Retrieved {num_objs} objects with current cursor url: {pag_list.__nextUrl}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error during cursor-based pagination: {e}")
|
||||||
|
if num_objs - prev_num_objs > 0:
|
||||||
|
raise
|
||||||
|
|
||||||
|
if pag_list.__nextUrl is not None and not retrying:
|
||||||
|
logger.info(
|
||||||
|
"Assuming that this error is due to cursor "
|
||||||
|
"expiration because no objects were retrieved. "
|
||||||
|
"Retrying from the first page."
|
||||||
|
)
|
||||||
|
yield from _paginate_until_error(
|
||||||
|
git_objs, None, prev_num_objs, cursor_url_callback, retrying=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# for no cursor url or if we reach this point after a retry, raise the error
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _get_batch_rate_limited(
|
def _get_batch_rate_limited(
|
||||||
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
|
# We pass in a callable because we want git_objs to produce a fresh
|
||||||
) -> list[PullRequest | Issue]:
|
# PaginatedList each time it's called to avoid using the same object for cursor-based pagination
|
||||||
|
# from a partial offset-based pagination call.
|
||||||
|
git_objs: Callable[[], PaginatedList],
|
||||||
|
page_num: int,
|
||||||
|
cursor_url: str | None,
|
||||||
|
prev_num_objs: int,
|
||||||
|
cursor_url_callback: Callable[[str | None, int], None],
|
||||||
|
github_client: Github,
|
||||||
|
attempt_num: int = 0,
|
||||||
|
) -> Generator[PullRequest | Issue, None, None]:
|
||||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
objs = list(git_objs.get_page(page_num))
|
if cursor_url:
|
||||||
|
# when this is set, we are resuming from an earlier
|
||||||
|
# cursor-based pagination call.
|
||||||
|
yield from _paginate_until_error(
|
||||||
|
git_objs, cursor_url, prev_num_objs, cursor_url_callback
|
||||||
|
)
|
||||||
|
return
|
||||||
|
objs = list(git_objs().get_page(page_num))
|
||||||
# fetch all data here to disable lazy loading later
|
# fetch all data here to disable lazy loading later
|
||||||
# this is needed to capture the rate limit exception here (if one occurs)
|
# this is needed to capture the rate limit exception here (if one occurs)
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if hasattr(obj, "raw_data"):
|
if hasattr(obj, "raw_data"):
|
||||||
getattr(obj, "raw_data")
|
getattr(obj, "raw_data")
|
||||||
return objs
|
yield from objs
|
||||||
except RateLimitExceededException:
|
except RateLimitExceededException:
|
||||||
_sleep_after_rate_limit_exception(github_client)
|
_sleep_after_rate_limit_exception(github_client)
|
||||||
return _get_batch_rate_limited(
|
yield from _get_batch_rate_limited(
|
||||||
git_objs, page_num, github_client, attempt_num + 1
|
git_objs,
|
||||||
|
page_num,
|
||||||
|
cursor_url,
|
||||||
|
prev_num_objs,
|
||||||
|
cursor_url_callback,
|
||||||
|
github_client,
|
||||||
|
attempt_num + 1,
|
||||||
|
)
|
||||||
|
except GithubException as e:
|
||||||
|
if not (
|
||||||
|
e.status == 422
|
||||||
|
and (
|
||||||
|
"cursor" in (e.message or "")
|
||||||
|
or "cursor" in (e.data or {}).get("message", "")
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise
|
||||||
|
# Fallback to a cursor-based pagination strategy
|
||||||
|
# This can happen for "large datasets," but there's no documentation
|
||||||
|
# On the error on the web as far as we can tell.
|
||||||
|
# Error message:
|
||||||
|
# "Pagination with the page parameter is not supported for large datasets,
|
||||||
|
# please use cursor based pagination (after/before)"
|
||||||
|
yield from _paginate_until_error(
|
||||||
|
git_objs, cursor_url, prev_num_objs, cursor_url_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -142,6 +256,18 @@ class GithubConnectorCheckpoint(ConnectorCheckpoint):
|
|||||||
cached_repo_ids: list[int] | None = None
|
cached_repo_ids: list[int] | None = None
|
||||||
cached_repo: SerializedRepository | None = None
|
cached_repo: SerializedRepository | None = None
|
||||||
|
|
||||||
|
# Used for the fallback cursor-based pagination strategy
|
||||||
|
num_retrieved: int
|
||||||
|
cursor_url: str | None = None
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None)
|
||||||
|
"""
|
||||||
|
self.curr_page = 0
|
||||||
|
self.num_retrieved = 0
|
||||||
|
self.cursor_url = None
|
||||||
|
|
||||||
|
|
||||||
class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -230,6 +356,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
|||||||
try:
|
try:
|
||||||
org = github_client.get_organization(self.repo_owner)
|
org = github_client.get_organization(self.repo_owner)
|
||||||
return list(org.get_repos())
|
return list(org.get_repos())
|
||||||
|
|
||||||
except GithubException:
|
except GithubException:
|
||||||
# If not an org, try as a user
|
# If not an org, try as a user
|
||||||
user = github_client.get_user(self.repo_owner)
|
user = github_client.get_user(self.repo_owner)
|
||||||
@ -266,11 +393,12 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
|||||||
checkpoint.has_more = False
|
checkpoint.has_more = False
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
|
curr_repo = repos.pop()
|
||||||
|
checkpoint.cached_repo_ids = [repo.id for repo in repos]
|
||||||
checkpoint.cached_repo = SerializedRepository(
|
checkpoint.cached_repo = SerializedRepository(
|
||||||
id=checkpoint.cached_repo_ids[0],
|
id=curr_repo.id,
|
||||||
headers=repos[0].raw_headers,
|
headers=curr_repo.raw_headers,
|
||||||
raw_data=repos[0].raw_data,
|
raw_data=curr_repo.raw_data,
|
||||||
)
|
)
|
||||||
checkpoint.stage = GithubConnectorStage.PRS
|
checkpoint.stage = GithubConnectorStage.PRS
|
||||||
checkpoint.curr_page = 0
|
checkpoint.curr_page = 0
|
||||||
@ -299,26 +427,41 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
|||||||
repo_id = checkpoint.cached_repo.id
|
repo_id = checkpoint.cached_repo.id
|
||||||
repo = self.github_client.get_repo(repo_id)
|
repo = self.github_client.get_repo(repo_id)
|
||||||
|
|
||||||
|
def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None:
|
||||||
|
checkpoint.cursor_url = cursor_url
|
||||||
|
checkpoint.num_retrieved = num_objs
|
||||||
|
|
||||||
|
# TODO: all PRs are also issues, so we should be able to _only_ get issues
|
||||||
|
# and then filter appropriately whenever include_issues is True
|
||||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||||
pull_requests = repo.get_pulls(
|
|
||||||
state=self.state_filter, sort="updated", direction="desc"
|
|
||||||
)
|
|
||||||
|
|
||||||
doc_batch: list[Document] = []
|
def pull_requests_func() -> PaginatedList[PullRequest]:
|
||||||
|
return repo.get_pulls(
|
||||||
|
state=self.state_filter, sort="updated", direction="desc"
|
||||||
|
)
|
||||||
|
|
||||||
pr_batch = _get_batch_rate_limited(
|
pr_batch = _get_batch_rate_limited(
|
||||||
pull_requests, checkpoint.curr_page, self.github_client
|
pull_requests_func,
|
||||||
|
checkpoint.curr_page,
|
||||||
|
checkpoint.cursor_url,
|
||||||
|
checkpoint.num_retrieved,
|
||||||
|
cursor_url_callback,
|
||||||
|
self.github_client,
|
||||||
)
|
)
|
||||||
checkpoint.curr_page += 1
|
checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback
|
||||||
done_with_prs = False
|
done_with_prs = False
|
||||||
|
num_prs = 0
|
||||||
|
pr = None
|
||||||
for pr in pr_batch:
|
for pr in pr_batch:
|
||||||
|
num_prs += 1
|
||||||
|
|
||||||
# we iterate backwards in time, so at this point we stop processing prs
|
# we iterate backwards in time, so at this point we stop processing prs
|
||||||
if (
|
if (
|
||||||
start is not None
|
start is not None
|
||||||
and pr.updated_at
|
and pr.updated_at
|
||||||
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
||||||
):
|
):
|
||||||
yield from doc_batch
|
|
||||||
done_with_prs = True
|
done_with_prs = True
|
||||||
break
|
break
|
||||||
# Skip PRs updated after the end date
|
# Skip PRs updated after the end date
|
||||||
@ -329,7 +472,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
yield _convert_pr_to_document(cast(PullRequest, pr))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error converting PR to document: {e}"
|
error_msg = f"Error converting PR to document: {e}"
|
||||||
logger.exception(error_msg)
|
logger.exception(error_msg)
|
||||||
@ -342,37 +485,57 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# if we found any PRs on the page, yield any associated documents and return the checkpoint
|
# If we reach this point with a cursor url in the checkpoint, we were using
|
||||||
if not done_with_prs and len(pr_batch) > 0:
|
# the fallback cursor-based pagination strategy. That strategy tries to get all
|
||||||
yield from doc_batch
|
# PRs, so having curosr_url set means we are done with prs. However, we need to
|
||||||
|
# return AFTER the checkpoint reset to avoid infinite loops.
|
||||||
|
|
||||||
|
# if we found any PRs on the page and there are more PRs to get, return the checkpoint.
|
||||||
|
# In offset mode, while indexing without time constraints, the pr batch
|
||||||
|
# will be empty when we're done.
|
||||||
|
if num_prs > 0 and not done_with_prs and not checkpoint.cursor_url:
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
# if we went past the start date during the loop or there are no more
|
# if we went past the start date during the loop or there are no more
|
||||||
# prs to get, we move on to issues
|
# prs to get, we move on to issues
|
||||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||||
checkpoint.curr_page = 0
|
checkpoint.reset()
|
||||||
|
|
||||||
|
if checkpoint.cursor_url:
|
||||||
|
# save the checkpoint after changing stage; next run will continue from issues
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||||
|
|
||||||
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
||||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||||
issues = repo.get_issues(
|
|
||||||
state=self.state_filter, sort="updated", direction="desc"
|
|
||||||
)
|
|
||||||
|
|
||||||
doc_batch = []
|
def issues_func() -> PaginatedList[Issue]:
|
||||||
issue_batch = _get_batch_rate_limited(
|
return repo.get_issues(
|
||||||
issues, checkpoint.curr_page, self.github_client
|
state=self.state_filter, sort="updated", direction="desc"
|
||||||
|
)
|
||||||
|
|
||||||
|
issue_batch = list(
|
||||||
|
_get_batch_rate_limited(
|
||||||
|
issues_func,
|
||||||
|
checkpoint.curr_page,
|
||||||
|
checkpoint.cursor_url,
|
||||||
|
checkpoint.num_retrieved,
|
||||||
|
cursor_url_callback,
|
||||||
|
self.github_client,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
checkpoint.curr_page += 1
|
checkpoint.curr_page += 1
|
||||||
done_with_issues = False
|
done_with_issues = False
|
||||||
for issue in cast(list[Issue], issue_batch):
|
num_issues = 0
|
||||||
|
for issue in issue_batch:
|
||||||
|
num_issues += 1
|
||||||
|
issue = cast(Issue, issue)
|
||||||
# we iterate backwards in time, so at this point we stop processing prs
|
# we iterate backwards in time, so at this point we stop processing prs
|
||||||
if (
|
if (
|
||||||
start is not None
|
start is not None
|
||||||
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
||||||
):
|
):
|
||||||
yield from doc_batch
|
|
||||||
done_with_issues = True
|
done_with_issues = True
|
||||||
break
|
break
|
||||||
# Skip PRs updated after the end date
|
# Skip PRs updated after the end date
|
||||||
@ -384,10 +547,11 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
|||||||
|
|
||||||
if issue.pull_request is not None:
|
if issue.pull_request is not None:
|
||||||
# PRs are handled separately
|
# PRs are handled separately
|
||||||
|
# TODO: but they shouldn't always be
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
doc_batch.append(_convert_issue_to_document(issue))
|
yield _convert_issue_to_document(issue)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error converting issue to document: {e}"
|
error_msg = f"Error converting issue to document: {e}"
|
||||||
logger.exception(error_msg)
|
logger.exception(error_msg)
|
||||||
@ -401,17 +565,17 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# if we found any issues on the page, yield them and return the checkpoint
|
# if we found any issues on the page, and we're not done, return the checkpoint.
|
||||||
if not done_with_issues and len(issue_batch) > 0:
|
# don't return if we're using cursor-based pagination to avoid infinite loops
|
||||||
yield from doc_batch
|
if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url:
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
# if we went past the start date during the loop or there are no more
|
# if we went past the start date during the loop or there are no more
|
||||||
# issues to get, we move on to the next repo
|
# issues to get, we move on to the next repo
|
||||||
checkpoint.stage = GithubConnectorStage.PRS
|
checkpoint.stage = GithubConnectorStage.PRS
|
||||||
checkpoint.curr_page = 0
|
checkpoint.reset()
|
||||||
|
|
||||||
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
|
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0
|
||||||
if checkpoint.cached_repo_ids:
|
if checkpoint.cached_repo_ids:
|
||||||
next_id = checkpoint.cached_repo_ids.pop()
|
next_id = checkpoint.cached_repo_ids.pop()
|
||||||
next_repo = self.github_client.get_repo(next_id)
|
next_repo = self.github_client.get_repo(next_id)
|
||||||
@ -553,7 +717,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
|||||||
|
|
||||||
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
||||||
return GithubConnectorCheckpoint(
|
return GithubConnectorCheckpoint(
|
||||||
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
|
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,8 +9,8 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from github import Github
|
from github import Github
|
||||||
from github import GithubException
|
|
||||||
from github import RateLimitExceededException
|
from github import RateLimitExceededException
|
||||||
|
from github.GithubException import GithubException
|
||||||
from github.Issue import Issue
|
from github.Issue import Issue
|
||||||
from github.PullRequest import PullRequest
|
from github.PullRequest import PullRequest
|
||||||
from github.RateLimit import RateLimit
|
from github.RateLimit import RateLimit
|
||||||
@ -24,6 +24,9 @@ from onyx.connectors.github.connector import GithubConnector
|
|||||||
from onyx.connectors.github.connector import SerializedRepository
|
from onyx.connectors.github.connector import SerializedRepository
|
||||||
from onyx.connectors.models import Document
|
from onyx.connectors.models import Document
|
||||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||||
|
from tests.unit.onyx.connectors.utils import (
|
||||||
|
load_everything_from_checkpoint_connector_from_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -439,3 +442,190 @@ def test_validate_connector_settings_success(
|
|||||||
github_connector.github_client.get_repo.assert_called_once_with(
|
github_connector.github_client.get_repo.assert_called_once_with(
|
||||||
f"{github_connector.repo_owner}/{github_connector.repositories}"
|
f"{github_connector.repo_owner}/{github_connector.repositories}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_checkpoint_with_cursor_fallback(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
mock_github_client: MagicMock,
|
||||||
|
create_mock_repo: Callable[..., MagicMock],
|
||||||
|
create_mock_pr: Callable[..., MagicMock],
|
||||||
|
) -> None:
|
||||||
|
"""Test loading from checkpoint with fallback to cursor-based pagination"""
|
||||||
|
# Set up mocked repo
|
||||||
|
mock_repo = create_mock_repo()
|
||||||
|
github_connector.github_client = mock_github_client
|
||||||
|
mock_github_client.get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Set up mocked PRs
|
||||||
|
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||||
|
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||||
|
|
||||||
|
# Create a mock paginated list that will raise the 422 error on get_page
|
||||||
|
mock_paginated_list = MagicMock()
|
||||||
|
mock_paginated_list.get_page.side_effect = [
|
||||||
|
GithubException(
|
||||||
|
422,
|
||||||
|
{
|
||||||
|
"message": "Pagination with the page parameter is not supported for large datasets. Use cursor"
|
||||||
|
},
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a new mock for cursor-based pagination
|
||||||
|
mock_cursor_paginated_list = MagicMock()
|
||||||
|
mock_cursor_paginated_list.__nextUrl = (
|
||||||
|
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=abc123"
|
||||||
|
)
|
||||||
|
mock_cursor_paginated_list.__iter__.return_value = iter([mock_pr1, mock_pr2])
|
||||||
|
|
||||||
|
mock_repo.get_pulls.side_effect = [
|
||||||
|
mock_paginated_list,
|
||||||
|
mock_cursor_paginated_list,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||||
|
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||||
|
# Call load_from_checkpoint
|
||||||
|
end_time = time.time()
|
||||||
|
outputs = load_everything_from_checkpoint_connector(
|
||||||
|
github_connector, 0, end_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we got the documents via cursor-based pagination
|
||||||
|
assert len(outputs) >= 2
|
||||||
|
assert len(outputs[1].items) == 2
|
||||||
|
assert isinstance(outputs[1].items[0], Document)
|
||||||
|
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||||
|
assert isinstance(outputs[1].items[1], Document)
|
||||||
|
assert outputs[1].items[1].id == "https://github.com/test-org/test-repo/pull/2"
|
||||||
|
|
||||||
|
# Verify cursor URL is not set in checkpoint since pagination succeeded without failures
|
||||||
|
assert outputs[1].next_checkpoint.cursor_url is None
|
||||||
|
assert outputs[1].next_checkpoint.num_retrieved == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_checkpoint_resume_cursor_pagination(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
mock_github_client: MagicMock,
|
||||||
|
create_mock_repo: Callable[..., MagicMock],
|
||||||
|
create_mock_pr: Callable[..., MagicMock],
|
||||||
|
) -> None:
|
||||||
|
"""Test resuming from a checkpoint that was using cursor-based pagination"""
|
||||||
|
# Set up mocked repo
|
||||||
|
mock_repo = create_mock_repo()
|
||||||
|
github_connector.github_client = mock_github_client
|
||||||
|
mock_github_client.get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Set up mocked PRs
|
||||||
|
mock_pr3 = create_mock_pr(number=3, title="PR 3")
|
||||||
|
mock_pr4 = create_mock_pr(number=4, title="PR 4")
|
||||||
|
|
||||||
|
# Create a checkpoint that was using cursor-based pagination
|
||||||
|
checkpoint = github_connector.build_dummy_checkpoint()
|
||||||
|
checkpoint.cursor_url = (
|
||||||
|
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=abc123"
|
||||||
|
)
|
||||||
|
checkpoint.num_retrieved = 2
|
||||||
|
|
||||||
|
# Mock get_pulls to use cursor-based pagination
|
||||||
|
mock_paginated_list = MagicMock()
|
||||||
|
mock_paginated_list.__nextUrl = (
|
||||||
|
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=def456"
|
||||||
|
)
|
||||||
|
mock_paginated_list.__iter__.return_value = iter([mock_pr3, mock_pr4])
|
||||||
|
mock_repo.get_pulls.return_value = mock_paginated_list
|
||||||
|
|
||||||
|
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||||
|
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||||
|
# Call load_from_checkpoint with the checkpoint
|
||||||
|
end_time = time.time()
|
||||||
|
outputs = load_everything_from_checkpoint_connector_from_checkpoint(
|
||||||
|
github_connector, 0, end_time, checkpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we got the documents via cursor-based pagination
|
||||||
|
assert len(outputs) >= 2
|
||||||
|
assert len(outputs[1].items) == 2
|
||||||
|
assert isinstance(outputs[1].items[0], Document)
|
||||||
|
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/3"
|
||||||
|
assert isinstance(outputs[1].items[1], Document)
|
||||||
|
assert outputs[1].items[1].id == "https://github.com/test-org/test-repo/pull/4"
|
||||||
|
|
||||||
|
# Verify cursor URL was stored in checkpoint
|
||||||
|
assert outputs[1].next_checkpoint.cursor_url is None
|
||||||
|
assert outputs[1].next_checkpoint.num_retrieved == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_checkpoint_cursor_expiration(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
mock_github_client: MagicMock,
|
||||||
|
create_mock_repo: Callable[..., MagicMock],
|
||||||
|
create_mock_pr: Callable[..., MagicMock],
|
||||||
|
) -> None:
|
||||||
|
"""Test handling of cursor expiration during cursor-based pagination"""
|
||||||
|
# Set up mocked repo
|
||||||
|
mock_repo = create_mock_repo()
|
||||||
|
github_connector.github_client = mock_github_client
|
||||||
|
mock_github_client.get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Set up mocked PRs
|
||||||
|
mock_pr4 = create_mock_pr(number=4, title="PR 4")
|
||||||
|
|
||||||
|
# Create a checkpoint with an expired cursor
|
||||||
|
checkpoint = github_connector.build_dummy_checkpoint()
|
||||||
|
checkpoint.cursor_url = (
|
||||||
|
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=expired"
|
||||||
|
)
|
||||||
|
checkpoint.num_retrieved = 3 # We've already retrieved 3 items
|
||||||
|
|
||||||
|
# Mock get_pulls to simulate cursor expiration by raising an error before any results
|
||||||
|
mock_paginated_list = MagicMock()
|
||||||
|
mock_paginated_list.__nextUrl = (
|
||||||
|
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=expired"
|
||||||
|
)
|
||||||
|
mock_paginated_list.__iter__.side_effect = GithubException(
|
||||||
|
422, {"message": "Cursor expired"}, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new mock for successful retrieval after retry
|
||||||
|
mock_retry_paginated_list = MagicMock()
|
||||||
|
mock_retry_paginated_list.__nextUrl = None
|
||||||
|
|
||||||
|
# Create an iterator that will yield the remaining PR
|
||||||
|
def retry_iterator() -> Generator[PullRequest, None, None]:
|
||||||
|
yield mock_pr4
|
||||||
|
|
||||||
|
# Create a mock for the _Slice object that will be returned by pag_list[prev_num_objs:]
|
||||||
|
mock_slice = MagicMock()
|
||||||
|
mock_slice.__iter__.return_value = retry_iterator()
|
||||||
|
|
||||||
|
# Set up the slice behavior for the retry paginated list
|
||||||
|
mock_retry_paginated_list.__getitem__.return_value = mock_slice
|
||||||
|
|
||||||
|
# Set up the side effect for get_pulls to return our mocks
|
||||||
|
mock_repo.get_pulls.side_effect = [
|
||||||
|
mock_paginated_list,
|
||||||
|
mock_retry_paginated_list,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||||
|
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||||
|
# Call load_from_checkpoint with the checkpoint
|
||||||
|
end_time = time.time()
|
||||||
|
outputs = load_everything_from_checkpoint_connector_from_checkpoint(
|
||||||
|
github_connector, 0, end_time, checkpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we got the remaining document after retrying from the beginning
|
||||||
|
assert len(outputs) >= 2
|
||||||
|
assert len(outputs[1].items) == 1
|
||||||
|
assert isinstance(outputs[1].items[0], Document)
|
||||||
|
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/4"
|
||||||
|
|
||||||
|
# Verify cursor URL was cleared in checkpoint
|
||||||
|
assert outputs[1].next_checkpoint.cursor_url is None
|
||||||
|
assert outputs[1].next_checkpoint.num_retrieved == 0
|
||||||
|
|
||||||
|
# Verify that the slice was called with the correct argument
|
||||||
|
mock_retry_paginated_list.__getitem__.assert_called_once_with(slice(3, None))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user