WIP rebased

This commit is contained in:
Evan Lohn 2025-03-12 16:38:10 -07:00
parent 463340b8a1
commit ad5136941d
20 changed files with 1128 additions and 366 deletions

View File

@ -1,3 +1,4 @@
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from io import BytesIO
@ -71,6 +72,7 @@ def get_latest_valid_checkpoint(
search_settings_id: int,
window_start: datetime,
window_end: datetime,
build_dummy_checkpoint: Callable[[], ConnectorCheckpoint],
) -> ConnectorCheckpoint:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
@ -105,7 +107,7 @@ def get_latest_valid_checkpoint(
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return ConnectorCheckpoint.build_dummy_checkpoint()
return build_dummy_checkpoint()
# assumes latest checkpoint is the furthest along. This only isn't true
# if something else has gone wrong.
@ -113,7 +115,7 @@ def get_latest_valid_checkpoint(
checkpoint_candidates[0] if checkpoint_candidates else None
)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
checkpoint = build_dummy_checkpoint()
if latest_valid_checkpoint_candidate:
try:
previous_checkpoint = load_checkpoint(
@ -193,7 +195,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
"""Check if the checkpoint content size exceeds the limit (200MB)"""
content_size = deep_getsizeof(checkpoint.checkpoint_content)
content_size = deep_getsizeof(checkpoint.model_dump())
if content_size > 200_000_000: # 200MB in bytes
raise ValueError(
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"

View File

@ -24,7 +24,6 @@ from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
@ -405,7 +404,7 @@ def _run_indexing(
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling.
if index_attempt.from_beginning:
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
db_session=db_session_temp,
@ -413,6 +412,7 @@ def _run_indexing(
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
build_dummy_checkpoint=connector_runner.connector.build_dummy_checkpoint,
)
unresolved_errors = get_index_attempt_errors_for_cc_pair(

View File

@ -132,7 +132,7 @@ class ConnectorRunner:
)
else:
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
finished_checkpoint = self.connector.build_dummy_checkpoint()
finished_checkpoint.has_more = False
if isinstance(self.connector, PollConnector):

View File

@ -1,13 +1,17 @@
import copy
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from typing_extensions import override
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@ -23,12 +27,16 @@ from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import DriveRetrievalStage
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_utils.google_auth import get_google_creds
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.google_utils import GoogleFields
from onyx.connectors.google_utils.resources import get_admin_service
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.google_utils.resources import get_google_docs_service
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
@ -36,21 +44,27 @@ from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_S
from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from onyx.connectors.google_utils.shared_constants import USER_FIELDS
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import EntityFailure
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import ThreadSafeDict
logger = setup_logger()
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
# All file retrievals could be batched and made at once
BATCHES_PER_CHECKPOINT = 10
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
if not string:
@ -66,10 +80,16 @@ def _convert_single_file(
creds: Any,
primary_admin_email: str,
file: dict[str, Any],
) -> Any:
) -> Document | ConnectorFailure | None:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
docs_service = get_google_docs_service(creds, user_email=user_email)
# Only construct these services when needed
user_drive_service = lazy_eval(
lambda: get_drive_service(creds, user_email=user_email)
)
docs_service = lazy_eval(
lambda: get_google_docs_service(creds, user_email=user_email)
)
return convert_drive_item_to_document(
file=file,
drive_service=user_drive_service,
@ -77,23 +97,6 @@ def _convert_single_file(
)
def _process_files_batch(
files: list[GoogleDriveFileType],
convert_func: Callable[[GoogleDriveFileType], Any],
batch_size: int,
) -> GenerateDocumentsOutput:
doc_batch = []
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
for doc in executor.map(convert_func, files):
if doc:
doc_batch.append(doc)
if len(doc_batch) >= batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def _clean_requested_drive_ids(
requested_drive_ids: set[str],
requested_folder_ids: set[str],
@ -112,7 +115,7 @@ def _clean_requested_drive_ids(
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpoint]):
def __init__(
self,
include_shared_drives: bool = False,
@ -145,13 +148,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if continue_on_failure is not None:
logger.warning("The 'continue_on_failure' parameter is deprecated.")
if (
not include_shared_drives
and not include_my_drives
and not include_files_shared_with_me
and not shared_folder_urls
and not my_drive_emails
and not shared_drive_urls
if not any(
(
include_shared_drives,
include_my_drives,
include_files_shared_with_me,
shared_folder_urls,
my_drive_emails,
shared_drive_urls,
)
):
raise ConnectorValidationError(
"Nothing to index. Please specify at least one of the following: "
@ -221,15 +226,12 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
)
return self._creds
# TODO: ensure returned new_creds_dict is actually persisted when this is called?
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
try:
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
except KeyError:
raise ValueError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
raise ValueError("Credentials json missing primary admin key")
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
@ -238,6 +240,25 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
return new_creds_dict
def _checkpoint_yield(
self,
drive_files: Iterator[GoogleDriveFileType],
checkpoint: GoogleDriveCheckpoint,
key: Callable[
[GoogleDriveCheckpoint], str
] = lambda check: check.curr_completion_key,
) -> Iterator[GoogleDriveFileType]:
"""
Wraps a file iterator with a checkpoint to record all the files that have been retrieved.
The key function is used to extract a unique key from the checkpoint to record the completion time,
defaults to the "curr completion key" which works when set before synchronous workflows.
"""
for drive_file in drive_files:
checkpoint.completion_map[key(checkpoint)] = datetime.fromisoformat(
drive_file[GoogleFields.MODIFIED_TIME.value]
).timestamp()
yield drive_file
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
@ -286,7 +307,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if not all_drive_ids:
logger.warning(
"No drives found even though we are indexing shared drives was requested."
"No drives found even though indexing shared drives was requested."
)
return all_drive_ids
@ -295,6 +316,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
self,
user_email: str,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
filtered_drive_ids: set[str],
filtered_folder_ids: set[str],
start: SecondsSinceUnixEpoch | None = None,
@ -330,8 +352,11 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
checkpoint=checkpoint,
start=start,
end=end,
key=lambda check: user_email
+ "@my_drive", # completion map keyed by user email
)
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
@ -341,15 +366,24 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
checkpoint=checkpoint,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
key=lambda check: user_email
+ "@"
+ drive_id, # completion map keyed by drive id
)
# I believe there may be some duplication here,
# i.e. if two users have access to the same folder
# and are retrieving in parallel.
remaining_folders = filtered_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
yield from crawl_folders_for_files(
is_slim=is_slim,
checkpoint=checkpoint,
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
@ -361,25 +395,27 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def _manage_service_account_retrieval(
self,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
all_org_emails: list[str] = self._get_all_user_emails()
if checkpoint.completion_stage == DriveRetrievalStage.START:
checkpoint.completion_stage = DriveRetrievalStage.USER_EMAILS
all_drive_ids: set[str] = self.get_all_drive_ids()
if checkpoint.completion_stage == DriveRetrievalStage.USER_EMAILS:
all_org_emails: list[str] = self._get_all_user_emails()
if not is_slim:
checkpoint.user_emails = all_org_emails
checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS
else:
assert checkpoint.user_emails is not None, "user emails not set"
all_org_emails = checkpoint.user_emails
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids(
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
)
# checkpoint - we've found all users and drives, now time to actually start
# we've found all users and drives, now time to actually start
# fetching stuff
logger.info(f"Found {len(all_org_emails)} users to impersonate")
logger.debug(f"Users: {all_org_emails}")
@ -388,24 +424,19 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
logger.debug(f"Folders: {folder_ids_to_retrieve}")
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
executor.submit(
self._impersonate_user_for_retrieval,
email,
is_slim,
drive_ids_to_retrieve,
folder_ids_to_retrieve,
start,
end,
): email
for email in all_org_emails
}
# Yield results as they complete
for future in as_completed(future_to_email):
yield from future.result()
user_retrieval_gens = [
self._impersonate_user_for_retrieval(
email,
is_slim,
checkpoint,
drive_ids_to_retrieve,
folder_ids_to_retrieve,
start,
end,
)
for email in all_org_emails
]
yield from parallel_yield(user_retrieval_gens, max_workers=10)
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
@ -415,74 +446,123 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _manage_oauth_retrieval(
def _determine_retrieval_ids(
self,
checkpoint: GoogleDriveCheckpoint,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, self.primary_admin_email)
if self.include_files_shared_with_me or self.include_my_drives:
logger.info(
f"Getting shared files/my drive files for OAuth "
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
f"include_my_drives={self.include_my_drives}, "
f"include_shared_drives={self.include_shared_drives}."
f"Using '{self.primary_admin_email}' as the account."
)
yield from get_all_files_for_oauth(
service=drive_service,
include_files_shared_with_me=self.include_files_shared_with_me,
include_my_drives=self.include_my_drives,
include_shared_drives=self.include_shared_drives,
is_slim=is_slim,
start=start,
end=end,
)
all_requested = (
self.include_files_shared_with_me
and self.include_my_drives
and self.include_shared_drives
)
if all_requested:
# If all 3 are true, we already yielded from get_all_files_for_oauth
return
next_stage: DriveRetrievalStage,
) -> tuple[set[str], set[str]]:
all_drive_ids = self.get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
if checkpoint.completion_stage == DriveRetrievalStage.DRIVE_IDS:
if self._requested_shared_drive_ids or self._requested_folder_ids:
(
drive_ids_to_retrieve,
folder_ids_to_retrieve,
) = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
if not is_slim:
checkpoint.drive_ids_to_retrieve = list(drive_ids_to_retrieve)
checkpoint.folder_ids_to_retrieve = list(folder_ids_to_retrieve)
checkpoint.completion_stage = next_stage
else:
assert (
checkpoint.drive_ids_to_retrieve is not None
), "drive ids to retrieve not set"
assert (
checkpoint.folder_ids_to_retrieve is not None
), "folder ids to retrieve not set"
# When loading from a checkpoint, load the previously cached drive and folder ids
drive_ids_to_retrieve = set(checkpoint.drive_ids_to_retrieve)
folder_ids_to_retrieve = set(checkpoint.folder_ids_to_retrieve)
return drive_ids_to_retrieve, folder_ids_to_retrieve
def _oauth_retrieval_all_files(
self,
is_slim: bool,
drive_service: GoogleDriveService,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
if not self.include_files_shared_with_me and not self.include_my_drives:
return
logger.info(
f"Getting shared files/my drive files for OAuth "
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
f"include_my_drives={self.include_my_drives}, "
f"include_shared_drives={self.include_shared_drives}."
f"Using '{self.primary_admin_email}' as the account."
)
yield from get_all_files_for_oauth(
service=drive_service,
include_files_shared_with_me=self.include_files_shared_with_me,
include_my_drives=self.include_my_drives,
include_shared_drives=self.include_shared_drives,
is_slim=is_slim,
checkpoint=checkpoint,
start=start,
end=end,
)
def _oauth_retrieval_drives(
self,
is_slim: bool,
drive_service: GoogleDriveService,
drive_ids_to_retrieve: set[str],
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
for drive_id in drive_ids_to_retrieve:
logger.info(
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
)
checkpoint.curr_completion_key = drive_id
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
checkpoint=checkpoint,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
def _oauth_retrieval_folders(
self,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
drive_service: GoogleDriveService,
drive_ids_to_retrieve: set[str],
folder_ids_to_retrieve: set[str],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
# Even if no folders were requested, we still check if any drives were requested
# that could be folders.
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
# the times stored in the completion_map aren't used due to the crawling behavior
# instead, the traversed_parent_ids are used to determine what we have left to retrieve
checkpoint.curr_completion_key = checkpoint.completion_stage
for folder_id in remaining_folders:
logger.info(
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
)
yield from crawl_folders_for_files(
is_slim=is_slim,
checkpoint=checkpoint,
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
@ -499,46 +579,163 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _fetch_drive_items(
def _checkpointed_oauth_retrieval(
self,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_files = self._manage_oauth_retrieval(
is_slim=is_slim,
checkpoint=checkpoint,
start=start,
end=end,
)
if is_slim:
return drive_files
return self._checkpoint_yield(
drive_files=drive_files,
checkpoint=checkpoint,
)
def _manage_oauth_retrieval(
self,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
if checkpoint.completion_stage == DriveRetrievalStage.START:
checkpoint.completion_stage = DriveRetrievalStage.OAUTH_FILES
drive_service = get_drive_service(self.creds, self.primary_admin_email)
if checkpoint.completion_stage == DriveRetrievalStage.OAUTH_FILES:
checkpoint.curr_completion_key = checkpoint.completion_stage
yield from self._oauth_retrieval_all_files(
drive_service=drive_service,
is_slim=is_slim,
checkpoint=checkpoint,
start=start,
end=end,
)
checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS
all_requested = (
self.include_files_shared_with_me
and self.include_my_drives
and self.include_shared_drives
)
if all_requested:
# If all 3 are true, we already yielded from get_all_files_for_oauth
checkpoint.completion_stage = DriveRetrievalStage.DONE
return
drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids(
checkpoint, is_slim, DriveRetrievalStage.SHARED_DRIVE_FILES
)
if checkpoint.completion_stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
yield from self._oauth_retrieval_drives(
is_slim=is_slim,
drive_service=drive_service,
drive_ids_to_retrieve=drive_ids_to_retrieve,
checkpoint=checkpoint,
start=start,
end=end,
)
checkpoint.completion_stage = DriveRetrievalStage.FOLDER_FILES
if checkpoint.completion_stage == DriveRetrievalStage.FOLDER_FILES:
checkpoint.curr_completion_key = checkpoint.completion_stage
yield from self._oauth_retrieval_folders(
is_slim=is_slim,
drive_service=drive_service,
drive_ids_to_retrieve=drive_ids_to_retrieve,
folder_ids_to_retrieve=folder_ids_to_retrieve,
checkpoint=checkpoint,
start=start,
end=end,
)
checkpoint.completion_stage = DriveRetrievalStage.DONE
def _fetch_drive_items(
self,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
assert checkpoint is not None, "Must provide checkpoint for full retrieval"
retrieval_method = (
self._manage_service_account_retrieval
if isinstance(self.creds, ServiceAccountCredentials)
else self._manage_oauth_retrieval
else self._checkpointed_oauth_retrieval
)
drive_files = retrieval_method(
return retrieval_method(
is_slim=is_slim,
checkpoint=checkpoint,
start=start,
end=end,
)
return drive_files
def _extract_docs_from_google_drive(
self,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
# Create a larger process pool for file conversion
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
)
) -> Iterator[list[Document | ConnectorFailure]]:
try:
# Create a larger process pool for file conversion
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
)
# Fetch files in batches
files_batch: list[GoogleDriveFileType] = []
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
files_batch.append(file)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
for file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
files_batch.append(file)
if len(files_batch) >= self.batch_size:
# Process the batch
if len(files_batch) >= self.batch_size:
# Process the batch
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
logger.error(f"Error converting file: {e}")
if documents:
yield documents
batches_complete += 1
files_batch = []
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_ids = list(self._retrieved_ids)
return # create a new checkpoint
# Process any remaining files
if files_batch:
futures = [
executor.submit(convert_func, file) for file in files_batch
]
@ -553,40 +750,52 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if documents:
yield documents
files_batch = []
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
if MISSING_SCOPES_ERROR_STR in str(e):
raise e
yield [
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=checkpoint.curr_completion_key,
missed_time_range=(
datetime.fromtimestamp(start or 0),
datetime.fromtimestamp(end or 0),
),
),
failure_message=f"Error extracting documents from Google Drive: {e}",
exception=e,
)
]
# Process any remaining files
if files_batch:
futures = [executor.submit(convert_func, file) for file in files_batch]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
logger.error(f"Error converting file: {e}")
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GoogleDriveCheckpoint,
) -> Generator[Document | ConnectorFailure, None, GoogleDriveCheckpoint]:
"""
Entrypoint for the connector; first run is with an empty checkpoint.
"""
if self._creds is None or self._primary_admin_email is None:
raise RuntimeError(
"Credentials missing, should not call this method before calling load_credentials"
)
if documents:
yield documents
def load_from_state(self) -> GenerateDocumentsOutput:
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_ids = set(checkpoint.retrieved_ids)
try:
yield from self._extract_docs_from_google_drive()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
try:
yield from self._extract_docs_from_google_drive(start, end)
for doc_list in self._extract_docs_from_google_drive(
checkpoint, start, end
):
yield from doc_list
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
if checkpoint.completion_stage == DriveRetrievalStage.DONE:
checkpoint.has_more = False
return checkpoint
def _extract_slim_docs_from_google_drive(
self,
@ -596,6 +805,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
) -> GenerateSlimDocumentOutput:
slim_batch = []
for file in self._fetch_drive_items(
checkpoint=self.build_dummy_checkpoint(),
is_slim=True,
start=start,
end=end,
@ -676,3 +886,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
raise ConnectorValidationError(
f"Unexpected error during Google Drive validation: {e}"
)
@override
def build_dummy_checkpoint(self) -> GoogleDriveCheckpoint:
return GoogleDriveCheckpoint(
prev_run_doc_ids=[],
retrieved_ids=[],
completion_stage=DriveRetrievalStage.START,
curr_completion_key="",
completion_map=ThreadSafeDict(),
has_more=True,
)

View File

@ -1,4 +1,5 @@
import io
from collections.abc import Callable
from datetime import datetime
from typing import cast
@ -13,7 +14,9 @@ from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_drive.section_extraction import get_document_sections
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
@ -202,26 +205,26 @@ def _extract_sections_basic(
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
) -> Document | None:
drive_service: Callable[[], GoogleDriveService],
docs_service: Callable[[], GoogleDocsService],
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
"""
try:
# skip shortcuts or folders
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
logger.info("Skipping shortcut/folder.")
return None
# If it's a Google Doc, we might do advanced parsing
sections: list[TextSection | ImageSection] = []
# Try to get sections using the advanced method first
# If it's a Google Doc, we might do advanced parsing
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
# get_document_sections is the advanced approach for Google Docs
doc_sections = get_document_sections(
docs_service=docs_service, doc_id=file.get("id", "")
docs_service=docs_service(), doc_id=file.get("id", "")
)
if doc_sections:
sections = cast(list[TextSection | ImageSection], doc_sections)
@ -232,7 +235,7 @@ def convert_drive_item_to_document(
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _extract_sections_basic(file, drive_service)
sections = _extract_sections_basic(file, drive_service())
# If we still don't have any sections, skip this file
if not sections:
@ -257,8 +260,19 @@ def convert_drive_item_to_document(
),
)
except Exception as e:
logger.error(f"Error converting file {file.get('name')}: {e}")
return None
error_str = f"Error converting file '{file.get('name')}' to Document: {e}"
logger.exception(error_str)
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=sections[0].link
if sections
else None, # TODO: see if this is the best way to get a link
),
failed_entity=None,
failure_message=error_str,
exception=e,
)
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:

View File

@ -1,14 +1,17 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from typing import Any
from googleapiclient.discovery import Resource # type: ignore
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.google_utils import GoogleFields
from onyx.connectors.google_utils.google_utils import ORDER_BY_KEY
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.logger import setup_logger
@ -25,6 +28,21 @@ SLIM_FILE_FIELDS = (
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
def _get_kwargs_and_start(
checkpoint: GoogleDriveCheckpoint,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
key: Callable[
[GoogleDriveCheckpoint], str
] = lambda check: check.curr_completion_key,
) -> tuple[dict, SecondsSinceUnixEpoch | None]:
kwargs = {}
if not is_slim:
start = checkpoint.completion_map.get(key(checkpoint), start)
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
return kwargs, start
def _generate_time_range_filter(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
@ -65,11 +83,13 @@ def _get_folders_in_parent(
def _get_files_in_parent(
service: Resource,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
parent_id: str,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
is_slim: bool = False,
) -> Iterator[GoogleDriveFileType]:
kwargs, start = _get_kwargs_and_start(checkpoint, is_slim, start)
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
@ -83,11 +103,14 @@ def _get_files_in_parent(
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
**kwargs,
):
yield file
def crawl_folders_for_files(
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
service: Resource,
parent_id: str,
traversed_parent_ids: set[str],
@ -98,22 +121,25 @@ def crawl_folders_for_files(
"""
This function starts crawling from any folder. It is slower though.
"""
if parent_id in traversed_parent_ids:
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
return
logger.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
if parent_id not in traversed_parent_ids:
logger.info("Parent id not in traversed parent ids, getting files")
found_files = False
for file in _get_files_in_parent(
service=service,
is_slim=is_slim,
checkpoint=checkpoint,
start=start,
end=end,
parent_id=parent_id,
):
found_files = True
yield file
found_files = False
for file in _get_files_in_parent(
service=service,
start=start,
end=end,
parent_id=parent_id,
):
found_files = True
yield file
if found_files:
update_traversed_ids_func(parent_id)
if found_files:
update_traversed_ids_func(parent_id)
else:
logger.info(f"Skipping subfolder files since already traversed: {parent_id}")
for subfolder in _get_folders_in_parent(
service=service,
@ -121,6 +147,8 @@ def crawl_folders_for_files(
):
logger.info("Fetching all files in subfolder: " + subfolder["name"])
yield from crawl_folders_for_files(
is_slim=is_slim,
checkpoint=checkpoint,
service=service,
parent_id=subfolder["id"],
traversed_parent_ids=traversed_parent_ids,
@ -133,11 +161,17 @@ def crawl_folders_for_files(
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
is_slim: bool = False,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
key: Callable[
[GoogleDriveCheckpoint], str
] = lambda check: check.curr_completion_key,
) -> Iterator[GoogleDriveFileType]:
kwargs, start = _get_kwargs_and_start(checkpoint, is_slim, start, key)
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
@ -173,16 +207,22 @@ def get_files_in_shared_drive(
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
**kwargs,
)
def get_all_files_in_my_drive(
service: Any,
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool = False,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
key: Callable[
[GoogleDriveCheckpoint], str
] = lambda check: check.curr_completion_key,
) -> Iterator[GoogleDriveFileType]:
kwargs, start = _get_kwargs_and_start(checkpoint, is_slim, start, key)
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
@ -196,7 +236,7 @@ def get_all_files_in_my_drive(
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=folder_query,
):
update_traversed_ids_func(file["id"])
update_traversed_ids_func(file[GoogleFields.ID])
found_folders = True
if found_folders:
update_traversed_ids_func(get_root_folder_id(service))
@ -212,19 +252,23 @@ def get_all_files_in_my_drive(
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
**kwargs,
)
def get_all_files_for_oauth(
service: Any,
service: GoogleDriveService,
include_files_shared_with_me: bool,
include_my_drives: bool,
# One of the above 2 should be true
include_shared_drives: bool,
is_slim: bool = False,
is_slim: bool,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
kwargs, start = _get_kwargs_and_start(checkpoint, is_slim, start)
should_get_all = (
include_shared_drives and include_my_drives and include_files_shared_with_me
)
@ -243,11 +287,13 @@ def get_all_files_for_oauth(
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=False,
corpora=corpora,
includeItemsFromAllDrives=should_get_all,
supportsAllDrives=should_get_all,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
**kwargs,
)
@ -255,4 +301,8 @@ def get_all_files_for_oauth(
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields="id").execute()["id"]
return (
service.files()
.get(fileId="root", fields=GoogleFields.ID)
.execute()[GoogleFields.ID]
)

View File

@ -1,6 +1,10 @@
from enum import Enum
from typing import Any
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.threadpool_concurrency import ThreadSafeDict
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
@ -20,3 +24,75 @@ class GDriveMimeType(str, Enum):
GoogleDriveFileType = dict[str, Any]
TOKEN_EXPIRATION_TIME = 3600 # 1 hour
# These correspond to The major stages of retrieval for google drive.
# The stages for the oauth flow are:
# get_all_files_for_oauth(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# crawl_folders_for_files()
#
# The stages for the service account flow are roughly:
# get_all_user_emails(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# Then for each user:
# get_files_in_my_drive()
# get_files_in_shared_drive()
# crawl_folders_for_files()
class DriveRetrievalStage(str, Enum):
START = "start"
DONE = "done"
# OAuth specific stages
OAUTH_FILES = "oauth_files"
# Service account specific stages
USER_EMAILS = "user_emails"
MY_DRIVE_FILES = "my_drive_files"
# Used for both oauth and service account flows
DRIVE_IDS = "drive_ids"
SHARED_DRIVE_FILES = "shared_drive_files"
FOLDER_FILES = "folder_files"
class GoogleDriveCheckpoint(ConnectorCheckpoint):
# The doc ids that were completed in the previous run
prev_run_doc_ids: list[str]
# Checkpoint version of _retrieved_ids
retrieved_ids: list[str]
# Describes the point in the retrieval+indexing process that the
# checkpoint is at. when this is set to a given stage, the connector
# will have already yielded at least 1 file or error from that stage.
# The Done stage is used to signal that has_more should become False.
completion_stage: DriveRetrievalStage
# The key into completion_map that is currently being processed.
# For stages that directly make a big (paginated) api call, this
# will be the stage itself. For stages with multiple sub-stages,
# this will be the id of the sub-stage. For example, when processing
# shared drives, it will be the id of the shared drive.
curr_completion_key: str
# The latest timestamp of a file that has been retrieved per completion key.
# See curr_completion_key for more details on completion keys.
completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch] = ThreadSafeDict()
# cached version of the drive and folder ids to retrieve
drive_ids_to_retrieve: list[str] | None = None
folder_ids_to_retrieve: list[str] | None = None
# cached user emails
user_emails: list[str] | None = None
# @field_serializer("completion_map")
# def serialize_completion_map(
# self, completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch], _info: Any
# ) -> dict[str, SecondsSinceUnixEpoch]:
# return completion_map._dict

View File

@ -4,6 +4,7 @@ from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from enum import Enum
from typing import Any
from googleapiclient.errors import HttpError # type: ignore
@ -20,6 +21,20 @@ logger = setup_logger()
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
PAGE_TOKEN_KEY = "pageToken"
ORDER_BY_KEY = "orderBy"
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
class GoogleFields(str, Enum):
ID = "id"
CREATED_TIME = "createdTime"
MODIFIED_TIME = "modifiedTime"
NAME = "name"
SIZE = "size"
PARENTS = "parents"
def _execute_with_retry(request: Any) -> Any:
max_attempts = 10
@ -90,11 +105,11 @@ def execute_paginated_retrieval(
retrieval_function: The specific list function to call (e.g., service.files().list)
**kwargs: Arguments to pass to the list function
"""
next_page_token = ""
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
while next_page_token is not None:
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs["pageToken"] = next_page_token
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
try:
results = retrieval_function(**request_kwargs).execute()
@ -117,7 +132,7 @@ def execute_paginated_retrieval(
logger.exception("Error executing request:")
raise e
next_page_token = results.get("nextPageToken")
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
if list_key:
for item in results.get(list_key, []):
yield item

View File

@ -4,9 +4,11 @@ from collections.abc import Iterator
from types import TracebackType
from typing import Any
from typing import Generic
from typing import TypeAlias
from typing import TypeVar
from pydantic import BaseModel
from typing_extensions import override
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
@ -19,7 +21,6 @@ SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Iterator[list[Document]]
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
class BaseConnector(abc.ABC):
@ -57,6 +58,9 @@ class BaseConnector(abc.ABC):
Default is a no-op (always successful).
"""
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
return ConnectorCheckpoint(has_more=True)
# Large set update or reindex, generally pulling a complete state or from a savestate file
class LoadConnector(BaseConnector):
@ -74,6 +78,8 @@ class PollConnector(BaseConnector):
raise NotImplementedError
# Slim connectors can retrieve just the ids and
# permission syncing information for connected documents
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_documents(
@ -186,14 +192,21 @@ class EventConnector(BaseConnector):
raise NotImplementedError
class CheckpointConnector(BaseConnector):
CT = TypeVar("CT", bound=ConnectorCheckpoint)
# TODO: find a reasonable way to parameterize the return type of the generator
CheckpointOutput: TypeAlias = Generator[
Document | ConnectorFailure, None, ConnectorCheckpoint
]
class CheckpointConnector(BaseConnector, Generic[CT]):
@abc.abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
checkpoint: CT,
) -> Generator[Document | ConnectorFailure, None, CT]:
"""Yields back documents or failures. Final return is the new checkpoint.
Final return can be access via either:
@ -214,3 +227,10 @@ class CheckpointConnector(BaseConnector):
```
"""
raise NotImplementedError
# Ideally return type should be CT, but that's not possible if
# we want to override build_dummy_checkpoint and have BaseConnector
# return a base ConnectorCheckpoint
@override
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
raise NotImplementedError

View File

@ -1,4 +1,3 @@
import json
from datetime import datetime
from enum import Enum
from typing import Any
@ -232,21 +231,16 @@ class IndexAttemptMetadata(BaseModel):
class ConnectorCheckpoint(BaseModel):
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
checkpoint_content: dict
has_more: bool
@classmethod
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
content_str = self.model_dump_json()
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
return content_str
class DocumentFailure(BaseModel):

View File

@ -10,10 +10,11 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypedDict
from pydantic import BaseModel
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from typing_extensions import override
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@ -23,7 +24,6 @@ from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
@ -56,8 +56,8 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType]
class SlackCheckpointContent(TypedDict):
channel_ids: list[str]
class SlackCheckpoint(ConnectorCheckpoint):
channel_ids: list[str] | None
channel_completion_map: dict[str, str]
current_channel: ChannelType | None
seen_thread_ts: list[str]
@ -435,6 +435,12 @@ def _get_all_doc_ids(
yield channel_metadata_list
class ProcessedSlackMessage(BaseModel):
doc: Document | None
thread_ts: str | None
failure: ConnectorFailure | None
def _process_message(
message: MessageType,
client: WebClient,
@ -443,7 +449,7 @@ def _process_message(
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
) -> ProcessedSlackMessage:
thread_ts = message.get("thread_ts")
try:
# causes random failures for testing checkpointing / continue on failure
@ -460,13 +466,13 @@ def _process_message(
seen_thread_ts=seen_thread_ts,
msg_filter_func=msg_filter_func,
)
return (doc, thread_ts, None)
return ProcessedSlackMessage(doc=doc, thread_ts=thread_ts, failure=None)
except Exception as e:
logger.exception(f"Error processing message {message['ts']}")
return (
None,
thread_ts,
ConnectorFailure(
return ProcessedSlackMessage(
doc=None,
thread_ts=thread_ts,
failure=ConnectorFailure(
failed_document=DocumentFailure(
document_id=_build_doc_id(
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
@ -479,7 +485,7 @@ def _process_message(
)
class SlackConnector(SlimConnector, CheckpointConnector):
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
MAX_WORKERS = 2
def __init__(
@ -525,8 +531,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
checkpoint: SlackCheckpoint,
) -> Generator[Document | ConnectorFailure, None, SlackCheckpoint]:
"""Rough outline:
Step 1: Get all channels, yield back Checkpoint.
@ -542,49 +548,36 @@ class SlackConnector(SlimConnector, CheckpointConnector):
if self.client is None or self.text_cleaner is None:
raise ConnectorMissingCredentialError("Slack")
checkpoint_content = cast(
SlackCheckpointContent,
(
copy.deepcopy(checkpoint.checkpoint_content)
or {
"channel_ids": None,
"channel_completion_map": {},
"current_channel": None,
"seen_thread_ts": [],
}
),
)
checkpoint = cast(SlackCheckpoint, copy.deepcopy(checkpoint))
# if this is the very first time we've called this, need to
# get all relevant channels and save them into the checkpoint
if checkpoint_content["channel_ids"] is None:
if checkpoint.channel_ids is None:
raw_channels = get_channels(self.client)
filtered_channels = filter_channels(
raw_channels, self.channels, self.channel_regex_enabled
)
checkpoint.channel_ids = [c["id"] for c in filtered_channels]
if len(filtered_channels) == 0:
checkpoint.has_more = False
return checkpoint
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
checkpoint_content["current_channel"] = filtered_channels[0]
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=True,
)
checkpoint.current_channel = filtered_channels[0]
checkpoint.has_more = True
return checkpoint
final_channel_ids = checkpoint_content["channel_ids"]
channel = checkpoint_content["current_channel"]
final_channel_ids = checkpoint.channel_ids
channel = checkpoint.current_channel
if channel is None:
raise ValueError("current_channel key not found in checkpoint")
raise ValueError("current_channel key not set in checkpoint")
channel_id = channel["id"]
if channel_id not in final_channel_ids:
raise ValueError(f"Channel {channel_id} not found in checkpoint")
oldest = str(start) if start else None
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
latest = checkpoint.channel_completion_map.get(channel_id, str(end))
seen_thread_ts = set(checkpoint.seen_thread_ts)
try:
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
@ -596,7 +589,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
# Process messages in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=SlackConnector.MAX_WORKERS) as executor:
futures: list[Future] = []
futures: list[Future[ProcessedSlackMessage]] = []
for message in message_batch:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
@ -614,7 +607,10 @@ class SlackConnector(SlimConnector, CheckpointConnector):
)
for future in as_completed(futures):
doc, thread_ts, failures = future.result()
processed_slack_message = future.result()
doc = processed_slack_message.doc
thread_ts = processed_slack_message.thread_ts
failure = processed_slack_message.failure
if doc:
# handle race conditions here since this is single
# threaded. Multi-threaded _process_message reads from this
@ -624,36 +620,31 @@ class SlackConnector(SlimConnector, CheckpointConnector):
if thread_ts not in seen_thread_ts:
yield doc
if thread_ts:
seen_thread_ts.add(thread_ts)
elif failures:
for failure in failures:
yield failure
assert thread_ts, "found non-None doc with None thread_ts"
seen_thread_ts.add(thread_ts)
elif failure:
yield failure
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
checkpoint.seen_thread_ts = list(seen_thread_ts)
checkpoint.channel_completion_map[channel["id"]] = new_latest
if has_more_in_channel:
checkpoint_content["current_channel"] = channel
checkpoint.current_channel = channel
else:
new_channel_id = next(
(
channel_id
for channel_id in final_channel_ids
if channel_id
not in checkpoint_content["channel_completion_map"]
if channel_id not in checkpoint.channel_completion_map
),
None,
)
if new_channel_id:
new_channel = _get_channel_by_id(self.client, new_channel_id)
checkpoint_content["current_channel"] = new_channel
checkpoint.current_channel = new_channel
else:
checkpoint_content["current_channel"] = None
checkpoint.current_channel = None
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=checkpoint_content["current_channel"] is not None,
)
checkpoint.has_more = checkpoint.current_channel is not None
return checkpoint
except Exception as e:
@ -753,6 +744,16 @@ class SlackConnector(SlimConnector, CheckpointConnector):
f"Unexpected error during Slack settings validation: {e}"
)
@override
def build_dummy_checkpoint(self) -> SlackCheckpoint:
return SlackCheckpoint(
channel_ids=None,
channel_completion_map={},
current_channel=None,
seen_thread_ts=[],
has_more=True,
)
if __name__ == "__main__":
import os
@ -767,9 +768,11 @@ if __name__ == "__main__":
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
checkpoint = connector.build_dummy_checkpoint()
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
gen = connector.load_from_checkpoint(
one_day_ago, current, cast(SlackCheckpoint, checkpoint)
)
try:
for document_or_failure in gen:
if isinstance(document_or_failure, Document):

View File

@ -0,0 +1,13 @@
from collections.abc import Callable
from functools import lru_cache
from typing import TypeVar
R = TypeVar("R")
def lazy_eval(func: Callable[[], R]) -> Callable[[], R]:
@lru_cache(maxsize=1)
def lazy_func() -> R:
return func()
return lazy_func

View File

@ -1,18 +1,136 @@
import collections.abc
import contextvars
import copy
import threading
import uuid
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import MutableMapping
from concurrent.futures import as_completed
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
from typing import Any
from typing import Generic
from typing import TypeVar
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from onyx.utils.logger import setup_logger
logger = setup_logger()
R = TypeVar("R")
KT = TypeVar("KT") # Key type
VT = TypeVar("VT") # Value type
class ThreadSafeDict(MutableMapping[KT, VT]):
"""
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
Implements the MutableMapping interface to provide a complete dictionary-like interface.
Example usage:
# Create a thread-safe dictionary
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
# Basic operations
safe_dict["key"] = 1
value = safe_dict["key"]
del safe_dict["key"]
# Bulk operations (atomic)
safe_dict.update({"key1": 1, "key2": 2})
# Safe iteration
with safe_dict.lock:
for key, value in safe_dict.items():
# This block is atomic
pass
"""
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
self._dict: dict[KT, VT] = input_dict or {}
self.lock = threading.Lock()
def __getitem__(self, key: KT) -> VT:
with self.lock:
return self._dict[key]
def __setitem__(self, key: KT, value: VT) -> None:
with self.lock:
self._dict[key] = value
def __delitem__(self, key: KT) -> None:
with self.lock:
del self._dict[key]
def __iter__(self) -> Iterator[KT]:
# Return a snapshot of keys to avoid potential modification during iteration
with self.lock:
return iter(list(self._dict.keys()))
def __len__(self) -> int:
with self.lock:
return len(self._dict)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return handler(dict[KT, VT])
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
return ThreadSafeDict(copy.deepcopy(self._dict))
def clear(self) -> None:
"""Remove all items from the dictionary atomically."""
with self.lock:
self._dict.clear()
def copy(self) -> dict[KT, VT]:
"""Return a shallow copy of the dictionary atomically."""
with self.lock:
return self._dict.copy()
def get(self, key: KT, default: Any = None) -> Any:
"""Get a value with a default, atomically."""
with self.lock:
return self._dict.get(key, default)
def pop(self, key: KT, default: Any = None) -> Any:
"""Remove and return a value with optional default, atomically."""
with self.lock:
if default is None:
return self._dict.pop(key)
return self._dict.pop(key, default)
def setdefault(self, key: KT, default: VT) -> VT:
"""Set a default value if key is missing, atomically."""
with self.lock:
return self._dict.setdefault(key, default)
def update(self, *args: Any, **kwargs: VT) -> None:
"""Update the dictionary atomically from another mapping or from kwargs."""
with self.lock:
self._dict.update(*args, **kwargs)
def items(self) -> collections.abc.ItemsView[KT, VT]:
"""Return a view of (key, value) pairs atomically."""
with self.lock:
return collections.abc.ItemsView(self)
def keys(self) -> collections.abc.KeysView[KT]:
"""Return a view of keys atomically."""
with self.lock:
return collections.abc.KeysView(self)
def values(self) -> collections.abc.ValuesView[VT]:
"""Return a view of values atomically."""
with self.lock:
return collections.abc.ValuesView(self)
def run_functions_tuples_in_parallel(
@ -190,3 +308,30 @@ def wait_on_background(task: TimeoutThread[R]) -> R:
raise task.exception
return task.result
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
return ind, next(g, None)
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
}
next_ind = len(gens)
while future_to_index:
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
for future in done:
ind, result = future.result()
print(ind, result)
if result is not None:
yield result
del future_to_index[future]
future_to_index[
executor.submit(_next_or_none, ind, gens[ind])
] = next_ind
next_ind += 1
else:
del future_to_index[future]

View File

@ -1,5 +1,7 @@
import time
from collections.abc import Sequence
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
@ -202,3 +204,13 @@ def assert_retrieved_docs_match_expected(
retrieved=valid_retrieved_texts,
)
assert expected_file_texts == valid_retrieved_texts
def load_all_docs(connector: GoogleDriveConnector) -> list[Document]:
retrieved_docs: list[Document] = []
for doc in connector.load_from_checkpoint(
0, time.time(), connector.build_dummy_checkpoint()
):
assert isinstance(doc, Document), f"Should not fail with {type(doc)} {doc}"
retrieved_docs.append(doc)
return retrieved_docs

View File

@ -1,10 +1,8 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
@ -23,6 +21,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
@ -47,9 +46,7 @@ def test_include_all(
my_drive_emails=None,
shared_drive_urls=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should get everything in shared and admin's My Drive with oauth
expected_file_ids = (
@ -89,9 +86,7 @@ def test_include_shared_drives_only(
my_drive_emails=None,
shared_drive_urls=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should only get shared drives
expected_file_ids = (
@ -129,9 +124,7 @@ def test_include_my_drives_only(
my_drive_emails=None,
shared_drive_urls=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should only get primary_admins My Drive because we are impersonating them
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
@ -160,9 +153,7 @@ def test_drive_one_only(
my_drive_emails=None,
shared_drive_urls=",".join([str(url) for url in drive_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
SHARED_DRIVE_1_FILE_IDS
@ -196,9 +187,7 @@ def test_folder_and_shared_drive(
my_drive_emails=None,
shared_drive_urls=",".join([str(url) for url in drive_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
SHARED_DRIVE_1_FILE_IDS
@ -243,9 +232,7 @@ def test_folders_only(
my_drive_emails=None,
shared_drive_urls=",".join([str(url) for url in shared_drive_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
FOLDER_1_1_FILE_IDS
@ -281,9 +268,7 @@ def test_personal_folders_only(
my_drive_emails=None,
shared_drive_urls=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected(

View File

@ -1,11 +1,10 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_URL
@ -37,9 +36,7 @@ def test_google_drive_sections(
my_drive_emails=None,
)
for connector in [oauth_connector, service_acct_connector]:
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Verify we got the 1 doc with sections
assert len(retrieved_docs) == 1

View File

@ -1,10 +1,8 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
@ -23,6 +21,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
@ -52,9 +51,7 @@ def test_include_all(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should get everything
expected_file_ids = (
@ -97,9 +94,7 @@ def test_include_shared_drives_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should only get shared drives
expected_file_ids = (
@ -137,9 +132,7 @@ def test_include_my_drives_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should only get everyone's My Drives
expected_file_ids = (
@ -174,9 +167,7 @@ def test_drive_one_only(
shared_drive_urls=",".join([str(url) for url in urls]),
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# We ignore shared_drive_urls if include_shared_drives is False
expected_file_ids = (
@ -211,9 +202,7 @@ def test_folder_and_shared_drive(
shared_folder_urls=",".join([str(url) for url in folder_urls]),
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should get everything except for the top level files in drive 2
expected_file_ids = (
@ -259,9 +248,7 @@ def test_folders_only(
shared_folder_urls=",".join([str(url) for url in folder_urls]),
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
FOLDER_1_1_FILE_IDS
@ -298,9 +285,7 @@ def test_specific_emails(
shared_drive_urls=None,
my_drive_emails=",".join([str(email) for email in my_drive_emails]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS
assert_retrieved_docs_match_expected(
@ -330,9 +315,7 @@ def get_specific_folders_in_my_drive(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected(

View File

@ -1,10 +1,8 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_retrieved_docs_match_expected,
@ -14,6 +12,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS
@ -37,9 +36,7 @@ def test_all(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
# These are the files from my drive
@ -77,9 +74,7 @@ def test_shared_drives_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
# These are the files from shared drives
@ -112,9 +107,7 @@ def test_shared_with_me_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
# These are the files shared with me from admin
@ -145,9 +138,7 @@ def test_my_drive_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# These are the files from my drive
expected_file_ids = TEST_USER_1_FILE_IDS
@ -175,9 +166,7 @@ def test_shared_my_drive_folder(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
# this is a folder from admin's drive that is shared with me
@ -207,9 +196,7 @@ def test_shared_drive_folder(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS
assert_retrieved_docs_match_expected(

View File

@ -54,9 +54,9 @@ def test_mock_connector_basic_flow(
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
@ -128,9 +128,9 @@ def test_mock_connector_with_failures(
json=[
{
"documents": [doc1.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [doc2_failure.model_dump(mode="json")],
}
],
@ -208,9 +208,9 @@ def test_mock_connector_failure_recovery(
json=[
{
"documents": [doc1.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [
doc2_failure.model_dump(mode="json"),
ConnectorFailure(
@ -292,9 +292,9 @@ def test_mock_connector_failure_recovery(
doc1.model_dump(mode="json"),
doc2.model_dump(mode="json"),
],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
@ -372,24 +372,24 @@ def test_mock_connector_checkpoint_recovery(
json=[
{
"documents": [doc.model_dump(mode="json") for doc in docs_batch_1],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=True
).model_dump(mode="json"),
"checkpoint": ConnectorCheckpoint(has_more=True).model_dump(
mode="json"
),
"failures": [],
},
{
"documents": [doc2.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=True
).model_dump(mode="json"),
"checkpoint": ConnectorCheckpoint(has_more=True).model_dump(
mode="json"
),
"failures": [],
},
{
"documents": [],
# should never hit this, unhandled exception happens first
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
"unhandled_exception": "Simulated unhandled error",
},
@ -463,9 +463,9 @@ def test_mock_connector_checkpoint_recovery(
json=[
{
"documents": [doc3.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],

View File

@ -1,10 +1,16 @@
import contextvars
import threading
import time
from collections.abc import Generator
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
import pytest
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.threadpool_concurrency import ThreadSafeDict
from onyx.utils.threadpool_concurrency import wait_on_background
# Create a context variable for testing
@ -148,3 +154,242 @@ def test_multiple_background_tasks() -> None:
# Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s)
assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations
def test_thread_safe_dict_basic_operations() -> None:
"""Test basic operations of ThreadSafeDict"""
d = ThreadSafeDict[str, int]()
# Test setting and getting
d["a"] = 1
assert d["a"] == 1
# Test get with default
assert d.get("a", None) == 1
assert d.get("b", 2) == 2
# Test deletion
del d["a"]
assert "a" not in d
# Test length
d["x"] = 10
d["y"] = 20
assert len(d) == 2
# Test iteration
keys = sorted(d.keys())
assert keys == ["x", "y"]
# Test items and values
assert dict(d.items()) == {"x": 10, "y": 20}
assert sorted(d.values()) == [10, 20]
def test_thread_safe_dict_concurrent_access() -> None:
"""Test ThreadSafeDict with concurrent access from multiple threads"""
d = ThreadSafeDict[str, int]()
num_threads = 10
iterations = 1000
def increment_values() -> None:
for i in range(iterations):
key = str(i % 5) # Use 5 different keys
# Get current value or 0 if not exists, increment, then store
current = d.get(key, 0)
if current is not None: # This check is needed since get can return None
d[key] = current + 1
else:
d[key] = 1
# Create and start threads
threads = []
for _ in range(num_threads):
t = threading.Thread(target=increment_values)
threads.append(t)
t.start()
# Wait for all threads to complete
for t in threads:
t.join()
# Verify results
# Each key should have been incremented (num_threads * iterations) / 5 times
expected_value = (num_threads * iterations) // 5
for i in range(5):
assert d[str(i)] == expected_value
def test_thread_safe_dict_bulk_operations() -> None:
"""Test bulk operations of ThreadSafeDict"""
d = ThreadSafeDict[str, int]()
# Test update with dict
d.update({"a": 1, "b": 2})
assert dict(d.items()) == {"a": 1, "b": 2}
# Test update with kwargs
d.update(c=3, d=4)
assert dict(d.items()) == {"a": 1, "b": 2, "c": 3, "d": 4}
# Test clear
d.clear()
assert len(d) == 0
def test_thread_safe_dict_concurrent_bulk_operations() -> None:
"""Test ThreadSafeDict with concurrent bulk operations"""
d = ThreadSafeDict[str, int]()
num_threads = 5
def bulk_update(start: int) -> None:
# Each thread updates with its own range of numbers
updates = {str(i): i for i in range(start, start + 20)}
d.update(updates)
time.sleep(0.01) # Add some delay to increase chance of thread overlap
# Run updates concurrently
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(bulk_update, i * 20) for i in range(num_threads)]
for future in futures:
future.result()
# Verify results
assert len(d) == num_threads * 20
# Verify all numbers from 0 to (num_threads * 20) are present
for i in range(num_threads * 20):
assert d[str(i)] == i
def test_thread_safe_dict_atomic_operations() -> None:
"""Test atomic operations with ThreadSafeDict's lock"""
d = ThreadSafeDict[str, list[int]]()
d["numbers"] = []
def append_numbers(start: int) -> None:
numbers = d["numbers"]
with d.lock:
for i in range(start, start + 5):
numbers.append(i)
time.sleep(0.001) # Add delay to increase chance of thread overlap
d["numbers"] = numbers
# Run concurrent append operations
threads = []
for i in range(4): # 4 threads, each adding 5 numbers
t = threading.Thread(target=append_numbers, args=(i * 5,))
threads.append(t)
t.start()
for t in threads:
t.join()
# Verify results
numbers = d["numbers"]
assert len(numbers) == 20 # 4 threads * 5 numbers each
assert sorted(numbers) == list(range(20)) # All numbers 0-19 should be present
def test_parallel_yield_basic() -> None:
"""Test that parallel_yield correctly yields values from multiple generators."""
def make_gen(values: list[int], delay: float) -> Generator[int, None, None]:
for v in values:
time.sleep(delay)
yield v
# Create generators with different delays
gen1 = make_gen([1, 4, 7], 0.1) # Slower generator
gen2 = make_gen([2, 5, 8], 0.05) # Faster generator
gen3 = make_gen([3, 6, 9], 0.15) # Slowest generator
# Collect results with timestamps
results: list[tuple[float, int]] = []
start_time = time.time()
for value in parallel_yield([gen1, gen2, gen3]):
results.append((time.time() - start_time, value))
# Verify all values were yielded
assert sorted(v for _, v in results) == list(range(1, 10))
# Verify that faster generators yielded earlier
# Group results by generator (values 1,4,7 are gen1, 2,5,8 are gen2, 3,6,9 are gen3)
gen1_times = [t for t, v in results if v in (1, 4, 7)]
gen2_times = [t for t, v in results if v in (2, 5, 8)]
gen3_times = [t for t, v in results if v in (3, 6, 9)]
# Average times for each generator
avg_gen1 = sum(gen1_times) / len(gen1_times)
avg_gen2 = sum(gen2_times) / len(gen2_times)
avg_gen3 = sum(gen3_times) / len(gen3_times)
# Verify gen2 (fastest) has lowest average time
assert avg_gen2 < avg_gen1
assert avg_gen2 < avg_gen3
def test_parallel_yield_empty_generators() -> None:
"""Test parallel_yield with empty generators."""
def empty_gen() -> Iterator[int]:
if False:
yield 0 # Makes this a generator function
gens = [empty_gen() for _ in range(3)]
results = list(parallel_yield(gens))
assert len(results) == 0
def test_parallel_yield_different_lengths() -> None:
"""Test parallel_yield with generators of different lengths."""
def make_gen(count: int) -> Iterator[int]:
for i in range(count):
yield i
time.sleep(0.01) # Small delay to ensure concurrent execution
gens = [
make_gen(1), # Yields: [0]
make_gen(3), # Yields: [0, 1, 2]
make_gen(2), # Yields: [0, 1]
]
results = list(parallel_yield(gens))
assert len(results) == 6 # Total number of items from all generators
assert sorted(results) == [0, 0, 0, 1, 1, 2]
def test_parallel_yield_exception_handling() -> None:
"""Test parallel_yield handles exceptions in generators properly."""
def failing_gen() -> Iterator[int]:
yield 1
raise ValueError("Generator failure")
def normal_gen() -> Iterator[int]:
yield 2
yield 3
gens = [failing_gen(), normal_gen()]
with pytest.raises(ValueError, match="Generator failure"):
list(parallel_yield(gens))
def test_parallel_yield_non_blocking() -> None:
"""Test parallel_yield with non-blocking generators (simple ranges)."""
def range_gen(start: int, end: int) -> Iterator[int]:
for i in range(start, end):
yield i
# Create three overlapping ranges
gens = [range_gen(0, 100), range_gen(100, 200), range_gen(200, 300)]
results = list(parallel_yield(gens))
print(results)
# Verify no values are missing
assert len(results) == 300 # Should have all values from 0 to 299
assert sorted(results) == list(range(300))