mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-12 12:53:01 +02:00
add emails to retry with on 403 (#4565)
* add emails to retry with on 403 * attempted fix for connector test * CW comments * connector test fix * test fixes and continue on 403 * fix tests * fix tests * fix concurrency tests * fix integration tests with llmprovider eager loading
This commit is contained in:
parent
f3d5303d93
commit
eb569bf79d
@ -40,6 +40,7 @@ from onyx.connectors.google_drive.models import RetrievedDriveFile
|
||||
from onyx.connectors.google_drive.models import StageCompletion
|
||||
from onyx.connectors.google_utils.google_auth import get_google_creds
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.google_utils import get_file_owners
|
||||
from onyx.connectors.google_utils.google_utils import GoogleFields
|
||||
from onyx.connectors.google_utils.resources import get_admin_service
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
@ -949,7 +950,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
(
|
||||
convert_func,
|
||||
(
|
||||
file.user_email,
|
||||
[file.user_email, self.primary_admin_email]
|
||||
+ get_file_owners(file.drive_file),
|
||||
file.drive_file,
|
||||
),
|
||||
)
|
||||
|
@ -43,6 +43,8 @@ logger = setup_logger()
|
||||
SMART_CHIP_CHAR = "\ue907"
|
||||
WEB_VIEW_LINK_KEY = "webViewLink"
|
||||
|
||||
MAX_RETRIEVER_EMAILS = 20
|
||||
|
||||
# Mapping of Google Drive mime types to export formats
|
||||
GOOGLE_MIME_TYPES_TO_EXPORT = {
|
||||
GDriveMimeType.DOC.value: "text/plain",
|
||||
@ -297,12 +299,66 @@ def align_basic_advanced(
|
||||
return new_sections
|
||||
|
||||
|
||||
# We used to always get the user email from the file owners when available,
|
||||
# but this was causing issues with shared folders where the owner was not included in the service account
|
||||
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
|
||||
# wanting to retry with file owners and/or admin email at some point.
|
||||
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
def convert_drive_item_to_document(
|
||||
creds: Any,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
retriever_emails: list[str],
|
||||
file: GoogleDriveFileType,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Attempt to convert a drive item to a document with each retriever email
|
||||
in order. returns upon a successful retrieval or a non-403 error.
|
||||
|
||||
We used to always get the user email from the file owners when available,
|
||||
but this was causing issues with shared folders where the owner was not included in the service account
|
||||
now we use the email of the account that successfully listed the file. There are cases where a
|
||||
user that can list a file cannot download it, so we retry with file owners and admin email.
|
||||
"""
|
||||
first_error = None
|
||||
doc_or_failure = None
|
||||
retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS]
|
||||
# use seen instead of list(set()) to avoid re-ordering the retriever emails
|
||||
seen = set()
|
||||
for retriever_email in retriever_emails:
|
||||
if retriever_email in seen:
|
||||
continue
|
||||
seen.add(retriever_email)
|
||||
doc_or_failure = _convert_drive_item_to_document(
|
||||
creds, allow_images, size_threshold, retriever_email, file
|
||||
)
|
||||
if (
|
||||
doc_or_failure is None
|
||||
or isinstance(doc_or_failure, Document)
|
||||
or not (
|
||||
isinstance(doc_or_failure.exception, HttpError)
|
||||
and doc_or_failure.exception.status_code == 403
|
||||
)
|
||||
):
|
||||
return doc_or_failure
|
||||
|
||||
if first_error is None:
|
||||
first_error = doc_or_failure
|
||||
else:
|
||||
first_error.failure_message += f"\n\n{doc_or_failure.failure_message}"
|
||||
|
||||
if (
|
||||
first_error
|
||||
and isinstance(first_error.exception, HttpError)
|
||||
and first_error.exception.status_code == 403
|
||||
):
|
||||
# This SHOULD happen very rarely, and we don't want to break the indexing process when
|
||||
# a high volume of 403s occurs early. We leave a verbose log to help investigate.
|
||||
logger.error(
|
||||
f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error."
|
||||
f"Attempted to retrieve with {retriever_emails},"
|
||||
f"got the following errors: {first_error.failure_message}"
|
||||
)
|
||||
return None
|
||||
return first_error
|
||||
|
||||
|
||||
def _convert_drive_item_to_document(
|
||||
creds: Any,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
@ -392,7 +448,9 @@ def convert_drive_item_to_document(
|
||||
)
|
||||
except Exception as e:
|
||||
file_name = file.get("name")
|
||||
error_str = f"Error converting file '{file_name}' to Document: {e}"
|
||||
error_str = (
|
||||
f"Error converting file '{file_name}' to Document as {retriever_email}: {e}"
|
||||
)
|
||||
if isinstance(e, HttpError) and e.status_code == 403:
|
||||
logger.warning(
|
||||
f"Uncommon permissions error while downloading file. User "
|
||||
|
@ -97,6 +97,17 @@ def _execute_with_retry(request: Any) -> Any:
|
||||
raise Exception(f"Failed to execute request after {max_attempts} attempts")
|
||||
|
||||
|
||||
def get_file_owners(file: GoogleDriveFileType) -> list[str]:
|
||||
"""
|
||||
Get the owners of a file if the attribute is present.
|
||||
"""
|
||||
return [
|
||||
owner.get("emailAddress")
|
||||
for owner in file.get("owners", [])
|
||||
if owner.get("emailAddress")
|
||||
]
|
||||
|
||||
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function: Callable,
|
||||
list_key: str | None = None,
|
||||
|
@ -73,10 +73,8 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.name == llm_provider_upsert_request.name
|
||||
)
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
llm_provider_upsert_request.name, db_session
|
||||
)
|
||||
|
||||
if not existing_llm_provider:
|
||||
@ -155,8 +153,13 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
only_public: bool = False,
|
||||
) -> list[LLMProviderModel]:
|
||||
stmt = select(LLMProviderModel)
|
||||
stmt = select(LLMProviderModel).options(
|
||||
selectinload(LLMProviderModel.model_configurations)
|
||||
)
|
||||
if only_public:
|
||||
stmt = stmt.where(LLMProviderModel.is_public == True) # noqa: E712
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
@ -164,7 +167,9 @@ def fetch_existing_llm_provider(
|
||||
provider_name: str, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.name == provider_name)
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
)
|
||||
|
||||
return provider_model
|
||||
@ -174,21 +179,18 @@ def fetch_existing_llm_providers_for_user(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> list[LLMProviderModel]:
|
||||
# if user is anonymous
|
||||
if not user:
|
||||
if AUTH_TYPE != AuthType.DISABLED:
|
||||
# User is anonymous
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_public == True # noqa: E712
|
||||
)
|
||||
).all()
|
||||
)
|
||||
else:
|
||||
# If auth is disabled, user has access to all providers
|
||||
return fetch_existing_llm_providers(db_session)
|
||||
# Only fetch public providers if auth is turned on
|
||||
return fetch_existing_llm_providers(
|
||||
db_session, only_public=AUTH_TYPE != AuthType.DISABLED
|
||||
)
|
||||
|
||||
stmt = select(LLMProviderModel).distinct()
|
||||
stmt = (
|
||||
select(LLMProviderModel)
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
.distinct()
|
||||
)
|
||||
user_groups_select = select(User__UserGroup.user_group_id).where(
|
||||
User__UserGroup.user_id == user.id
|
||||
)
|
||||
@ -217,9 +219,9 @@ def fetch_embedding_provider(
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||
)
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_default_provider == True) # noqa: E712
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
@ -228,9 +230,9 @@ def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
|
||||
|
||||
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
)
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_default_vision_provider == True) # noqa: E712
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
@ -240,9 +242,7 @@ def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None
|
||||
def fetch_llm_provider_view(
|
||||
db_session: Session, provider_name: str
|
||||
) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
)
|
||||
provider_model = fetch_existing_llm_provider(provider_name, db_session)
|
||||
if not provider_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
@ -254,11 +254,7 @@ def fetch_max_input_tokens(
|
||||
model_name: str,
|
||||
output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
|
||||
) -> int:
|
||||
llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.provider == provider_name)
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
)
|
||||
llm_provider = fetch_existing_llm_provider(provider_name, db_session)
|
||||
if not llm_provider:
|
||||
raise RuntimeError(f"No LLM Provider with the name {provider_name}")
|
||||
|
||||
|
@ -145,6 +145,30 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
with self.lock:
|
||||
return collections.abc.ValuesView(self)
|
||||
|
||||
@overload
|
||||
def atomic_get_set(
|
||||
self, key: KT, value_callback: Callable[[VT], VT], default: VT
|
||||
) -> tuple[VT, VT]: ...
|
||||
|
||||
@overload
|
||||
def atomic_get_set(
|
||||
self, key: KT, value_callback: Callable[[VT | _T], VT], default: VT | _T
|
||||
) -> tuple[VT | _T, VT]: ...
|
||||
|
||||
def atomic_get_set(
|
||||
self, key: KT, value_callback: Callable[[Any], VT], default: Any = None
|
||||
) -> tuple[Any, VT]:
|
||||
"""Replace a value from the dict with a function applied to the previous value, atomically.
|
||||
|
||||
Returns:
|
||||
A tuple of the previous value and the new value.
|
||||
"""
|
||||
with self.lock:
|
||||
val = self._dict.get(key, default)
|
||||
new_val = value_callback(val)
|
||||
self._dict[key] = new_val
|
||||
return val, new_val
|
||||
|
||||
|
||||
class CallableProtocol(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
@ -2,9 +2,11 @@ import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
from tests.daily.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
ALL_FILES = list(range(0, 60))
|
||||
SHARED_DRIVE_FILES = list(range(20, 25))
|
||||
@ -26,6 +28,8 @@ FOLDER_2_2_FILE_IDS = list(range(55, 60))
|
||||
SECTIONS_FILE_IDS = [61]
|
||||
FOLDER_3_FILE_IDS = list(range(62, 65))
|
||||
|
||||
DONWLOAD_REVOKED_FILE_ID = 21
|
||||
|
||||
PUBLIC_FOLDER_RANGE = FOLDER_1_2_FILE_IDS
|
||||
PUBLIC_FILE_IDS = list(range(55, 57))
|
||||
PUBLIC_RANGE = PUBLIC_FOLDER_RANGE + PUBLIC_FILE_IDS
|
||||
@ -234,3 +238,13 @@ def load_all_docs(connector: GoogleDriveConnector) -> list[Document]:
|
||||
0,
|
||||
time.time(),
|
||||
)
|
||||
|
||||
|
||||
def load_all_docs_with_failures(
|
||||
connector: GoogleDriveConnector,
|
||||
) -> list[Document | ConnectorFailure]:
|
||||
return load_everything_from_checkpoint_connector(
|
||||
connector,
|
||||
0,
|
||||
time.time(),
|
||||
)
|
||||
|
@ -3,21 +3,54 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
DONWLOAD_REVOKED_FILE_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
load_all_docs_with_failures,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS
|
||||
|
||||
|
||||
def _check_for_error(
|
||||
retrieved_docs_failures: list[Document | ConnectorFailure],
|
||||
expected_file_ids: list[int],
|
||||
) -> list[Document]:
|
||||
retrieved_docs = [
|
||||
doc for doc in retrieved_docs_failures if isinstance(doc, Document)
|
||||
]
|
||||
retrieved_failures = [
|
||||
failure
|
||||
for failure in retrieved_docs_failures
|
||||
if isinstance(failure, ConnectorFailure)
|
||||
]
|
||||
assert len(retrieved_failures) <= 1
|
||||
|
||||
# current behavior is to fail silently for 403s; leaving this here for when we revert
|
||||
# if all 403s get fixed
|
||||
if len(retrieved_failures) == 1:
|
||||
fail_msg = retrieved_failures[0].failure_message
|
||||
assert "HttpError 403" in fail_msg
|
||||
assert f"file_{DONWLOAD_REVOKED_FILE_ID}.txt" in fail_msg
|
||||
|
||||
expected_file_ids.remove(DONWLOAD_REVOKED_FILE_ID)
|
||||
return retrieved_docs
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
@ -36,7 +69,7 @@ def test_all(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs_failures = load_all_docs_with_failures(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from my drive
|
||||
@ -50,6 +83,9 @@ def test_all(
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
|
||||
retrieved_docs = _check_for_error(retrieved_docs_failures, expected_file_ids)
|
||||
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
@ -74,7 +110,7 @@ def test_shared_drives_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs_failures = load_all_docs_with_failures(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from shared drives
|
||||
@ -83,6 +119,8 @@ def test_shared_drives_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
|
||||
retrieved_docs = _check_for_error(retrieved_docs_failures, expected_file_ids)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
|
@ -197,7 +197,7 @@ def test_thread_safe_dict_concurrent_access() -> None:
|
||||
for i in range(iterations):
|
||||
key = str(i % 5) # Use 5 different keys
|
||||
# Get current value or 0 if not exists, increment, then store
|
||||
d[key] = d.get(key, 0) + 1
|
||||
d.atomic_get_set(key, lambda x: x + 1, 0)
|
||||
|
||||
# Create and start threads
|
||||
threads = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user