Merge 94009b6e6b862878ce535eb7c2f9ee86295f177e into 99546e4a4d60d3d9f29587c153998eeeeae62ef5

This commit is contained in:
pablonyx 2025-03-23 16:26:18 -07:00 committed by GitHub
commit 7db9a754b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 113 additions and 107 deletions

View File

@ -2,8 +2,6 @@ import copy
import threading
from collections.abc import Callable
from collections.abc import Iterator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from functools import partial
from typing import Any
@ -64,6 +62,7 @@ from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import ThreadSafeDict
logger = setup_logger()
@ -899,115 +898,114 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[list[Document | ConnectorFailure]]:
try:
# Create a larger process pool for file conversion
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
func_with_args: list[
tuple[
Callable[..., Document | ConnectorFailure | None], tuple[Any, ...]
]
] = []
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield [
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
]
continue
files_batch.append(retrieved_file.drive_file)
if len(files_batch) < self.batch_size:
continue
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [(convert_func, (file,)) for file in files_batch]
results = run_functions_tuples_in_parallel(
func_with_args, max_workers=8
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
documents = []
for idx, (result, exception) in enumerate(results):
if exception:
error_str = f"Error converting file: {exception}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
failed_document=DocumentFailure(
document_id=files_batch[idx]["id"],
document_link=files_batch[idx]["webViewLink"],
),
failure_message=failure_message,
exception=retrieved_file.error,
failure_message=error_str,
exception=exception,
)
]
continue
files_batch.append(retrieved_file.drive_file)
elif result is not None:
documents.append(result)
if len(files_batch) < self.batch_size:
continue
if documents:
yield documents
batches_complete += 1
files_batch = []
# Process the batch
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
if documents:
yield documents
batches_complete += 1
files_batch = []
# Process any remaining files
if files_batch:
func_with_args = [(convert_func, (file,)) for file in files_batch]
results = run_functions_tuples_in_parallel(
func_with_args, max_workers=8
)
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
documents = []
for idx, (result, exception) in enumerate(results):
if exception:
error_str = f"Error converting file: {exception}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=files_batch[idx]["id"],
document_link=files_batch[idx]["webViewLink"],
),
failure_message=error_str,
exception=exception,
)
]
elif result is not None:
documents.append(result)
# Process any remaining files
if files_batch:
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
if documents:
yield documents
if documents:
yield documents
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
raise e
@ -1067,9 +1065,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
raise RuntimeError(
"_extract_slim_docs_from_google_drive: Stop signal detected"
)
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(

View File

@ -6,14 +6,17 @@ import uuid
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import MutableMapping
from collections.abc import Sequence
from concurrent.futures import as_completed
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
from typing import Any
from typing import cast
from typing import Generic
from typing import overload
from typing import Protocol
from typing import TypeVar
from pydantic import GetCoreSchemaHandler
@ -145,13 +148,20 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
return collections.abc.ValuesView(self)
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
allow_failures: bool = False,
max_workers: int | None = None,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a list of the results for each function.
This function preserves contextvars across threads, which is important for maintaining
context like tenant IDs in database sessions.
Args:
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
@ -159,7 +169,7 @@ def run_functions_tuples_in_parallel(
max_workers: Max number of worker threads
Returns:
dict: A dictionary mapping function names to their results or error messages.
list: A list of results from each function, in the same order as the input functions.
"""
workers = (
min(max_workers, len(functions_with_args))
@ -186,7 +196,7 @@ def run_functions_tuples_in_parallel(
results.append((index, future.result()))
except Exception as e:
logger.exception(f"Function at index {index} failed due to {e}")
results.append((index, None))
results.append((index, None)) # type: ignore
if not allow_failures:
raise
@ -288,7 +298,7 @@ def run_with_timeout(
if task.is_alive():
task.end()
return task.result
return task.result # type: ignore
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
@ -304,9 +314,9 @@ def run_in_background(
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
task.start()
return task
return cast(TimeoutThread[R], task)
def wait_on_background(task: TimeoutThread[R]) -> R: