mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 20:08:36 +02:00
Merge 94009b6e6b862878ce535eb7c2f9ee86295f177e into 99546e4a4d60d3d9f29587c153998eeeeae62ef5
This commit is contained in:
commit
7db9a754b2
@ -2,8 +2,6 @@ import copy
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
@ -64,6 +62,7 @@ from onyx.utils.lazy import lazy_eval
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
|
||||
logger = setup_logger()
|
||||
@ -899,115 +898,114 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[list[Document | ConnectorFailure]]:
|
||||
try:
|
||||
# Create a larger process pool for file conversion
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
self.allow_images,
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
self.allow_images,
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
func_with_args: list[
|
||||
tuple[
|
||||
Callable[..., Document | ConnectorFailure | None], tuple[Any, ...]
|
||||
]
|
||||
] = []
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if retrieved_file.error is not None:
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = (
|
||||
f"retrieval failure during stage: {failure_stage},"
|
||||
)
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += (
|
||||
f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
)
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
]
|
||||
continue
|
||||
files_batch.append(retrieved_file.drive_file)
|
||||
|
||||
if len(files_batch) < self.batch_size:
|
||||
continue
|
||||
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [(convert_func, (file,)) for file in files_batch]
|
||||
results = run_functions_tuples_in_parallel(
|
||||
func_with_args, max_workers=8
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if retrieved_file.error is not None:
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = (
|
||||
f"retrieval failure during stage: {failure_stage},"
|
||||
)
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += (
|
||||
f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
)
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
documents = []
|
||||
for idx, (result, exception) in enumerate(results):
|
||||
if exception:
|
||||
error_str = f"Error converting file: {exception}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
failed_document=DocumentFailure(
|
||||
document_id=files_batch[idx]["id"],
|
||||
document_link=files_batch[idx]["webViewLink"],
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
failure_message=error_str,
|
||||
exception=exception,
|
||||
)
|
||||
]
|
||||
continue
|
||||
files_batch.append(retrieved_file.drive_file)
|
||||
elif result is not None:
|
||||
documents.append(result)
|
||||
|
||||
if len(files_batch) < self.batch_size:
|
||||
continue
|
||||
if documents:
|
||||
yield documents
|
||||
batches_complete += 1
|
||||
files_batch = []
|
||||
|
||||
# Process the batch
|
||||
futures = [
|
||||
executor.submit(convert_func, file) for file in files_batch
|
||||
]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
error_str = f"Error converting file: {e}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=retrieved_file.drive_file["id"],
|
||||
document_link=retrieved_file.drive_file[
|
||||
"webViewLink"
|
||||
],
|
||||
),
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
]
|
||||
if batches_complete > BATCHES_PER_CHECKPOINT:
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
|
||||
return # create a new checkpoint
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
batches_complete += 1
|
||||
files_batch = []
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
func_with_args = [(convert_func, (file,)) for file in files_batch]
|
||||
results = run_functions_tuples_in_parallel(
|
||||
func_with_args, max_workers=8
|
||||
)
|
||||
|
||||
if batches_complete > BATCHES_PER_CHECKPOINT:
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
|
||||
return # create a new checkpoint
|
||||
documents = []
|
||||
for idx, (result, exception) in enumerate(results):
|
||||
if exception:
|
||||
error_str = f"Error converting file: {exception}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=files_batch[idx]["id"],
|
||||
document_link=files_batch[idx]["webViewLink"],
|
||||
),
|
||||
failure_message=error_str,
|
||||
exception=exception,
|
||||
)
|
||||
]
|
||||
elif result is not None:
|
||||
documents.append(result)
|
||||
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
futures = [
|
||||
executor.submit(convert_func, file) for file in files_batch
|
||||
]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
error_str = f"Error converting file: {e}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=retrieved_file.drive_file["id"],
|
||||
document_link=retrieved_file.drive_file[
|
||||
"webViewLink"
|
||||
],
|
||||
),
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
]
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
if documents:
|
||||
yield documents
|
||||
except Exception as e:
|
||||
logger.exception(f"Error extracting documents from Google Drive: {e}")
|
||||
raise e
|
||||
@ -1067,9 +1065,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
raise RuntimeError(
|
||||
"_extract_slim_docs_from_google_drive: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
||||
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
|
@ -6,14 +6,17 @@ import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import MutableMapping
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import FIRST_COMPLETED
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import wait
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Generic
|
||||
from typing import overload
|
||||
from typing import Protocol
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
@ -145,13 +148,20 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
return collections.abc.ValuesView(self)
|
||||
|
||||
|
||||
class CallableProtocol(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
|
||||
allow_failures: bool = False,
|
||||
max_workers: int | None = None,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Executes multiple functions in parallel and returns a list of the results for each function.
|
||||
This function preserves contextvars across threads, which is important for maintaining
|
||||
context like tenant IDs in database sessions.
|
||||
|
||||
Args:
|
||||
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
||||
@ -159,7 +169,7 @@ def run_functions_tuples_in_parallel(
|
||||
max_workers: Max number of worker threads
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping function names to their results or error messages.
|
||||
list: A list of results from each function, in the same order as the input functions.
|
||||
"""
|
||||
workers = (
|
||||
min(max_workers, len(functions_with_args))
|
||||
@ -186,7 +196,7 @@ def run_functions_tuples_in_parallel(
|
||||
results.append((index, future.result()))
|
||||
except Exception as e:
|
||||
logger.exception(f"Function at index {index} failed due to {e}")
|
||||
results.append((index, None))
|
||||
results.append((index, None)) # type: ignore
|
||||
|
||||
if not allow_failures:
|
||||
raise
|
||||
@ -288,7 +298,7 @@ def run_with_timeout(
|
||||
if task.is_alive():
|
||||
task.end()
|
||||
|
||||
return task.result
|
||||
return task.result # type: ignore
|
||||
|
||||
|
||||
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
|
||||
@ -304,9 +314,9 @@ def run_in_background(
|
||||
"""
|
||||
context = contextvars.copy_context()
|
||||
# Timeout not used in the non-blocking case
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
|
||||
task.start()
|
||||
return task
|
||||
return cast(TimeoutThread[R], task)
|
||||
|
||||
|
||||
def wait_on_background(task: TimeoutThread[R]) -> R:
|
||||
|
Loading…
x
Reference in New Issue
Block a user