add emails to retry with on 403 (#4565)

* add emails to retry with on 403

* attempted fix for connector test

* CW comments

* connector test fix

* test fixes and continue on 403

* fix tests

* fix tests

* fix concurrency tests

* fix integration tests with llmprovider eager loading
This commit is contained in:
Evan Lohn 2025-04-21 16:27:31 -07:00 committed by GitHub
parent f3d5303d93
commit eb569bf79d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 186 additions and 43 deletions

View File

@ -40,6 +40,7 @@ from onyx.connectors.google_drive.models import RetrievedDriveFile
from onyx.connectors.google_drive.models import StageCompletion
from onyx.connectors.google_utils.google_auth import get_google_creds
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.google_utils import get_file_owners
from onyx.connectors.google_utils.google_utils import GoogleFields
from onyx.connectors.google_utils.resources import get_admin_service
from onyx.connectors.google_utils.resources import get_drive_service
@ -949,7 +950,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
(
convert_func,
(
file.user_email,
[file.user_email, self.primary_admin_email]
+ get_file_owners(file.drive_file),
file.drive_file,
),
)

View File

@ -43,6 +43,8 @@ logger = setup_logger()
SMART_CHIP_CHAR = "\ue907"
WEB_VIEW_LINK_KEY = "webViewLink"
MAX_RETRIEVER_EMAILS = 20
# Mapping of Google Drive mime types to export formats
GOOGLE_MIME_TYPES_TO_EXPORT = {
GDriveMimeType.DOC.value: "text/plain",
@ -297,12 +299,66 @@ def align_basic_advanced(
return new_sections
# We used to always get the user email from the file owners when available,
# but this was causing issues with shared folders where the owner was not included in the service account
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
# wanting to retry with file owners and/or admin email at some point.
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
def convert_drive_item_to_document(
creds: Any,
allow_images: bool,
size_threshold: int,
retriever_emails: list[str],
file: GoogleDriveFileType,
) -> Document | ConnectorFailure | None:
"""
Attempt to convert a drive item to a document with each retriever email
in order. returns upon a successful retrieval or a non-403 error.
We used to always get the user email from the file owners when available,
but this was causing issues with shared folders where the owner was not included in the service account
now we use the email of the account that successfully listed the file. There are cases where a
user that can list a file cannot download it, so we retry with file owners and admin email.
"""
first_error = None
doc_or_failure = None
retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS]
# use seen instead of list(set()) to avoid re-ordering the retriever emails
seen = set()
for retriever_email in retriever_emails:
if retriever_email in seen:
continue
seen.add(retriever_email)
doc_or_failure = _convert_drive_item_to_document(
creds, allow_images, size_threshold, retriever_email, file
)
if (
doc_or_failure is None
or isinstance(doc_or_failure, Document)
or not (
isinstance(doc_or_failure.exception, HttpError)
and doc_or_failure.exception.status_code == 403
)
):
return doc_or_failure
if first_error is None:
first_error = doc_or_failure
else:
first_error.failure_message += f"\n\n{doc_or_failure.failure_message}"
if (
first_error
and isinstance(first_error.exception, HttpError)
and first_error.exception.status_code == 403
):
# This SHOULD happen very rarely, and we don't want to break the indexing process when
# a high volume of 403s occurs early. We leave a verbose log to help investigate.
logger.error(
f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error."
f"Attempted to retrieve with {retriever_emails},"
f"got the following errors: {first_error.failure_message}"
)
return None
return first_error
def _convert_drive_item_to_document(
creds: Any,
allow_images: bool,
size_threshold: int,
@ -392,7 +448,9 @@ def convert_drive_item_to_document(
)
except Exception as e:
file_name = file.get("name")
error_str = f"Error converting file '{file_name}' to Document: {e}"
error_str = (
f"Error converting file '{file_name}' to Document as {retriever_email}: {e}"
)
if isinstance(e, HttpError) and e.status_code == 403:
logger.warning(
f"Uncommon permissions error while downloading file. User "

View File

@ -97,6 +97,17 @@ def _execute_with_retry(request: Any) -> Any:
raise Exception(f"Failed to execute request after {max_attempts} attempts")
def get_file_owners(file: GoogleDriveFileType) -> list[str]:
"""
Get the owners of a file if the attribute is present.
"""
return [
owner.get("emailAddress")
for owner in file.get("owners", [])
if owner.get("emailAddress")
]
def execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str | None = None,

View File

@ -73,10 +73,8 @@ def upsert_llm_provider(
llm_provider_upsert_request: LLMProviderUpsertRequest,
db_session: Session,
) -> LLMProviderView:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.name == llm_provider_upsert_request.name
)
existing_llm_provider = fetch_existing_llm_provider(
llm_provider_upsert_request.name, db_session
)
if not existing_llm_provider:
@ -155,8 +153,13 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
def fetch_existing_llm_providers(
db_session: Session,
only_public: bool = False,
) -> list[LLMProviderModel]:
stmt = select(LLMProviderModel)
stmt = select(LLMProviderModel).options(
selectinload(LLMProviderModel.model_configurations)
)
if only_public:
stmt = stmt.where(LLMProviderModel.is_public == True) # noqa: E712
return list(db_session.scalars(stmt).all())
@ -164,7 +167,9 @@ def fetch_existing_llm_provider(
provider_name: str, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
select(LLMProviderModel)
.where(LLMProviderModel.name == provider_name)
.options(selectinload(LLMProviderModel.model_configurations))
)
return provider_model
@ -174,21 +179,18 @@ def fetch_existing_llm_providers_for_user(
db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
# if user is anonymous
if not user:
if AUTH_TYPE != AuthType.DISABLED:
# User is anonymous
return list(
db_session.scalars(
select(LLMProviderModel).where(
LLMProviderModel.is_public == True # noqa: E712
)
).all()
)
else:
# If auth is disabled, user has access to all providers
return fetch_existing_llm_providers(db_session)
# Only fetch public providers if auth is turned on
return fetch_existing_llm_providers(
db_session, only_public=AUTH_TYPE != AuthType.DISABLED
)
stmt = select(LLMProviderModel).distinct()
stmt = (
select(LLMProviderModel)
.options(selectinload(LLMProviderModel.model_configurations))
.distinct()
)
user_groups_select = select(User__UserGroup.user_group_id).where(
User__UserGroup.user_id == user.id
)
@ -217,9 +219,9 @@ def fetch_embedding_provider(
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
)
select(LLMProviderModel)
.where(LLMProviderModel.is_default_provider == True) # noqa: E712
.options(selectinload(LLMProviderModel.model_configurations))
)
if not provider_model:
return None
@ -228,9 +230,9 @@ def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
)
select(LLMProviderModel)
.where(LLMProviderModel.is_default_vision_provider == True) # noqa: E712
.options(selectinload(LLMProviderModel.model_configurations))
)
if not provider_model:
return None
@ -240,9 +242,7 @@ def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None
def fetch_llm_provider_view(
db_session: Session, provider_name: str
) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
provider_model = fetch_existing_llm_provider(provider_name, db_session)
if not provider_model:
return None
return LLMProviderView.from_model(provider_model)
@ -254,11 +254,7 @@ def fetch_max_input_tokens(
model_name: str,
output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
) -> int:
llm_provider = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.provider == provider_name)
.options(selectinload(LLMProviderModel.model_configurations))
)
llm_provider = fetch_existing_llm_provider(provider_name, db_session)
if not llm_provider:
raise RuntimeError(f"No LLM Provider with the name {provider_name}")

