mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-03 03:31:09 +02:00
github cursor pagination (#4642)
* v1 of cursor pagination * mypy * unit tests * CW comments
This commit is contained in:
parent
a6cc1c84dc
commit
6436b60763
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@ -39,6 +40,7 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
ITEMS_PER_PAGE = 100
|
||||
CURSOR_LOG_FREQUENCY = 100
|
||||
|
||||
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
@ -52,26 +54,138 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
time.sleep(sleep_time.seconds)
|
||||
|
||||
|
||||
# Cases
|
||||
# X (from start) standard run, no fallback to cursor-based pagination
|
||||
# X (from start) standard run errors, fallback to cursor-based pagination
|
||||
# X error in the middle of a page
|
||||
# X no errors: run to completion
|
||||
# X (from checkpoint) standard run, no fallback to cursor-based pagination
|
||||
# X (from checkpoint) continue from cursor-based pagination
|
||||
# - retrying
|
||||
# - no retrying
|
||||
|
||||
# things to check:
|
||||
# checkpoint state on return
|
||||
# checkpoint progress (no infinite loop)
|
||||
|
||||
|
||||
def _paginate_until_error(
|
||||
git_objs: Callable[[], PaginatedList[PullRequest | Issue]],
|
||||
cursor_url: str | None,
|
||||
prev_num_objs: int,
|
||||
cursor_url_callback: Callable[[str | None, int], None],
|
||||
retrying: bool = False,
|
||||
) -> Generator[PullRequest | Issue, None, None]:
|
||||
num_objs = prev_num_objs
|
||||
pag_list = git_objs()
|
||||
if cursor_url:
|
||||
pag_list.__nextUrl = cursor_url
|
||||
elif retrying:
|
||||
# if we are retrying, we want to skip the objects retrieved
|
||||
# over previous calls. Unfortunately, this WILL retrieve all
|
||||
# pages before the one we are resuming from, so we really
|
||||
# don't want this case to be hit often
|
||||
logger.warning(
|
||||
"Retrying from a previous cursor-based pagination call. "
|
||||
"This will retrieve all pages before the one we are resuming from, "
|
||||
"which may take a while and consume many API calls."
|
||||
)
|
||||
pag_list = pag_list[prev_num_objs:]
|
||||
num_objs = 0
|
||||
|
||||
try:
|
||||
# this for loop handles cursor-based pagination
|
||||
for issue_or_pr in pag_list:
|
||||
num_objs += 1
|
||||
yield issue_or_pr
|
||||
# used to store the current cursor url in the checkpoint. This value
|
||||
# is updated during iteration over pag_list.
|
||||
cursor_url_callback(pag_list.__nextUrl, num_objs)
|
||||
|
||||
if num_objs % CURSOR_LOG_FREQUENCY == 0:
|
||||
logger.info(
|
||||
f"Retrieved {num_objs} objects with current cursor url: {pag_list.__nextUrl}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during cursor-based pagination: {e}")
|
||||
if num_objs - prev_num_objs > 0:
|
||||
raise
|
||||
|
||||
if pag_list.__nextUrl is not None and not retrying:
|
||||
logger.info(
|
||||
"Assuming that this error is due to cursor "
|
||||
"expiration because no objects were retrieved. "
|
||||
"Retrying from the first page."
|
||||
)
|
||||
yield from _paginate_until_error(
|
||||
git_objs, None, prev_num_objs, cursor_url_callback, retrying=True
|
||||
)
|
||||
return
|
||||
|
||||
# for no cursor url or if we reach this point after a retry, raise the error
|
||||
raise
|
||||
|
||||
|
||||
def _get_batch_rate_limited(
|
||||
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
|
||||
) -> list[PullRequest | Issue]:
|
||||
# We pass in a callable because we want git_objs to produce a fresh
|
||||
# PaginatedList each time it's called to avoid using the same object for cursor-based pagination
|
||||
# from a partial offset-based pagination call.
|
||||
git_objs: Callable[[], PaginatedList],
|
||||
page_num: int,
|
||||
cursor_url: str | None,
|
||||
prev_num_objs: int,
|
||||
cursor_url_callback: Callable[[str | None, int], None],
|
||||
github_client: Github,
|
||||
attempt_num: int = 0,
|
||||
) -> Generator[PullRequest | Issue, None, None]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
|
||||
try:
|
||||
objs = list(git_objs.get_page(page_num))
|
||||
if cursor_url:
|
||||
# when this is set, we are resuming from an earlier
|
||||
# cursor-based pagination call.
|
||||
yield from _paginate_until_error(
|
||||
git_objs, cursor_url, prev_num_objs, cursor_url_callback
|
||||
)
|
||||
return
|
||||
objs = list(git_objs().get_page(page_num))
|
||||
# fetch all data here to disable lazy loading later
|
||||
# this is needed to capture the rate limit exception here (if one occurs)
|
||||
for obj in objs:
|
||||
if hasattr(obj, "raw_data"):
|
||||
getattr(obj, "raw_data")
|
||||
return objs
|
||||
yield from objs
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return _get_batch_rate_limited(
|
||||
git_objs, page_num, github_client, attempt_num + 1
|
||||
yield from _get_batch_rate_limited(
|
||||
git_objs,
|
||||
page_num,
|
||||
cursor_url,
|
||||
prev_num_objs,
|
||||
cursor_url_callback,
|
||||
github_client,
|
||||
attempt_num + 1,
|
||||
)
|
||||
except GithubException as e:
|
||||
if not (
|
||||
e.status == 422
|
||||
and (
|
||||
"cursor" in (e.message or "")
|
||||
or "cursor" in (e.data or {}).get("message", "")
|
||||
)
|
||||
):
|
||||
raise
|
||||
# Fallback to a cursor-based pagination strategy
|
||||
# This can happen for "large datasets," but there's no documentation
|
||||
# On the error on the web as far as we can tell.
|
||||
# Error message:
|
||||
# "Pagination with the page parameter is not supported for large datasets,
|
||||
# please use cursor based pagination (after/before)"
|
||||
yield from _paginate_until_error(
|
||||
git_objs, cursor_url, prev_num_objs, cursor_url_callback
|
||||
)
|
||||
|
||||
|
||||
@ -142,6 +256,18 @@ class GithubConnectorCheckpoint(ConnectorCheckpoint):
|
||||
cached_repo_ids: list[int] | None = None
|
||||
cached_repo: SerializedRepository | None = None
|
||||
|
||||
# Used for the fallback cursor-based pagination strategy
|
||||
num_retrieved: int
|
||||
cursor_url: str | None = None
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None)
|
||||
"""
|
||||
self.curr_page = 0
|
||||
self.num_retrieved = 0
|
||||
self.cursor_url = None
|
||||
|
||||
|
||||
class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
def __init__(
|
||||
@ -230,6 +356,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
try:
|
||||
org = github_client.get_organization(self.repo_owner)
|
||||
return list(org.get_repos())
|
||||
|
||||
except GithubException:
|
||||
# If not an org, try as a user
|
||||
user = github_client.get_user(self.repo_owner)
|
||||
@ -266,11 +393,12 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
|
||||
curr_repo = repos.pop()
|
||||
checkpoint.cached_repo_ids = [repo.id for repo in repos]
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=checkpoint.cached_repo_ids[0],
|
||||
headers=repos[0].raw_headers,
|
||||
raw_data=repos[0].raw_data,
|
||||
id=curr_repo.id,
|
||||
headers=curr_repo.raw_headers,
|
||||
raw_data=curr_repo.raw_data,
|
||||
)
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
@ -299,26 +427,41 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
repo_id = checkpoint.cached_repo.id
|
||||
repo = self.github_client.get_repo(repo_id)
|
||||
|
||||
def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None:
|
||||
checkpoint.cursor_url = cursor_url
|
||||
checkpoint.num_retrieved = num_objs
|
||||
|
||||
# TODO: all PRs are also issues, so we should be able to _only_ get issues
|
||||
# and then filter appropriately whenever include_issues is True
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
def pull_requests_func() -> PaginatedList[PullRequest]:
|
||||
return repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
pr_batch = _get_batch_rate_limited(
|
||||
pull_requests, checkpoint.curr_page, self.github_client
|
||||
pull_requests_func,
|
||||
checkpoint.curr_page,
|
||||
checkpoint.cursor_url,
|
||||
checkpoint.num_retrieved,
|
||||
cursor_url_callback,
|
||||
self.github_client,
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback
|
||||
done_with_prs = False
|
||||
num_prs = 0
|
||||
pr = None
|
||||
for pr in pr_batch:
|
||||
num_prs += 1
|
||||
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
yield from doc_batch
|
||||
done_with_prs = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
@ -329,7 +472,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
):
|
||||
continue
|
||||
try:
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield _convert_pr_to_document(cast(PullRequest, pr))
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
@ -342,37 +485,57 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
)
|
||||
continue
|
||||
|
||||
# if we found any PRs on the page, yield any associated documents and return the checkpoint
|
||||
if not done_with_prs and len(pr_batch) > 0:
|
||||
yield from doc_batch
|
||||
# If we reach this point with a cursor url in the checkpoint, we were using
|
||||
# the fallback cursor-based pagination strategy. That strategy tries to get all
|
||||
# PRs, so having curosr_url set means we are done with prs. However, we need to
|
||||
# return AFTER the checkpoint reset to avoid infinite loops.
|
||||
|
||||
# if we found any PRs on the page and there are more PRs to get, return the checkpoint.
|
||||
# In offset mode, while indexing without time constraints, the pr batch
|
||||
# will be empty when we're done.
|
||||
if num_prs > 0 and not done_with_prs and not checkpoint.cursor_url:
|
||||
return checkpoint
|
||||
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# prs to get, we move on to issues
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
checkpoint.curr_page = 0
|
||||
checkpoint.reset()
|
||||
|
||||
if checkpoint.cursor_url:
|
||||
# save the checkpoint after changing stage; next run will continue from issues
|
||||
return checkpoint
|
||||
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
|
||||
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch = []
|
||||
issue_batch = _get_batch_rate_limited(
|
||||
issues, checkpoint.curr_page, self.github_client
|
||||
def issues_func() -> PaginatedList[Issue]:
|
||||
return repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
issue_batch = list(
|
||||
_get_batch_rate_limited(
|
||||
issues_func,
|
||||
checkpoint.curr_page,
|
||||
checkpoint.cursor_url,
|
||||
checkpoint.num_retrieved,
|
||||
cursor_url_callback,
|
||||
self.github_client,
|
||||
)
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
done_with_issues = False
|
||||
for issue in cast(list[Issue], issue_batch):
|
||||
num_issues = 0
|
||||
for issue in issue_batch:
|
||||
num_issues += 1
|
||||
issue = cast(Issue, issue)
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
yield from doc_batch
|
||||
done_with_issues = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
@ -384,10 +547,11 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
# TODO: but they shouldn't always be
|
||||
continue
|
||||
|
||||
try:
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield _convert_issue_to_document(issue)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting issue to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
@ -401,17 +565,17 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
)
|
||||
continue
|
||||
|
||||
# if we found any issues on the page, yield them and return the checkpoint
|
||||
if not done_with_issues and len(issue_batch) > 0:
|
||||
yield from doc_batch
|
||||
# if we found any issues on the page, and we're not done, return the checkpoint.
|
||||
# don't return if we're using cursor-based pagination to avoid infinite loops
|
||||
if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url:
|
||||
return checkpoint
|
||||
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# issues to get, we move on to the next repo
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
checkpoint.reset()
|
||||
|
||||
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
|
||||
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0
|
||||
if checkpoint.cached_repo_ids:
|
||||
next_id = checkpoint.cached_repo_ids.pop()
|
||||
next_repo = self.github_client.get_repo(next_id)
|
||||
@ -553,7 +717,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
|
||||
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint(
|
||||
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
|
||||
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0
|
||||
)
|
||||
|
||||
|
||||
|
@ -9,8 +9,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from github import Github
|
||||
from github import GithubException
|
||||
from github import RateLimitExceededException
|
||||
from github.GithubException import GithubException
|
||||
from github.Issue import Issue
|
||||
from github.PullRequest import PullRequest
|
||||
from github.RateLimit import RateLimit
|
||||
@ -24,6 +24,9 @@ from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.github.connector import SerializedRepository
|
||||
from onyx.connectors.models import Document
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
from tests.unit.onyx.connectors.utils import (
|
||||
load_everything_from_checkpoint_connector_from_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -439,3 +442,190 @@ def test_validate_connector_settings_success(
|
||||
github_connector.github_client.get_repo.assert_called_once_with(
|
||||
f"{github_connector.repo_owner}/{github_connector.repositories}"
|
||||
)
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_cursor_fallback(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with fallback to cursor-based pagination"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PRs
|
||||
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||
|
||||
# Create a mock paginated list that will raise the 422 error on get_page
|
||||
mock_paginated_list = MagicMock()
|
||||
mock_paginated_list.get_page.side_effect = [
|
||||
GithubException(
|
||||
422,
|
||||
{
|
||||
"message": "Pagination with the page parameter is not supported for large datasets. Use cursor"
|
||||
},
|
||||
{},
|
||||
),
|
||||
]
|
||||
|
||||
# Create a new mock for cursor-based pagination
|
||||
mock_cursor_paginated_list = MagicMock()
|
||||
mock_cursor_paginated_list.__nextUrl = (
|
||||
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=abc123"
|
||||
)
|
||||
mock_cursor_paginated_list.__iter__.return_value = iter([mock_pr1, mock_pr2])
|
||||
|
||||
mock_repo.get_pulls.side_effect = [
|
||||
mock_paginated_list,
|
||||
mock_cursor_paginated_list,
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we got the documents via cursor-based pagination
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 2
|
||||
assert isinstance(outputs[1].items[0], Document)
|
||||
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||
assert isinstance(outputs[1].items[1], Document)
|
||||
assert outputs[1].items[1].id == "https://github.com/test-org/test-repo/pull/2"
|
||||
|
||||
# Verify cursor URL is not set in checkpoint since pagination succeeded without failures
|
||||
assert outputs[1].next_checkpoint.cursor_url is None
|
||||
assert outputs[1].next_checkpoint.num_retrieved == 0
|
||||
|
||||
|
||||
def test_load_from_checkpoint_resume_cursor_pagination(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test resuming from a checkpoint that was using cursor-based pagination"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PRs
|
||||
mock_pr3 = create_mock_pr(number=3, title="PR 3")
|
||||
mock_pr4 = create_mock_pr(number=4, title="PR 4")
|
||||
|
||||
# Create a checkpoint that was using cursor-based pagination
|
||||
checkpoint = github_connector.build_dummy_checkpoint()
|
||||
checkpoint.cursor_url = (
|
||||
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=abc123"
|
||||
)
|
||||
checkpoint.num_retrieved = 2
|
||||
|
||||
# Mock get_pulls to use cursor-based pagination
|
||||
mock_paginated_list = MagicMock()
|
||||
mock_paginated_list.__nextUrl = (
|
||||
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=def456"
|
||||
)
|
||||
mock_paginated_list.__iter__.return_value = iter([mock_pr3, mock_pr4])
|
||||
mock_repo.get_pulls.return_value = mock_paginated_list
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint with the checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector_from_checkpoint(
|
||||
github_connector, 0, end_time, checkpoint
|
||||
)
|
||||
|
||||
# Check that we got the documents via cursor-based pagination
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 2
|
||||
assert isinstance(outputs[1].items[0], Document)
|
||||
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/3"
|
||||
assert isinstance(outputs[1].items[1], Document)
|
||||
assert outputs[1].items[1].id == "https://github.com/test-org/test-repo/pull/4"
|
||||
|
||||
# Verify cursor URL was stored in checkpoint
|
||||
assert outputs[1].next_checkpoint.cursor_url is None
|
||||
assert outputs[1].next_checkpoint.num_retrieved == 0
|
||||
|
||||
|
||||
def test_load_from_checkpoint_cursor_expiration(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test handling of cursor expiration during cursor-based pagination"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PRs
|
||||
mock_pr4 = create_mock_pr(number=4, title="PR 4")
|
||||
|
||||
# Create a checkpoint with an expired cursor
|
||||
checkpoint = github_connector.build_dummy_checkpoint()
|
||||
checkpoint.cursor_url = (
|
||||
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=expired"
|
||||
)
|
||||
checkpoint.num_retrieved = 3 # We've already retrieved 3 items
|
||||
|
||||
# Mock get_pulls to simulate cursor expiration by raising an error before any results
|
||||
mock_paginated_list = MagicMock()
|
||||
mock_paginated_list.__nextUrl = (
|
||||
"https://api.github.com/repos/test-org/test-repo/pulls?cursor=expired"
|
||||
)
|
||||
mock_paginated_list.__iter__.side_effect = GithubException(
|
||||
422, {"message": "Cursor expired"}, {}
|
||||
)
|
||||
|
||||
# Create a new mock for successful retrieval after retry
|
||||
mock_retry_paginated_list = MagicMock()
|
||||
mock_retry_paginated_list.__nextUrl = None
|
||||
|
||||
# Create an iterator that will yield the remaining PR
|
||||
def retry_iterator() -> Generator[PullRequest, None, None]:
|
||||
yield mock_pr4
|
||||
|
||||
# Create a mock for the _Slice object that will be returned by pag_list[prev_num_objs:]
|
||||
mock_slice = MagicMock()
|
||||
mock_slice.__iter__.return_value = retry_iterator()
|
||||
|
||||
# Set up the slice behavior for the retry paginated list
|
||||
mock_retry_paginated_list.__getitem__.return_value = mock_slice
|
||||
|
||||
# Set up the side effect for get_pulls to return our mocks
|
||||
mock_repo.get_pulls.side_effect = [
|
||||
mock_paginated_list,
|
||||
mock_retry_paginated_list,
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint with the checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector_from_checkpoint(
|
||||
github_connector, 0, end_time, checkpoint
|
||||
)
|
||||
|
||||
# Check that we got the remaining document after retrying from the beginning
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 1
|
||||
assert isinstance(outputs[1].items[0], Document)
|
||||
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/4"
|
||||
|
||||
# Verify cursor URL was cleared in checkpoint
|
||||
assert outputs[1].next_checkpoint.cursor_url is None
|
||||
assert outputs[1].next_checkpoint.num_retrieved == 0
|
||||
|
||||
# Verify that the slice was called with the correct argument
|
||||
mock_retry_paginated_list.__getitem__.assert_called_once_with(slice(3, None))
|
||||
|
Loading…
x
Reference in New Issue
Block a user