From fb79a9e700c5d7ab48dccd21b34af920b2e29da6 Mon Sep 17 00:00:00 2001 From: evan-danswer Date: Fri, 21 Mar 2025 18:48:05 -0700 Subject: [PATCH] Checkpointed GitHub connector (#4307) * WIP github checkpointing * first draft of github checkpointing * nit * CW comments * github basic connector test * connector test env var * secrets cant start with GITHUB_ * unit tests and bug fix * connector failures * address CW comments * validation fix * validation fix * remove prints * fixed tests * 100 items per page --- .../workflows/pr-python-connector-tests.yml | 2 + .../onyx/background/indexing/run_indexing.py | 2 +- backend/onyx/configs/app_configs.py | 5 +- backend/onyx/connectors/github/connector.py | 320 +++++++++---- backend/requirements/default.txt | 2 +- .../connectors/github/test_github_basic.py | 54 +++ .../github/test_github_checkpointing.py | 441 ++++++++++++++++++ 7 files changed, 730 insertions(+), 96 deletions(-) create mode 100644 backend/tests/daily/connectors/github/test_github_basic.py create mode 100644 backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml index aa740aa8d9a..8c419b21c07 100644 --- a/.github/workflows/pr-python-connector-tests.yml +++ b/.github/workflows/pr-python-connector-tests.yml @@ -45,6 +45,8 @@ env: SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }} SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }} SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }} + # Github + ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }} # Gitbook GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }} GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }} diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index 7802e3d20c9..815ca811544 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -435,7 +435,7 @@ def _run_indexing( while checkpoint.has_more: logger.info( - f"Running '{ctx.source}' connector with checkpoint: {checkpoint}" + f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}" ) for document_batch, failure, next_checkpoint in connector_runner.run( checkpoint diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 14da1066448..7594f4996d9 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -157,10 +157,7 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH") VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH") # Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder) -try: - INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16)) -except ValueError: - INDEX_BATCH_SIZE = 16 +INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16) MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4)) diff --git a/backend/onyx/connectors/github/connector.py b/backend/onyx/connectors/github/connector.py index 06461351769..757d3d04a0c 100644 --- a/backend/onyx/connectors/github/connector.py +++ b/backend/onyx/connectors/github/connector.py @@ -1,8 +1,10 @@ +import copy import time -from collections.abc import Iterator +from collections.abc import Generator from datetime import datetime from datetime import timedelta from datetime import timezone +from enum import Enum from typing import Any from typing import cast @@ -13,26 +15,30 @@ from github.GithubException import GithubException from github.Issue import Issue from github.PaginatedList import PaginatedList from github.PullRequest import PullRequest +from github.Requester import Requester +from pydantic import BaseModel +from typing_extensions import override from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL -from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError -from onyx.connectors.interfaces import GenerateDocumentsOutput -from onyx.connectors.interfaces import LoadConnector -from onyx.connectors.interfaces import PollConnector +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CheckpointOutput +from onyx.connectors.interfaces import ConnectorCheckpoint +from onyx.connectors.interfaces import ConnectorFailure from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document +from onyx.connectors.models import DocumentFailure from onyx.connectors.models import TextSection -from onyx.utils.batching import batch_generator from onyx.utils.logger import setup_logger logger = setup_logger() +ITEMS_PER_PAGE = 100 _MAX_NUM_RATE_LIMIT_RETRIES = 5 @@ -48,7 +54,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None: def _get_batch_rate_limited( git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0 -) -> list[Any]: +) -> list[PullRequest | Issue]: if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: raise RuntimeError( "Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github" @@ -69,21 +75,6 @@ def _get_batch_rate_limited( ) -def _batch_github_objects( - git_objs: PaginatedList, github_client: Github, batch_size: int -) -> Iterator[list[Any]]: - page_num = 0 - while True: - batch = _get_batch_rate_limited(git_objs, page_num, github_client) - page_num += 1 - - if not batch: - break - - for mini_batch in batch_generator(batch, batch_size=batch_size): - yield mini_batch - - def _convert_pr_to_document(pull_request: PullRequest) -> Document: return Document( id=pull_request.html_url, @@ -95,7 +86,9 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document: # updated_at is UTC time but is timezone unaware, explicitly add UTC # as there is logic in indexing to prevent wrong timestamped docs # due to local time discrepancies with UTC - doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc), + doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc) + if pull_request.updated_at + else None, metadata={ "merged": str(pull_request.merged), "state": pull_request.state, @@ -122,31 +115,58 @@ def _convert_issue_to_document(issue: Issue) -> Document: ) -class GithubConnector(LoadConnector, PollConnector): +class SerializedRepository(BaseModel): + # id is part of the raw_data as well, just pulled out for convenience + id: int + headers: dict[str, str | int] + raw_data: dict[str, Any] + + def to_Repository(self, requester: Requester) -> Repository.Repository: + return Repository.Repository( + requester, self.headers, self.raw_data, completed=True + ) + + +class GithubConnectorStage(Enum): + START = "start" + PRS = "prs" + ISSUES = "issues" + + +class GithubConnectorCheckpoint(ConnectorCheckpoint): + stage: GithubConnectorStage + curr_page: int + + cached_repo_ids: list[int] | None = None + cached_repo: SerializedRepository | None = None + + +class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]): def __init__( self, repo_owner: str, repositories: str | None = None, - batch_size: int = INDEX_BATCH_SIZE, state_filter: str = "all", include_prs: bool = True, include_issues: bool = False, ) -> None: self.repo_owner = repo_owner self.repositories = repositories - self.batch_size = batch_size self.state_filter = state_filter self.include_prs = include_prs self.include_issues = include_issues self.github_client: Github | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + # defaults to 30 items per page, can be set to as high as 100 self.github_client = ( Github( - credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL + credentials["github_access_token"], + base_url=GITHUB_CONNECTOR_BASE_URL, + per_page=ITEMS_PER_PAGE, ) if GITHUB_CONNECTOR_BASE_URL - else Github(credentials["github_access_token"]) + else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE) ) return None @@ -217,85 +237,193 @@ class GithubConnector(LoadConnector, PollConnector): return self._get_all_repos(github_client, attempt_num + 1) def _fetch_from_github( - self, start: datetime | None = None, end: datetime | None = None - ) -> GenerateDocumentsOutput: + self, + checkpoint: GithubConnectorCheckpoint, + start: datetime | None = None, + end: datetime | None = None, + ) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]: if self.github_client is None: raise ConnectorMissingCredentialError("GitHub") - repos = [] - if self.repositories: - if "," in self.repositories: - # Multiple repositories specified - repos = self._get_github_repos(self.github_client) + checkpoint = copy.deepcopy(checkpoint) + + # First run of the connector, fetch all repos and store in checkpoint + if checkpoint.cached_repo_ids is None: + repos = [] + if self.repositories: + if "," in self.repositories: + # Multiple repositories specified + repos = self._get_github_repos(self.github_client) + else: + # Single repository (backward compatibility) + repos = [self._get_github_repo(self.github_client)] else: - # Single repository (backward compatibility) - repos = [self._get_github_repo(self.github_client)] - else: - # All repositories - repos = self._get_all_repos(self.github_client) + # All repositories + repos = self._get_all_repos(self.github_client) + if not repos: + checkpoint.has_more = False + return checkpoint - for repo in repos: - if self.include_prs: - logger.info(f"Fetching PRs for repo: {repo.name}") - pull_requests = repo.get_pulls( - state=self.state_filter, sort="updated", direction="desc" - ) + checkpoint.cached_repo_ids = sorted([repo.id for repo in repos]) + checkpoint.cached_repo = SerializedRepository( + id=checkpoint.cached_repo_ids[0], + headers=repos[0].raw_headers, + raw_data=repos[0].raw_data, + ) + checkpoint.stage = GithubConnectorStage.PRS + checkpoint.curr_page = 0 + # save checkpoint with repo ids retrieved + return checkpoint - for pr_batch in _batch_github_objects( - pull_requests, self.github_client, self.batch_size + assert checkpoint.cached_repo is not None, "No repo saved in checkpoint" + repo = checkpoint.cached_repo.to_Repository(self.github_client.requester) + + if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS: + logger.info(f"Fetching PRs for repo: {repo.name}") + pull_requests = repo.get_pulls( + state=self.state_filter, sort="updated", direction="desc" + ) + + doc_batch: list[Document] = [] + pr_batch = _get_batch_rate_limited( + pull_requests, checkpoint.curr_page, self.github_client + ) + checkpoint.curr_page += 1 + done_with_prs = False + for pr in pr_batch: + # we iterate backwards in time, so at this point we stop processing prs + if ( + start is not None + and pr.updated_at + and pr.updated_at.replace(tzinfo=timezone.utc) < start ): - doc_batch: list[Document] = [] - for pr in pr_batch: - if start is not None and pr.updated_at < start: - yield doc_batch - break - if end is not None and pr.updated_at > end: - continue - doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr))) - yield doc_batch - - if self.include_issues: - logger.info(f"Fetching issues for repo: {repo.name}") - issues = repo.get_issues( - state=self.state_filter, sort="updated", direction="desc" - ) - - for issue_batch in _batch_github_objects( - issues, self.github_client, self.batch_size + yield from doc_batch + done_with_prs = True + break + # Skip PRs updated after the end date + if ( + end is not None + and pr.updated_at + and pr.updated_at.replace(tzinfo=timezone.utc) > end ): - doc_batch = [] - for issue in issue_batch: - issue = cast(Issue, issue) - if start is not None and issue.updated_at < start: - yield doc_batch - break - if end is not None and issue.updated_at > end: - continue - if issue.pull_request is not None: - # PRs are handled separately - continue - doc_batch.append(_convert_issue_to_document(issue)) - yield doc_batch + continue + try: + doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr))) + except Exception as e: + error_msg = f"Error converting PR to document: {e}" + logger.exception(error_msg) + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=str(pr.id), document_link=pr.html_url + ), + failure_message=error_msg, + exception=e, + ) + continue - def load_from_state(self) -> GenerateDocumentsOutput: - return self._fetch_from_github() + # if we found any PRs on the page, yield any associated documents and return the checkpoint + if not done_with_prs and len(pr_batch) > 0: + yield from doc_batch + return checkpoint - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> GenerateDocumentsOutput: - start_datetime = datetime.utcfromtimestamp(start) - end_datetime = datetime.utcfromtimestamp(end) + # if we went past the start date during the loop or there are no more + # prs to get, we move on to issues + checkpoint.stage = GithubConnectorStage.ISSUES + checkpoint.curr_page = 0 + + checkpoint.stage = GithubConnectorStage.ISSUES + + if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES: + logger.info(f"Fetching issues for repo: {repo.name}") + issues = repo.get_issues( + state=self.state_filter, sort="updated", direction="desc" + ) + + doc_batch = [] + issue_batch = _get_batch_rate_limited( + issues, checkpoint.curr_page, self.github_client + ) + checkpoint.curr_page += 1 + done_with_issues = False + for issue in cast(list[Issue], issue_batch): + # we iterate backwards in time, so at this point we stop processing prs + if ( + start is not None + and issue.updated_at.replace(tzinfo=timezone.utc) < start + ): + yield from doc_batch + done_with_issues = True + break + # Skip PRs updated after the end date + if ( + end is not None + and issue.updated_at.replace(tzinfo=timezone.utc) > end + ): + continue + + if issue.pull_request is not None: + # PRs are handled separately + continue + + try: + doc_batch.append(_convert_issue_to_document(issue)) + except Exception as e: + error_msg = f"Error converting issue to document: {e}" + logger.exception(error_msg) + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=str(issue.id), + document_link=issue.html_url, + ), + failure_message=error_msg, + exception=e, + ) + continue + + # if we found any issues on the page, yield them and return the checkpoint + if not done_with_issues and len(issue_batch) > 0: + yield from doc_batch + return checkpoint + + # if we went past the start date during the loop or there are no more + # issues to get, we move on to the next repo + checkpoint.stage = GithubConnectorStage.PRS + checkpoint.curr_page = 0 + + checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1 + if checkpoint.cached_repo_ids: + next_id = checkpoint.cached_repo_ids.pop() + next_repo = self.github_client.get_repo(next_id) + checkpoint.cached_repo = SerializedRepository( + id=next_id, + headers=next_repo.raw_headers, + raw_data=next_repo.raw_data, + ) + + return checkpoint + + @override + def load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: GithubConnectorCheckpoint, + ) -> CheckpointOutput[GithubConnectorCheckpoint]: + start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) + end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) # Move start time back by 3 hours, since some Issues/PRs are getting dropped # Could be due to delayed processing on GitHub side # The non-updated issues since last poll will be shortcut-ed and not embedded adjusted_start_datetime = start_datetime - timedelta(hours=3) - epoch = datetime.utcfromtimestamp(0) + epoch = datetime.fromtimestamp(0, tz=timezone.utc) if adjusted_start_datetime < epoch: adjusted_start_datetime = epoch - return self._fetch_from_github(adjusted_start_datetime, end_datetime) + return self._fetch_from_github( + checkpoint, start=adjusted_start_datetime, end=end_datetime + ) def validate_connector_settings(self) -> None: if self.github_client is None: @@ -397,6 +525,16 @@ class GithubConnector(LoadConnector, PollConnector): f"Unexpected error during GitHub settings validation: {exc}" ) + def validate_checkpoint_json( + self, checkpoint_json: str + ) -> GithubConnectorCheckpoint: + return GithubConnectorCheckpoint.model_validate_json(checkpoint_json) + + def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint: + return GithubConnectorCheckpoint( + stage=GithubConnectorStage.PRS, curr_page=0, has_more=True + ) + if __name__ == "__main__": import os @@ -406,7 +544,9 @@ if __name__ == "__main__": repositories=os.environ["REPOSITORIES"], ) connector.load_credentials( - {"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]} + {"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]} + ) + document_batches = connector.load_from_checkpoint( + 0, time.time(), connector.build_dummy_checkpoint() ) - document_batches = connector.load_from_state() print(next(document_batches)) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index b014d340f7a..98cc9c33f30 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -56,7 +56,7 @@ puremagic==1.28 pyairtable==3.0.1 pycryptodome==3.19.1 pydantic==2.8.2 -PyGithub==1.58.2 +PyGithub==2.5.0 python-dateutil==2.8.2 python-gitlab==3.9.0 python-pptx==0.6.23 diff --git a/backend/tests/daily/connectors/github/test_github_basic.py b/backend/tests/daily/connectors/github/test_github_basic.py new file mode 100644 index 00000000000..235352cc5ef --- /dev/null +++ b/backend/tests/daily/connectors/github/test_github_basic.py @@ -0,0 +1,54 @@ +import os +import time + +import pytest + +from onyx.configs.constants import DocumentSource +from onyx.connectors.github.connector import GithubConnector +from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector + + +@pytest.fixture +def github_connector() -> GithubConnector: + connector = GithubConnector( + repo_owner="onyx-dot-app", + repositories="documentation", + include_prs=True, + include_issues=True, + ) + connector.load_credentials( + { + "github_access_token": os.environ["ACCESS_TOKEN_GITHUB"], + } + ) + return connector + + +def test_github_connector_basic(github_connector: GithubConnector) -> None: + docs = load_all_docs_from_checkpoint_connector( + connector=github_connector, + start=0, + end=time.time(), + ) + assert len(docs) > 0 # We expect at least one PR to exist + + # Test the first document's structure + doc = docs[0] + + # Verify basic document properties + assert doc.source == DocumentSource.GITHUB + assert doc.secondary_owners is None + assert doc.from_ingestion_api is False + assert doc.additional_info is None + + # Verify GitHub-specific properties + assert "github.com" in doc.id # Should be a GitHub URL + assert doc.metadata is not None + assert "state" in doc.metadata + assert "merged" in doc.metadata + + # Verify sections + assert len(doc.sections) == 1 + section = doc.sections[0] + assert section.link == doc.id # Section link should match document ID + assert isinstance(section.text, str) # Should have some text content diff --git a/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py b/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py new file mode 100644 index 00000000000..5c797a5dc33 --- /dev/null +++ b/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py @@ -0,0 +1,441 @@ +import time +from collections.abc import Callable +from collections.abc import Generator +from datetime import datetime +from datetime import timezone +from typing import cast +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from github import Github +from github import GithubException +from github import RateLimitExceededException +from github.Issue import Issue +from github.PullRequest import PullRequest +from github.RateLimit import RateLimit +from github.Repository import Repository +from github.Requester import Requester + +from onyx.connectors.exceptions import ConnectorValidationError +from onyx.connectors.exceptions import CredentialExpiredError +from onyx.connectors.exceptions import InsufficientPermissionsError +from onyx.connectors.github.connector import GithubConnector +from onyx.connectors.github.connector import SerializedRepository +from onyx.connectors.models import Document +from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector + + +@pytest.fixture +def repo_owner() -> str: + return "test-org" + + +@pytest.fixture +def repositories() -> str: + return "test-repo" + + +@pytest.fixture +def mock_github_client() -> MagicMock: + """Create a mock GitHub client with proper typing""" + mock = MagicMock(spec=Github) + # Add proper return typing for get_repo method + mock.get_repo = MagicMock(return_value=MagicMock(spec=Repository)) + # Add proper return typing for get_organization method + mock.get_organization = MagicMock() + # Add proper return typing for get_user method + mock.get_user = MagicMock() + # Add proper return typing for get_rate_limit method + mock.get_rate_limit = MagicMock(return_value=MagicMock(spec=RateLimit)) + # Add requester for repository deserialization + mock.requester = MagicMock(spec=Requester) + return mock + + +@pytest.fixture +def github_connector( + repo_owner: str, repositories: str, mock_github_client: MagicMock +) -> Generator[GithubConnector, None, None]: + connector = GithubConnector( + repo_owner=repo_owner, + repositories=repositories, + include_prs=True, + include_issues=True, + ) + connector.github_client = mock_github_client + yield connector + + +@pytest.fixture +def create_mock_pr() -> Callable[..., MagicMock]: + def _create_mock_pr( + number: int = 1, + title: str = "Test PR", + body: str = "Test Description", + state: str = "open", + merged: bool = False, + updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc), + ) -> MagicMock: + """Helper to create a mock PullRequest object""" + mock_pr = MagicMock(spec=PullRequest) + mock_pr.number = number + mock_pr.title = title + mock_pr.body = body + mock_pr.state = state + mock_pr.merged = merged + mock_pr.updated_at = updated_at + mock_pr.html_url = f"https://github.com/test-org/test-repo/pull/{number}" + return mock_pr + + return _create_mock_pr + + +@pytest.fixture +def create_mock_issue() -> Callable[..., MagicMock]: + def _create_mock_issue( + number: int = 1, + title: str = "Test Issue", + body: str = "Test Description", + state: str = "open", + updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc), + ) -> MagicMock: + """Helper to create a mock Issue object""" + mock_issue = MagicMock(spec=Issue) + mock_issue.number = number + mock_issue.title = title + mock_issue.body = body + mock_issue.state = state + mock_issue.updated_at = updated_at + mock_issue.html_url = f"https://github.com/test-org/test-repo/issues/{number}" + mock_issue.pull_request = None # Not a PR + return mock_issue + + return _create_mock_issue + + +@pytest.fixture +def create_mock_repo() -> Callable[..., MagicMock]: + def _create_mock_repo( + name: str = "test-repo", + id: int = 1, + ) -> MagicMock: + """Helper to create a mock Repository object""" + mock_repo = MagicMock(spec=Repository) + mock_repo.name = name + mock_repo.id = id + mock_repo.raw_headers = {"status": "200 OK", "content-type": "application/json"} + mock_repo.raw_data = { + "id": str(id), + "name": name, + "full_name": f"test-org/{name}", + "private": str(False), + "description": "Test repository", + } + return mock_repo + + return _create_mock_repo + + +def test_load_from_checkpoint_happy_path( + github_connector: GithubConnector, + mock_github_client: MagicMock, + create_mock_repo: Callable[..., MagicMock], + create_mock_pr: Callable[..., MagicMock], + create_mock_issue: Callable[..., MagicMock], +) -> None: + """Test loading from checkpoint - happy path""" + # Set up mocked repo + mock_repo = create_mock_repo() + github_connector.github_client = mock_github_client + mock_github_client.get_repo.return_value = mock_repo + + # Set up mocked PRs and issues + mock_pr1 = create_mock_pr(number=1, title="PR 1") + mock_pr2 = create_mock_pr(number=2, title="PR 2") + mock_issue1 = create_mock_issue(number=1, title="Issue 1") + mock_issue2 = create_mock_issue(number=2, title="Issue 2") + + # Mock get_pulls and get_issues methods + mock_repo.get_pulls.return_value = MagicMock() + mock_repo.get_pulls.return_value.get_page.side_effect = [ + [mock_pr1, mock_pr2], + [], + ] + mock_repo.get_issues.return_value = MagicMock() + mock_repo.get_issues.return_value.get_page.side_effect = [ + [mock_issue1, mock_issue2], + [], + ] + + # Mock SerializedRepository.to_Repository to return our mock repo + with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector( + github_connector, 0, end_time + ) + + # Check that we got all documents and final has_more=False + assert len(outputs) == 4 + + repo_batch = outputs[0] + assert len(repo_batch.items) == 0 + assert repo_batch.next_checkpoint.has_more is True + + # Check first batch (PRs) + first_batch = outputs[1] + assert len(first_batch.items) == 2 + assert isinstance(first_batch.items[0], Document) + assert first_batch.items[0].id == "https://github.com/test-org/test-repo/pull/1" + assert isinstance(first_batch.items[1], Document) + assert first_batch.items[1].id == "https://github.com/test-org/test-repo/pull/2" + assert first_batch.next_checkpoint.curr_page == 1 + + # Check second batch (Issues) + second_batch = outputs[2] + assert len(second_batch.items) == 2 + assert isinstance(second_batch.items[0], Document) + assert ( + second_batch.items[0].id == "https://github.com/test-org/test-repo/issues/1" + ) + assert isinstance(second_batch.items[1], Document) + assert ( + second_batch.items[1].id == "https://github.com/test-org/test-repo/issues/2" + ) + assert second_batch.next_checkpoint.has_more + + # Check third batch (finished checkpoint) + third_batch = outputs[3] + assert len(third_batch.items) == 0 + assert third_batch.next_checkpoint.has_more is False + + +def test_load_from_checkpoint_with_rate_limit( + github_connector: GithubConnector, + mock_github_client: MagicMock, + create_mock_repo: Callable[..., MagicMock], + create_mock_pr: Callable[..., MagicMock], +) -> None: + """Test loading from checkpoint with rate limit handling""" + # Set up mocked repo + mock_repo = create_mock_repo() + github_connector.github_client = mock_github_client + mock_github_client.get_repo.return_value = mock_repo + + # Set up mocked PR + mock_pr = create_mock_pr() + + # Mock get_pulls to raise RateLimitExceededException on first call + mock_repo.get_pulls.return_value = MagicMock() + mock_repo.get_pulls.return_value.get_page.side_effect = [ + RateLimitExceededException(403, {"message": "Rate limit exceeded"}, {}), + [mock_pr], + [], + ] + + # Mock rate limit reset time + mock_rate_limit = MagicMock(spec=RateLimit) + mock_rate_limit.core.reset = datetime.now(timezone.utc) + github_connector.github_client.get_rate_limit.return_value = mock_rate_limit + + # Mock SerializedRepository.to_Repository to return our mock repo + with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): + # Call load_from_checkpoint + end_time = time.time() + with patch( + "onyx.connectors.github.connector._sleep_after_rate_limit_exception" + ) as mock_sleep: + outputs = load_everything_from_checkpoint_connector( + github_connector, 0, end_time + ) + + assert mock_sleep.call_count == 1 + + # Check that we got the document after rate limit was handled + assert len(outputs) >= 2 + assert len(outputs[1].items) == 1 + assert isinstance(outputs[1].items[0], Document) + assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1" + + assert outputs[-1].next_checkpoint.has_more is False + + +def test_load_from_checkpoint_with_empty_repo( + github_connector: GithubConnector, + mock_github_client: MagicMock, + create_mock_repo: Callable[..., MagicMock], +) -> None: + """Test loading from checkpoint with an empty repository""" + # Set up mocked repo + mock_repo = create_mock_repo() + github_connector.github_client = mock_github_client + mock_github_client.get_repo.return_value = mock_repo + + # Mock get_pulls and get_issues to return empty lists + mock_repo.get_pulls.return_value = MagicMock() + mock_repo.get_pulls.return_value.get_page.return_value = [] + mock_repo.get_issues.return_value = MagicMock() + mock_repo.get_issues.return_value.get_page.return_value = [] + + # Mock SerializedRepository.to_Repository to return our mock repo + with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector( + github_connector, 0, end_time + ) + + # Check that we got no documents + assert len(outputs) == 2 + assert len(outputs[-1].items) == 0 + assert not outputs[-1].next_checkpoint.has_more + + +def test_load_from_checkpoint_with_prs_only( + github_connector: GithubConnector, + mock_github_client: MagicMock, + create_mock_repo: Callable[..., MagicMock], + create_mock_pr: Callable[..., MagicMock], +) -> None: + """Test loading from checkpoint with only PRs enabled""" + # Configure connector to only include PRs + github_connector.include_prs = True + github_connector.include_issues = False + + # Set up mocked repo + mock_repo = create_mock_repo() + github_connector.github_client = mock_github_client + mock_github_client.get_repo.return_value = mock_repo + + # Set up mocked PRs + mock_pr1 = create_mock_pr(number=1, title="PR 1") + mock_pr2 = create_mock_pr(number=2, title="PR 2") + + # Mock get_pulls method + mock_repo.get_pulls.return_value = MagicMock() + mock_repo.get_pulls.return_value.get_page.side_effect = [ + [mock_pr1, mock_pr2], + [], + ] + + # Mock SerializedRepository.to_Repository to return our mock repo + with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector( + github_connector, 0, end_time + ) + + # Check that we only got PRs + assert len(outputs) >= 2 + assert len(outputs[1].items) == 2 + assert all( + isinstance(doc, Document) and "pull" in doc.id for doc in outputs[0].items + ) # All documents should be PRs + + assert outputs[-1].next_checkpoint.has_more is False + + +def test_load_from_checkpoint_with_issues_only( + github_connector: GithubConnector, + mock_github_client: MagicMock, + create_mock_repo: Callable[..., MagicMock], + create_mock_issue: Callable[..., MagicMock], +) -> None: + """Test loading from checkpoint with only issues enabled""" + # Configure connector to only include issues + github_connector.include_prs = False + github_connector.include_issues = True + + # Set up mocked repo + mock_repo = create_mock_repo() + github_connector.github_client = mock_github_client + mock_github_client.get_repo.return_value = mock_repo + + # Set up mocked issues + mock_issue1 = create_mock_issue(number=1, title="Issue 1") + mock_issue2 = create_mock_issue(number=2, title="Issue 2") + + # Mock get_issues method + mock_repo.get_issues.return_value = MagicMock() + mock_repo.get_issues.return_value.get_page.side_effect = [ + [mock_issue1, mock_issue2], + [], + ] + + # Mock SerializedRepository.to_Repository to return our mock repo + with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector( + github_connector, 0, end_time + ) + + # Check that we only got issues + assert len(outputs) >= 2 + assert len(outputs[1].items) == 2 + assert all( + isinstance(doc, Document) and "issues" in doc.id for doc in outputs[0].items + ) # All documents should be issues + assert outputs[1].next_checkpoint.has_more + + assert outputs[-1].next_checkpoint.has_more is False + + +@pytest.mark.parametrize( + "status_code,expected_exception,expected_message", + [ + ( + 401, + CredentialExpiredError, + "GitHub credential appears to be invalid or expired", + ), + ( + 403, + InsufficientPermissionsError, + "Your GitHub token does not have sufficient permissions", + ), + ( + 404, + ConnectorValidationError, + "GitHub repository not found", + ), + ], +) +def test_validate_connector_settings_errors( + github_connector: GithubConnector, + status_code: int, + expected_exception: type[Exception], + expected_message: str, +) -> None: + """Test validation with various error scenarios""" + error = GithubException(status=status_code, data={}, headers={}) + + github_client = cast(Github, github_connector.github_client) + get_repo_mock = cast(MagicMock, github_client.get_repo) + get_repo_mock.side_effect = error + + with pytest.raises(expected_exception) as excinfo: + github_connector.validate_connector_settings() + assert expected_message in str(excinfo.value) + + +def test_validate_connector_settings_success( + github_connector: GithubConnector, + mock_github_client: MagicMock, + create_mock_repo: Callable[..., MagicMock], +) -> None: + """Test successful validation""" + # Set up mocked repo + mock_repo = create_mock_repo() + github_connector.github_client = mock_github_client + mock_github_client.get_repo.return_value = mock_repo + + # Mock get_contents to simulate successful access + mock_repo.get_contents.return_value = MagicMock() + + github_connector.validate_connector_settings() + github_connector.github_client.get_repo.assert_called_once_with( + f"{github_connector.repo_owner}/{github_connector.repositories}" + )