mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-23 19:35:54 +02:00
Checkpointed GitHub connector (#4307)
* WIP github checkpointing * first draft of github checkpointing * nit * CW comments * github basic connector test * connector test env var * secrets cant start with GITHUB_ * unit tests and bug fix * connector failures * address CW comments * validation fix * validation fix * remove prints * fixed tests * 100 items per page
This commit is contained in:
@@ -45,6 +45,8 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
|
@@ -435,7 +435,7 @@ def _run_indexing(
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
|
||||
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
|
@@ -157,10 +157,7 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
|
||||
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
|
||||
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
try:
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
|
||||
|
||||
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
|
||||
|
||||
|
@@ -1,8 +1,10 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -13,26 +15,30 @@ from github.GithubException import GithubException
|
||||
from github.Issue import Issue
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
from github.Requester import Requester
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ITEMS_PER_PAGE = 100
|
||||
|
||||
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
@@ -48,7 +54,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
|
||||
def _get_batch_rate_limited(
|
||||
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Any]:
|
||||
) -> list[PullRequest | Issue]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
||||
@@ -69,21 +75,6 @@ def _get_batch_rate_limited(
|
||||
)
|
||||
|
||||
|
||||
def _batch_github_objects(
|
||||
git_objs: PaginatedList, github_client: Github, batch_size: int
|
||||
) -> Iterator[list[Any]]:
|
||||
page_num = 0
|
||||
while True:
|
||||
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
|
||||
page_num += 1
|
||||
|
||||
if not batch:
|
||||
break
|
||||
|
||||
for mini_batch in batch_generator(batch, batch_size=batch_size):
|
||||
yield mini_batch
|
||||
|
||||
|
||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
@@ -95,7 +86,9 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
# as there is logic in indexing to prevent wrong timestamped docs
|
||||
# due to local time discrepancies with UTC
|
||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
|
||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.updated_at
|
||||
else None,
|
||||
metadata={
|
||||
"merged": str(pull_request.merged),
|
||||
"state": pull_request.state,
|
||||
@@ -122,31 +115,58 @@ def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
)
|
||||
|
||||
|
||||
class GithubConnector(LoadConnector, PollConnector):
|
||||
class SerializedRepository(BaseModel):
|
||||
# id is part of the raw_data as well, just pulled out for convenience
|
||||
id: int
|
||||
headers: dict[str, str | int]
|
||||
raw_data: dict[str, Any]
|
||||
|
||||
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||
return Repository.Repository(
|
||||
requester, self.headers, self.raw_data, completed=True
|
||||
)
|
||||
|
||||
|
||||
class GithubConnectorStage(Enum):
|
||||
START = "start"
|
||||
PRS = "prs"
|
||||
ISSUES = "issues"
|
||||
|
||||
|
||||
class GithubConnectorCheckpoint(ConnectorCheckpoint):
|
||||
stage: GithubConnectorStage
|
||||
curr_page: int
|
||||
|
||||
cached_repo_ids: list[int] | None = None
|
||||
cached_repo: SerializedRepository | None = None
|
||||
|
||||
|
||||
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repositories: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
include_issues: bool = False,
|
||||
) -> None:
|
||||
self.repo_owner = repo_owner
|
||||
self.repositories = repositories
|
||||
self.batch_size = batch_size
|
||||
self.state_filter = state_filter
|
||||
self.include_prs = include_prs
|
||||
self.include_issues = include_issues
|
||||
self.github_client: Github | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# defaults to 30 items per page, can be set to as high as 100
|
||||
self.github_client = (
|
||||
Github(
|
||||
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
|
||||
credentials["github_access_token"],
|
||||
base_url=GITHUB_CONNECTOR_BASE_URL,
|
||||
per_page=ITEMS_PER_PAGE,
|
||||
)
|
||||
if GITHUB_CONNECTOR_BASE_URL
|
||||
else Github(credentials["github_access_token"])
|
||||
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -217,85 +237,193 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
return self._get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _fetch_from_github(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
self,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
|
||||
# First run of the connector, fetch all repos and store in checkpoint
|
||||
if checkpoint.cached_repo_ids is None:
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
if not repos:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=checkpoint.cached_repo_ids[0],
|
||||
headers=repos[0].raw_headers,
|
||||
raw_data=repos[0].raw_data,
|
||||
)
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
# save checkpoint with repo ids retrieved
|
||||
return checkpoint
|
||||
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
|
||||
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
|
||||
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
pr_batch = _get_batch_rate_limited(
|
||||
pull_requests, checkpoint.curr_page, self.github_client
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
done_with_prs = False
|
||||
for pr in pr_batch:
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
|
||||
if self.include_issues:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
yield from doc_batch
|
||||
done_with_prs = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
continue
|
||||
try:
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(pr.id), document_link=pr.html_url
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_github()
|
||||
# if we found any PRs on the page, yield any associated documents and return the checkpoint
|
||||
if not done_with_prs and len(pr_batch) > 0:
|
||||
yield from doc_batch
|
||||
return checkpoint
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# prs to get, we move on to issues
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
checkpoint.curr_page = 0
|
||||
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
|
||||
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch = []
|
||||
issue_batch = _get_batch_rate_limited(
|
||||
issues, checkpoint.curr_page, self.github_client
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
done_with_issues = False
|
||||
for issue in cast(list[Issue], issue_batch):
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
yield from doc_batch
|
||||
done_with_issues = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
continue
|
||||
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
|
||||
try:
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting issue to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(issue.id),
|
||||
document_link=issue.html_url,
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
# if we found any issues on the page, yield them and return the checkpoint
|
||||
if not done_with_issues and len(issue_batch) > 0:
|
||||
yield from doc_batch
|
||||
return checkpoint
|
||||
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# issues to get, we move on to the next repo
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
|
||||
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
|
||||
if checkpoint.cached_repo_ids:
|
||||
next_id = checkpoint.cached_repo_ids.pop()
|
||||
next_repo = self.github_client.get_repo(next_id)
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=next_id,
|
||||
headers=next_repo.raw_headers,
|
||||
raw_data=next_repo.raw_data,
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
|
||||
# Could be due to delayed processing on GitHub side
|
||||
# The non-updated issues since last poll will be shortcut-ed and not embedded
|
||||
adjusted_start_datetime = start_datetime - timedelta(hours=3)
|
||||
|
||||
epoch = datetime.utcfromtimestamp(0)
|
||||
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
if adjusted_start_datetime < epoch:
|
||||
adjusted_start_datetime = epoch
|
||||
|
||||
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
|
||||
return self._fetch_from_github(
|
||||
checkpoint, start=adjusted_start_datetime, end=end_datetime
|
||||
)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
@@ -397,6 +525,16 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
)
|
||||
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint(
|
||||
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@@ -406,7 +544,9 @@ if __name__ == "__main__":
|
||||
repositories=os.environ["REPOSITORIES"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
|
||||
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
|
||||
)
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
0, time.time(), connector.build_dummy_checkpoint()
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
@@ -56,7 +56,7 @@ puremagic==1.28
|
||||
pyairtable==3.0.1
|
||||
pycryptodome==3.19.1
|
||||
pydantic==2.8.2
|
||||
PyGithub==1.58.2
|
||||
PyGithub==2.5.0
|
||||
python-dateutil==2.8.2
|
||||
python-gitlab==3.9.0
|
||||
python-pptx==0.6.23
|
||||
|
54
backend/tests/daily/connectors/github/test_github_basic.py
Normal file
54
backend/tests/daily/connectors/github/test_github_basic.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_connector() -> GithubConnector:
|
||||
connector = GithubConnector(
|
||||
repo_owner="onyx-dot-app",
|
||||
repositories="documentation",
|
||||
include_prs=True,
|
||||
include_issues=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def test_github_connector_basic(github_connector: GithubConnector) -> None:
|
||||
docs = load_all_docs_from_checkpoint_connector(
|
||||
connector=github_connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
assert len(docs) > 0 # We expect at least one PR to exist
|
||||
|
||||
# Test the first document's structure
|
||||
doc = docs[0]
|
||||
|
||||
# Verify basic document properties
|
||||
assert doc.source == DocumentSource.GITHUB
|
||||
assert doc.secondary_owners is None
|
||||
assert doc.from_ingestion_api is False
|
||||
assert doc.additional_info is None
|
||||
|
||||
# Verify GitHub-specific properties
|
||||
assert "github.com" in doc.id # Should be a GitHub URL
|
||||
assert doc.metadata is not None
|
||||
assert "state" in doc.metadata
|
||||
assert "merged" in doc.metadata
|
||||
|
||||
# Verify sections
|
||||
assert len(doc.sections) == 1
|
||||
section = doc.sections[0]
|
||||
assert section.link == doc.id # Section link should match document ID
|
||||
assert isinstance(section.text, str) # Should have some text content
|
@@ -0,0 +1,441 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from github import Github
|
||||
from github import GithubException
|
||||
from github import RateLimitExceededException
|
||||
from github.Issue import Issue
|
||||
from github.PullRequest import PullRequest
|
||||
from github.RateLimit import RateLimit
|
||||
from github.Repository import Repository
|
||||
from github.Requester import Requester
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.github.connector import SerializedRepository
|
||||
from onyx.connectors.models import Document
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo_owner() -> str:
|
||||
return "test-org"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repositories() -> str:
|
||||
return "test-repo"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_client() -> MagicMock:
|
||||
"""Create a mock GitHub client with proper typing"""
|
||||
mock = MagicMock(spec=Github)
|
||||
# Add proper return typing for get_repo method
|
||||
mock.get_repo = MagicMock(return_value=MagicMock(spec=Repository))
|
||||
# Add proper return typing for get_organization method
|
||||
mock.get_organization = MagicMock()
|
||||
# Add proper return typing for get_user method
|
||||
mock.get_user = MagicMock()
|
||||
# Add proper return typing for get_rate_limit method
|
||||
mock.get_rate_limit = MagicMock(return_value=MagicMock(spec=RateLimit))
|
||||
# Add requester for repository deserialization
|
||||
mock.requester = MagicMock(spec=Requester)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_connector(
|
||||
repo_owner: str, repositories: str, mock_github_client: MagicMock
|
||||
) -> Generator[GithubConnector, None, None]:
|
||||
connector = GithubConnector(
|
||||
repo_owner=repo_owner,
|
||||
repositories=repositories,
|
||||
include_prs=True,
|
||||
include_issues=True,
|
||||
)
|
||||
connector.github_client = mock_github_client
|
||||
yield connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_pr() -> Callable[..., MagicMock]:
|
||||
def _create_mock_pr(
|
||||
number: int = 1,
|
||||
title: str = "Test PR",
|
||||
body: str = "Test Description",
|
||||
state: str = "open",
|
||||
merged: bool = False,
|
||||
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock PullRequest object"""
|
||||
mock_pr = MagicMock(spec=PullRequest)
|
||||
mock_pr.number = number
|
||||
mock_pr.title = title
|
||||
mock_pr.body = body
|
||||
mock_pr.state = state
|
||||
mock_pr.merged = merged
|
||||
mock_pr.updated_at = updated_at
|
||||
mock_pr.html_url = f"https://github.com/test-org/test-repo/pull/{number}"
|
||||
return mock_pr
|
||||
|
||||
return _create_mock_pr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_issue() -> Callable[..., MagicMock]:
|
||||
def _create_mock_issue(
|
||||
number: int = 1,
|
||||
title: str = "Test Issue",
|
||||
body: str = "Test Description",
|
||||
state: str = "open",
|
||||
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock Issue object"""
|
||||
mock_issue = MagicMock(spec=Issue)
|
||||
mock_issue.number = number
|
||||
mock_issue.title = title
|
||||
mock_issue.body = body
|
||||
mock_issue.state = state
|
||||
mock_issue.updated_at = updated_at
|
||||
mock_issue.html_url = f"https://github.com/test-org/test-repo/issues/{number}"
|
||||
mock_issue.pull_request = None # Not a PR
|
||||
return mock_issue
|
||||
|
||||
return _create_mock_issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_repo() -> Callable[..., MagicMock]:
|
||||
def _create_mock_repo(
|
||||
name: str = "test-repo",
|
||||
id: int = 1,
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock Repository object"""
|
||||
mock_repo = MagicMock(spec=Repository)
|
||||
mock_repo.name = name
|
||||
mock_repo.id = id
|
||||
mock_repo.raw_headers = {"status": "200 OK", "content-type": "application/json"}
|
||||
mock_repo.raw_data = {
|
||||
"id": str(id),
|
||||
"name": name,
|
||||
"full_name": f"test-org/{name}",
|
||||
"private": str(False),
|
||||
"description": "Test repository",
|
||||
}
|
||||
return mock_repo
|
||||
|
||||
return _create_mock_repo
|
||||
|
||||
|
||||
def test_load_from_checkpoint_happy_path(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
create_mock_issue: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint - happy path"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PRs and issues
|
||||
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
|
||||
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
|
||||
|
||||
# Mock get_pulls and get_issues methods
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||
[mock_pr1, mock_pr2],
|
||||
[],
|
||||
]
|
||||
mock_repo.get_issues.return_value = MagicMock()
|
||||
mock_repo.get_issues.return_value.get_page.side_effect = [
|
||||
[mock_issue1, mock_issue2],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we got all documents and final has_more=False
|
||||
assert len(outputs) == 4
|
||||
|
||||
repo_batch = outputs[0]
|
||||
assert len(repo_batch.items) == 0
|
||||
assert repo_batch.next_checkpoint.has_more is True
|
||||
|
||||
# Check first batch (PRs)
|
||||
first_batch = outputs[1]
|
||||
assert len(first_batch.items) == 2
|
||||
assert isinstance(first_batch.items[0], Document)
|
||||
assert first_batch.items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||
assert isinstance(first_batch.items[1], Document)
|
||||
assert first_batch.items[1].id == "https://github.com/test-org/test-repo/pull/2"
|
||||
assert first_batch.next_checkpoint.curr_page == 1
|
||||
|
||||
# Check second batch (Issues)
|
||||
second_batch = outputs[2]
|
||||
assert len(second_batch.items) == 2
|
||||
assert isinstance(second_batch.items[0], Document)
|
||||
assert (
|
||||
second_batch.items[0].id == "https://github.com/test-org/test-repo/issues/1"
|
||||
)
|
||||
assert isinstance(second_batch.items[1], Document)
|
||||
assert (
|
||||
second_batch.items[1].id == "https://github.com/test-org/test-repo/issues/2"
|
||||
)
|
||||
assert second_batch.next_checkpoint.has_more
|
||||
|
||||
# Check third batch (finished checkpoint)
|
||||
third_batch = outputs[3]
|
||||
assert len(third_batch.items) == 0
|
||||
assert third_batch.next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_rate_limit(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with rate limit handling"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PR
|
||||
mock_pr = create_mock_pr()
|
||||
|
||||
# Mock get_pulls to raise RateLimitExceededException on first call
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||
RateLimitExceededException(403, {"message": "Rate limit exceeded"}, {}),
|
||||
[mock_pr],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock rate limit reset time
|
||||
mock_rate_limit = MagicMock(spec=RateLimit)
|
||||
mock_rate_limit.core.reset = datetime.now(timezone.utc)
|
||||
github_connector.github_client.get_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
with patch(
|
||||
"onyx.connectors.github.connector._sleep_after_rate_limit_exception"
|
||||
) as mock_sleep:
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
assert mock_sleep.call_count == 1
|
||||
|
||||
# Check that we got the document after rate limit was handled
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 1
|
||||
assert isinstance(outputs[1].items[0], Document)
|
||||
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_empty_repo(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with an empty repository"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Mock get_pulls and get_issues to return empty lists
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.return_value = []
|
||||
mock_repo.get_issues.return_value = MagicMock()
|
||||
mock_repo.get_issues.return_value.get_page.return_value = []
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we got no documents
|
||||
assert len(outputs) == 2
|
||||
assert len(outputs[-1].items) == 0
|
||||
assert not outputs[-1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_prs_only(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with only PRs enabled"""
|
||||
# Configure connector to only include PRs
|
||||
github_connector.include_prs = True
|
||||
github_connector.include_issues = False
|
||||
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PRs
|
||||
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||
|
||||
# Mock get_pulls method
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||
[mock_pr1, mock_pr2],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we only got PRs
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 2
|
||||
assert all(
|
||||
isinstance(doc, Document) and "pull" in doc.id for doc in outputs[0].items
|
||||
) # All documents should be PRs
|
||||
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_issues_only(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_issue: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with only issues enabled"""
|
||||
# Configure connector to only include issues
|
||||
github_connector.include_prs = False
|
||||
github_connector.include_issues = True
|
||||
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked issues
|
||||
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
|
||||
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
|
||||
|
||||
# Mock get_issues method
|
||||
mock_repo.get_issues.return_value = MagicMock()
|
||||
mock_repo.get_issues.return_value.get_page.side_effect = [
|
||||
[mock_issue1, mock_issue2],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we only got issues
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 2
|
||||
assert all(
|
||||
isinstance(doc, Document) and "issues" in doc.id for doc in outputs[0].items
|
||||
) # All documents should be issues
|
||||
assert outputs[1].next_checkpoint.has_more
|
||||
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_exception,expected_message",
|
||||
[
|
||||
(
|
||||
401,
|
||||
CredentialExpiredError,
|
||||
"GitHub credential appears to be invalid or expired",
|
||||
),
|
||||
(
|
||||
403,
|
||||
InsufficientPermissionsError,
|
||||
"Your GitHub token does not have sufficient permissions",
|
||||
),
|
||||
(
|
||||
404,
|
||||
ConnectorValidationError,
|
||||
"GitHub repository not found",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_connector_settings_errors(
|
||||
github_connector: GithubConnector,
|
||||
status_code: int,
|
||||
expected_exception: type[Exception],
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
"""Test validation with various error scenarios"""
|
||||
error = GithubException(status=status_code, data={}, headers={})
|
||||
|
||||
github_client = cast(Github, github_connector.github_client)
|
||||
get_repo_mock = cast(MagicMock, github_client.get_repo)
|
||||
get_repo_mock.side_effect = error
|
||||
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
github_connector.validate_connector_settings()
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_validate_connector_settings_success(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test successful validation"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Mock get_contents to simulate successful access
|
||||
mock_repo.get_contents.return_value = MagicMock()
|
||||
|
||||
github_connector.validate_connector_settings()
|
||||
github_connector.github_client.get_repo.assert_called_once_with(
|
||||
f"{github_connector.repo_owner}/{github_connector.repositories}"
|
||||
)
|
Reference in New Issue
Block a user