diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index a9f1f9469..325537d5a 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -2,8 +2,6 @@ import copy import threading from collections.abc import Callable from collections.abc import Iterator -from concurrent.futures import as_completed -from concurrent.futures import ThreadPoolExecutor from enum import Enum from functools import partial from typing import Any @@ -64,6 +62,7 @@ from onyx.utils.lazy import lazy_eval from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder from onyx.utils.threadpool_concurrency import parallel_yield +from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from onyx.utils.threadpool_concurrency import ThreadSafeDict logger = setup_logger() @@ -899,115 +898,114 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[list[Document | ConnectorFailure]]: try: - # Create a larger process pool for file conversion - with ThreadPoolExecutor(max_workers=8) as executor: - # Prepare a partial function with the credentials and admin email - convert_func = partial( - _convert_single_file, - self.creds, - self.primary_admin_email, - self.allow_images, + # Prepare a partial function with the credentials and admin email + convert_func = partial( + _convert_single_file, + self.creds, + self.primary_admin_email, + self.allow_images, + ) + + # Fetch files in batches + batches_complete = 0 + files_batch: list[GoogleDriveFileType] = [] + func_with_args: list[ + tuple[ + Callable[..., Document | ConnectorFailure | None], tuple[Any, ...] + ] + ] = [] + for retrieved_file in self._fetch_drive_items( + is_slim=False, + checkpoint=checkpoint, + start=start, + end=end, + ): + if retrieved_file.error is not None: + failure_stage = retrieved_file.completion_stage.value + failure_message = ( + f"retrieval failure during stage: {failure_stage}," + ) + failure_message += f"user: {retrieved_file.user_email}," + failure_message += ( + f"parent drive/folder: {retrieved_file.parent_id}," + ) + failure_message += f"error: {retrieved_file.error}" + logger.error(failure_message) + yield [ + ConnectorFailure( + failed_entity=EntityFailure( + entity_id=failure_stage, + ), + failure_message=failure_message, + exception=retrieved_file.error, + ) + ] + continue + files_batch.append(retrieved_file.drive_file) + + if len(files_batch) < self.batch_size: + continue + + # Process the batch using run_functions_tuples_in_parallel + func_with_args = [(convert_func, (file,)) for file in files_batch] + results = run_functions_tuples_in_parallel( + func_with_args, max_workers=8 ) - # Fetch files in batches - batches_complete = 0 - files_batch: list[GoogleDriveFileType] = [] - for retrieved_file in self._fetch_drive_items( - is_slim=False, - checkpoint=checkpoint, - start=start, - end=end, - ): - if retrieved_file.error is not None: - failure_stage = retrieved_file.completion_stage.value - failure_message = ( - f"retrieval failure during stage: {failure_stage}," - ) - failure_message += f"user: {retrieved_file.user_email}," - failure_message += ( - f"parent drive/folder: {retrieved_file.parent_id}," - ) - failure_message += f"error: {retrieved_file.error}" - logger.error(failure_message) + documents = [] + for idx, (result, exception) in enumerate(results): + if exception: + error_str = f"Error converting file: {exception}" + logger.error(error_str) yield [ ConnectorFailure( - failed_entity=EntityFailure( - entity_id=failure_stage, + failed_document=DocumentFailure( + document_id=files_batch[idx]["id"], + document_link=files_batch[idx]["webViewLink"], ), - failure_message=failure_message, - exception=retrieved_file.error, + failure_message=error_str, + exception=exception, ) ] - continue - files_batch.append(retrieved_file.drive_file) + elif result is not None: + documents.append(result) - if len(files_batch) < self.batch_size: - continue + if documents: + yield documents + batches_complete += 1 + files_batch = [] - # Process the batch - futures = [ - executor.submit(convert_func, file) for file in files_batch - ] - documents = [] - for future in as_completed(futures): - try: - doc = future.result() - if doc is not None: - documents.append(doc) - except Exception as e: - error_str = f"Error converting file: {e}" - logger.error(error_str) - yield [ - ConnectorFailure( - failed_document=DocumentFailure( - document_id=retrieved_file.drive_file["id"], - document_link=retrieved_file.drive_file[ - "webViewLink" - ], - ), - failure_message=error_str, - exception=e, - ) - ] + if batches_complete > BATCHES_PER_CHECKPOINT: + checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids + return # create a new checkpoint - if documents: - yield documents - batches_complete += 1 - files_batch = [] + # Process any remaining files + if files_batch: + func_with_args = [(convert_func, (file,)) for file in files_batch] + results = run_functions_tuples_in_parallel( + func_with_args, max_workers=8 + ) - if batches_complete > BATCHES_PER_CHECKPOINT: - checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids - return # create a new checkpoint + documents = [] + for idx, (result, exception) in enumerate(results): + if exception: + error_str = f"Error converting file: {exception}" + logger.error(error_str) + yield [ + ConnectorFailure( + failed_document=DocumentFailure( + document_id=files_batch[idx]["id"], + document_link=files_batch[idx]["webViewLink"], + ), + failure_message=error_str, + exception=exception, + ) + ] + elif result is not None: + documents.append(result) - # Process any remaining files - if files_batch: - futures = [ - executor.submit(convert_func, file) for file in files_batch - ] - documents = [] - for future in as_completed(futures): - try: - doc = future.result() - if doc is not None: - documents.append(doc) - except Exception as e: - error_str = f"Error converting file: {e}" - logger.error(error_str) - yield [ - ConnectorFailure( - failed_document=DocumentFailure( - document_id=retrieved_file.drive_file["id"], - document_link=retrieved_file.drive_file[ - "webViewLink" - ], - ), - failure_message=error_str, - exception=e, - ) - ] - - if documents: - yield documents + if documents: + yield documents except Exception as e: logger.exception(f"Error extracting documents from Google Drive: {e}") raise e @@ -1067,9 +1065,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo raise RuntimeError( "_extract_slim_docs_from_google_drive: Stop signal detected" ) - callback.progress("_extract_slim_docs_from_google_drive", 1) - yield slim_batch def retrieve_all_slim_documents( diff --git a/backend/onyx/utils/threadpool_concurrency.py b/backend/onyx/utils/threadpool_concurrency.py index 67d90b02f..fd39fa04d 100644 --- a/backend/onyx/utils/threadpool_concurrency.py +++ b/backend/onyx/utils/threadpool_concurrency.py @@ -6,14 +6,17 @@ import uuid from collections.abc import Callable from collections.abc import Iterator from collections.abc import MutableMapping +from collections.abc import Sequence from concurrent.futures import as_completed from concurrent.futures import FIRST_COMPLETED from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from concurrent.futures import wait from typing import Any +from typing import cast from typing import Generic from typing import overload +from typing import Protocol from typing import TypeVar from pydantic import GetCoreSchemaHandler @@ -145,13 +148,20 @@ class ThreadSafeDict(MutableMapping[KT, VT]): return collections.abc.ValuesView(self) +class CallableProtocol(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + def run_functions_tuples_in_parallel( - functions_with_args: list[tuple[Callable, tuple]], + functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]], allow_failures: bool = False, max_workers: int | None = None, ) -> list[Any]: """ Executes multiple functions in parallel and returns a list of the results for each function. + This function preserves contextvars across threads, which is important for maintaining + context like tenant IDs in database sessions. Args: functions_with_args: List of tuples each containing the function callable and a tuple of arguments. @@ -159,7 +169,7 @@ def run_functions_tuples_in_parallel( max_workers: Max number of worker threads Returns: - dict: A dictionary mapping function names to their results or error messages. + list: A list of results from each function, in the same order as the input functions. """ workers = ( min(max_workers, len(functions_with_args)) @@ -186,7 +196,7 @@ def run_functions_tuples_in_parallel( results.append((index, future.result())) except Exception as e: logger.exception(f"Function at index {index} failed due to {e}") - results.append((index, None)) + results.append((index, None)) # type: ignore if not allow_failures: raise @@ -288,7 +298,7 @@ def run_with_timeout( if task.is_alive(): task.end() - return task.result + return task.result # type: ignore # NOTE: this function should really only be used when run_functions_tuples_in_parallel is @@ -304,9 +314,9 @@ def run_in_background( """ context = contextvars.copy_context() # Timeout not used in the non-blocking case - task = TimeoutThread(-1, context.run, func, *args, **kwargs) + task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore task.start() - return task + return cast(TimeoutThread[R], task) def wait_on_background(task: TimeoutThread[R]) -> R: