Checkpointed GitHub connector (#4307)

* WIP github checkpointing

* first draft of github checkpointing

* nit

* CW comments

* github basic connector test

* connector test env var

* secrets cant start with GITHUB_

* unit tests and bug fix

* connector failures

* address CW comments

* validation fix

* validation fix

* remove prints

* fixed tests

* 100 items per page
This commit is contained in:
evan-danswer
2025-03-21 18:48:05 -07:00
committed by GitHub
parent 587ba11bbc
commit fb79a9e700
7 changed files with 730 additions and 96 deletions

View File

@@ -45,6 +45,8 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}

View File

@@ -435,7 +435,7 @@ def _run_indexing(
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint

View File

@@ -157,10 +157,7 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
try:
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
except ValueError:
INDEX_BATCH_SIZE = 16
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))

View File

@@ -1,8 +1,10 @@
import copy
import time
from collections.abc import Iterator
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
@@ -13,26 +15,30 @@ from github.GithubException import GithubException
from github.Issue import Issue
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from github.Requester import Requester
from pydantic import BaseModel
from typing_extensions import override
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import TextSection
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
ITEMS_PER_PAGE = 100
_MAX_NUM_RATE_LIMIT_RETRIES = 5
@@ -48,7 +54,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
def _get_batch_rate_limited(
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
) -> list[Any]:
) -> list[PullRequest | Issue]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
@@ -69,21 +75,6 @@ def _get_batch_rate_limited(
)
def _batch_github_objects(
git_objs: PaginatedList, github_client: Github, batch_size: int
) -> Iterator[list[Any]]:
page_num = 0
while True:
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
page_num += 1
if not batch:
break
for mini_batch in batch_generator(batch, batch_size=batch_size):
yield mini_batch
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
@@ -95,7 +86,9 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None,
metadata={
"merged": str(pull_request.merged),
"state": pull_request.state,
@@ -122,31 +115,58 @@ def _convert_issue_to_document(issue: Issue) -> Document:
)
class GithubConnector(LoadConnector, PollConnector):
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
ISSUES = "issues"
class GithubConnectorCheckpoint(ConnectorCheckpoint):
stage: GithubConnectorStage
curr_page: int
cached_repo_ids: list[int] | None = None
cached_repo: SerializedRepository | None = None
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
self.include_issues = include_issues
self.github_client: Github | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# defaults to 30 items per page, can be set to as high as 100
self.github_client = (
Github(
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
credentials["github_access_token"],
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
)
if GITHUB_CONNECTOR_BASE_URL
else Github(credentials["github_access_token"])
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
)
return None
@@ -217,85 +237,193 @@ class GithubConnector(LoadConnector, PollConnector):
return self._get_all_repos(github_client, attempt_num + 1)
def _fetch_from_github(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
self,
checkpoint: GithubConnectorCheckpoint,
start: datetime | None = None,
end: datetime | None = None,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
checkpoint = copy.deepcopy(checkpoint)
# First run of the connector, fetch all repos and store in checkpoint
if checkpoint.cached_repo_ids is None:
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
# All repositories
repos = self._get_all_repos(self.github_client)
if not repos:
checkpoint.has_more = False
return checkpoint
for repo in repos:
if self.include_prs:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
checkpoint.cached_repo = SerializedRepository(
id=checkpoint.cached_repo_ids[0],
headers=repos[0].raw_headers,
raw_data=repos[0].raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
# save checkpoint with repo ids retrieved
return checkpoint
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch: list[Document] = []
pr_batch = _get_batch_rate_limited(
pull_requests, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_prs = False
for pr in pr_batch:
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
break
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
if self.include_issues:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
yield from doc_batch
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
break
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
continue
try:
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_github()
# if we found any PRs on the page, yield any associated documents and return the checkpoint
if not done_with_prs and len(pr_batch) > 0:
yield from doc_batch
return checkpoint
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.curr_page = 0
checkpoint.stage = GithubConnectorStage.ISSUES
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch = []
issue_batch = _get_batch_rate_limited(
issues, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_issues = False
for issue in cast(list[Issue], issue_batch):
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
doc_batch.append(_convert_issue_to_document(issue))
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
# if we found any issues on the page, yield them and return the checkpoint
if not done_with_issues and len(issue_batch) > 0:
yield from doc_batch
return checkpoint
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
if checkpoint.cached_repo_ids:
next_id = checkpoint.cached_repo_ids.pop()
next_repo = self.github_client.get_repo(next_id)
checkpoint.cached_repo = SerializedRepository(
id=next_id,
headers=next_repo.raw_headers,
raw_data=next_repo.raw_data,
)
return checkpoint
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
# Could be due to delayed processing on GitHub side
# The non-updated issues since last poll will be shortcut-ed and not embedded
adjusted_start_datetime = start_datetime - timedelta(hours=3)
epoch = datetime.utcfromtimestamp(0)
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
if adjusted_start_datetime < epoch:
adjusted_start_datetime = epoch
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
return self._fetch_from_github(
checkpoint, start=adjusted_start_datetime, end=end_datetime
)
def validate_connector_settings(self) -> None:
if self.github_client is None:
@@ -397,6 +525,16 @@ class GithubConnector(LoadConnector, PollConnector):
f"Unexpected error during GitHub settings validation: {exc}"
)
def validate_checkpoint_json(
self, checkpoint_json: str
) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint(
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
)
if __name__ == "__main__":
import os
@@ -406,7 +544,9 @@ if __name__ == "__main__":
repositories=os.environ["REPOSITORIES"],
)
connector.load_credentials(
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
)
document_batches = connector.load_from_checkpoint(
0, time.time(), connector.build_dummy_checkpoint()
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -56,7 +56,7 @@ puremagic==1.28
pyairtable==3.0.1
pycryptodome==3.19.1
pydantic==2.8.2
PyGithub==1.58.2
PyGithub==2.5.0
python-dateutil==2.8.2
python-gitlab==3.9.0
python-pptx==0.6.23

View File

@@ -0,0 +1,54 @@
import os
import time
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.github.connector import GithubConnector
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
@pytest.fixture
def github_connector() -> GithubConnector:
connector = GithubConnector(
repo_owner="onyx-dot-app",
repositories="documentation",
include_prs=True,
include_issues=True,
)
connector.load_credentials(
{
"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"],
}
)
return connector
def test_github_connector_basic(github_connector: GithubConnector) -> None:
docs = load_all_docs_from_checkpoint_connector(
connector=github_connector,
start=0,
end=time.time(),
)
assert len(docs) > 0 # We expect at least one PR to exist
# Test the first document's structure
doc = docs[0]
# Verify basic document properties
assert doc.source == DocumentSource.GITHUB
assert doc.secondary_owners is None
assert doc.from_ingestion_api is False
assert doc.additional_info is None
# Verify GitHub-specific properties
assert "github.com" in doc.id # Should be a GitHub URL
assert doc.metadata is not None
assert "state" in doc.metadata
assert "merged" in doc.metadata
# Verify sections
assert len(doc.sections) == 1
section = doc.sections[0]
assert section.link == doc.id # Section link should match document ID
assert isinstance(section.text, str) # Should have some text content

View File

@@ -0,0 +1,441 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from github import Github
from github import GithubException
from github import RateLimitExceededException
from github.Issue import Issue
from github.PullRequest import PullRequest
from github.RateLimit import RateLimit
from github.Repository import Repository
from github.Requester import Requester
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.github.connector import SerializedRepository
from onyx.connectors.models import Document
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
@pytest.fixture
def repo_owner() -> str:
return "test-org"
@pytest.fixture
def repositories() -> str:
return "test-repo"
@pytest.fixture
def mock_github_client() -> MagicMock:
"""Create a mock GitHub client with proper typing"""
mock = MagicMock(spec=Github)
# Add proper return typing for get_repo method
mock.get_repo = MagicMock(return_value=MagicMock(spec=Repository))
# Add proper return typing for get_organization method
mock.get_organization = MagicMock()
# Add proper return typing for get_user method
mock.get_user = MagicMock()
# Add proper return typing for get_rate_limit method
mock.get_rate_limit = MagicMock(return_value=MagicMock(spec=RateLimit))
# Add requester for repository deserialization
mock.requester = MagicMock(spec=Requester)
return mock
@pytest.fixture
def github_connector(
repo_owner: str, repositories: str, mock_github_client: MagicMock
) -> Generator[GithubConnector, None, None]:
connector = GithubConnector(
repo_owner=repo_owner,
repositories=repositories,
include_prs=True,
include_issues=True,
)
connector.github_client = mock_github_client
yield connector
@pytest.fixture
def create_mock_pr() -> Callable[..., MagicMock]:
def _create_mock_pr(
number: int = 1,
title: str = "Test PR",
body: str = "Test Description",
state: str = "open",
merged: bool = False,
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
) -> MagicMock:
"""Helper to create a mock PullRequest object"""
mock_pr = MagicMock(spec=PullRequest)
mock_pr.number = number
mock_pr.title = title
mock_pr.body = body
mock_pr.state = state
mock_pr.merged = merged
mock_pr.updated_at = updated_at
mock_pr.html_url = f"https://github.com/test-org/test-repo/pull/{number}"
return mock_pr
return _create_mock_pr
@pytest.fixture
def create_mock_issue() -> Callable[..., MagicMock]:
def _create_mock_issue(
number: int = 1,
title: str = "Test Issue",
body: str = "Test Description",
state: str = "open",
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
) -> MagicMock:
"""Helper to create a mock Issue object"""
mock_issue = MagicMock(spec=Issue)
mock_issue.number = number
mock_issue.title = title
mock_issue.body = body
mock_issue.state = state
mock_issue.updated_at = updated_at
mock_issue.html_url = f"https://github.com/test-org/test-repo/issues/{number}"
mock_issue.pull_request = None # Not a PR
return mock_issue
return _create_mock_issue
@pytest.fixture
def create_mock_repo() -> Callable[..., MagicMock]:
def _create_mock_repo(
name: str = "test-repo",
id: int = 1,
) -> MagicMock:
"""Helper to create a mock Repository object"""
mock_repo = MagicMock(spec=Repository)
mock_repo.name = name
mock_repo.id = id
mock_repo.raw_headers = {"status": "200 OK", "content-type": "application/json"}
mock_repo.raw_data = {
"id": str(id),
"name": name,
"full_name": f"test-org/{name}",
"private": str(False),
"description": "Test repository",
}
return mock_repo
return _create_mock_repo
def test_load_from_checkpoint_happy_path(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
create_mock_issue: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint - happy path"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked PRs and issues
mock_pr1 = create_mock_pr(number=1, title="PR 1")
mock_pr2 = create_mock_pr(number=2, title="PR 2")
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
# Mock get_pulls and get_issues methods
mock_repo.get_pulls.return_value = MagicMock()
mock_repo.get_pulls.return_value.get_page.side_effect = [
[mock_pr1, mock_pr2],
[],
]
mock_repo.get_issues.return_value = MagicMock()
mock_repo.get_issues.return_value.get_page.side_effect = [
[mock_issue1, mock_issue2],
[],
]
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
# Check that we got all documents and final has_more=False
assert len(outputs) == 4
repo_batch = outputs[0]
assert len(repo_batch.items) == 0
assert repo_batch.next_checkpoint.has_more is True
# Check first batch (PRs)
first_batch = outputs[1]
assert len(first_batch.items) == 2
assert isinstance(first_batch.items[0], Document)
assert first_batch.items[0].id == "https://github.com/test-org/test-repo/pull/1"
assert isinstance(first_batch.items[1], Document)
assert first_batch.items[1].id == "https://github.com/test-org/test-repo/pull/2"
assert first_batch.next_checkpoint.curr_page == 1
# Check second batch (Issues)
second_batch = outputs[2]
assert len(second_batch.items) == 2
assert isinstance(second_batch.items[0], Document)
assert (
second_batch.items[0].id == "https://github.com/test-org/test-repo/issues/1"
)
assert isinstance(second_batch.items[1], Document)
assert (
second_batch.items[1].id == "https://github.com/test-org/test-repo/issues/2"
)
assert second_batch.next_checkpoint.has_more
# Check third batch (finished checkpoint)
third_batch = outputs[3]
assert len(third_batch.items) == 0
assert third_batch.next_checkpoint.has_more is False
def test_load_from_checkpoint_with_rate_limit(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint with rate limit handling"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked PR
mock_pr = create_mock_pr()
# Mock get_pulls to raise RateLimitExceededException on first call
mock_repo.get_pulls.return_value = MagicMock()
mock_repo.get_pulls.return_value.get_page.side_effect = [
RateLimitExceededException(403, {"message": "Rate limit exceeded"}, {}),
[mock_pr],
[],
]
# Mock rate limit reset time
mock_rate_limit = MagicMock(spec=RateLimit)
mock_rate_limit.core.reset = datetime.now(timezone.utc)
github_connector.github_client.get_rate_limit.return_value = mock_rate_limit
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
with patch(
"onyx.connectors.github.connector._sleep_after_rate_limit_exception"
) as mock_sleep:
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
assert mock_sleep.call_count == 1
# Check that we got the document after rate limit was handled
assert len(outputs) >= 2
assert len(outputs[1].items) == 1
assert isinstance(outputs[1].items[0], Document)
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
assert outputs[-1].next_checkpoint.has_more is False
def test_load_from_checkpoint_with_empty_repo(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint with an empty repository"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Mock get_pulls and get_issues to return empty lists
mock_repo.get_pulls.return_value = MagicMock()
mock_repo.get_pulls.return_value.get_page.return_value = []
mock_repo.get_issues.return_value = MagicMock()
mock_repo.get_issues.return_value.get_page.return_value = []
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
# Check that we got no documents
assert len(outputs) == 2
assert len(outputs[-1].items) == 0
assert not outputs[-1].next_checkpoint.has_more
def test_load_from_checkpoint_with_prs_only(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint with only PRs enabled"""
# Configure connector to only include PRs
github_connector.include_prs = True
github_connector.include_issues = False
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked PRs
mock_pr1 = create_mock_pr(number=1, title="PR 1")
mock_pr2 = create_mock_pr(number=2, title="PR 2")
# Mock get_pulls method
mock_repo.get_pulls.return_value = MagicMock()
mock_repo.get_pulls.return_value.get_page.side_effect = [
[mock_pr1, mock_pr2],
[],
]
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
# Check that we only got PRs
assert len(outputs) >= 2
assert len(outputs[1].items) == 2
assert all(
isinstance(doc, Document) and "pull" in doc.id for doc in outputs[0].items
) # All documents should be PRs
assert outputs[-1].next_checkpoint.has_more is False
def test_load_from_checkpoint_with_issues_only(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_issue: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint with only issues enabled"""
# Configure connector to only include issues
github_connector.include_prs = False
github_connector.include_issues = True
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked issues
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
# Mock get_issues method
mock_repo.get_issues.return_value = MagicMock()
mock_repo.get_issues.return_value.get_page.side_effect = [
[mock_issue1, mock_issue2],
[],
]
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
# Check that we only got issues
assert len(outputs) >= 2
assert len(outputs[1].items) == 2
assert all(
isinstance(doc, Document) and "issues" in doc.id for doc in outputs[0].items
) # All documents should be issues
assert outputs[1].next_checkpoint.has_more
assert outputs[-1].next_checkpoint.has_more is False
@pytest.mark.parametrize(
"status_code,expected_exception,expected_message",
[
(
401,
CredentialExpiredError,
"GitHub credential appears to be invalid or expired",
),
(
403,
InsufficientPermissionsError,
"Your GitHub token does not have sufficient permissions",
),
(
404,
ConnectorValidationError,
"GitHub repository not found",
),
],
)
def test_validate_connector_settings_errors(
github_connector: GithubConnector,
status_code: int,
expected_exception: type[Exception],
expected_message: str,
) -> None:
"""Test validation with various error scenarios"""
error = GithubException(status=status_code, data={}, headers={})
github_client = cast(Github, github_connector.github_client)
get_repo_mock = cast(MagicMock, github_client.get_repo)
get_repo_mock.side_effect = error
with pytest.raises(expected_exception) as excinfo:
github_connector.validate_connector_settings()
assert expected_message in str(excinfo.value)
def test_validate_connector_settings_success(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
) -> None:
"""Test successful validation"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Mock get_contents to simulate successful access
mock_repo.get_contents.return_value = MagicMock()
github_connector.validate_connector_settings()
github_connector.github_client.get_repo.assert_called_once_with(
f"{github_connector.repo_owner}/{github_connector.repositories}"
)