View File

@ -145,6 +145,30 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
with self.lock:
return collections.abc.ValuesView(self)
@overload
def atomic_get_set(
self, key: KT, value_callback: Callable[[VT], VT], default: VT
) -> tuple[VT, VT]: ...
@overload
def atomic_get_set(
self, key: KT, value_callback: Callable[[VT | _T], VT], default: VT | _T
) -> tuple[VT | _T, VT]: ...
def atomic_get_set(
self, key: KT, value_callback: Callable[[Any], VT], default: Any = None
) -> tuple[Any, VT]:
"""Replace a value from the dict with a function applied to the previous value, atomically.
Returns:
A tuple of the previous value and the new value.
"""
with self.lock:
val = self._dict.get(key, default)
new_val = value_callback(val)
self._dict[key] = new_val
return val, new_val
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...

View File

@ -2,9 +2,11 @@ import time
from collections.abc import Sequence
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
from tests.daily.connectors.utils import load_everything_from_checkpoint_connector
ALL_FILES = list(range(0, 60))
SHARED_DRIVE_FILES = list(range(20, 25))
@ -26,6 +28,8 @@ FOLDER_2_2_FILE_IDS = list(range(55, 60))
SECTIONS_FILE_IDS = [61]
FOLDER_3_FILE_IDS = list(range(62, 65))
DONWLOAD_REVOKED_FILE_ID = 21
PUBLIC_FOLDER_RANGE = FOLDER_1_2_FILE_IDS
PUBLIC_FILE_IDS = list(range(55, 57))
PUBLIC_RANGE = PUBLIC_FOLDER_RANGE + PUBLIC_FILE_IDS
@ -234,3 +238,13 @@ def load_all_docs(connector: GoogleDriveConnector) -> list[Document]:
0,
time.time(),
)
def load_all_docs_with_failures(
connector: GoogleDriveConnector,
) -> list[Document | ConnectorFailure]:
return load_everything_from_checkpoint_connector(
connector,
0,
time.time(),
)

