diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index 78715f983..66003670b 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -28,7 +28,9 @@ from onyx.connectors.google_drive.doc_conversion import ( ) from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth -from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive +from onyx.connectors.google_drive.file_retrieval import ( + get_all_files_in_my_drive_and_shared, +) from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive from onyx.connectors.google_drive.file_retrieval import get_root_folder_id from onyx.connectors.google_drive.models import DriveRetrievalStage @@ -86,13 +88,18 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]: def _convert_single_file( creds: Any, - primary_admin_email: str, allow_images: bool, size_threshold: int, + retriever_email: str, file: dict[str, Any], ) -> Document | ConnectorFailure | None: - user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email + # We used to always get the user email from the file owners when available, + # but this was causing issues with shared folders where the owner was not included in the service account + # now we use the email of the account that successfully listed the file. Leaving this in case we end up + # wanting to retry with file owners and/or admin email at some point. + # user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email + user_email = retriever_email # Only construct these services when needed user_drive_service = lazy_eval( lambda: get_drive_service(creds, user_email=user_email) @@ -450,10 +457,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo logger.info(f"Getting all files in my drive as '{user_email}'") yield from add_retrieval_info( - get_all_files_in_my_drive( + get_all_files_in_my_drive_and_shared( service=drive_service, update_traversed_ids_func=self._update_traversed_parent_ids, is_slim=is_slim, + include_shared_with_me=self.include_files_shared_with_me, start=curr_stage.completed_until if resuming else start, end=end, ), @@ -916,20 +924,28 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo convert_func = partial( _convert_single_file, self.creds, - self.primary_admin_email, self.allow_images, self.size_threshold, ) # Fetch files in batches batches_complete = 0 - files_batch: list[GoogleDriveFileType] = [] + files_batch: list[RetrievedDriveFile] = [] def _yield_batch( - files_batch: list[GoogleDriveFileType], + files_batch: list[RetrievedDriveFile], ) -> Iterator[Document | ConnectorFailure]: nonlocal batches_complete # Process the batch using run_functions_tuples_in_parallel - func_with_args = [(convert_func, (file,)) for file in files_batch] + func_with_args = [ + ( + convert_func, + ( + file.user_email, + file.drive_file, + ), + ) + for file in files_batch + ] results = cast( list[Document | ConnectorFailure | None], run_functions_tuples_in_parallel(func_with_args, max_workers=8), @@ -967,7 +983,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo ) continue - files_batch.append(retrieved_file.drive_file) + files_batch.append(retrieved_file) if len(files_batch) < self.batch_size: continue diff --git a/backend/onyx/connectors/google_drive/doc_conversion.py b/backend/onyx/connectors/google_drive/doc_conversion.py index cbb9d1cb2..c4d015d0a 100644 --- a/backend/onyx/connectors/google_drive/doc_conversion.py +++ b/backend/onyx/connectors/google_drive/doc_conversion.py @@ -87,35 +87,17 @@ def _download_and_extract_sections_basic( mime_type = file["mimeType"] link = file.get("webViewLink", "") - try: - # skip images if not explicitly enabled - if not allow_images and is_gdrive_image_mime_type(mime_type): - return [] + # skip images if not explicitly enabled + if not allow_images and is_gdrive_image_mime_type(mime_type): + return [] - # For Google Docs, Sheets, and Slides, export as plain text - if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT: - export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type] - # Use the correct API call for exporting files - request = service.files().export_media( - fileId=file_id, mimeType=export_mime_type - ) - response_bytes = io.BytesIO() - downloader = MediaIoBaseDownload(response_bytes, request) - done = False - while not done: - _, done = downloader.next_chunk() - - response = response_bytes.getvalue() - if not response: - logger.warning(f"Failed to export {file_name} as {export_mime_type}") - return [] - - text = response.decode("utf-8") - return [TextSection(link=link, text=text)] - - # For other file types, download the file - # Use the correct API call for downloading files - request = service.files().get_media(fileId=file_id) + # For Google Docs, Sheets, and Slides, export as plain text + if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT: + export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type] + # Use the correct API call for exporting files + request = service.files().export_media( + fileId=file_id, mimeType=export_mime_type + ) response_bytes = io.BytesIO() downloader = MediaIoBaseDownload(response_bytes, request) done = False @@ -124,88 +106,100 @@ def _download_and_extract_sections_basic( response = response_bytes.getvalue() if not response: - logger.warning(f"Failed to download {file_name}") + logger.warning(f"Failed to export {file_name} as {export_mime_type}") return [] - # Process based on mime type - if mime_type == "text/plain": - text = response.decode("utf-8") - return [TextSection(link=link, text=text)] + text = response.decode("utf-8") + return [TextSection(link=link, text=text)] - elif ( - mime_type - == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ): - text, _ = docx_to_text_and_images(io.BytesIO(response)) - return [TextSection(link=link, text=text)] + # For other file types, download the file + # Use the correct API call for downloading files + request = service.files().get_media(fileId=file_id) + response_bytes = io.BytesIO() + downloader = MediaIoBaseDownload(response_bytes, request) + done = False + while not done: + _, done = downloader.next_chunk() - elif ( - mime_type - == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" - ): - text = xlsx_to_text(io.BytesIO(response)) - return [TextSection(link=link, text=text)] + response = response_bytes.getvalue() + if not response: + logger.warning(f"Failed to download {file_name}") + return [] - elif ( - mime_type - == "application/vnd.openxmlformats-officedocument.presentationml.presentation" - ): - text = pptx_to_text(io.BytesIO(response)) - return [TextSection(link=link, text=text)] + # Process based on mime type + if mime_type == "text/plain": + text = response.decode("utf-8") + return [TextSection(link=link, text=text)] - elif is_gdrive_image_mime_type(mime_type): - # For images, store them for later processing - sections: list[TextSection | ImageSection] = [] - try: - with get_session_with_current_tenant() as db_session: + elif ( + mime_type + == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ): + text, _ = docx_to_text_and_images(io.BytesIO(response)) + return [TextSection(link=link, text=text)] + + elif ( + mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ): + text = xlsx_to_text(io.BytesIO(response)) + return [TextSection(link=link, text=text)] + + elif ( + mime_type + == "application/vnd.openxmlformats-officedocument.presentationml.presentation" + ): + text = pptx_to_text(io.BytesIO(response)) + return [TextSection(link=link, text=text)] + + elif is_gdrive_image_mime_type(mime_type): + # For images, store them for later processing + sections: list[TextSection | ImageSection] = [] + try: + with get_session_with_current_tenant() as db_session: + section, embedded_id = store_image_and_create_section( + db_session=db_session, + image_data=response, + file_name=file_id, + display_name=file_name, + media_type=mime_type, + file_origin=FileOrigin.CONNECTOR, + link=link, + ) + sections.append(section) + except Exception as e: + logger.error(f"Failed to process image {file_name}: {e}") + return sections + + elif mime_type == "application/pdf": + text, _pdf_meta, images = read_pdf_file(io.BytesIO(response)) + pdf_sections: list[TextSection | ImageSection] = [ + TextSection(link=link, text=text) + ] + + # Process embedded images in the PDF + try: + with get_session_with_current_tenant() as db_session: + for idx, (img_data, img_name) in enumerate(images): section, embedded_id = store_image_and_create_section( db_session=db_session, - image_data=response, - file_name=file_id, - display_name=file_name, - media_type=mime_type, + image_data=img_data, + file_name=f"{file_id}_img_{idx}", + display_name=img_name or f"{file_name} - image {idx}", file_origin=FileOrigin.CONNECTOR, - link=link, ) - sections.append(section) - except Exception as e: - logger.error(f"Failed to process image {file_name}: {e}") - return sections + pdf_sections.append(section) + except Exception as e: + logger.error(f"Failed to process PDF images in {file_name}: {e}") + return pdf_sections - elif mime_type == "application/pdf": - text, _pdf_meta, images = read_pdf_file(io.BytesIO(response)) - pdf_sections: list[TextSection | ImageSection] = [ - TextSection(link=link, text=text) - ] - - # Process embedded images in the PDF - try: - with get_session_with_current_tenant() as db_session: - for idx, (img_data, img_name) in enumerate(images): - section, embedded_id = store_image_and_create_section( - db_session=db_session, - image_data=img_data, - file_name=f"{file_id}_img_{idx}", - display_name=img_name or f"{file_name} - image {idx}", - file_origin=FileOrigin.CONNECTOR, - ) - pdf_sections.append(section) - except Exception as e: - logger.error(f"Failed to process PDF images in {file_name}: {e}") - return pdf_sections - - else: - # For unsupported file types, try to extract text - try: - text = extract_file_text(io.BytesIO(response), file_name) - return [TextSection(link=link, text=text)] - except Exception as e: - logger.warning(f"Failed to extract text from {file_name}: {e}") - return [] - - except Exception as e: - logger.error(f"Error processing file {file_name}: {e}") - return [] + else: + # For unsupported file types, try to extract text + try: + text = extract_file_text(io.BytesIO(response), file_name) + return [TextSection(link=link, text=text)] + except Exception as e: + logger.warning(f"Failed to extract text from {file_name}: {e}") + return [] def convert_drive_item_to_document( diff --git a/backend/onyx/connectors/google_drive/file_retrieval.py b/backend/onyx/connectors/google_drive/file_retrieval.py index 1c094db52..2b432d0ae 100644 --- a/backend/onyx/connectors/google_drive/file_retrieval.py +++ b/backend/onyx/connectors/google_drive/file_retrieval.py @@ -214,10 +214,11 @@ def get_files_in_shared_drive( yield file -def get_all_files_in_my_drive( +def get_all_files_in_my_drive_and_shared( service: GoogleDriveService, update_traversed_ids_func: Callable, is_slim: bool, + include_shared_with_me: bool, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: @@ -229,7 +230,8 @@ def get_all_files_in_my_drive( # Get all folders being queried and add them to the traversed set folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" folder_query += " and trashed = false" - folder_query += " and 'me' in owners" + if not include_shared_with_me: + folder_query += " and 'me' in owners" found_folders = False for file in execute_paginated_retrieval( retrieval_function=service.files().list, @@ -246,7 +248,8 @@ def get_all_files_in_my_drive( # Then get the files file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" file_query += " and trashed = false" - file_query += " and 'me' in owners" + if not include_shared_with_me: + file_query += " and 'me' in owners" file_query += _generate_time_range_filter(start, end) yield from execute_paginated_retrieval( retrieval_function=service.files().list, diff --git a/backend/tests/daily/connectors/google_drive/consts_and_utils.py b/backend/tests/daily/connectors/google_drive/consts_and_utils.py index c50b7a0e4..f08a58787 100644 --- a/backend/tests/daily/connectors/google_drive/consts_and_utils.py +++ b/backend/tests/daily/connectors/google_drive/consts_and_utils.py @@ -58,6 +58,16 @@ SECTIONS_FOLDER_URL = ( "https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33" ) +EXTERNAL_SHARED_FOLDER_URL = ( + "https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS" +) +EXTERNAL_SHARED_DOCS_IN_FOLDER = [ + "https://docs.google.com/document/d/1Sywmv1-H6ENk2GcgieKou3kQHR_0te1mhIUcq8XlcdY" +] +EXTERNAL_SHARED_DOC_SINGLETON = ( + "https://docs.google.com/document/d/11kmisDfdvNcw5LYZbkdPVjTOdj-Uc5ma6Jep68xzeeA" +) + SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA" ADMIN_EMAIL = "admin@onyx-test.com" diff --git a/backend/tests/daily/connectors/google_drive/test_service_acct.py b/backend/tests/daily/connectors/google_drive/test_service_acct.py index 66a9f2781..68c055dee 100644 --- a/backend/tests/daily/connectors/google_drive/test_service_acct.py +++ b/backend/tests/daily/connectors/google_drive/test_service_acct.py @@ -1,6 +1,7 @@ from collections.abc import Callable from unittest.mock import MagicMock from unittest.mock import patch +from urllib.parse import urlparse from onyx.connectors.google_drive.connector import GoogleDriveConnector from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL @@ -9,6 +10,15 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_ from tests.daily.connectors.google_drive.consts_and_utils import ( assert_expected_docs_in_retrieved_docs, ) +from tests.daily.connectors.google_drive.consts_and_utils import ( + EXTERNAL_SHARED_DOC_SINGLETON, +) +from tests.daily.connectors.google_drive.consts_and_utils import ( + EXTERNAL_SHARED_DOCS_IN_FOLDER, +) +from tests.daily.connectors.google_drive.consts_and_utils import ( + EXTERNAL_SHARED_FOLDER_URL, +) from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS @@ -100,7 +110,8 @@ def test_include_shared_drives_only_with_size_threshold( retrieved_docs = load_all_docs(connector) - assert len(retrieved_docs) == 50 + # 2 extra files from shared drive owned by non-admin and not shared with admin + assert len(retrieved_docs) == 52 @patch( @@ -137,7 +148,8 @@ def test_include_shared_drives_only( + SECTIONS_FILE_IDS ) - assert len(retrieved_docs) == 51 + # 2 extra files from shared drive owned by non-admin and not shared with admin + assert len(retrieved_docs) == 53 assert_expected_docs_in_retrieved_docs( retrieved_docs=retrieved_docs, @@ -294,6 +306,64 @@ def test_folders_only( ) +def test_shared_folder_owned_by_external_user( + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_shared_folder_owned_by_external_user") + connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, + include_shared_drives=False, + include_my_drives=False, + include_files_shared_with_me=False, + shared_drive_urls=None, + shared_folder_urls=EXTERNAL_SHARED_FOLDER_URL, + my_drive_emails=None, + ) + retrieved_docs = load_all_docs(connector) + + expected_docs = EXTERNAL_SHARED_DOCS_IN_FOLDER + + assert len(retrieved_docs) == len(expected_docs) # 1 for now + assert expected_docs[0] in retrieved_docs[0].id + + +def test_shared_with_me( + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_shared_with_me") + connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, + include_shared_drives=False, + include_my_drives=True, + include_files_shared_with_me=True, + shared_drive_urls=None, + shared_folder_urls=None, + my_drive_emails=None, + ) + retrieved_docs = load_all_docs(connector) + + print(retrieved_docs) + + expected_file_ids = ( + ADMIN_FILE_IDS + + ADMIN_FOLDER_3_FILE_IDS + + TEST_USER_1_FILE_IDS + + TEST_USER_2_FILE_IDS + + TEST_USER_3_FILE_IDS + ) + assert_expected_docs_in_retrieved_docs( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + retrieved_ids = {urlparse(doc.id).path.split("/")[-2] for doc in retrieved_docs} + for id in retrieved_ids: + print(id) + + assert EXTERNAL_SHARED_DOC_SINGLETON.split("/")[-1] in retrieved_ids + assert EXTERNAL_SHARED_DOCS_IN_FOLDER[0].split("/")[-1] in retrieved_ids + + @patch( "onyx.file_processing.extract_file_text.get_unstructured_api_key", return_value=None, diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index 4524a32df..ae2661e6a 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -281,7 +281,7 @@ export default function AddConnector({ return ( [ field.name, diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index f644bc6a4..547d9669c 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -1292,7 +1292,8 @@ For example, specifying .*-support.* as a "channel" will cause the connector to }, }; export function createConnectorInitialValues( - connector: ConfigurableSources + connector: ConfigurableSources, + currentCredential: Credential | null = null ): Record & AccessTypeGroupSelectorFormType { const configuration = connectorConfigs[connector]; @@ -1307,7 +1308,16 @@ export function createConnectorInitialValues( } else if (field.type === "list") { acc[field.name] = field.default || []; } else if (field.type === "checkbox") { - acc[field.name] = field.default || false; + // Special case for include_files_shared_with_me when using service account + if ( + field.name === "include_files_shared_with_me" && + currentCredential && + !currentCredential.credential_json?.google_tokens + ) { + acc[field.name] = true; + } else { + acc[field.name] = field.default || false; + } } else if (field.default !== undefined) { acc[field.name] = field.default; }