Respect contextvars when parallelizing for Google Drive (#4291)

* k

* k

* fix typing
This commit is contained in:
pablonyx
2025-03-25 10:40:12 -07:00
committed by GitHub
parent 431316929c
commit 2024776c19
6 changed files with 256 additions and 114 deletions

View File

@@ -68,6 +68,8 @@ from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
@@ -875,6 +877,21 @@ def monitor_ccpair_permissions_taskset(
f"remaining={remaining} "
f"initial={initial}"
)
# Add telemetry for permission syncing progress
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
data={
"cc_pair_id": cc_pair_id,
"id": payload.id if payload else None,
"total_docs": initial if initial is not None else 0,
"remaining_docs": remaining,
"synced_docs": (initial - remaining) if initial is not None else 0,
"is_complete": remaining == 0,
},
tenant_id=tenant_id,
)
if remaining > 0:
return

View File

@@ -56,9 +56,12 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.redis.redis_connector import RedisConnector
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
@@ -570,6 +573,22 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# Add telemetry for indexing progress
optional_telemetry(
record_type=RecordType.INDEXING_PROGRESS,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_num": batch_num,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
memory_tracer.increment_and_maybe_trace()
# `make sure the checkpoints aren't getting too large`at some regular interval
@@ -585,6 +604,30 @@ def _run_indexing(
checkpoint=checkpoint,
)
# Add telemetry for completed indexing
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
redis_connector_index = redis_connector.new_index(
index_attempt_start.search_settings_id
)
final_progress = redis_connector_index.get_progress() or 0
optional_telemetry(
record_type=RecordType.INDEXING_COMPLETE,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_count": batch_num,
"time_elapsed_seconds": time.monotonic() - start_time,
"source": ctx.source.value,
"redis_progress": final_progress,
},
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "

View File

@@ -2,8 +2,6 @@ import copy
import threading
from collections.abc import Callable
from collections.abc import Iterator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from functools import partial
from typing import Any
@@ -65,6 +63,7 @@ from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import ThreadSafeDict
logger = setup_logger()
@@ -904,116 +903,114 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[list[Document | ConnectorFailure]]:
try:
# Create a larger process pool for file conversion
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
func_with_args: list[
tuple[
Callable[..., Document | ConnectorFailure | None], tuple[Any, ...]
]
] = []
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield [
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
]
continue
files_batch.append(retrieved_file.drive_file)
if len(files_batch) < self.batch_size:
continue
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [(convert_func, (file,)) for file in files_batch]
results = run_functions_tuples_in_parallel(
func_with_args, max_workers=8
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
documents = []
for idx, (result, exception) in enumerate(results):
if exception:
error_str = f"Error converting file: {exception}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
failed_document=DocumentFailure(
document_id=files_batch[idx]["id"],
document_link=files_batch[idx]["webViewLink"],
),
failure_message=failure_message,
exception=retrieved_file.error,
failure_message=error_str,
exception=exception,
)
]
continue
files_batch.append(retrieved_file.drive_file)
elif result is not None:
documents.append(result)
if len(files_batch) < self.batch_size:
continue
if documents:
yield documents
batches_complete += 1
files_batch = []
# Process the batch
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
if documents:
yield documents
batches_complete += 1
files_batch = []
# Process any remaining files
if files_batch:
func_with_args = [(convert_func, (file,)) for file in files_batch]
results = run_functions_tuples_in_parallel(
func_with_args, max_workers=8
)
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
documents = []
for idx, (result, exception) in enumerate(results):
if exception:
error_str = f"Error converting file: {exception}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=files_batch[idx]["id"],
document_link=files_batch[idx]["webViewLink"],
),
failure_message=error_str,
exception=exception,
)
]
elif result is not None:
documents.append(result)
# Process any remaining files
if files_batch:
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
if documents:
yield documents
if documents:
yield documents
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
raise e
@@ -1073,9 +1070,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
raise RuntimeError(
"_extract_slim_docs_from_google_drive: Stop signal detected"
)
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(

View File

@@ -8,23 +8,31 @@ from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.server.documents.models import ConnectorCredentialPair
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
# Comment out unused imports that cause mypy errors
# from onyx.auth.models import UserRole
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
# from onyx.db.engine import async_query_for_dms
logger = setup_logger()
@@ -201,6 +209,17 @@ def mark_attempt_in_progress(
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt.id,
"status": IndexingStatus.IN_PROGRESS.value,
"cc_pair_id": index_attempt.connector_credential_pair_id,
"search_settings_id": index_attempt.search_settings_id,
},
)
except Exception:
db_session.rollback()
raise
@@ -219,6 +238,19 @@ def mark_attempt_succeeded(
attempt.status = IndexingStatus.SUCCESS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.SUCCESS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -237,6 +269,19 @@ def mark_attempt_partially_succeeded(
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -259,6 +304,20 @@ def mark_attempt_canceled(
attempt.status = IndexingStatus.CANCELED
attempt.error_msg = reason
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.CANCELED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -283,6 +342,20 @@ def mark_attempt_failed(
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.FAILED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": failure_reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -434,7 +507,7 @@ def get_latest_index_attempts_parallel(
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,

View File

@@ -36,6 +36,10 @@ class RecordType(str, Enum):
LATENCY = "latency"
FAILURE = "failure"
METRIC = "metric"
INDEXING_PROGRESS = "indexing_progress"
INDEXING_COMPLETE = "indexing_complete"
PERMISSION_SYNC_PROGRESS = "permission_sync_progress"
INDEX_ATTEMPT_STATUS = "index_attempt_status"
def _get_or_generate_customer_id_mt(tenant_id: str) -> str:

View File

@@ -6,14 +6,17 @@ import uuid
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import MutableMapping
from collections.abc import Sequence
from concurrent.futures import as_completed
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
from typing import Any
from typing import cast
from typing import Generic
from typing import overload
from typing import Protocol
from typing import TypeVar
from pydantic import GetCoreSchemaHandler
@@ -145,13 +148,20 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
return collections.abc.ValuesView(self)
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
allow_failures: bool = False,
max_workers: int | None = None,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a list of the results for each function.
This function preserves contextvars across threads, which is important for maintaining
context like tenant IDs in database sessions.
Args:
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
@@ -159,7 +169,7 @@ def run_functions_tuples_in_parallel(
max_workers: Max number of worker threads
Returns:
dict: A dictionary mapping function names to their results or error messages.
list: A list of results from each function, in the same order as the input functions.
"""
workers = (
min(max_workers, len(functions_with_args))
@@ -186,7 +196,7 @@ def run_functions_tuples_in_parallel(
results.append((index, future.result()))
except Exception as e:
logger.exception(f"Function at index {index} failed due to {e}")
results.append((index, None))
results.append((index, None)) # type: ignore
if not allow_failures:
raise
@@ -288,7 +298,7 @@ def run_with_timeout(
if task.is_alive():
task.end()
return task.result
return task.result # type: ignore
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
@@ -304,9 +314,9 @@ def run_in_background(
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
task.start()
return task
return cast(TimeoutThread[R], task)
def wait_on_background(task: TimeoutThread[R]) -> R: