From eb569bf79d2dd17ae2c7265af974053a328b709c Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 21 Apr 2025 16:27:31 -0700 Subject: [PATCH] add emails to retry with on 403 (#4565) * add emails to retry with on 403 * attempted fix for connector test * CW comments * connector test fix * test fixes and continue on 403 * fix tests * fix tests * fix concurrency tests * fix integration tests with llmprovider eager loading --- .../onyx/connectors/google_drive/connector.py | 4 +- .../connectors/google_drive/doc_conversion.py | 70 +++++++++++++++++-- .../connectors/google_utils/google_utils.py | 11 +++ backend/onyx/db/llm.py | 62 ++++++++-------- backend/onyx/utils/threadpool_concurrency.py | 24 +++++++ .../google_drive/consts_and_utils.py | 14 ++++ .../google_drive/test_user_1_oauth.py | 42 ++++++++++- .../onyx/utils/test_threadpool_concurrency.py | 2 +- 8 files changed, 186 insertions(+), 43 deletions(-) diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index d74cda80c..533b66710 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -40,6 +40,7 @@ from onyx.connectors.google_drive.models import RetrievedDriveFile from onyx.connectors.google_drive.models import StageCompletion from onyx.connectors.google_utils.google_auth import get_google_creds from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval +from onyx.connectors.google_utils.google_utils import get_file_owners from onyx.connectors.google_utils.google_utils import GoogleFields from onyx.connectors.google_utils.resources import get_admin_service from onyx.connectors.google_utils.resources import get_drive_service @@ -949,7 +950,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck ( convert_func, ( - file.user_email, + [file.user_email, self.primary_admin_email] + + get_file_owners(file.drive_file), file.drive_file, ), ) diff --git a/backend/onyx/connectors/google_drive/doc_conversion.py b/backend/onyx/connectors/google_drive/doc_conversion.py index 246a08a5d..3308e181e 100644 --- a/backend/onyx/connectors/google_drive/doc_conversion.py +++ b/backend/onyx/connectors/google_drive/doc_conversion.py @@ -43,6 +43,8 @@ logger = setup_logger() SMART_CHIP_CHAR = "\ue907" WEB_VIEW_LINK_KEY = "webViewLink" +MAX_RETRIEVER_EMAILS = 20 + # Mapping of Google Drive mime types to export formats GOOGLE_MIME_TYPES_TO_EXPORT = { GDriveMimeType.DOC.value: "text/plain", @@ -297,12 +299,66 @@ def align_basic_advanced( return new_sections -# We used to always get the user email from the file owners when available, -# but this was causing issues with shared folders where the owner was not included in the service account -# now we use the email of the account that successfully listed the file. Leaving this in case we end up -# wanting to retry with file owners and/or admin email at some point. -# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email def convert_drive_item_to_document( + creds: Any, + allow_images: bool, + size_threshold: int, + retriever_emails: list[str], + file: GoogleDriveFileType, +) -> Document | ConnectorFailure | None: + """ + Attempt to convert a drive item to a document with each retriever email + in order. returns upon a successful retrieval or a non-403 error. + + We used to always get the user email from the file owners when available, + but this was causing issues with shared folders where the owner was not included in the service account + now we use the email of the account that successfully listed the file. There are cases where a + user that can list a file cannot download it, so we retry with file owners and admin email. + """ + first_error = None + doc_or_failure = None + retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS] + # use seen instead of list(set()) to avoid re-ordering the retriever emails + seen = set() + for retriever_email in retriever_emails: + if retriever_email in seen: + continue + seen.add(retriever_email) + doc_or_failure = _convert_drive_item_to_document( + creds, allow_images, size_threshold, retriever_email, file + ) + if ( + doc_or_failure is None + or isinstance(doc_or_failure, Document) + or not ( + isinstance(doc_or_failure.exception, HttpError) + and doc_or_failure.exception.status_code == 403 + ) + ): + return doc_or_failure + + if first_error is None: + first_error = doc_or_failure + else: + first_error.failure_message += f"\n\n{doc_or_failure.failure_message}" + + if ( + first_error + and isinstance(first_error.exception, HttpError) + and first_error.exception.status_code == 403 + ): + # This SHOULD happen very rarely, and we don't want to break the indexing process when + # a high volume of 403s occurs early. We leave a verbose log to help investigate. + logger.error( + f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error." + f"Attempted to retrieve with {retriever_emails}," + f"got the following errors: {first_error.failure_message}" + ) + return None + return first_error + + +def _convert_drive_item_to_document( creds: Any, allow_images: bool, size_threshold: int, @@ -392,7 +448,9 @@ def convert_drive_item_to_document( ) except Exception as e: file_name = file.get("name") - error_str = f"Error converting file '{file_name}' to Document: {e}" + error_str = ( + f"Error converting file '{file_name}' to Document as {retriever_email}: {e}" + ) if isinstance(e, HttpError) and e.status_code == 403: logger.warning( f"Uncommon permissions error while downloading file. User " diff --git a/backend/onyx/connectors/google_utils/google_utils.py b/backend/onyx/connectors/google_utils/google_utils.py index 4ad6cbe7a..12985e39e 100644 --- a/backend/onyx/connectors/google_utils/google_utils.py +++ b/backend/onyx/connectors/google_utils/google_utils.py @@ -97,6 +97,17 @@ def _execute_with_retry(request: Any) -> Any: raise Exception(f"Failed to execute request after {max_attempts} attempts") +def get_file_owners(file: GoogleDriveFileType) -> list[str]: + """ + Get the owners of a file if the attribute is present. + """ + return [ + owner.get("emailAddress") + for owner in file.get("owners", []) + if owner.get("emailAddress") + ] + + def execute_paginated_retrieval( retrieval_function: Callable, list_key: str | None = None, diff --git a/backend/onyx/db/llm.py b/backend/onyx/db/llm.py index 91537c493..085ca4494 100644 --- a/backend/onyx/db/llm.py +++ b/backend/onyx/db/llm.py @@ -73,10 +73,8 @@ def upsert_llm_provider( llm_provider_upsert_request: LLMProviderUpsertRequest, db_session: Session, ) -> LLMProviderView: - existing_llm_provider = db_session.scalar( - select(LLMProviderModel).where( - LLMProviderModel.name == llm_provider_upsert_request.name - ) + existing_llm_provider = fetch_existing_llm_provider( + llm_provider_upsert_request.name, db_session ) if not existing_llm_provider: @@ -155,8 +153,13 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM def fetch_existing_llm_providers( db_session: Session, + only_public: bool = False, ) -> list[LLMProviderModel]: - stmt = select(LLMProviderModel) + stmt = select(LLMProviderModel).options( + selectinload(LLMProviderModel.model_configurations) + ) + if only_public: + stmt = stmt.where(LLMProviderModel.is_public == True) # noqa: E712 return list(db_session.scalars(stmt).all()) @@ -164,7 +167,9 @@ def fetch_existing_llm_provider( provider_name: str, db_session: Session ) -> LLMProviderModel | None: provider_model = db_session.scalar( - select(LLMProviderModel).where(LLMProviderModel.name == provider_name) + select(LLMProviderModel) + .where(LLMProviderModel.name == provider_name) + .options(selectinload(LLMProviderModel.model_configurations)) ) return provider_model @@ -174,21 +179,18 @@ def fetch_existing_llm_providers_for_user( db_session: Session, user: User | None = None, ) -> list[LLMProviderModel]: + # if user is anonymous if not user: - if AUTH_TYPE != AuthType.DISABLED: - # User is anonymous - return list( - db_session.scalars( - select(LLMProviderModel).where( - LLMProviderModel.is_public == True # noqa: E712 - ) - ).all() - ) - else: - # If auth is disabled, user has access to all providers - return fetch_existing_llm_providers(db_session) + # Only fetch public providers if auth is turned on + return fetch_existing_llm_providers( + db_session, only_public=AUTH_TYPE != AuthType.DISABLED + ) - stmt = select(LLMProviderModel).distinct() + stmt = ( + select(LLMProviderModel) + .options(selectinload(LLMProviderModel.model_configurations)) + .distinct() + ) user_groups_select = select(User__UserGroup.user_group_id).where( User__UserGroup.user_id == user.id ) @@ -217,9 +219,9 @@ def fetch_embedding_provider( def fetch_default_provider(db_session: Session) -> LLMProviderView | None: provider_model = db_session.scalar( - select(LLMProviderModel).where( - LLMProviderModel.is_default_provider == True # noqa: E712 - ) + select(LLMProviderModel) + .where(LLMProviderModel.is_default_provider == True) # noqa: E712 + .options(selectinload(LLMProviderModel.model_configurations)) ) if not provider_model: return None @@ -228,9 +230,9 @@ def fetch_default_provider(db_session: Session) -> LLMProviderView | None: def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None: provider_model = db_session.scalar( - select(LLMProviderModel).where( - LLMProviderModel.is_default_vision_provider == True # noqa: E712 - ) + select(LLMProviderModel) + .where(LLMProviderModel.is_default_vision_provider == True) # noqa: E712 + .options(selectinload(LLMProviderModel.model_configurations)) ) if not provider_model: return None @@ -240,9 +242,7 @@ def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None def fetch_llm_provider_view( db_session: Session, provider_name: str ) -> LLMProviderView | None: - provider_model = db_session.scalar( - select(LLMProviderModel).where(LLMProviderModel.name == provider_name) - ) + provider_model = fetch_existing_llm_provider(provider_name, db_session) if not provider_model: return None return LLMProviderView.from_model(provider_model) @@ -254,11 +254,7 @@ def fetch_max_input_tokens( model_name: str, output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS, ) -> int: - llm_provider = db_session.scalar( - select(LLMProviderModel) - .where(LLMProviderModel.provider == provider_name) - .options(selectinload(LLMProviderModel.model_configurations)) - ) + llm_provider = fetch_existing_llm_provider(provider_name, db_session) if not llm_provider: raise RuntimeError(f"No LLM Provider with the name {provider_name}") diff --git a/backend/onyx/utils/threadpool_concurrency.py b/backend/onyx/utils/threadpool_concurrency.py index 429480edc..0271bc640 100644 --- a/backend/onyx/utils/threadpool_concurrency.py +++ b/backend/onyx/utils/threadpool_concurrency.py @@ -145,6 +145,30 @@ class ThreadSafeDict(MutableMapping[KT, VT]): with self.lock: return collections.abc.ValuesView(self) + @overload + def atomic_get_set( + self, key: KT, value_callback: Callable[[VT], VT], default: VT + ) -> tuple[VT, VT]: ... + + @overload + def atomic_get_set( + self, key: KT, value_callback: Callable[[VT | _T], VT], default: VT | _T + ) -> tuple[VT | _T, VT]: ... + + def atomic_get_set( + self, key: KT, value_callback: Callable[[Any], VT], default: Any = None + ) -> tuple[Any, VT]: + """Replace a value from the dict with a function applied to the previous value, atomically. + + Returns: + A tuple of the previous value and the new value. + """ + with self.lock: + val = self._dict.get(key, default) + new_val = value_callback(val) + self._dict[key] = new_val + return val, new_val + class CallableProtocol(Protocol): def __call__(self, *args: Any, **kwargs: Any) -> Any: ... diff --git a/backend/tests/daily/connectors/google_drive/consts_and_utils.py b/backend/tests/daily/connectors/google_drive/consts_and_utils.py index 678d958da..40e93afae 100644 --- a/backend/tests/daily/connectors/google_drive/consts_and_utils.py +++ b/backend/tests/daily/connectors/google_drive/consts_and_utils.py @@ -2,9 +2,11 @@ import time from collections.abc import Sequence from onyx.connectors.google_drive.connector import GoogleDriveConnector +from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.connectors.models import TextSection from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector +from tests.daily.connectors.utils import load_everything_from_checkpoint_connector ALL_FILES = list(range(0, 60)) SHARED_DRIVE_FILES = list(range(20, 25)) @@ -26,6 +28,8 @@ FOLDER_2_2_FILE_IDS = list(range(55, 60)) SECTIONS_FILE_IDS = [61] FOLDER_3_FILE_IDS = list(range(62, 65)) +DONWLOAD_REVOKED_FILE_ID = 21 + PUBLIC_FOLDER_RANGE = FOLDER_1_2_FILE_IDS PUBLIC_FILE_IDS = list(range(55, 57)) PUBLIC_RANGE = PUBLIC_FOLDER_RANGE + PUBLIC_FILE_IDS @@ -234,3 +238,13 @@ def load_all_docs(connector: GoogleDriveConnector) -> list[Document]: 0, time.time(), ) + + +def load_all_docs_with_failures( + connector: GoogleDriveConnector, +) -> list[Document | ConnectorFailure]: + return load_everything_from_checkpoint_connector( + connector, + 0, + time.time(), + ) diff --git a/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py b/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py index 7c4eb91ed..7728e4255 100644 --- a/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py +++ b/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py @@ -3,21 +3,54 @@ from unittest.mock import MagicMock from unittest.mock import patch from onyx.connectors.google_drive.connector import GoogleDriveConnector +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import ( assert_expected_docs_in_retrieved_docs, ) +from tests.daily.connectors.google_drive.consts_and_utils import ( + DONWLOAD_REVOKED_FILE_ID, +) from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs +from tests.daily.connectors.google_drive.consts_and_utils import ( + load_all_docs_with_failures, +) from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS +def _check_for_error( + retrieved_docs_failures: list[Document | ConnectorFailure], + expected_file_ids: list[int], +) -> list[Document]: + retrieved_docs = [ + doc for doc in retrieved_docs_failures if isinstance(doc, Document) + ] + retrieved_failures = [ + failure + for failure in retrieved_docs_failures + if isinstance(failure, ConnectorFailure) + ] + assert len(retrieved_failures) <= 1 + + # current behavior is to fail silently for 403s; leaving this here for when we revert + # if all 403s get fixed + if len(retrieved_failures) == 1: + fail_msg = retrieved_failures[0].failure_message + assert "HttpError 403" in fail_msg + assert f"file_{DONWLOAD_REVOKED_FILE_ID}.txt" in fail_msg + + expected_file_ids.remove(DONWLOAD_REVOKED_FILE_ID) + return retrieved_docs + + @patch( "onyx.file_processing.extract_file_text.get_unstructured_api_key", return_value=None, @@ -36,7 +69,7 @@ def test_all( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs = load_all_docs(connector) + retrieved_docs_failures = load_all_docs_with_failures(connector) expected_file_ids = ( # These are the files from my drive @@ -50,6 +83,9 @@ def test_all( + ADMIN_FOLDER_3_FILE_IDS + list(range(0, 2)) ) + + retrieved_docs = _check_for_error(retrieved_docs_failures, expected_file_ids) + assert_expected_docs_in_retrieved_docs( retrieved_docs=retrieved_docs, expected_file_ids=expected_file_ids, @@ -74,7 +110,7 @@ def test_shared_drives_only( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs = load_all_docs(connector) + retrieved_docs_failures = load_all_docs_with_failures(connector) expected_file_ids = ( # These are the files from shared drives @@ -83,6 +119,8 @@ def test_shared_drives_only( + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS ) + + retrieved_docs = _check_for_error(retrieved_docs_failures, expected_file_ids) assert_expected_docs_in_retrieved_docs( retrieved_docs=retrieved_docs, expected_file_ids=expected_file_ids, diff --git a/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py b/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py index a4bfee8a8..602aa3865 100644 --- a/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py +++ b/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py @@ -197,7 +197,7 @@ def test_thread_safe_dict_concurrent_access() -> None: for i in range(iterations): key = str(i % 5) # Use 5 different keys # Get current value or 0 if not exists, increment, then store - d[key] = d.get(key, 0) + 1 + d.atomic_get_set(key, lambda x: x + 1, 0) # Create and start threads threads = []