mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-12 05:49:36 +02:00
WIP rebased
This commit is contained in:
parent
463340b8a1
commit
ad5136941d
@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from io import BytesIO
|
||||
@ -71,6 +72,7 @@ def get_latest_valid_checkpoint(
|
||||
search_settings_id: int,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
build_dummy_checkpoint: Callable[[], ConnectorCheckpoint],
|
||||
) -> ConnectorCheckpoint:
|
||||
"""Get the latest valid checkpoint for a given connector credential pair"""
|
||||
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
|
||||
@ -105,7 +107,7 @@ def get_latest_valid_checkpoint(
|
||||
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
|
||||
"from scratch."
|
||||
)
|
||||
return ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
return build_dummy_checkpoint()
|
||||
|
||||
# assumes latest checkpoint is the furthest along. This only isn't true
|
||||
# if something else has gone wrong.
|
||||
@ -113,7 +115,7 @@ def get_latest_valid_checkpoint(
|
||||
checkpoint_candidates[0] if checkpoint_candidates else None
|
||||
)
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
checkpoint = build_dummy_checkpoint()
|
||||
if latest_valid_checkpoint_candidate:
|
||||
try:
|
||||
previous_checkpoint = load_checkpoint(
|
||||
@ -193,7 +195,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
|
||||
|
||||
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
|
||||
"""Check if the checkpoint content size exceeds the limit (200MB)"""
|
||||
content_size = deep_getsizeof(checkpoint.checkpoint_content)
|
||||
content_size = deep_getsizeof(checkpoint.model_dump())
|
||||
if content_size > 200_000_000: # 200MB in bytes
|
||||
raise ValueError(
|
||||
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
|
||||
|
@ -24,7 +24,6 @@ from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
@ -405,7 +404,7 @@ def _run_indexing(
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling.
|
||||
if index_attempt.from_beginning:
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint = get_latest_valid_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
@ -413,6 +412,7 @@ def _run_indexing(
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
build_dummy_checkpoint=connector_runner.connector.build_dummy_checkpoint,
|
||||
)
|
||||
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
|
@ -132,7 +132,7 @@ class ConnectorRunner:
|
||||
)
|
||||
|
||||
else:
|
||||
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
finished_checkpoint = self.connector.build_dummy_checkpoint()
|
||||
finished_checkpoint.has_more = False
|
||||
|
||||
if isinstance(self.connector, PollConnector):
|
||||
|
@ -1,13 +1,17 @@
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@ -23,12 +27,16 @@ from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.google_auth import get_google_creds
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.google_utils import GoogleFields
|
||||
from onyx.connectors.google_utils.resources import get_admin_service
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.google_utils.resources import get_google_docs_service
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
@ -36,21 +44,27 @@ from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_S
|
||||
from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
|
||||
from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
|
||||
from onyx.connectors.google_utils.shared_constants import USER_FIELDS
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.lazy import lazy_eval
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
|
||||
logger = setup_logger()
|
||||
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
|
||||
# All file retrievals could be batched and made at once
|
||||
|
||||
BATCHES_PER_CHECKPOINT = 10
|
||||
|
||||
|
||||
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
|
||||
if not string:
|
||||
@ -66,10 +80,16 @@ def _convert_single_file(
|
||||
creds: Any,
|
||||
primary_admin_email: str,
|
||||
file: dict[str, Any],
|
||||
) -> Any:
|
||||
) -> Document | ConnectorFailure | None:
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
user_drive_service = get_drive_service(creds, user_email=user_email)
|
||||
docs_service = get_google_docs_service(creds, user_email=user_email)
|
||||
|
||||
# Only construct these services when needed
|
||||
user_drive_service = lazy_eval(
|
||||
lambda: get_drive_service(creds, user_email=user_email)
|
||||
)
|
||||
docs_service = lazy_eval(
|
||||
lambda: get_google_docs_service(creds, user_email=user_email)
|
||||
)
|
||||
return convert_drive_item_to_document(
|
||||
file=file,
|
||||
drive_service=user_drive_service,
|
||||
@ -77,23 +97,6 @@ def _convert_single_file(
|
||||
)
|
||||
|
||||
|
||||
def _process_files_batch(
|
||||
files: list[GoogleDriveFileType],
|
||||
convert_func: Callable[[GoogleDriveFileType], Any],
|
||||
batch_size: int,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
|
||||
for doc in executor.map(convert_func, files):
|
||||
if doc:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
|
||||
def _clean_requested_drive_ids(
|
||||
requested_drive_ids: set[str],
|
||||
requested_folder_ids: set[str],
|
||||
@ -112,7 +115,7 @@ def _clean_requested_drive_ids(
|
||||
return valid_requested_drive_ids, filtered_folder_ids
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
include_shared_drives: bool = False,
|
||||
@ -145,13 +148,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if continue_on_failure is not None:
|
||||
logger.warning("The 'continue_on_failure' parameter is deprecated.")
|
||||
|
||||
if (
|
||||
not include_shared_drives
|
||||
and not include_my_drives
|
||||
and not include_files_shared_with_me
|
||||
and not shared_folder_urls
|
||||
and not my_drive_emails
|
||||
and not shared_drive_urls
|
||||
if not any(
|
||||
(
|
||||
include_shared_drives,
|
||||
include_my_drives,
|
||||
include_files_shared_with_me,
|
||||
shared_folder_urls,
|
||||
my_drive_emails,
|
||||
shared_drive_urls,
|
||||
)
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
"Nothing to index. Please specify at least one of the following: "
|
||||
@ -221,15 +226,12 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return self._creds
|
||||
|
||||
# TODO: ensure returned new_creds_dict is actually persisted when this is called?
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
try:
|
||||
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
raise ValueError("Credentials json missing primary admin key")
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
@ -238,6 +240,25 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
return new_creds_dict
|
||||
|
||||
def _checkpoint_yield(
|
||||
self,
|
||||
drive_files: Iterator[GoogleDriveFileType],
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
key: Callable[
|
||||
[GoogleDriveCheckpoint], str
|
||||
] = lambda check: check.curr_completion_key,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
"""
|
||||
Wraps a file iterator with a checkpoint to record all the files that have been retrieved.
|
||||
The key function is used to extract a unique key from the checkpoint to record the completion time,
|
||||
defaults to the "curr completion key" which works when set before synchronous workflows.
|
||||
"""
|
||||
for drive_file in drive_files:
|
||||
checkpoint.completion_map[key(checkpoint)] = datetime.fromisoformat(
|
||||
drive_file[GoogleFields.MODIFIED_TIME.value]
|
||||
).timestamp()
|
||||
yield drive_file
|
||||
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
self._retrieved_ids.add(folder_id)
|
||||
|
||||
@ -286,7 +307,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
if not all_drive_ids:
|
||||
logger.warning(
|
||||
"No drives found even though we are indexing shared drives was requested."
|
||||
"No drives found even though indexing shared drives was requested."
|
||||
)
|
||||
|
||||
return all_drive_ids
|
||||
@ -295,6 +316,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
user_email: str,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
filtered_drive_ids: set[str],
|
||||
filtered_folder_ids: set[str],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
@ -330,8 +352,11 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
key=lambda check: user_email
|
||||
+ "@my_drive", # completion map keyed by user email
|
||||
)
|
||||
|
||||
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
|
||||
@ -341,15 +366,24 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
key=lambda check: user_email
|
||||
+ "@"
|
||||
+ drive_id, # completion map keyed by drive id
|
||||
)
|
||||
|
||||
# I believe there may be some duplication here,
|
||||
# i.e. if two users have access to the same folder
|
||||
# and are retrieving in parallel.
|
||||
remaining_folders = filtered_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
|
||||
yield from crawl_folders_for_files(
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
traversed_parent_ids=self._retrieved_ids,
|
||||
@ -361,25 +395,27 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def _manage_service_account_retrieval(
|
||||
self,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
all_org_emails: list[str] = self._get_all_user_emails()
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.START:
|
||||
checkpoint.completion_stage = DriveRetrievalStage.USER_EMAILS
|
||||
|
||||
all_drive_ids: set[str] = self.get_all_drive_ids()
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.USER_EMAILS:
|
||||
all_org_emails: list[str] = self._get_all_user_emails()
|
||||
if not is_slim:
|
||||
checkpoint.user_emails = all_org_emails
|
||||
checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS
|
||||
else:
|
||||
assert checkpoint.user_emails is not None, "user emails not set"
|
||||
all_org_emails = checkpoint.user_emails
|
||||
|
||||
drive_ids_to_retrieve: set[str] = set()
|
||||
folder_ids_to_retrieve: set[str] = set()
|
||||
if self._requested_shared_drive_ids or self._requested_folder_ids:
|
||||
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
|
||||
requested_drive_ids=self._requested_shared_drive_ids,
|
||||
requested_folder_ids=self._requested_folder_ids,
|
||||
all_drive_ids_available=all_drive_ids,
|
||||
)
|
||||
elif self.include_shared_drives:
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids(
|
||||
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
|
||||
)
|
||||
|
||||
# checkpoint - we've found all users and drives, now time to actually start
|
||||
# we've found all users and drives, now time to actually start
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
logger.debug(f"Users: {all_org_emails}")
|
||||
@ -388,24 +424,19 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
|
||||
logger.debug(f"Folders: {folder_ids_to_retrieve}")
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
executor.submit(
|
||||
self._impersonate_user_for_retrieval,
|
||||
email,
|
||||
is_slim,
|
||||
drive_ids_to_retrieve,
|
||||
folder_ids_to_retrieve,
|
||||
start,
|
||||
end,
|
||||
): email
|
||||
for email in all_org_emails
|
||||
}
|
||||
|
||||
# Yield results as they complete
|
||||
for future in as_completed(future_to_email):
|
||||
yield from future.result()
|
||||
user_retrieval_gens = [
|
||||
self._impersonate_user_for_retrieval(
|
||||
email,
|
||||
is_slim,
|
||||
checkpoint,
|
||||
drive_ids_to_retrieve,
|
||||
folder_ids_to_retrieve,
|
||||
start,
|
||||
end,
|
||||
)
|
||||
for email in all_org_emails
|
||||
]
|
||||
yield from parallel_yield(user_retrieval_gens, max_workers=10)
|
||||
|
||||
remaining_folders = (
|
||||
drive_ids_to_retrieve | folder_ids_to_retrieve
|
||||
@ -415,74 +446,123 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
|
||||
)
|
||||
|
||||
def _manage_oauth_retrieval(
|
||||
def _determine_retrieval_ids(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
is_slim: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
drive_service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
|
||||
if self.include_files_shared_with_me or self.include_my_drives:
|
||||
logger.info(
|
||||
f"Getting shared files/my drive files for OAuth "
|
||||
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
|
||||
f"include_my_drives={self.include_my_drives}, "
|
||||
f"include_shared_drives={self.include_shared_drives}."
|
||||
f"Using '{self.primary_admin_email}' as the account."
|
||||
)
|
||||
yield from get_all_files_for_oauth(
|
||||
service=drive_service,
|
||||
include_files_shared_with_me=self.include_files_shared_with_me,
|
||||
include_my_drives=self.include_my_drives,
|
||||
include_shared_drives=self.include_shared_drives,
|
||||
is_slim=is_slim,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
all_requested = (
|
||||
self.include_files_shared_with_me
|
||||
and self.include_my_drives
|
||||
and self.include_shared_drives
|
||||
)
|
||||
if all_requested:
|
||||
# If all 3 are true, we already yielded from get_all_files_for_oauth
|
||||
return
|
||||
|
||||
next_stage: DriveRetrievalStage,
|
||||
) -> tuple[set[str], set[str]]:
|
||||
all_drive_ids = self.get_all_drive_ids()
|
||||
drive_ids_to_retrieve: set[str] = set()
|
||||
folder_ids_to_retrieve: set[str] = set()
|
||||
if self._requested_shared_drive_ids or self._requested_folder_ids:
|
||||
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
|
||||
requested_drive_ids=self._requested_shared_drive_ids,
|
||||
requested_folder_ids=self._requested_folder_ids,
|
||||
all_drive_ids_available=all_drive_ids,
|
||||
)
|
||||
elif self.include_shared_drives:
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.DRIVE_IDS:
|
||||
if self._requested_shared_drive_ids or self._requested_folder_ids:
|
||||
(
|
||||
drive_ids_to_retrieve,
|
||||
folder_ids_to_retrieve,
|
||||
) = _clean_requested_drive_ids(
|
||||
requested_drive_ids=self._requested_shared_drive_ids,
|
||||
requested_folder_ids=self._requested_folder_ids,
|
||||
all_drive_ids_available=all_drive_ids,
|
||||
)
|
||||
elif self.include_shared_drives:
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
if not is_slim:
|
||||
checkpoint.drive_ids_to_retrieve = list(drive_ids_to_retrieve)
|
||||
checkpoint.folder_ids_to_retrieve = list(folder_ids_to_retrieve)
|
||||
checkpoint.completion_stage = next_stage
|
||||
else:
|
||||
assert (
|
||||
checkpoint.drive_ids_to_retrieve is not None
|
||||
), "drive ids to retrieve not set"
|
||||
assert (
|
||||
checkpoint.folder_ids_to_retrieve is not None
|
||||
), "folder ids to retrieve not set"
|
||||
# When loading from a checkpoint, load the previously cached drive and folder ids
|
||||
drive_ids_to_retrieve = set(checkpoint.drive_ids_to_retrieve)
|
||||
folder_ids_to_retrieve = set(checkpoint.folder_ids_to_retrieve)
|
||||
|
||||
return drive_ids_to_retrieve, folder_ids_to_retrieve
|
||||
|
||||
def _oauth_retrieval_all_files(
|
||||
self,
|
||||
is_slim: bool,
|
||||
drive_service: GoogleDriveService,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
if not self.include_files_shared_with_me and not self.include_my_drives:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Getting shared files/my drive files for OAuth "
|
||||
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
|
||||
f"include_my_drives={self.include_my_drives}, "
|
||||
f"include_shared_drives={self.include_shared_drives}."
|
||||
f"Using '{self.primary_admin_email}' as the account."
|
||||
)
|
||||
yield from get_all_files_for_oauth(
|
||||
service=drive_service,
|
||||
include_files_shared_with_me=self.include_files_shared_with_me,
|
||||
include_my_drives=self.include_my_drives,
|
||||
include_shared_drives=self.include_shared_drives,
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _oauth_retrieval_drives(
|
||||
self,
|
||||
is_slim: bool,
|
||||
drive_service: GoogleDriveService,
|
||||
drive_ids_to_retrieve: set[str],
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
for drive_id in drive_ids_to_retrieve:
|
||||
logger.info(
|
||||
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
checkpoint.curr_completion_key = drive_id
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _oauth_retrieval_folders(
|
||||
self,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
drive_service: GoogleDriveService,
|
||||
drive_ids_to_retrieve: set[str],
|
||||
folder_ids_to_retrieve: set[str],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
# Even if no folders were requested, we still check if any drives were requested
|
||||
# that could be folders.
|
||||
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
|
||||
|
||||
# the times stored in the completion_map aren't used due to the crawling behavior
|
||||
# instead, the traversed_parent_ids are used to determine what we have left to retrieve
|
||||
checkpoint.curr_completion_key = checkpoint.completion_stage
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(
|
||||
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
|
||||
yield from crawl_folders_for_files(
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
traversed_parent_ids=self._retrieved_ids,
|
||||
@ -499,46 +579,163 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
|
||||
)
|
||||
|
||||
def _fetch_drive_items(
|
||||
def _checkpointed_oauth_retrieval(
|
||||
self,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
drive_files = self._manage_oauth_retrieval(
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
if is_slim:
|
||||
return drive_files
|
||||
|
||||
return self._checkpoint_yield(
|
||||
drive_files=drive_files,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def _manage_oauth_retrieval(
|
||||
self,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.START:
|
||||
checkpoint.completion_stage = DriveRetrievalStage.OAUTH_FILES
|
||||
|
||||
drive_service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.OAUTH_FILES:
|
||||
checkpoint.curr_completion_key = checkpoint.completion_stage
|
||||
yield from self._oauth_retrieval_all_files(
|
||||
drive_service=drive_service,
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS
|
||||
|
||||
all_requested = (
|
||||
self.include_files_shared_with_me
|
||||
and self.include_my_drives
|
||||
and self.include_shared_drives
|
||||
)
|
||||
if all_requested:
|
||||
# If all 3 are true, we already yielded from get_all_files_for_oauth
|
||||
checkpoint.completion_stage = DriveRetrievalStage.DONE
|
||||
return
|
||||
|
||||
drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids(
|
||||
checkpoint, is_slim, DriveRetrievalStage.SHARED_DRIVE_FILES
|
||||
)
|
||||
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
|
||||
yield from self._oauth_retrieval_drives(
|
||||
is_slim=is_slim,
|
||||
drive_service=drive_service,
|
||||
drive_ids_to_retrieve=drive_ids_to_retrieve,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
checkpoint.completion_stage = DriveRetrievalStage.FOLDER_FILES
|
||||
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.FOLDER_FILES:
|
||||
checkpoint.curr_completion_key = checkpoint.completion_stage
|
||||
yield from self._oauth_retrieval_folders(
|
||||
is_slim=is_slim,
|
||||
drive_service=drive_service,
|
||||
drive_ids_to_retrieve=drive_ids_to_retrieve,
|
||||
folder_ids_to_retrieve=folder_ids_to_retrieve,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
checkpoint.completion_stage = DriveRetrievalStage.DONE
|
||||
|
||||
def _fetch_drive_items(
|
||||
self,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
assert checkpoint is not None, "Must provide checkpoint for full retrieval"
|
||||
retrieval_method = (
|
||||
self._manage_service_account_retrieval
|
||||
if isinstance(self.creds, ServiceAccountCredentials)
|
||||
else self._manage_oauth_retrieval
|
||||
else self._checkpointed_oauth_retrieval
|
||||
)
|
||||
drive_files = retrieval_method(
|
||||
|
||||
return retrieval_method(
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
return drive_files
|
||||
|
||||
def _extract_docs_from_google_drive(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
# Create a larger process pool for file conversion
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
)
|
||||
) -> Iterator[list[Document | ConnectorFailure]]:
|
||||
try:
|
||||
# Create a larger process pool for file conversion
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
|
||||
files_batch.append(file)
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
for file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
files_batch.append(file)
|
||||
|
||||
if len(files_batch) >= self.batch_size:
|
||||
# Process the batch
|
||||
if len(files_batch) >= self.batch_size:
|
||||
# Process the batch
|
||||
futures = [
|
||||
executor.submit(convert_func, file) for file in files_batch
|
||||
]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting file: {e}")
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
batches_complete += 1
|
||||
files_batch = []
|
||||
|
||||
if batches_complete > BATCHES_PER_CHECKPOINT:
|
||||
checkpoint.retrieved_ids = list(self._retrieved_ids)
|
||||
return # create a new checkpoint
|
||||
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
futures = [
|
||||
executor.submit(convert_func, file) for file in files_batch
|
||||
]
|
||||
@ -553,40 +750,52 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
files_batch = []
|
||||
except Exception as e:
|
||||
logger.exception(f"Error extracting documents from Google Drive: {e}")
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise e
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=checkpoint.curr_completion_key,
|
||||
missed_time_range=(
|
||||
datetime.fromtimestamp(start or 0),
|
||||
datetime.fromtimestamp(end or 0),
|
||||
),
|
||||
),
|
||||
failure_message=f"Error extracting documents from Google Drive: {e}",
|
||||
exception=e,
|
||||
)
|
||||
]
|
||||
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
futures = [executor.submit(convert_func, file) for file in files_batch]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting file: {e}")
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, GoogleDriveCheckpoint]:
|
||||
"""
|
||||
Entrypoint for the connector; first run is with an empty checkpoint.
|
||||
"""
|
||||
if self._creds is None or self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Credentials missing, should not call this method before calling load_credentials"
|
||||
)
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
self._retrieved_ids = set(checkpoint.retrieved_ids)
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive()
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive(start, end)
|
||||
for doc_list in self._extract_docs_from_google_drive(
|
||||
checkpoint, start, end
|
||||
):
|
||||
yield from doc_list
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.DONE:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
@ -596,6 +805,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_batch = []
|
||||
for file in self._fetch_drive_items(
|
||||
checkpoint=self.build_dummy_checkpoint(),
|
||||
is_slim=True,
|
||||
start=start,
|
||||
end=end,
|
||||
@ -676,3 +886,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected error during Google Drive validation: {e}"
|
||||
)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> GoogleDriveCheckpoint:
|
||||
return GoogleDriveCheckpoint(
|
||||
prev_run_doc_ids=[],
|
||||
retrieved_ids=[],
|
||||
completion_stage=DriveRetrievalStage.START,
|
||||
curr_completion_key="",
|
||||
completion_map=ThreadSafeDict(),
|
||||
has_more=True,
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
@ -13,7 +14,9 @@ from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_drive.section_extraction import get_document_sections
|
||||
from onyx.connectors.google_utils.resources import GoogleDocsService
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@ -202,26 +205,26 @@ def _extract_sections_basic(
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType,
|
||||
drive_service: GoogleDriveService,
|
||||
docs_service: GoogleDocsService,
|
||||
) -> Document | None:
|
||||
drive_service: Callable[[], GoogleDriveService],
|
||||
docs_service: Callable[[], GoogleDocsService],
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
"""
|
||||
|
||||
try:
|
||||
# skip shortcuts or folders
|
||||
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
|
||||
logger.info("Skipping shortcut/folder.")
|
||||
return None
|
||||
|
||||
# If it's a Google Doc, we might do advanced parsing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
# Try to get sections using the advanced method first
|
||||
# If it's a Google Doc, we might do advanced parsing
|
||||
if file.get("mimeType") == GDriveMimeType.DOC.value:
|
||||
try:
|
||||
# get_document_sections is the advanced approach for Google Docs
|
||||
doc_sections = get_document_sections(
|
||||
docs_service=docs_service, doc_id=file.get("id", "")
|
||||
docs_service=docs_service(), doc_id=file.get("id", "")
|
||||
)
|
||||
if doc_sections:
|
||||
sections = cast(list[TextSection | ImageSection], doc_sections)
|
||||
@ -232,7 +235,7 @@ def convert_drive_item_to_document(
|
||||
|
||||
# If we don't have sections yet, use the basic extraction method
|
||||
if not sections:
|
||||
sections = _extract_sections_basic(file, drive_service)
|
||||
sections = _extract_sections_basic(file, drive_service())
|
||||
|
||||
# If we still don't have any sections, skip this file
|
||||
if not sections:
|
||||
@ -257,8 +260,19 @@ def convert_drive_item_to_document(
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting file {file.get('name')}: {e}")
|
||||
return None
|
||||
error_str = f"Error converting file '{file.get('name')}' to Document: {e}"
|
||||
logger.exception(error_str)
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=sections[0].link
|
||||
if sections
|
||||
else None, # TODO: see if this is the best way to get a link
|
||||
),
|
||||
failed_entity=None,
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
|
||||
|
||||
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
|
||||
|
@ -1,14 +1,17 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.google_utils import GoogleFields
|
||||
from onyx.connectors.google_utils.google_utils import ORDER_BY_KEY
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -25,6 +28,21 @@ SLIM_FILE_FIELDS = (
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
|
||||
def _get_kwargs_and_start(
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
is_slim: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
key: Callable[
|
||||
[GoogleDriveCheckpoint], str
|
||||
] = lambda check: check.curr_completion_key,
|
||||
) -> tuple[dict, SecondsSinceUnixEpoch | None]:
|
||||
kwargs = {}
|
||||
if not is_slim:
|
||||
start = checkpoint.completion_map.get(key(checkpoint), start)
|
||||
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
|
||||
return kwargs, start
|
||||
|
||||
|
||||
def _generate_time_range_filter(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
@ -65,11 +83,13 @@ def _get_folders_in_parent(
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
parent_id: str,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
is_slim: bool = False,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs, start = _get_kwargs_and_start(checkpoint, is_slim, start)
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
|
||||
query += " and trashed = false"
|
||||
query += _generate_time_range_filter(start, end)
|
||||
@ -83,11 +103,14 @@ def _get_files_in_parent(
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
**kwargs,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def crawl_folders_for_files(
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
traversed_parent_ids: set[str],
|
||||
@ -98,22 +121,25 @@ def crawl_folders_for_files(
|
||||
"""
|
||||
This function starts crawling from any folder. It is slower though.
|
||||
"""
|
||||
if parent_id in traversed_parent_ids:
|
||||
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
|
||||
return
|
||||
logger.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
|
||||
if parent_id not in traversed_parent_ids:
|
||||
logger.info("Parent id not in traversed parent ids, getting files")
|
||||
found_files = False
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
parent_id=parent_id,
|
||||
):
|
||||
found_files = True
|
||||
yield file
|
||||
|
||||
found_files = False
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
start=start,
|
||||
end=end,
|
||||
parent_id=parent_id,
|
||||
):
|
||||
found_files = True
|
||||
yield file
|
||||
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
else:
|
||||
logger.info(f"Skipping subfolder files since already traversed: {parent_id}")
|
||||
|
||||
for subfolder in _get_folders_in_parent(
|
||||
service=service,
|
||||
@ -121,6 +147,8 @@ def crawl_folders_for_files(
|
||||
):
|
||||
logger.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
yield from crawl_folders_for_files(
|
||||
is_slim=is_slim,
|
||||
checkpoint=checkpoint,
|
||||
service=service,
|
||||
parent_id=subfolder["id"],
|
||||
traversed_parent_ids=traversed_parent_ids,
|
||||
@ -133,11 +161,17 @@ def crawl_folders_for_files(
|
||||
def get_files_in_shared_drive(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
is_slim: bool = False,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
key: Callable[
|
||||
[GoogleDriveCheckpoint], str
|
||||
] = lambda check: check.curr_completion_key,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs, start = _get_kwargs_and_start(checkpoint, is_slim, start, key)
|
||||
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
@ -173,16 +207,22 @@ def get_files_in_shared_drive(
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_all_files_in_my_drive(
|
||||
service: Any,
|
||||
service: GoogleDriveService,
|
||||
update_traversed_ids_func: Callable,
|
||||
is_slim: bool = False,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
key: Callable[
|
||||
[GoogleDriveCheckpoint], str
|
||||
] = lambda check: check.curr_completion_key,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs, start = _get_kwargs_and_start(checkpoint, is_slim, start, key)
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
@ -196,7 +236,7 @@ def get_all_files_in_my_drive(
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(file["id"])
|
||||
update_traversed_ids_func(file[GoogleFields.ID])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(get_root_folder_id(service))
|
||||
@ -212,19 +252,23 @@ def get_all_files_in_my_drive(
|
||||
corpora="user",
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_all_files_for_oauth(
|
||||
service: Any,
|
||||
service: GoogleDriveService,
|
||||
include_files_shared_with_me: bool,
|
||||
include_my_drives: bool,
|
||||
# One of the above 2 should be true
|
||||
include_shared_drives: bool,
|
||||
is_slim: bool = False,
|
||||
is_slim: bool,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs, start = _get_kwargs_and_start(checkpoint, is_slim, start)
|
||||
|
||||
should_get_all = (
|
||||
include_shared_drives and include_my_drives and include_files_shared_with_me
|
||||
)
|
||||
@ -243,11 +287,13 @@ def get_all_files_for_oauth(
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora=corpora,
|
||||
includeItemsFromAllDrives=should_get_all,
|
||||
supportsAllDrives=should_get_all,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -255,4 +301,8 @@ def get_all_files_for_oauth(
|
||||
def get_root_folder_id(service: Resource) -> str:
|
||||
# we dont paginate here because there is only one root folder per user
|
||||
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||
return service.files().get(fileId="root", fields="id").execute()["id"]
|
||||
return (
|
||||
service.files()
|
||||
.get(fileId="root", fields=GoogleFields.ID)
|
||||
.execute()[GoogleFields.ID]
|
||||
)
|
||||
|
@ -1,6 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
|
||||
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
@ -20,3 +24,75 @@ class GDriveMimeType(str, Enum):
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
|
||||
TOKEN_EXPIRATION_TIME = 3600 # 1 hour
|
||||
|
||||
|
||||
# These correspond to The major stages of retrieval for google drive.
|
||||
# The stages for the oauth flow are:
|
||||
# get_all_files_for_oauth(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# crawl_folders_for_files()
|
||||
#
|
||||
# The stages for the service account flow are roughly:
|
||||
# get_all_user_emails(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# Then for each user:
|
||||
# get_files_in_my_drive()
|
||||
# get_files_in_shared_drive()
|
||||
# crawl_folders_for_files()
|
||||
class DriveRetrievalStage(str, Enum):
|
||||
START = "start"
|
||||
DONE = "done"
|
||||
# OAuth specific stages
|
||||
OAUTH_FILES = "oauth_files"
|
||||
|
||||
# Service account specific stages
|
||||
USER_EMAILS = "user_emails"
|
||||
MY_DRIVE_FILES = "my_drive_files"
|
||||
|
||||
# Used for both oauth and service account flows
|
||||
DRIVE_IDS = "drive_ids"
|
||||
SHARED_DRIVE_FILES = "shared_drive_files"
|
||||
FOLDER_FILES = "folder_files"
|
||||
|
||||
|
||||
class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
# The doc ids that were completed in the previous run
|
||||
prev_run_doc_ids: list[str]
|
||||
|
||||
# Checkpoint version of _retrieved_ids
|
||||
retrieved_ids: list[str]
|
||||
|
||||
# Describes the point in the retrieval+indexing process that the
|
||||
# checkpoint is at. when this is set to a given stage, the connector
|
||||
# will have already yielded at least 1 file or error from that stage.
|
||||
# The Done stage is used to signal that has_more should become False.
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The key into completion_map that is currently being processed.
|
||||
# For stages that directly make a big (paginated) api call, this
|
||||
# will be the stage itself. For stages with multiple sub-stages,
|
||||
# this will be the id of the sub-stage. For example, when processing
|
||||
# shared drives, it will be the id of the shared drive.
|
||||
curr_completion_key: str
|
||||
|
||||
# The latest timestamp of a file that has been retrieved per completion key.
|
||||
# See curr_completion_key for more details on completion keys.
|
||||
completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch] = ThreadSafeDict()
|
||||
|
||||
# cached version of the drive and folder ids to retrieve
|
||||
drive_ids_to_retrieve: list[str] | None = None
|
||||
folder_ids_to_retrieve: list[str] | None = None
|
||||
|
||||
# cached user emails
|
||||
user_emails: list[str] | None = None
|
||||
|
||||
# @field_serializer("completion_map")
|
||||
# def serialize_completion_map(
|
||||
# self, completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch], _info: Any
|
||||
# ) -> dict[str, SecondsSinceUnixEpoch]:
|
||||
# return completion_map._dict
|
||||
|
@ -4,6 +4,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
@ -20,6 +21,20 @@ logger = setup_logger()
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
|
||||
PAGE_TOKEN_KEY = "pageToken"
|
||||
ORDER_BY_KEY = "orderBy"
|
||||
|
||||
|
||||
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
|
||||
class GoogleFields(str, Enum):
|
||||
ID = "id"
|
||||
CREATED_TIME = "createdTime"
|
||||
MODIFIED_TIME = "modifiedTime"
|
||||
NAME = "name"
|
||||
SIZE = "size"
|
||||
PARENTS = "parents"
|
||||
|
||||
|
||||
def _execute_with_retry(request: Any) -> Any:
|
||||
max_attempts = 10
|
||||
@ -90,11 +105,11 @@ def execute_paginated_retrieval(
|
||||
retrieval_function: The specific list function to call (e.g., service.files().list)
|
||||
**kwargs: Arguments to pass to the list function
|
||||
"""
|
||||
next_page_token = ""
|
||||
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
|
||||
while next_page_token is not None:
|
||||
request_kwargs = kwargs.copy()
|
||||
if next_page_token:
|
||||
request_kwargs["pageToken"] = next_page_token
|
||||
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
|
||||
|
||||
try:
|
||||
results = retrieval_function(**request_kwargs).execute()
|
||||
@ -117,7 +132,7 @@ def execute_paginated_retrieval(
|
||||
logger.exception("Error executing request:")
|
||||
raise e
|
||||
|
||||
next_page_token = results.get("nextPageToken")
|
||||
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
|
||||
if list_key:
|
||||
for item in results.get(list_key, []):
|
||||
yield item
|
||||
|
@ -4,9 +4,11 @@ from collections.abc import Iterator
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeAlias
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
@ -19,7 +21,6 @@ SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC):
|
||||
@ -57,6 +58,9 @@ class BaseConnector(abc.ABC):
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
return ConnectorCheckpoint(has_more=True)
|
||||
|
||||
|
||||
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
||||
class LoadConnector(BaseConnector):
|
||||
@ -74,6 +78,8 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Slim connectors can retrieve just the ids and
|
||||
# permission syncing information for connected documents
|
||||
class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_documents(
|
||||
@ -186,14 +192,21 @@ class EventConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CheckpointConnector(BaseConnector):
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
# TODO: find a reasonable way to parameterize the return type of the generator
|
||||
CheckpointOutput: TypeAlias = Generator[
|
||||
Document | ConnectorFailure, None, ConnectorCheckpoint
|
||||
]
|
||||
|
||||
|
||||
class CheckpointConnector(BaseConnector, Generic[CT]):
|
||||
@abc.abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
checkpoint: CT,
|
||||
) -> Generator[Document | ConnectorFailure, None, CT]:
|
||||
"""Yields back documents or failures. Final return is the new checkpoint.
|
||||
|
||||
Final return can be access via either:
|
||||
@ -214,3 +227,10 @@ class CheckpointConnector(BaseConnector):
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Ideally return type should be CT, but that's not possible if
|
||||
# we want to override build_dummy_checkpoint and have BaseConnector
|
||||
# return a base ConnectorCheckpoint
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
raise NotImplementedError
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
@ -232,21 +231,16 @@ class IndexAttemptMetadata(BaseModel):
|
||||
|
||||
class ConnectorCheckpoint(BaseModel):
|
||||
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
|
||||
checkpoint_content: dict
|
||||
has_more: bool
|
||||
|
||||
@classmethod
|
||||
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
|
||||
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the checkpoint, with truncation for large checkpoint content."""
|
||||
MAX_CHECKPOINT_CONTENT_CHARS = 1000
|
||||
|
||||
content_str = json.dumps(self.checkpoint_content)
|
||||
content_str = self.model_dump_json()
|
||||
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
|
||||
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
|
||||
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
|
||||
return content_str
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
|
@ -10,10 +10,11 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@ -23,7 +24,6 @@ from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
@ -56,8 +56,8 @@ MessageType = dict[str, Any]
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
|
||||
class SlackCheckpointContent(TypedDict):
|
||||
channel_ids: list[str]
|
||||
class SlackCheckpoint(ConnectorCheckpoint):
|
||||
channel_ids: list[str] | None
|
||||
channel_completion_map: dict[str, str]
|
||||
current_channel: ChannelType | None
|
||||
seen_thread_ts: list[str]
|
||||
@ -435,6 +435,12 @@ def _get_all_doc_ids(
|
||||
yield channel_metadata_list
|
||||
|
||||
|
||||
class ProcessedSlackMessage(BaseModel):
|
||||
doc: Document | None
|
||||
thread_ts: str | None
|
||||
failure: ConnectorFailure | None
|
||||
|
||||
|
||||
def _process_message(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
@ -443,7 +449,7 @@ def _process_message(
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
|
||||
) -> ProcessedSlackMessage:
|
||||
thread_ts = message.get("thread_ts")
|
||||
try:
|
||||
# causes random failures for testing checkpointing / continue on failure
|
||||
@ -460,13 +466,13 @@ def _process_message(
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return (doc, thread_ts, None)
|
||||
return ProcessedSlackMessage(doc=doc, thread_ts=thread_ts, failure=None)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message {message['ts']}")
|
||||
return (
|
||||
None,
|
||||
thread_ts,
|
||||
ConnectorFailure(
|
||||
return ProcessedSlackMessage(
|
||||
doc=None,
|
||||
thread_ts=thread_ts,
|
||||
failure=ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
|
||||
@ -479,7 +485,7 @@ def _process_message(
|
||||
)
|
||||
|
||||
|
||||
class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
MAX_WORKERS = 2
|
||||
|
||||
def __init__(
|
||||
@ -525,8 +531,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
checkpoint: SlackCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, SlackCheckpoint]:
|
||||
"""Rough outline:
|
||||
|
||||
Step 1: Get all channels, yield back Checkpoint.
|
||||
@ -542,49 +548,36 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
if self.client is None or self.text_cleaner is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
checkpoint_content = cast(
|
||||
SlackCheckpointContent,
|
||||
(
|
||||
copy.deepcopy(checkpoint.checkpoint_content)
|
||||
or {
|
||||
"channel_ids": None,
|
||||
"channel_completion_map": {},
|
||||
"current_channel": None,
|
||||
"seen_thread_ts": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
checkpoint = cast(SlackCheckpoint, copy.deepcopy(checkpoint))
|
||||
|
||||
# if this is the very first time we've called this, need to
|
||||
# get all relevant channels and save them into the checkpoint
|
||||
if checkpoint_content["channel_ids"] is None:
|
||||
if checkpoint.channel_ids is None:
|
||||
raw_channels = get_channels(self.client)
|
||||
filtered_channels = filter_channels(
|
||||
raw_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
checkpoint.channel_ids = [c["id"] for c in filtered_channels]
|
||||
if len(filtered_channels) == 0:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
|
||||
checkpoint_content["current_channel"] = filtered_channels[0]
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=True,
|
||||
)
|
||||
checkpoint.current_channel = filtered_channels[0]
|
||||
checkpoint.has_more = True
|
||||
return checkpoint
|
||||
|
||||
final_channel_ids = checkpoint_content["channel_ids"]
|
||||
channel = checkpoint_content["current_channel"]
|
||||
final_channel_ids = checkpoint.channel_ids
|
||||
channel = checkpoint.current_channel
|
||||
if channel is None:
|
||||
raise ValueError("current_channel key not found in checkpoint")
|
||||
raise ValueError("current_channel key not set in checkpoint")
|
||||
|
||||
channel_id = channel["id"]
|
||||
if channel_id not in final_channel_ids:
|
||||
raise ValueError(f"Channel {channel_id} not found in checkpoint")
|
||||
|
||||
oldest = str(start) if start else None
|
||||
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
|
||||
latest = checkpoint.channel_completion_map.get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint.seen_thread_ts)
|
||||
try:
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
@ -596,7 +589,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
|
||||
# Process messages in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=SlackConnector.MAX_WORKERS) as executor:
|
||||
futures: list[Future] = []
|
||||
futures: list[Future[ProcessedSlackMessage]] = []
|
||||
for message in message_batch:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
@ -614,7 +607,10 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
)
|
||||
|
||||
for future in as_completed(futures):
|
||||
doc, thread_ts, failures = future.result()
|
||||
processed_slack_message = future.result()
|
||||
doc = processed_slack_message.doc
|
||||
thread_ts = processed_slack_message.thread_ts
|
||||
failure = processed_slack_message.failure
|
||||
if doc:
|
||||
# handle race conditions here since this is single
|
||||
# threaded. Multi-threaded _process_message reads from this
|
||||
@ -624,36 +620,31 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
if thread_ts not in seen_thread_ts:
|
||||
yield doc
|
||||
|
||||
if thread_ts:
|
||||
seen_thread_ts.add(thread_ts)
|
||||
elif failures:
|
||||
for failure in failures:
|
||||
yield failure
|
||||
assert thread_ts, "found non-None doc with None thread_ts"
|
||||
seen_thread_ts.add(thread_ts)
|
||||
elif failure:
|
||||
yield failure
|
||||
|
||||
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
|
||||
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
|
||||
checkpoint.seen_thread_ts = list(seen_thread_ts)
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_latest
|
||||
if has_more_in_channel:
|
||||
checkpoint_content["current_channel"] = channel
|
||||
checkpoint.current_channel = channel
|
||||
else:
|
||||
new_channel_id = next(
|
||||
(
|
||||
channel_id
|
||||
for channel_id in final_channel_ids
|
||||
if channel_id
|
||||
not in checkpoint_content["channel_completion_map"]
|
||||
if channel_id not in checkpoint.channel_completion_map
|
||||
),
|
||||
None,
|
||||
)
|
||||
if new_channel_id:
|
||||
new_channel = _get_channel_by_id(self.client, new_channel_id)
|
||||
checkpoint_content["current_channel"] = new_channel
|
||||
checkpoint.current_channel = new_channel
|
||||
else:
|
||||
checkpoint_content["current_channel"] = None
|
||||
checkpoint.current_channel = None
|
||||
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=checkpoint_content["current_channel"] is not None,
|
||||
)
|
||||
checkpoint.has_more = checkpoint.current_channel is not None
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
@ -753,6 +744,16 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
f"Unexpected error during Slack settings validation: {e}"
|
||||
)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> SlackCheckpoint:
|
||||
return SlackCheckpoint(
|
||||
channel_ids=None,
|
||||
channel_completion_map={},
|
||||
current_channel=None,
|
||||
seen_thread_ts=[],
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@ -767,9 +768,11 @@ if __name__ == "__main__":
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
|
||||
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
|
||||
gen = connector.load_from_checkpoint(
|
||||
one_day_ago, current, cast(SlackCheckpoint, checkpoint)
|
||||
)
|
||||
try:
|
||||
for document_or_failure in gen:
|
||||
if isinstance(document_or_failure, Document):
|
||||
|
13
backend/onyx/utils/lazy.py
Normal file
13
backend/onyx/utils/lazy.py
Normal file
@ -0,0 +1,13 @@
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import TypeVar
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def lazy_eval(func: Callable[[], R]) -> Callable[[], R]:
|
||||
@lru_cache(maxsize=1)
|
||||
def lazy_func() -> R:
|
||||
return func()
|
||||
|
||||
return lazy_func
|
@ -1,18 +1,136 @@
|
||||
import collections.abc
|
||||
import contextvars
|
||||
import copy
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import MutableMapping
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import FIRST_COMPLETED
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import wait
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
R = TypeVar("R")
|
||||
KT = TypeVar("KT") # Key type
|
||||
VT = TypeVar("VT") # Value type
|
||||
|
||||
|
||||
class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
"""
|
||||
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
|
||||
Implements the MutableMapping interface to provide a complete dictionary-like interface.
|
||||
|
||||
Example usage:
|
||||
# Create a thread-safe dictionary
|
||||
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
|
||||
|
||||
# Basic operations
|
||||
safe_dict["key"] = 1
|
||||
value = safe_dict["key"]
|
||||
del safe_dict["key"]
|
||||
|
||||
# Bulk operations (atomic)
|
||||
safe_dict.update({"key1": 1, "key2": 2})
|
||||
|
||||
# Safe iteration
|
||||
with safe_dict.lock:
|
||||
for key, value in safe_dict.items():
|
||||
# This block is atomic
|
||||
pass
|
||||
"""
|
||||
|
||||
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
|
||||
self._dict: dict[KT, VT] = input_dict or {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __getitem__(self, key: KT) -> VT:
|
||||
with self.lock:
|
||||
return self._dict[key]
|
||||
|
||||
def __setitem__(self, key: KT, value: VT) -> None:
|
||||
with self.lock:
|
||||
self._dict[key] = value
|
||||
|
||||
def __delitem__(self, key: KT) -> None:
|
||||
with self.lock:
|
||||
del self._dict[key]
|
||||
|
||||
def __iter__(self) -> Iterator[KT]:
|
||||
# Return a snapshot of keys to avoid potential modification during iteration
|
||||
with self.lock:
|
||||
return iter(list(self._dict.keys()))
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self.lock:
|
||||
return len(self._dict)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
return handler(dict[KT, VT])
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
return ThreadSafeDict(copy.deepcopy(self._dict))
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items from the dictionary atomically."""
|
||||
with self.lock:
|
||||
self._dict.clear()
|
||||
|
||||
def copy(self) -> dict[KT, VT]:
|
||||
"""Return a shallow copy of the dictionary atomically."""
|
||||
with self.lock:
|
||||
return self._dict.copy()
|
||||
|
||||
def get(self, key: KT, default: Any = None) -> Any:
|
||||
"""Get a value with a default, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def pop(self, key: KT, default: Any = None) -> Any:
|
||||
"""Remove and return a value with optional default, atomically."""
|
||||
with self.lock:
|
||||
if default is None:
|
||||
return self._dict.pop(key)
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
def setdefault(self, key: KT, default: VT) -> VT:
|
||||
"""Set a default value if key is missing, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.setdefault(key, default)
|
||||
|
||||
def update(self, *args: Any, **kwargs: VT) -> None:
|
||||
"""Update the dictionary atomically from another mapping or from kwargs."""
|
||||
with self.lock:
|
||||
self._dict.update(*args, **kwargs)
|
||||
|
||||
def items(self) -> collections.abc.ItemsView[KT, VT]:
|
||||
"""Return a view of (key, value) pairs atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ItemsView(self)
|
||||
|
||||
def keys(self) -> collections.abc.KeysView[KT]:
|
||||
"""Return a view of keys atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.KeysView(self)
|
||||
|
||||
def values(self) -> collections.abc.ValuesView[VT]:
|
||||
"""Return a view of values atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ValuesView(self)
|
||||
|
||||
|
||||
def run_functions_tuples_in_parallel(
|
||||
@ -190,3 +308,30 @@ def wait_on_background(task: TimeoutThread[R]) -> R:
|
||||
raise task.exception
|
||||
|
||||
return task.result
|
||||
|
||||
|
||||
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
|
||||
return ind, next(g, None)
|
||||
|
||||
|
||||
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_index: dict[Future[tuple[int, R | None]], int] = {
|
||||
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
|
||||
}
|
||||
|
||||
next_ind = len(gens)
|
||||
while future_to_index:
|
||||
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
|
||||
for future in done:
|
||||
ind, result = future.result()
|
||||
print(ind, result)
|
||||
if result is not None:
|
||||
yield result
|
||||
del future_to_index[future]
|
||||
future_to_index[
|
||||
executor.submit(_next_or_none, ind, gens[ind])
|
||||
] = next_ind
|
||||
next_ind += 1
|
||||
else:
|
||||
del future_to_index[future]
|
||||
|
@ -1,5 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
|
||||
@ -202,3 +204,13 @@ def assert_retrieved_docs_match_expected(
|
||||
retrieved=valid_retrieved_texts,
|
||||
)
|
||||
assert expected_file_texts == valid_retrieved_texts
|
||||
|
||||
|
||||
def load_all_docs(connector: GoogleDriveConnector) -> list[Document]:
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc in connector.load_from_checkpoint(
|
||||
0, time.time(), connector.build_dummy_checkpoint()
|
||||
):
|
||||
assert isinstance(doc, Document), f"Should not fail with {type(doc)} {doc}"
|
||||
retrieved_docs.append(doc)
|
||||
return retrieved_docs
|
||||
|
@ -1,10 +1,8 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
@ -23,6 +21,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
|
||||
@ -47,9 +46,7 @@ def test_include_all(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Should get everything in shared and admin's My Drive with oauth
|
||||
expected_file_ids = (
|
||||
@ -89,9 +86,7 @@ def test_include_shared_drives_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Should only get shared drives
|
||||
expected_file_ids = (
|
||||
@ -129,9 +124,7 @@ def test_include_my_drives_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Should only get primary_admins My Drive because we are impersonating them
|
||||
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
|
||||
@ -160,9 +153,7 @@ def test_drive_one_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=",".join([str(url) for url in drive_urls]),
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
SHARED_DRIVE_1_FILE_IDS
|
||||
@ -196,9 +187,7 @@ def test_folder_and_shared_drive(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=",".join([str(url) for url in drive_urls]),
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
SHARED_DRIVE_1_FILE_IDS
|
||||
@ -243,9 +232,7 @@ def test_folders_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=",".join([str(url) for url in shared_drive_urls]),
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
FOLDER_1_1_FILE_IDS
|
||||
@ -281,9 +268,7 @@ def test_personal_folders_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
|
@ -1,11 +1,10 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_URL
|
||||
|
||||
|
||||
@ -37,9 +36,7 @@ def test_google_drive_sections(
|
||||
my_drive_emails=None,
|
||||
)
|
||||
for connector in [oauth_connector, service_acct_connector]:
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Verify we got the 1 doc with sections
|
||||
assert len(retrieved_docs) == 1
|
||||
|
@ -1,10 +1,8 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
@ -23,6 +21,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
|
||||
@ -52,9 +51,7 @@ def test_include_all(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Should get everything
|
||||
expected_file_ids = (
|
||||
@ -97,9 +94,7 @@ def test_include_shared_drives_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Should only get shared drives
|
||||
expected_file_ids = (
|
||||
@ -137,9 +132,7 @@ def test_include_my_drives_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Should only get everyone's My Drives
|
||||
expected_file_ids = (
|
||||
@ -174,9 +167,7 @@ def test_drive_one_only(
|
||||
shared_drive_urls=",".join([str(url) for url in urls]),
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# We ignore shared_drive_urls if include_shared_drives is False
|
||||
expected_file_ids = (
|
||||
@ -211,9 +202,7 @@ def test_folder_and_shared_drive(
|
||||
shared_folder_urls=",".join([str(url) for url in folder_urls]),
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Should get everything except for the top level files in drive 2
|
||||
expected_file_ids = (
|
||||
@ -259,9 +248,7 @@ def test_folders_only(
|
||||
shared_folder_urls=",".join([str(url) for url in folder_urls]),
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
FOLDER_1_1_FILE_IDS
|
||||
@ -298,9 +285,7 @@ def test_specific_emails(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=",".join([str(email) for email in my_drive_emails]),
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
@ -330,9 +315,7 @@ def get_specific_folders_in_my_drive(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
|
@ -1,10 +1,8 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_retrieved_docs_match_expected,
|
||||
@ -14,6 +12,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS
|
||||
@ -37,9 +36,7 @@ def test_all(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from my drive
|
||||
@ -77,9 +74,7 @@ def test_shared_drives_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from shared drives
|
||||
@ -112,9 +107,7 @@ def test_shared_with_me_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files shared with me from admin
|
||||
@ -145,9 +138,7 @@ def test_my_drive_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# These are the files from my drive
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS
|
||||
@ -175,9 +166,7 @@ def test_shared_my_drive_folder(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# this is a folder from admin's drive that is shared with me
|
||||
@ -207,9 +196,7 @@ def test_shared_drive_folder(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
|
@ -54,9 +54,9 @@ def test_mock_connector_basic_flow(
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
@ -128,9 +128,9 @@ def test_mock_connector_with_failures(
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [doc2_failure.model_dump(mode="json")],
|
||||
}
|
||||
],
|
||||
@ -208,9 +208,9 @@ def test_mock_connector_failure_recovery(
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [
|
||||
doc2_failure.model_dump(mode="json"),
|
||||
ConnectorFailure(
|
||||
@ -292,9 +292,9 @@ def test_mock_connector_failure_recovery(
|
||||
doc1.model_dump(mode="json"),
|
||||
doc2.model_dump(mode="json"),
|
||||
],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
@ -372,24 +372,24 @@ def test_mock_connector_checkpoint_recovery(
|
||||
json=[
|
||||
{
|
||||
"documents": [doc.model_dump(mode="json") for doc in docs_batch_1],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"checkpoint": ConnectorCheckpoint(has_more=True).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [doc2.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"checkpoint": ConnectorCheckpoint(has_more=True).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [],
|
||||
# should never hit this, unhandled exception happens first
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
"unhandled_exception": "Simulated unhandled error",
|
||||
},
|
||||
@ -463,9 +463,9 @@ def test_mock_connector_checkpoint_recovery(
|
||||
json=[
|
||||
{
|
||||
"documents": [doc3.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"checkpoint": ConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
|
@ -1,10 +1,16 @@
|
||||
import contextvars
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Create a context variable for testing
|
||||
@ -148,3 +154,242 @@ def test_multiple_background_tasks() -> None:
|
||||
|
||||
# Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s)
|
||||
assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations
|
||||
|
||||
|
||||
def test_thread_safe_dict_basic_operations() -> None:
|
||||
"""Test basic operations of ThreadSafeDict"""
|
||||
d = ThreadSafeDict[str, int]()
|
||||
|
||||
# Test setting and getting
|
||||
d["a"] = 1
|
||||
assert d["a"] == 1
|
||||
|
||||
# Test get with default
|
||||
assert d.get("a", None) == 1
|
||||
assert d.get("b", 2) == 2
|
||||
|
||||
# Test deletion
|
||||
del d["a"]
|
||||
assert "a" not in d
|
||||
|
||||
# Test length
|
||||
d["x"] = 10
|
||||
d["y"] = 20
|
||||
assert len(d) == 2
|
||||
|
||||
# Test iteration
|
||||
keys = sorted(d.keys())
|
||||
assert keys == ["x", "y"]
|
||||
|
||||
# Test items and values
|
||||
assert dict(d.items()) == {"x": 10, "y": 20}
|
||||
assert sorted(d.values()) == [10, 20]
|
||||
|
||||
|
||||
def test_thread_safe_dict_concurrent_access() -> None:
|
||||
"""Test ThreadSafeDict with concurrent access from multiple threads"""
|
||||
d = ThreadSafeDict[str, int]()
|
||||
num_threads = 10
|
||||
iterations = 1000
|
||||
|
||||
def increment_values() -> None:
|
||||
for i in range(iterations):
|
||||
key = str(i % 5) # Use 5 different keys
|
||||
# Get current value or 0 if not exists, increment, then store
|
||||
current = d.get(key, 0)
|
||||
if current is not None: # This check is needed since get can return None
|
||||
d[key] = current + 1
|
||||
else:
|
||||
d[key] = 1
|
||||
|
||||
# Create and start threads
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=increment_values)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Verify results
|
||||
# Each key should have been incremented (num_threads * iterations) / 5 times
|
||||
expected_value = (num_threads * iterations) // 5
|
||||
for i in range(5):
|
||||
assert d[str(i)] == expected_value
|
||||
|
||||
|
||||
def test_thread_safe_dict_bulk_operations() -> None:
|
||||
"""Test bulk operations of ThreadSafeDict"""
|
||||
d = ThreadSafeDict[str, int]()
|
||||
|
||||
# Test update with dict
|
||||
d.update({"a": 1, "b": 2})
|
||||
assert dict(d.items()) == {"a": 1, "b": 2}
|
||||
|
||||
# Test update with kwargs
|
||||
d.update(c=3, d=4)
|
||||
assert dict(d.items()) == {"a": 1, "b": 2, "c": 3, "d": 4}
|
||||
|
||||
# Test clear
|
||||
d.clear()
|
||||
assert len(d) == 0
|
||||
|
||||
|
||||
def test_thread_safe_dict_concurrent_bulk_operations() -> None:
|
||||
"""Test ThreadSafeDict with concurrent bulk operations"""
|
||||
d = ThreadSafeDict[str, int]()
|
||||
num_threads = 5
|
||||
|
||||
def bulk_update(start: int) -> None:
|
||||
# Each thread updates with its own range of numbers
|
||||
updates = {str(i): i for i in range(start, start + 20)}
|
||||
d.update(updates)
|
||||
time.sleep(0.01) # Add some delay to increase chance of thread overlap
|
||||
|
||||
# Run updates concurrently
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(bulk_update, i * 20) for i in range(num_threads)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# Verify results
|
||||
assert len(d) == num_threads * 20
|
||||
# Verify all numbers from 0 to (num_threads * 20) are present
|
||||
for i in range(num_threads * 20):
|
||||
assert d[str(i)] == i
|
||||
|
||||
|
||||
def test_thread_safe_dict_atomic_operations() -> None:
|
||||
"""Test atomic operations with ThreadSafeDict's lock"""
|
||||
d = ThreadSafeDict[str, list[int]]()
|
||||
d["numbers"] = []
|
||||
|
||||
def append_numbers(start: int) -> None:
|
||||
numbers = d["numbers"]
|
||||
with d.lock:
|
||||
for i in range(start, start + 5):
|
||||
numbers.append(i)
|
||||
time.sleep(0.001) # Add delay to increase chance of thread overlap
|
||||
d["numbers"] = numbers
|
||||
|
||||
# Run concurrent append operations
|
||||
threads = []
|
||||
for i in range(4): # 4 threads, each adding 5 numbers
|
||||
t = threading.Thread(target=append_numbers, args=(i * 5,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Verify results
|
||||
numbers = d["numbers"]
|
||||
assert len(numbers) == 20 # 4 threads * 5 numbers each
|
||||
assert sorted(numbers) == list(range(20)) # All numbers 0-19 should be present
|
||||
|
||||
|
||||
def test_parallel_yield_basic() -> None:
|
||||
"""Test that parallel_yield correctly yields values from multiple generators."""
|
||||
|
||||
def make_gen(values: list[int], delay: float) -> Generator[int, None, None]:
|
||||
for v in values:
|
||||
time.sleep(delay)
|
||||
yield v
|
||||
|
||||
# Create generators with different delays
|
||||
gen1 = make_gen([1, 4, 7], 0.1) # Slower generator
|
||||
gen2 = make_gen([2, 5, 8], 0.05) # Faster generator
|
||||
gen3 = make_gen([3, 6, 9], 0.15) # Slowest generator
|
||||
|
||||
# Collect results with timestamps
|
||||
results: list[tuple[float, int]] = []
|
||||
start_time = time.time()
|
||||
|
||||
for value in parallel_yield([gen1, gen2, gen3]):
|
||||
results.append((time.time() - start_time, value))
|
||||
|
||||
# Verify all values were yielded
|
||||
assert sorted(v for _, v in results) == list(range(1, 10))
|
||||
|
||||
# Verify that faster generators yielded earlier
|
||||
# Group results by generator (values 1,4,7 are gen1, 2,5,8 are gen2, 3,6,9 are gen3)
|
||||
gen1_times = [t for t, v in results if v in (1, 4, 7)]
|
||||
gen2_times = [t for t, v in results if v in (2, 5, 8)]
|
||||
gen3_times = [t for t, v in results if v in (3, 6, 9)]
|
||||
|
||||
# Average times for each generator
|
||||
avg_gen1 = sum(gen1_times) / len(gen1_times)
|
||||
avg_gen2 = sum(gen2_times) / len(gen2_times)
|
||||
avg_gen3 = sum(gen3_times) / len(gen3_times)
|
||||
|
||||
# Verify gen2 (fastest) has lowest average time
|
||||
assert avg_gen2 < avg_gen1
|
||||
assert avg_gen2 < avg_gen3
|
||||
|
||||
|
||||
def test_parallel_yield_empty_generators() -> None:
|
||||
"""Test parallel_yield with empty generators."""
|
||||
|
||||
def empty_gen() -> Iterator[int]:
|
||||
if False:
|
||||
yield 0 # Makes this a generator function
|
||||
|
||||
gens = [empty_gen() for _ in range(3)]
|
||||
results = list(parallel_yield(gens))
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_parallel_yield_different_lengths() -> None:
|
||||
"""Test parallel_yield with generators of different lengths."""
|
||||
|
||||
def make_gen(count: int) -> Iterator[int]:
|
||||
for i in range(count):
|
||||
yield i
|
||||
time.sleep(0.01) # Small delay to ensure concurrent execution
|
||||
|
||||
gens = [
|
||||
make_gen(1), # Yields: [0]
|
||||
make_gen(3), # Yields: [0, 1, 2]
|
||||
make_gen(2), # Yields: [0, 1]
|
||||
]
|
||||
|
||||
results = list(parallel_yield(gens))
|
||||
assert len(results) == 6 # Total number of items from all generators
|
||||
assert sorted(results) == [0, 0, 0, 1, 1, 2]
|
||||
|
||||
|
||||
def test_parallel_yield_exception_handling() -> None:
|
||||
"""Test parallel_yield handles exceptions in generators properly."""
|
||||
|
||||
def failing_gen() -> Iterator[int]:
|
||||
yield 1
|
||||
raise ValueError("Generator failure")
|
||||
|
||||
def normal_gen() -> Iterator[int]:
|
||||
yield 2
|
||||
yield 3
|
||||
|
||||
gens = [failing_gen(), normal_gen()]
|
||||
|
||||
with pytest.raises(ValueError, match="Generator failure"):
|
||||
list(parallel_yield(gens))
|
||||
|
||||
|
||||
def test_parallel_yield_non_blocking() -> None:
|
||||
"""Test parallel_yield with non-blocking generators (simple ranges)."""
|
||||
|
||||
def range_gen(start: int, end: int) -> Iterator[int]:
|
||||
for i in range(start, end):
|
||||
yield i
|
||||
|
||||
# Create three overlapping ranges
|
||||
gens = [range_gen(0, 100), range_gen(100, 200), range_gen(200, 300)]
|
||||
|
||||
results = list(parallel_yield(gens))
|
||||
|
||||
print(results)
|
||||
# Verify no values are missing
|
||||
assert len(results) == 300 # Should have all values from 0 to 299
|
||||
assert sorted(results) == list(range(300))
|
||||
|
Loading…
x
Reference in New Issue
Block a user