View File

@ -3,21 +3,54 @@ from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_expected_docs_in_retrieved_docs,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
DONWLOAD_REVOKED_FILE_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
from tests.daily.connectors.google_drive.consts_and_utils import (
load_all_docs_with_failures,
)
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS
def _check_for_error(
retrieved_docs_failures: list[Document | ConnectorFailure],
expected_file_ids: list[int],
) -> list[Document]:
retrieved_docs = [
doc for doc in retrieved_docs_failures if isinstance(doc, Document)
]
retrieved_failures = [
failure
for failure in retrieved_docs_failures
if isinstance(failure, ConnectorFailure)
]
assert len(retrieved_failures) <= 1
# current behavior is to fail silently for 403s; leaving this here for when we revert
# if all 403s get fixed
if len(retrieved_failures) == 1:
fail_msg = retrieved_failures[0].failure_message
assert "HttpError 403" in fail_msg
assert f"file_{DONWLOAD_REVOKED_FILE_ID}.txt" in fail_msg
expected_file_ids.remove(DONWLOAD_REVOKED_FILE_ID)
return retrieved_docs
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
@ -36,7 +69,7 @@ def test_all(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
retrieved_docs_failures = load_all_docs_with_failures(connector)
expected_file_ids = (
# These are the files from my drive
@ -50,6 +83,9 @@ def test_all(
+ ADMIN_FOLDER_3_FILE_IDS
+ list(range(0, 2))
)
retrieved_docs = _check_for_error(retrieved_docs_failures, expected_file_ids)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
@ -74,7 +110,7 @@ def test_shared_drives_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
retrieved_docs_failures = load_all_docs_with_failures(connector)
expected_file_ids = (
# These are the files from shared drives
@ -83,6 +119,8 @@ def test_shared_drives_only(
+ FOLDER_1_1_FILE_IDS
+ FOLDER_1_2_FILE_IDS
)
retrieved_docs = _check_for_error(retrieved_docs_failures, expected_file_ids)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,

View File

@ -197,7 +197,7 @@ def test_thread_safe_dict_concurrent_access() -> None:
for i in range(iterations):
key = str(i % 5) # Use 5 different keys
# Get current value or 0 if not exists, increment, then store
d[key] = d.get(key, 0) + 1
d.atomic_get_set(key, lambda x: x + 1, 0)
# Create and start threads
threads = []