mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-27 21:33:41 +02:00
Checkpointed GitHub connector (#4307)
* WIP github checkpointing * first draft of github checkpointing * nit * CW comments * github basic connector test * connector test env var * secrets cant start with GITHUB_ * unit tests and bug fix * connector failures * address CW comments * validation fix * validation fix * remove prints * fixed tests * 100 items per page
This commit is contained in:
@@ -45,6 +45,8 @@ env:
|
|||||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||||
|
# Github
|
||||||
|
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||||
# Gitbook
|
# Gitbook
|
||||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||||
|
@@ -435,7 +435,7 @@ def _run_indexing(
|
|||||||
|
|
||||||
while checkpoint.has_more:
|
while checkpoint.has_more:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
|
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
|
||||||
)
|
)
|
||||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||||
checkpoint
|
checkpoint
|
||||||
|
@@ -157,10 +157,7 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
|
|||||||
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
|
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
|
||||||
|
|
||||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||||
try:
|
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
|
||||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
|
|
||||||
except ValueError:
|
|
||||||
INDEX_BATCH_SIZE = 16
|
|
||||||
|
|
||||||
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
|
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
|
||||||
|
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
|
import copy
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterator
|
from collections.abc import Generator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@@ -13,26 +15,30 @@ from github.GithubException import GithubException
|
|||||||
from github.Issue import Issue
|
from github.Issue import Issue
|
||||||
from github.PaginatedList import PaginatedList
|
from github.PaginatedList import PaginatedList
|
||||||
from github.PullRequest import PullRequest
|
from github.PullRequest import PullRequest
|
||||||
|
from github.Requester import Requester
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
from onyx.connectors.exceptions import ConnectorValidationError
|
from onyx.connectors.exceptions import ConnectorValidationError
|
||||||
from onyx.connectors.exceptions import CredentialExpiredError
|
from onyx.connectors.exceptions import CredentialExpiredError
|
||||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
from onyx.connectors.interfaces import CheckpointConnector
|
||||||
from onyx.connectors.interfaces import LoadConnector
|
from onyx.connectors.interfaces import CheckpointOutput
|
||||||
from onyx.connectors.interfaces import PollConnector
|
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||||
|
from onyx.connectors.interfaces import ConnectorFailure
|
||||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||||
from onyx.connectors.models import Document
|
from onyx.connectors.models import Document
|
||||||
|
from onyx.connectors.models import DocumentFailure
|
||||||
from onyx.connectors.models import TextSection
|
from onyx.connectors.models import TextSection
|
||||||
from onyx.utils.batching import batch_generator
|
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
ITEMS_PER_PAGE = 100
|
||||||
|
|
||||||
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
||||||
|
|
||||||
@@ -48,7 +54,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
|||||||
|
|
||||||
def _get_batch_rate_limited(
|
def _get_batch_rate_limited(
|
||||||
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
|
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
|
||||||
) -> list[Any]:
|
) -> list[PullRequest | Issue]:
|
||||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
||||||
@@ -69,21 +75,6 @@ def _get_batch_rate_limited(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _batch_github_objects(
|
|
||||||
git_objs: PaginatedList, github_client: Github, batch_size: int
|
|
||||||
) -> Iterator[list[Any]]:
|
|
||||||
page_num = 0
|
|
||||||
while True:
|
|
||||||
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
|
|
||||||
page_num += 1
|
|
||||||
|
|
||||||
if not batch:
|
|
||||||
break
|
|
||||||
|
|
||||||
for mini_batch in batch_generator(batch, batch_size=batch_size):
|
|
||||||
yield mini_batch
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||||
return Document(
|
return Document(
|
||||||
id=pull_request.html_url,
|
id=pull_request.html_url,
|
||||||
@@ -95,7 +86,9 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
|||||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||||
# as there is logic in indexing to prevent wrong timestamped docs
|
# as there is logic in indexing to prevent wrong timestamped docs
|
||||||
# due to local time discrepancies with UTC
|
# due to local time discrepancies with UTC
|
||||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
|
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
|
||||||
|
if pull_request.updated_at
|
||||||
|
else None,
|
||||||
metadata={
|
metadata={
|
||||||
"merged": str(pull_request.merged),
|
"merged": str(pull_request.merged),
|
||||||
"state": pull_request.state,
|
"state": pull_request.state,
|
||||||
@@ -122,31 +115,58 @@ def _convert_issue_to_document(issue: Issue) -> Document:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GithubConnector(LoadConnector, PollConnector):
|
class SerializedRepository(BaseModel):
|
||||||
|
# id is part of the raw_data as well, just pulled out for convenience
|
||||||
|
id: int
|
||||||
|
headers: dict[str, str | int]
|
||||||
|
raw_data: dict[str, Any]
|
||||||
|
|
||||||
|
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||||
|
return Repository.Repository(
|
||||||
|
requester, self.headers, self.raw_data, completed=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubConnectorStage(Enum):
|
||||||
|
START = "start"
|
||||||
|
PRS = "prs"
|
||||||
|
ISSUES = "issues"
|
||||||
|
|
||||||
|
|
||||||
|
class GithubConnectorCheckpoint(ConnectorCheckpoint):
|
||||||
|
stage: GithubConnectorStage
|
||||||
|
curr_page: int
|
||||||
|
|
||||||
|
cached_repo_ids: list[int] | None = None
|
||||||
|
cached_repo: SerializedRepository | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_owner: str,
|
repo_owner: str,
|
||||||
repositories: str | None = None,
|
repositories: str | None = None,
|
||||||
batch_size: int = INDEX_BATCH_SIZE,
|
|
||||||
state_filter: str = "all",
|
state_filter: str = "all",
|
||||||
include_prs: bool = True,
|
include_prs: bool = True,
|
||||||
include_issues: bool = False,
|
include_issues: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.repo_owner = repo_owner
|
self.repo_owner = repo_owner
|
||||||
self.repositories = repositories
|
self.repositories = repositories
|
||||||
self.batch_size = batch_size
|
|
||||||
self.state_filter = state_filter
|
self.state_filter = state_filter
|
||||||
self.include_prs = include_prs
|
self.include_prs = include_prs
|
||||||
self.include_issues = include_issues
|
self.include_issues = include_issues
|
||||||
self.github_client: Github | None = None
|
self.github_client: Github | None = None
|
||||||
|
|
||||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
# defaults to 30 items per page, can be set to as high as 100
|
||||||
self.github_client = (
|
self.github_client = (
|
||||||
Github(
|
Github(
|
||||||
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
|
credentials["github_access_token"],
|
||||||
|
base_url=GITHUB_CONNECTOR_BASE_URL,
|
||||||
|
per_page=ITEMS_PER_PAGE,
|
||||||
)
|
)
|
||||||
if GITHUB_CONNECTOR_BASE_URL
|
if GITHUB_CONNECTOR_BASE_URL
|
||||||
else Github(credentials["github_access_token"])
|
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -217,85 +237,193 @@ class GithubConnector(LoadConnector, PollConnector):
|
|||||||
return self._get_all_repos(github_client, attempt_num + 1)
|
return self._get_all_repos(github_client, attempt_num + 1)
|
||||||
|
|
||||||
def _fetch_from_github(
|
def _fetch_from_github(
|
||||||
self, start: datetime | None = None, end: datetime | None = None
|
self,
|
||||||
) -> GenerateDocumentsOutput:
|
checkpoint: GithubConnectorCheckpoint,
|
||||||
|
start: datetime | None = None,
|
||||||
|
end: datetime | None = None,
|
||||||
|
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
|
||||||
if self.github_client is None:
|
if self.github_client is None:
|
||||||
raise ConnectorMissingCredentialError("GitHub")
|
raise ConnectorMissingCredentialError("GitHub")
|
||||||
|
|
||||||
repos = []
|
checkpoint = copy.deepcopy(checkpoint)
|
||||||
if self.repositories:
|
|
||||||
if "," in self.repositories:
|
# First run of the connector, fetch all repos and store in checkpoint
|
||||||
# Multiple repositories specified
|
if checkpoint.cached_repo_ids is None:
|
||||||
repos = self._get_github_repos(self.github_client)
|
repos = []
|
||||||
|
if self.repositories:
|
||||||
|
if "," in self.repositories:
|
||||||
|
# Multiple repositories specified
|
||||||
|
repos = self._get_github_repos(self.github_client)
|
||||||
|
else:
|
||||||
|
# Single repository (backward compatibility)
|
||||||
|
repos = [self._get_github_repo(self.github_client)]
|
||||||
else:
|
else:
|
||||||
# Single repository (backward compatibility)
|
# All repositories
|
||||||
repos = [self._get_github_repo(self.github_client)]
|
repos = self._get_all_repos(self.github_client)
|
||||||
else:
|
if not repos:
|
||||||
# All repositories
|
checkpoint.has_more = False
|
||||||
repos = self._get_all_repos(self.github_client)
|
return checkpoint
|
||||||
|
|
||||||
for repo in repos:
|
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
|
||||||
if self.include_prs:
|
checkpoint.cached_repo = SerializedRepository(
|
||||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
id=checkpoint.cached_repo_ids[0],
|
||||||
pull_requests = repo.get_pulls(
|
headers=repos[0].raw_headers,
|
||||||
state=self.state_filter, sort="updated", direction="desc"
|
raw_data=repos[0].raw_data,
|
||||||
)
|
)
|
||||||
|
checkpoint.stage = GithubConnectorStage.PRS
|
||||||
|
checkpoint.curr_page = 0
|
||||||
|
# save checkpoint with repo ids retrieved
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
for pr_batch in _batch_github_objects(
|
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
|
||||||
pull_requests, self.github_client, self.batch_size
|
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
|
||||||
|
|
||||||
|
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||||
|
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||||
|
pull_requests = repo.get_pulls(
|
||||||
|
state=self.state_filter, sort="updated", direction="desc"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_batch: list[Document] = []
|
||||||
|
pr_batch = _get_batch_rate_limited(
|
||||||
|
pull_requests, checkpoint.curr_page, self.github_client
|
||||||
|
)
|
||||||
|
checkpoint.curr_page += 1
|
||||||
|
done_with_prs = False
|
||||||
|
for pr in pr_batch:
|
||||||
|
# we iterate backwards in time, so at this point we stop processing prs
|
||||||
|
if (
|
||||||
|
start is not None
|
||||||
|
and pr.updated_at
|
||||||
|
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
||||||
):
|
):
|
||||||
doc_batch: list[Document] = []
|
yield from doc_batch
|
||||||
for pr in pr_batch:
|
done_with_prs = True
|
||||||
if start is not None and pr.updated_at < start:
|
break
|
||||||
yield doc_batch
|
# Skip PRs updated after the end date
|
||||||
break
|
if (
|
||||||
if end is not None and pr.updated_at > end:
|
end is not None
|
||||||
continue
|
and pr.updated_at
|
||||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
and pr.updated_at.replace(tzinfo=timezone.utc) > end
|
||||||
yield doc_batch
|
|
||||||
|
|
||||||
if self.include_issues:
|
|
||||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
|
||||||
issues = repo.get_issues(
|
|
||||||
state=self.state_filter, sort="updated", direction="desc"
|
|
||||||
)
|
|
||||||
|
|
||||||
for issue_batch in _batch_github_objects(
|
|
||||||
issues, self.github_client, self.batch_size
|
|
||||||
):
|
):
|
||||||
doc_batch = []
|
continue
|
||||||
for issue in issue_batch:
|
try:
|
||||||
issue = cast(Issue, issue)
|
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||||
if start is not None and issue.updated_at < start:
|
except Exception as e:
|
||||||
yield doc_batch
|
error_msg = f"Error converting PR to document: {e}"
|
||||||
break
|
logger.exception(error_msg)
|
||||||
if end is not None and issue.updated_at > end:
|
yield ConnectorFailure(
|
||||||
continue
|
failed_document=DocumentFailure(
|
||||||
if issue.pull_request is not None:
|
document_id=str(pr.id), document_link=pr.html_url
|
||||||
# PRs are handled separately
|
),
|
||||||
continue
|
failure_message=error_msg,
|
||||||
doc_batch.append(_convert_issue_to_document(issue))
|
exception=e,
|
||||||
yield doc_batch
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
# if we found any PRs on the page, yield any associated documents and return the checkpoint
|
||||||
return self._fetch_from_github()
|
if not done_with_prs and len(pr_batch) > 0:
|
||||||
|
yield from doc_batch
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
def poll_source(
|
# if we went past the start date during the loop or there are no more
|
||||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
# prs to get, we move on to issues
|
||||||
) -> GenerateDocumentsOutput:
|
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||||
start_datetime = datetime.utcfromtimestamp(start)
|
checkpoint.curr_page = 0
|
||||||
end_datetime = datetime.utcfromtimestamp(end)
|
|
||||||
|
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||||
|
|
||||||
|
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
||||||
|
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||||
|
issues = repo.get_issues(
|
||||||
|
state=self.state_filter, sort="updated", direction="desc"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_batch = []
|
||||||
|
issue_batch = _get_batch_rate_limited(
|
||||||
|
issues, checkpoint.curr_page, self.github_client
|
||||||
|
)
|
||||||
|
checkpoint.curr_page += 1
|
||||||
|
done_with_issues = False
|
||||||
|
for issue in cast(list[Issue], issue_batch):
|
||||||
|
# we iterate backwards in time, so at this point we stop processing prs
|
||||||
|
if (
|
||||||
|
start is not None
|
||||||
|
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
||||||
|
):
|
||||||
|
yield from doc_batch
|
||||||
|
done_with_issues = True
|
||||||
|
break
|
||||||
|
# Skip PRs updated after the end date
|
||||||
|
if (
|
||||||
|
end is not None
|
||||||
|
and issue.updated_at.replace(tzinfo=timezone.utc) > end
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if issue.pull_request is not None:
|
||||||
|
# PRs are handled separately
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
doc_batch.append(_convert_issue_to_document(issue))
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error converting issue to document: {e}"
|
||||||
|
logger.exception(error_msg)
|
||||||
|
yield ConnectorFailure(
|
||||||
|
failed_document=DocumentFailure(
|
||||||
|
document_id=str(issue.id),
|
||||||
|
document_link=issue.html_url,
|
||||||
|
),
|
||||||
|
failure_message=error_msg,
|
||||||
|
exception=e,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if we found any issues on the page, yield them and return the checkpoint
|
||||||
|
if not done_with_issues and len(issue_batch) > 0:
|
||||||
|
yield from doc_batch
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
# if we went past the start date during the loop or there are no more
|
||||||
|
# issues to get, we move on to the next repo
|
||||||
|
checkpoint.stage = GithubConnectorStage.PRS
|
||||||
|
checkpoint.curr_page = 0
|
||||||
|
|
||||||
|
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
|
||||||
|
if checkpoint.cached_repo_ids:
|
||||||
|
next_id = checkpoint.cached_repo_ids.pop()
|
||||||
|
next_repo = self.github_client.get_repo(next_id)
|
||||||
|
checkpoint.cached_repo = SerializedRepository(
|
||||||
|
id=next_id,
|
||||||
|
headers=next_repo.raw_headers,
|
||||||
|
raw_data=next_repo.raw_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
@override
|
||||||
|
def load_from_checkpoint(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: GithubConnectorCheckpoint,
|
||||||
|
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||||
|
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||||
|
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||||
|
|
||||||
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
|
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
|
||||||
# Could be due to delayed processing on GitHub side
|
# Could be due to delayed processing on GitHub side
|
||||||
# The non-updated issues since last poll will be shortcut-ed and not embedded
|
# The non-updated issues since last poll will be shortcut-ed and not embedded
|
||||||
adjusted_start_datetime = start_datetime - timedelta(hours=3)
|
adjusted_start_datetime = start_datetime - timedelta(hours=3)
|
||||||
|
|
||||||
epoch = datetime.utcfromtimestamp(0)
|
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||||
if adjusted_start_datetime < epoch:
|
if adjusted_start_datetime < epoch:
|
||||||
adjusted_start_datetime = epoch
|
adjusted_start_datetime = epoch
|
||||||
|
|
||||||
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
|
return self._fetch_from_github(
|
||||||
|
checkpoint, start=adjusted_start_datetime, end=end_datetime
|
||||||
|
)
|
||||||
|
|
||||||
def validate_connector_settings(self) -> None:
|
def validate_connector_settings(self) -> None:
|
||||||
if self.github_client is None:
|
if self.github_client is None:
|
||||||
@@ -397,6 +525,16 @@ class GithubConnector(LoadConnector, PollConnector):
|
|||||||
f"Unexpected error during GitHub settings validation: {exc}"
|
f"Unexpected error during GitHub settings validation: {exc}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_checkpoint_json(
|
||||||
|
self, checkpoint_json: str
|
||||||
|
) -> GithubConnectorCheckpoint:
|
||||||
|
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||||
|
|
||||||
|
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
||||||
|
return GithubConnectorCheckpoint(
|
||||||
|
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import os
|
import os
|
||||||
@@ -406,7 +544,9 @@ if __name__ == "__main__":
|
|||||||
repositories=os.environ["REPOSITORIES"],
|
repositories=os.environ["REPOSITORIES"],
|
||||||
)
|
)
|
||||||
connector.load_credentials(
|
connector.load_credentials(
|
||||||
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
|
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
|
||||||
|
)
|
||||||
|
document_batches = connector.load_from_checkpoint(
|
||||||
|
0, time.time(), connector.build_dummy_checkpoint()
|
||||||
)
|
)
|
||||||
document_batches = connector.load_from_state()
|
|
||||||
print(next(document_batches))
|
print(next(document_batches))
|
||||||
|
@@ -56,7 +56,7 @@ puremagic==1.28
|
|||||||
pyairtable==3.0.1
|
pyairtable==3.0.1
|
||||||
pycryptodome==3.19.1
|
pycryptodome==3.19.1
|
||||||
pydantic==2.8.2
|
pydantic==2.8.2
|
||||||
PyGithub==1.58.2
|
PyGithub==2.5.0
|
||||||
python-dateutil==2.8.2
|
python-dateutil==2.8.2
|
||||||
python-gitlab==3.9.0
|
python-gitlab==3.9.0
|
||||||
python-pptx==0.6.23
|
python-pptx==0.6.23
|
||||||
|
54
backend/tests/daily/connectors/github/test_github_basic.py
Normal file
54
backend/tests/daily/connectors/github/test_github_basic.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from onyx.configs.constants import DocumentSource
|
||||||
|
from onyx.connectors.github.connector import GithubConnector
|
||||||
|
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def github_connector() -> GithubConnector:
|
||||||
|
connector = GithubConnector(
|
||||||
|
repo_owner="onyx-dot-app",
|
||||||
|
repositories="documentation",
|
||||||
|
include_prs=True,
|
||||||
|
include_issues=True,
|
||||||
|
)
|
||||||
|
connector.load_credentials(
|
||||||
|
{
|
||||||
|
"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return connector
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_connector_basic(github_connector: GithubConnector) -> None:
|
||||||
|
docs = load_all_docs_from_checkpoint_connector(
|
||||||
|
connector=github_connector,
|
||||||
|
start=0,
|
||||||
|
end=time.time(),
|
||||||
|
)
|
||||||
|
assert len(docs) > 0 # We expect at least one PR to exist
|
||||||
|
|
||||||
|
# Test the first document's structure
|
||||||
|
doc = docs[0]
|
||||||
|
|
||||||
|
# Verify basic document properties
|
||||||
|
assert doc.source == DocumentSource.GITHUB
|
||||||
|
assert doc.secondary_owners is None
|
||||||
|
assert doc.from_ingestion_api is False
|
||||||
|
assert doc.additional_info is None
|
||||||
|
|
||||||
|
# Verify GitHub-specific properties
|
||||||
|
assert "github.com" in doc.id # Should be a GitHub URL
|
||||||
|
assert doc.metadata is not None
|
||||||
|
assert "state" in doc.metadata
|
||||||
|
assert "merged" in doc.metadata
|
||||||
|
|
||||||
|
# Verify sections
|
||||||
|
assert len(doc.sections) == 1
|
||||||
|
section = doc.sections[0]
|
||||||
|
assert section.link == doc.id # Section link should match document ID
|
||||||
|
assert isinstance(section.text, str) # Should have some text content
|
@@ -0,0 +1,441 @@
|
|||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timezone
|
||||||
|
from typing import cast
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from github import Github
|
||||||
|
from github import GithubException
|
||||||
|
from github import RateLimitExceededException
|
||||||
|
from github.Issue import Issue
|
||||||
|
from github.PullRequest import PullRequest
|
||||||
|
from github.RateLimit import RateLimit
|
||||||
|
from github.Repository import Repository
|
||||||
|
from github.Requester import Requester
|
||||||
|
|
||||||
|
from onyx.connectors.exceptions import ConnectorValidationError
|
||||||
|
from onyx.connectors.exceptions import CredentialExpiredError
|
||||||
|
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||||
|
from onyx.connectors.github.connector import GithubConnector
|
||||||
|
from onyx.connectors.github.connector import SerializedRepository
|
||||||
|
from onyx.connectors.models import Document
|
||||||
|
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def repo_owner() -> str:
|
||||||
|
return "test-org"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def repositories() -> str:
|
||||||
|
return "test-repo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_github_client() -> MagicMock:
|
||||||
|
"""Create a mock GitHub client with proper typing"""
|
||||||
|
mock = MagicMock(spec=Github)
|
||||||
|
# Add proper return typing for get_repo method
|
||||||
|
mock.get_repo = MagicMock(return_value=MagicMock(spec=Repository))
|
||||||
|
# Add proper return typing for get_organization method
|
||||||
|
mock.get_organization = MagicMock()
|
||||||
|
# Add proper return typing for get_user method
|
||||||
|
mock.get_user = MagicMock()
|
||||||
|
# Add proper return typing for get_rate_limit method
|
||||||
|
mock.get_rate_limit = MagicMock(return_value=MagicMock(spec=RateLimit))
|
||||||
|
# Add requester for repository deserialization
|
||||||
|
mock.requester = MagicMock(spec=Requester)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def github_connector(
|
||||||
|
repo_owner: str, repositories: str, mock_github_client: MagicMock
|
||||||
|
) -> Generator[GithubConnector, None, None]:
|
||||||
|
connector = GithubConnector(
|
||||||
|
repo_owner=repo_owner,
|
||||||
|
repositories=repositories,
|
||||||
|
include_prs=True,
|
||||||
|
include_issues=True,
|
||||||
|
)
|
||||||
|
connector.github_client = mock_github_client
|
||||||
|
yield connector
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_mock_pr() -> Callable[..., MagicMock]:
|
||||||
|
def _create_mock_pr(
|
||||||
|
number: int = 1,
|
||||||
|
title: str = "Test PR",
|
||||||
|
body: str = "Test Description",
|
||||||
|
state: str = "open",
|
||||||
|
merged: bool = False,
|
||||||
|
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Helper to create a mock PullRequest object"""
|
||||||
|
mock_pr = MagicMock(spec=PullRequest)
|
||||||
|
mock_pr.number = number
|
||||||
|
mock_pr.title = title
|
||||||
|
mock_pr.body = body
|
||||||
|
mock_pr.state = state
|
||||||
|
mock_pr.merged = merged
|
||||||
|
mock_pr.updated_at = updated_at
|
||||||
|
mock_pr.html_url = f"https://github.com/test-org/test-repo/pull/{number}"
|
||||||
|
return mock_pr
|
||||||
|
|
||||||
|
return _create_mock_pr
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_mock_issue() -> Callable[..., MagicMock]:
|
||||||
|
def _create_mock_issue(
|
||||||
|
number: int = 1,
|
||||||
|
title: str = "Test Issue",
|
||||||
|
body: str = "Test Description",
|
||||||
|
state: str = "open",
|
||||||
|
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Helper to create a mock Issue object"""
|
||||||
|
mock_issue = MagicMock(spec=Issue)
|
||||||
|
mock_issue.number = number
|
||||||
|
mock_issue.title = title
|
||||||
|
mock_issue.body = body
|
||||||
|
mock_issue.state = state
|
||||||
|
mock_issue.updated_at = updated_at
|
||||||
|
mock_issue.html_url = f"https://github.com/test-org/test-repo/issues/{number}"
|
||||||
|
mock_issue.pull_request = None # Not a PR
|
||||||
|
return mock_issue
|
||||||
|
|
||||||
|
return _create_mock_issue
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_mock_repo() -> Callable[..., MagicMock]:
|
||||||
|
def _create_mock_repo(
|
||||||
|
name: str = "test-repo",
|
||||||
|
id: int = 1,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Helper to create a mock Repository object"""
|
||||||
|
mock_repo = MagicMock(spec=Repository)
|
||||||
|
mock_repo.name = name
|
||||||
|
mock_repo.id = id
|
||||||
|
mock_repo.raw_headers = {"status": "200 OK", "content-type": "application/json"}
|
||||||
|
mock_repo.raw_data = {
|
||||||
|
"id": str(id),
|
||||||
|
"name": name,
|
||||||
|
"full_name": f"test-org/{name}",
|
||||||
|
"private": str(False),
|
||||||
|
"description": "Test repository",
|
||||||
|
}
|
||||||
|
return mock_repo
|
||||||
|
|
||||||
|
return _create_mock_repo
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_checkpoint_happy_path(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
mock_github_client: MagicMock,
|
||||||
|
create_mock_repo: Callable[..., MagicMock],
|
||||||
|
create_mock_pr: Callable[..., MagicMock],
|
||||||
|
create_mock_issue: Callable[..., MagicMock],
|
||||||
|
) -> None:
|
||||||
|
"""Test loading from checkpoint - happy path"""
|
||||||
|
# Set up mocked repo
|
||||||
|
mock_repo = create_mock_repo()
|
||||||
|
github_connector.github_client = mock_github_client
|
||||||
|
mock_github_client.get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Set up mocked PRs and issues
|
||||||
|
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||||
|
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||||
|
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
|
||||||
|
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
|
||||||
|
|
||||||
|
# Mock get_pulls and get_issues methods
|
||||||
|
mock_repo.get_pulls.return_value = MagicMock()
|
||||||
|
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||||
|
[mock_pr1, mock_pr2],
|
||||||
|
[],
|
||||||
|
]
|
||||||
|
mock_repo.get_issues.return_value = MagicMock()
|
||||||
|
mock_repo.get_issues.return_value.get_page.side_effect = [
|
||||||
|
[mock_issue1, mock_issue2],
|
||||||
|
[],
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||||
|
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||||
|
# Call load_from_checkpoint
|
||||||
|
end_time = time.time()
|
||||||
|
outputs = load_everything_from_checkpoint_connector(
|
||||||
|
github_connector, 0, end_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we got all documents and final has_more=False
|
||||||
|
assert len(outputs) == 4
|
||||||
|
|
||||||
|
repo_batch = outputs[0]
|
||||||
|
assert len(repo_batch.items) == 0
|
||||||
|
assert repo_batch.next_checkpoint.has_more is True
|
||||||
|
|
||||||
|
# Check first batch (PRs)
|
||||||
|
first_batch = outputs[1]
|
||||||
|
assert len(first_batch.items) == 2
|
||||||
|
assert isinstance(first_batch.items[0], Document)
|
||||||
|
assert first_batch.items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||||
|
assert isinstance(first_batch.items[1], Document)
|
||||||
|
assert first_batch.items[1].id == "https://github.com/test-org/test-repo/pull/2"
|
||||||
|
assert first_batch.next_checkpoint.curr_page == 1
|
||||||
|
|
||||||
|
# Check second batch (Issues)
|
||||||
|
second_batch = outputs[2]
|
||||||
|
assert len(second_batch.items) == 2
|
||||||
|
assert isinstance(second_batch.items[0], Document)
|
||||||
|
assert (
|
||||||
|
second_batch.items[0].id == "https://github.com/test-org/test-repo/issues/1"
|
||||||
|
)
|
||||||
|
assert isinstance(second_batch.items[1], Document)
|
||||||
|
assert (
|
||||||
|
second_batch.items[1].id == "https://github.com/test-org/test-repo/issues/2"
|
||||||
|
)
|
||||||
|
assert second_batch.next_checkpoint.has_more
|
||||||
|
|
||||||
|
# Check third batch (finished checkpoint)
|
||||||
|
third_batch = outputs[3]
|
||||||
|
assert len(third_batch.items) == 0
|
||||||
|
assert third_batch.next_checkpoint.has_more is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_checkpoint_with_rate_limit(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
mock_github_client: MagicMock,
|
||||||
|
create_mock_repo: Callable[..., MagicMock],
|
||||||
|
create_mock_pr: Callable[..., MagicMock],
|
||||||
|
) -> None:
|
||||||
|
"""Test loading from checkpoint with rate limit handling"""
|
||||||
|
# Set up mocked repo
|
||||||
|
mock_repo = create_mock_repo()
|
||||||
|
github_connector.github_client = mock_github_client
|
||||||
|
mock_github_client.get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Set up mocked PR
|
||||||
|
mock_pr = create_mock_pr()
|
||||||
|
|
||||||
|
# Mock get_pulls to raise RateLimitExceededException on first call
|
||||||
|
mock_repo.get_pulls.return_value = MagicMock()
|
||||||
|
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||||
|
RateLimitExceededException(403, {"message": "Rate limit exceeded"}, {}),
|
||||||
|
[mock_pr],
|
||||||
|
[],
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock rate limit reset time
|
||||||
|
mock_rate_limit = MagicMock(spec=RateLimit)
|
||||||
|
mock_rate_limit.core.reset = datetime.now(timezone.utc)
|
||||||
|
github_connector.github_client.get_rate_limit.return_value = mock_rate_limit
|
||||||
|
|
||||||
|
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||||
|
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||||
|
# Call load_from_checkpoint
|
||||||
|
end_time = time.time()
|
||||||
|
with patch(
|
||||||
|
"onyx.connectors.github.connector._sleep_after_rate_limit_exception"
|
||||||
|
) as mock_sleep:
|
||||||
|
outputs = load_everything_from_checkpoint_connector(
|
||||||
|
github_connector, 0, end_time
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_sleep.call_count == 1
|
||||||
|
|
||||||
|
# Check that we got the document after rate limit was handled
|
||||||
|
assert len(outputs) >= 2
|
||||||
|
assert len(outputs[1].items) == 1
|
||||||
|
assert isinstance(outputs[1].items[0], Document)
|
||||||
|
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||||
|
|
||||||
|
assert outputs[-1].next_checkpoint.has_more is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_checkpoint_with_empty_repo(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
mock_github_client: MagicMock,
|
||||||
|
create_mock_repo: Callable[..., MagicMock],
|
||||||
|
) -> None:
|
||||||
|
"""Test loading from checkpoint with an empty repository"""
|
||||||
|
# Set up mocked repo
|
||||||
|
mock_repo = create_mock_repo()
|
||||||
|
github_connector.github_client = mock_github_client
|
||||||
|
mock_github_client.get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Mock get_pulls and get_issues to return empty lists
|
||||||
|
mock_repo.get_pulls.return_value = MagicMock()
|
||||||
|
mock_repo.get_pulls.return_value.get_page.return_value = []
|
||||||
|
mock_repo.get_issues.return_value = MagicMock()
|
||||||
|
mock_repo.get_issues.return_value.get_page.return_value = []
|
||||||
|
|
||||||
|
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||||
|
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||||
|
# Call load_from_checkpoint
|
||||||
|
end_time = time.time()
|
||||||
|
outputs = load_everything_from_checkpoint_connector(
|
||||||
|
github_connector, 0, end_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we got no documents
|
||||||
|
assert len(outputs) == 2
|
||||||
|
assert len(outputs[-1].items) == 0
|
||||||
|
assert not outputs[-1].next_checkpoint.has_more
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_checkpoint_with_prs_only(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
mock_github_client: MagicMock,
|
||||||
|
create_mock_repo: Callable[..., MagicMock],
|
||||||
|
create_mock_pr: Callable[..., MagicMock],
|
||||||
|
) -> None:
|
||||||
|
"""Test loading from checkpoint with only PRs enabled"""
|
||||||
|
# Configure connector to only include PRs
|
||||||
|
github_connector.include_prs = True
|
||||||
|
github_connector.include_issues = False
|
||||||
|
|
||||||
|
# Set up mocked repo
|
||||||
|
mock_repo = create_mock_repo()
|
||||||
|
github_connector.github_client = mock_github_client
|
||||||
|
mock_github_client.get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Set up mocked PRs
|
||||||
|
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||||
|
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||||
|
|
||||||
|
# Mock get_pulls method
|
||||||
|
mock_repo.get_pulls.return_value = MagicMock()
|
||||||
|
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||||
|
[mock_pr1, mock_pr2],
|
||||||
|
[],
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||||
|
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||||
|
# Call load_from_checkpoint
|
||||||
|
end_time = time.time()
|
||||||
|
outputs = load_everything_from_checkpoint_connector(
|
||||||
|
github_connector, 0, end_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we only got PRs
|
||||||
|
assert len(outputs) >= 2
|
||||||
|
assert len(outputs[1].items) == 2
|
||||||
|
assert all(
|
||||||
|
isinstance(doc, Document) and "pull" in doc.id for doc in outputs[0].items
|
||||||
|
) # All documents should be PRs
|
||||||
|
|
||||||
|
assert outputs[-1].next_checkpoint.has_more is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_checkpoint_with_issues_only(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
mock_github_client: MagicMock,
|
||||||
|
create_mock_repo: Callable[..., MagicMock],
|
||||||
|
create_mock_issue: Callable[..., MagicMock],
|
||||||
|
) -> None:
|
||||||
|
"""Test loading from checkpoint with only issues enabled"""
|
||||||
|
# Configure connector to only include issues
|
||||||
|
github_connector.include_prs = False
|
||||||
|
github_connector.include_issues = True
|
||||||
|
|
||||||
|
# Set up mocked repo
|
||||||
|
mock_repo = create_mock_repo()
|
||||||
|
github_connector.github_client = mock_github_client
|
||||||
|
mock_github_client.get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Set up mocked issues
|
||||||
|
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
|
||||||
|
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
|
||||||
|
|
||||||
|
# Mock get_issues method
|
||||||
|
mock_repo.get_issues.return_value = MagicMock()
|
||||||
|
mock_repo.get_issues.return_value.get_page.side_effect = [
|
||||||
|
[mock_issue1, mock_issue2],
|
||||||
|
[],
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||||
|
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||||
|
# Call load_from_checkpoint
|
||||||
|
end_time = time.time()
|
||||||
|
outputs = load_everything_from_checkpoint_connector(
|
||||||
|
github_connector, 0, end_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we only got issues
|
||||||
|
assert len(outputs) >= 2
|
||||||
|
assert len(outputs[1].items) == 2
|
||||||
|
assert all(
|
||||||
|
isinstance(doc, Document) and "issues" in doc.id for doc in outputs[0].items
|
||||||
|
) # All documents should be issues
|
||||||
|
assert outputs[1].next_checkpoint.has_more
|
||||||
|
|
||||||
|
assert outputs[-1].next_checkpoint.has_more is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code,expected_exception,expected_message",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
401,
|
||||||
|
CredentialExpiredError,
|
||||||
|
"GitHub credential appears to be invalid or expired",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
403,
|
||||||
|
InsufficientPermissionsError,
|
||||||
|
"Your GitHub token does not have sufficient permissions",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
404,
|
||||||
|
ConnectorValidationError,
|
||||||
|
"GitHub repository not found",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_validate_connector_settings_errors(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
status_code: int,
|
||||||
|
expected_exception: type[Exception],
|
||||||
|
expected_message: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test validation with various error scenarios"""
|
||||||
|
error = GithubException(status=status_code, data={}, headers={})
|
||||||
|
|
||||||
|
github_client = cast(Github, github_connector.github_client)
|
||||||
|
get_repo_mock = cast(MagicMock, github_client.get_repo)
|
||||||
|
get_repo_mock.side_effect = error
|
||||||
|
|
||||||
|
with pytest.raises(expected_exception) as excinfo:
|
||||||
|
github_connector.validate_connector_settings()
|
||||||
|
assert expected_message in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_connector_settings_success(
|
||||||
|
github_connector: GithubConnector,
|
||||||
|
mock_github_client: MagicMock,
|
||||||
|
create_mock_repo: Callable[..., MagicMock],
|
||||||
|
) -> None:
|
||||||
|
"""Test successful validation"""
|
||||||
|
# Set up mocked repo
|
||||||
|
mock_repo = create_mock_repo()
|
||||||
|
github_connector.github_client = mock_github_client
|
||||||
|
mock_github_client.get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Mock get_contents to simulate successful access
|
||||||
|
mock_repo.get_contents.return_value = MagicMock()
|
||||||
|
|
||||||
|
github_connector.validate_connector_settings()
|
||||||
|
github_connector.github_client.get_repo.assert_called_once_with(
|
||||||
|
f"{github_connector.repo_owner}/{github_connector.repositories}"
|
||||||
|
)
|
Reference in New Issue
Block a user