fix drive slowness (#4668)

* fix slowness

* no more silent failing for users

* nits

* no silly info transfer
This commit is contained in:
Evan Lohn
2025-05-07 15:48:08 -07:00
committed by GitHub
parent ee09cb95af
commit 0eab6ab935
2 changed files with 65 additions and 24 deletions

View File

@@ -10,6 +10,7 @@ from typing import cast
from typing import Protocol
from urllib.parse import urlparse
from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
@@ -72,7 +73,9 @@ logger = setup_logger()
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
# All file retrievals could be batched and made at once
BATCHES_PER_CHECKPOINT = 10
BATCHES_PER_CHECKPOINT = 1
DRIVE_BATCH_SIZE = 80
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
@@ -184,8 +187,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
"shared_folder_urls, or my_drive_emails"
)
self.batch_size = batch_size
specific_requests_made = False
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
specific_requests_made = True
@@ -306,14 +307,14 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
return user_emails
def get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
return self._get_all_drives_for_user(self.primary_admin_email)
def _get_all_drives_for_user(self, user_email: str) -> set[str]:
drive_service = get_drive_service(self.creds, user_email)
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
all_drive_ids = set()
all_drive_ids: set[str] = set()
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
retrieval_function=drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=is_service_account,
fields="drives(id),nextPageToken",
@@ -373,6 +374,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
if drive_id in self._retrieved_folder_and_drive_ids
else DriveIdStatus.AVAILABLE
)
logger.debug(
f"Drive id status: {len(drive_id_status)}, user email: {thread_id},"
f"processed drive ids: {len(completion.processed_drive_ids)}"
)
# wake up other threads waiting for work
cv.notify_all()
@@ -423,6 +428,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
curr_stage = checkpoint.completion_map[user_email]
resuming = True
if curr_stage.stage == DriveRetrievalStage.START:
logger.info(f"Setting stage to {DriveRetrievalStage.MY_DRIVE_FILES.value}")
curr_stage.stage = DriveRetrievalStage.MY_DRIVE_FILES
resuming = False
drive_service = get_drive_service(self.creds, user_email)
@@ -430,6 +436,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
logger.debug(f"Getting root folder id for user {user_email}")
# default is ~17mins of retries, don't do that here for cases so we don't
# waste 17mins everytime we run into a user without access to drive APIs
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
@@ -445,14 +452,29 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
curr_stage.stage = DriveRetrievalStage.DONE
return
raise
except RefreshError as e:
logger.warning(
f"User '{user_email}' could not refresh their token. Error: {e}"
)
# mark this user as done so we don't try to retrieve anything for them
# again
yield RetrievedDriveFile(
completion_stage=DriveRetrievalStage.DONE,
drive_file={},
user_email=user_email,
error=e,
)
curr_stage.stage = DriveRetrievalStage.DONE
return
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - include_my_drives is true
# - the current user's email is in the requested emails
if curr_stage.stage == DriveRetrievalStage.MY_DRIVE_FILES:
if self.include_my_drives or user_email in self._requested_my_drive_emails:
logger.info(f"Getting all files in my drive as '{user_email}'")
logger.info(
f"Getting all files in my drive as '{user_email}. Resuming: {resuming}"
)
yield from add_retrieval_info(
get_all_files_in_my_drive_and_shared(
@@ -505,7 +527,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
for drive_id in concurrent_drive_itr(user_email):
logger.info(
f"Getting files in shared drive '{drive_id}' as '{user_email}'"
f"Getting files in shared drive '{drive_id}' as '{user_email}. Resuming: {resuming}"
)
curr_stage.completed_until = 0
curr_stage.current_folder_or_drive_id = drive_id
@@ -577,6 +599,14 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[RetrievedDriveFile]:
"""
The current implementation of the service account retrieval does some
initial setup work using the primary admin email, then runs MAX_DRIVE_WORKERS
concurrent threads, each of which impersonates a different user and retrieves
files for that user. Technically, the actual work each thread does is "yield the
next file retrieved by the user", at which point it returns to the thread pool;
see parallel_yield for more details.
"""
if checkpoint.completion_stage == DriveRetrievalStage.START:
checkpoint.completion_stage = DriveRetrievalStage.USER_EMAILS
@@ -602,6 +632,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
processed_drive_ids=set(),
)
# we've found all users and drives, now time to actually start
@@ -627,7 +658,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
# to the drive APIs. Without this, we could loop through these emails for
# more than 3 hours, causing a timeout and stalling progress.
email_batch_takes_us_to_completion = True
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = MAX_DRIVE_WORKERS
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
non_completed_org_emails = non_completed_org_emails[
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
@@ -871,6 +902,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
return
for file in drive_files:
logger.debug(
f"Updating checkpoint for file: {file.drive_file.get('name')}. "
f"Seen: {file.drive_file.get('id') in checkpoint.all_retrieved_file_ids}"
)
checkpoint.completion_map[file.user_email].update(
stage=file.completion_stage,
completed_until=datetime.fromisoformat(
@@ -1047,24 +1082,22 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
continue
files_batch.append(retrieved_file)
if len(files_batch) < self.batch_size:
if len(files_batch) < DRIVE_BATCH_SIZE:
continue
logger.info(
f"Yielding batch of {len(files_batch)} files; num seen doc ids: {len(checkpoint.all_retrieved_file_ids)}"
)
yield from _yield_batch(files_batch)
files_batch = []
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = (
self._retrieved_folder_and_drive_ids
)
return # create a new checkpoint
logger.info(
f"Processing remaining files: {[file.drive_file.get('name') for file in files_batch]}"
)
# Process any remaining files
if files_batch:
yield from _yield_batch(files_batch)
checkpoint.retrieved_folder_and_drive_ids = (
self._retrieved_folder_and_drive_ids
)
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
raise e
@@ -1083,6 +1116,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
"Credentials missing, should not call this method before calling load_credentials"
)
logger.info(
f"Loading from checkpoint with completion stage: {checkpoint.completion_stage},"
f"num retrieved ids: {len(checkpoint.all_retrieved_file_ids)}"
)
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
try:

View File

@@ -327,12 +327,16 @@ def convert_drive_item_to_document(
doc_or_failure = _convert_drive_item_to_document(
creds, allow_images, size_threshold, retriever_email, file
)
# There are a variety of permissions-based errors that occasionally occur
# when retrieving files. Often when these occur, there is another user
# that can successfully retrieve the file, so we try the next user.
if (
doc_or_failure is None
or isinstance(doc_or_failure, Document)
or not (
isinstance(doc_or_failure.exception, HttpError)
and doc_or_failure.exception.status_code in [403, 404]
and doc_or_failure.exception.status_code in [401, 403, 404]
)
):
return doc_or_failure