Google Drive Improvements (#3057)

* Google Drive Improvements

* mypy

* should work!

* variable cleanup

* final fixes
This commit is contained in:
hagen-danswer 2024-11-06 18:07:35 -08:00 committed by GitHub
parent 07a1b49b4f
commit 2758ffd9d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 798 additions and 521 deletions

View File

@ -277,16 +277,16 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
fields=THREAD_LIST_FIELDS, fields=THREAD_LIST_FIELDS,
q=query, q=query,
): ):
full_thread = add_retries( full_threads = execute_paginated_retrieval(
lambda: gmail_service.users() retrieval_function=gmail_service.users().threads().get,
.threads() list_key=None,
.get( userId=user_email,
userId=user_email, fields=THREAD_FIELDS,
id=thread["id"], id=thread["id"],
fields=THREAD_FIELDS, )
) # full_threads is an iterator containing a single thread
.execute() # so we need to convert it to a list and grab the first element
)() full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread) doc = thread_to_document(full_thread)
if doc is None: if doc is None:
continue continue

View File

@ -1,4 +1,8 @@
from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
@ -6,11 +10,12 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.google_drive.doc_conversion import build_slim_document
from danswer.connectors.google_drive.doc_conversion import ( from danswer.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document, convert_drive_item_to_document,
) )
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_files_in_my_drive from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.models import GoogleDriveFileType from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_utils.google_auth import get_google_creds from danswer.connectors.google_utils.google_auth import get_google_creds
@ -32,10 +37,11 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
# All file retrievals could be batched and made at once
def _extract_str_list_from_comma_str(string: str | None) -> list[str]: def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
@ -48,6 +54,34 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
return [url.split("/")[-1] for url in urls] return [url.split("/")[-1] for url in urls]
def _convert_single_file(
creds: Any, primary_admin_email: str, file: dict[str, Any]
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
docs_service = get_google_docs_service(creds, user_email=user_email)
return convert_drive_item_to_document(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
)
def _process_files_batch(
files: list[GoogleDriveFileType], convert_func: Callable, batch_size: int
) -> GenerateDocumentsOutput:
doc_batch = []
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
for doc in executor.map(convert_func, files):
if doc:
doc_batch.append(doc)
if len(doc_batch) >= batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__( def __init__(
self, self,
@ -97,19 +131,23 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
self.include_shared_drives = include_shared_drives self.include_shared_drives = include_shared_drives
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls) shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self.shared_drive_ids = _extract_ids_from_urls(shared_drive_url_list) self._requested_shared_drive_ids = set(
_extract_ids_from_urls(shared_drive_url_list)
)
self.include_my_drives = include_my_drives self.include_my_drives = include_my_drives
self.my_drive_emails = _extract_str_list_from_comma_str(my_drive_emails) self._requested_my_drive_emails = set(
_extract_str_list_from_comma_str(my_drive_emails)
)
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls) shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list) self._requested_folder_ids = set(_extract_ids_from_urls(shared_folder_url_list))
self._primary_admin_email: str | None = None self._primary_admin_email: str | None = None
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._TRAVERSED_PARENT_IDS: set[str] = set() self._retrieved_ids: set[str] = set()
@property @property
def primary_admin_email(self) -> str: def primary_admin_email(self) -> str:
@ -141,9 +179,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
) )
return self._creds return self._creds
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._TRAVERSED_PARENT_IDS.add(folder_id)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email self._primary_admin_email = primary_admin_email
@ -154,125 +189,167 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
) )
return new_creds_dict return new_creds_dict
def _get_all_user_emails(self) -> list[str]: def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
admin_service = get_admin_service( admin_service = get_admin_service(
creds=self.creds, creds=self.creds,
user_email=self.primary_admin_email, user_email=self.primary_admin_email,
) )
query = "isAdmin=true" if admins_only else "isAdmin=false"
emails = [] emails = []
for user in execute_paginated_retrieval( for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list, retrieval_function=admin_service.users().list,
list_key="users", list_key="users",
fields=USER_FIELDS, fields=USER_FIELDS,
domain=self.google_domain, domain=self.google_domain,
query=query,
): ):
if email := user.get("primaryEmail"): if email := user.get("primaryEmail"):
emails.append(email) emails.append(email)
return emails return emails
def _get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
all_drive_ids = set()
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=True,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])
return all_drive_ids
def _initialize_all_class_variables(self) -> None:
# Get all user emails
# Get admins first becuase they are more likely to have access to the most files
user_emails = [self.primary_admin_email]
for admins_only in [True, False]:
for email in self._get_all_user_emails(admins_only=admins_only):
if email not in user_emails:
user_emails.append(email)
self._all_org_emails = user_emails
self._all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
self._requested_folder_ids -= self._all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
if invalid_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
)
logger.warning("Checking for folder access instead...")
self._requested_folder_ids.update(invalid_drive_ids)
if not self.include_shared_drives:
self._requested_shared_drive_ids = set()
elif not self._requested_shared_drive_ids:
self._requested_shared_drive_ids = self._all_drive_ids
def _impersonate_user_for_retrieval(
self,
user_email: str,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, user_email)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
):
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
start=start,
end=end,
)
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
remaining_folders = self._requested_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
def _fetch_drive_items( def _fetch_drive_items(
self, self,
is_slim: bool, is_slim: bool,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]: ) -> Iterator[GoogleDriveFileType]:
primary_drive_service = get_drive_service( self._initialize_all_class_variables()
creds=self.creds,
user_email=self.primary_admin_email,
)
if self.include_shared_drives: # Process users in parallel using ThreadPoolExecutor
shared_drive_urls = self.shared_drive_ids with ThreadPoolExecutor(max_workers=10) as executor:
if not shared_drive_urls: future_to_email = {
# if no parent ids are specified, get all shared drives using the admin account executor.submit(
for drive in execute_paginated_retrieval( self._impersonate_user_for_retrieval, email, is_slim, start, end
retrieval_function=primary_drive_service.drives().list, ): email
list_key="drives", for email in self._all_org_emails
useDomainAdminAccess=True, }
fields="drives(id)",
):
shared_drive_urls.append(drive["id"])
# For each shared drive, retrieve all files # Yield results as they complete
for shared_drive_id in shared_drive_urls: for future in as_completed(future_to_email):
for file in get_files_in_shared_drive( yield from future.result()
service=primary_drive_service,
drive_id=shared_drive_id,
is_slim=is_slim,
cache_folders=bool(self.shared_folder_ids),
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
):
yield file
if self.shared_folder_ids: remaining_folders = self._requested_folder_ids - self._retrieved_ids
# Crawl all the shared parent ids for files if remaining_folders:
for folder_id in self.shared_folder_ids: logger.warning(
yield from crawl_folders_for_files( f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
service=primary_drive_service, )
parent_id=folder_id,
personal_drive=False,
traversed_parent_ids=self._TRAVERSED_PARENT_IDS,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
all_user_emails = []
# get all personal docs from each users' personal drive
if self.include_my_drives:
if isinstance(self.creds, ServiceAccountCredentials):
all_user_emails = self.my_drive_emails or []
# If using service account and no emails specified, fetch all users
if not all_user_emails:
all_user_emails = self._get_all_user_emails()
elif self.primary_admin_email:
# If using OAuth, only fetch the primary admin email
all_user_emails = [self.primary_admin_email]
for email in all_user_emails:
logger.info(f"Fetching personal files for user: {email}")
user_drive_service = get_drive_service(self.creds, user_email=email)
yield from get_files_in_my_drive(
service=user_drive_service,
email=email,
is_slim=is_slim,
start=start,
end=end,
)
def _extract_docs_from_google_drive( def _extract_docs_from_google_drive(
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput: ) -> GenerateDocumentsOutput:
doc_batch = [] # Create a larger process pool for file conversion
for file in self._fetch_drive_items( convert_func = partial(
is_slim=False, _convert_single_file, self.creds, self.primary_admin_email
start=start, )
end=end,
):
user_email = (
file.get("owners", [{}])[0].get("emailAddress")
or self.primary_admin_email
)
user_drive_service = get_drive_service(self.creds, user_email=user_email)
docs_service = get_google_docs_service(self.creds, user_email=user_email)
if doc := convert_drive_item_to_document(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch # Process files in larger batches
LARGE_BATCH_SIZE = self.batch_size * 4
files_to_process = []
# Gather the files into batches to be processed in parallel
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
files_to_process.append(file)
if len(files_to_process) >= LARGE_BATCH_SIZE:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
)
files_to_process = []
# Process any remaining files
if files_to_process:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
)
def load_from_state(self) -> GenerateDocumentsOutput: def load_from_state(self) -> GenerateDocumentsOutput:
try: try:
@ -303,18 +380,8 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
start=start, start=start,
end=end, end=end,
): ):
slim_batch.append( if doc := build_slim_document(file):
SlimDocument( slim_batch.append(doc)
id=file["webViewLink"],
perm_sync_data={
"doc_id": file.get("id"),
"permissions": file.get("permissions", []),
"permission_ids": file.get("permissionIds", []),
"name": file.get("name"),
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
},
)
)
if len(slim_batch) >= SLIM_BATCH_SIZE: if len(slim_batch) >= SLIM_BATCH_SIZE:
yield slim_batch yield slim_batch
slim_batch = [] slim_batch = []

View File

@ -7,6 +7,7 @@ from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from danswer.connectors.google_drive.models import GDriveMimeType from danswer.connectors.google_drive.models import GDriveMimeType
@ -16,6 +17,7 @@ from danswer.connectors.google_utils.resources import GoogleDocsService
from danswer.connectors.google_utils.resources import GoogleDriveService from danswer.connectors.google_utils.resources import GoogleDriveService
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.file_processing.extract_file_text import docx_to_text from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file from danswer.file_processing.extract_file_text import read_pdf_file
@ -25,6 +27,7 @@ from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
# these errors don't represent a failure in the connector, but simply files # these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed # that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [ ERRORS_TO_CONTINUE_ON = [
@ -120,6 +123,10 @@ def convert_drive_item_to_document(
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype") logger.info("Ignoring Drive Shortcut Filetype")
return None return None
# Skip files that are folders
if file.get("mimeType") == DRIVE_FOLDER_TYPE:
logger.info("Ignoring Drive Folder Filetype")
return None
sections: list[Section] = [] sections: list[Section] = []
@ -133,7 +140,6 @@ def convert_drive_item_to_document(
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'." f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
" Falling back to basic extraction." " Falling back to basic extraction."
) )
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc # NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
if not sections: if not sections:
try: try:
@ -150,7 +156,6 @@ def convert_drive_item_to_document(
return None return None
raise raise
if not sections: if not sections:
return None return None
@ -173,3 +178,20 @@ def convert_drive_item_to_document(
logger.exception("Ran into exception when pulling a file from Google Drive") logger.exception("Ran into exception when pulling a file from Google Drive")
return None return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
# Skip files that are folders or shortcuts
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
return SlimDocument(
id=file["webViewLink"],
perm_sync_data={
"doc_id": file.get("id"),
"permissions": file.get("permissions", []),
"permission_ids": file.get("permissionIds", []),
"name": file.get("name"),
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
},
)

View File

@ -1,6 +1,7 @@
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from datetime import datetime from datetime import datetime
from typing import Any
from googleapiclient.discovery import Resource # type: ignore from googleapiclient.discovery import Resource # type: ignore
@ -41,7 +42,6 @@ def _generate_time_range_filter(
def _get_folders_in_parent( def _get_folders_in_parent(
service: Resource, service: Resource,
parent_id: str | None = None, parent_id: str | None = None,
personal_drive: bool = False,
) -> Iterator[GoogleDriveFileType]: ) -> Iterator[GoogleDriveFileType]:
# Follow shortcuts to folders # Follow shortcuts to folders
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')" query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
@ -53,9 +53,10 @@ def _get_folders_in_parent(
for file in execute_paginated_retrieval( for file in execute_paginated_retrieval(
retrieval_function=service.files().list, retrieval_function=service.files().list,
list_key="files", list_key="files",
corpora="user" if personal_drive else "allDrives", continue_on_404_or_403=True,
supportsAllDrives=not personal_drive, corpora="allDrives",
includeItemsFromAllDrives=not personal_drive, supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=FOLDER_FIELDS, fields=FOLDER_FIELDS,
q=query, q=query,
): ):
@ -65,7 +66,6 @@ def _get_folders_in_parent(
def _get_files_in_parent( def _get_files_in_parent(
service: Resource, service: Resource,
parent_id: str, parent_id: str,
personal_drive: bool,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
is_slim: bool = False, is_slim: bool = False,
@ -77,9 +77,10 @@ def _get_files_in_parent(
for file in execute_paginated_retrieval( for file in execute_paginated_retrieval(
retrieval_function=service.files().list, retrieval_function=service.files().list,
list_key="files", list_key="files",
corpora="user" if personal_drive else "allDrives", continue_on_404_or_403=True,
supportsAllDrives=not personal_drive, corpora="allDrives",
includeItemsFromAllDrives=not personal_drive, supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query, q=query,
): ):
@ -89,7 +90,6 @@ def _get_files_in_parent(
def crawl_folders_for_files( def crawl_folders_for_files(
service: Resource, service: Resource,
parent_id: str, parent_id: str,
personal_drive: bool,
traversed_parent_ids: set[str], traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None], update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
@ -99,29 +99,30 @@ def crawl_folders_for_files(
This function starts crawling from any folder. It is slower though. This function starts crawling from any folder. It is slower though.
""" """
if parent_id in traversed_parent_ids: if parent_id in traversed_parent_ids:
print(f"Skipping subfolder since already traversed: {parent_id}") logger.info(f"Skipping subfolder since already traversed: {parent_id}")
return return
update_traversed_ids_func(parent_id) found_files = False
for file in _get_files_in_parent(
yield from _get_files_in_parent(
service=service, service=service,
personal_drive=personal_drive,
start=start, start=start,
end=end, end=end,
parent_id=parent_id, parent_id=parent_id,
) ):
found_files = True
yield file
if found_files:
update_traversed_ids_func(parent_id)
for subfolder in _get_folders_in_parent( for subfolder in _get_folders_in_parent(
service=service, service=service,
parent_id=parent_id, parent_id=parent_id,
personal_drive=personal_drive,
): ):
logger.info("Fetching all files in subfolder: " + subfolder["name"]) logger.info("Fetching all files in subfolder: " + subfolder["name"])
yield from crawl_folders_for_files( yield from crawl_folders_for_files(
service=service, service=service,
parent_id=subfolder["id"], parent_id=subfolder["id"],
personal_drive=personal_drive,
traversed_parent_ids=traversed_parent_ids, traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func, update_traversed_ids_func=update_traversed_ids_func,
start=start, start=start,
@ -133,55 +134,59 @@ def get_files_in_shared_drive(
service: Resource, service: Resource,
drive_id: str, drive_id: str,
is_slim: bool = False, is_slim: bool = False,
cache_folders: bool = True,
update_traversed_ids_func: Callable[[str], None] = lambda _: None, update_traversed_ids_func: Callable[[str], None] = lambda _: None,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]: ) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here # If we know we are going to folder crawl later, we can cache the folders here
if cache_folders: # Get all folders being queried and add them to the traversed set
# Get all folders being queried and add them to the traversed set query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" query += " and trashed = false"
query += " and trashed = false" found_folders = False
for file in execute_paginated_retrieval( for file in execute_paginated_retrieval(
retrieval_function=service.files().list, retrieval_function=service.files().list,
list_key="files", list_key="files",
corpora="drive", continue_on_404_or_403=True,
driveId=drive_id, corpora="drive",
supportsAllDrives=True, driveId=drive_id,
includeItemsFromAllDrives=True, supportsAllDrives=True,
fields="nextPageToken, files(id)", includeItemsFromAllDrives=True,
q=query, fields="nextPageToken, files(id)",
): q=query,
update_traversed_ids_func(file["id"]) ):
update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(drive_id)
# Get all files in the shared drive # Get all files in the shared drive
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false" query += " and trashed = false"
query += _generate_time_range_filter(start, end) query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval( yield from execute_paginated_retrieval(
retrieval_function=service.files().list, retrieval_function=service.files().list,
list_key="files", list_key="files",
continue_on_404_or_403=True,
corpora="drive", corpora="drive",
driveId=drive_id, driveId=drive_id,
supportsAllDrives=True, supportsAllDrives=True,
includeItemsFromAllDrives=True, includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query, q=query,
): )
yield file
def get_files_in_my_drive( def get_all_files_in_my_drive(
service: Resource, service: Any,
email: str, update_traversed_ids_func: Callable,
is_slim: bool = False, is_slim: bool = False,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]: ) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{email}' in owners" # If we know we are going to folder crawl later, we can cache the folders here
query += " and trashed = false" # Get all folders being queried and add them to the traversed set
query += _generate_time_range_filter(start, end) query = "trashed = false and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval( for file in execute_paginated_retrieval(
retrieval_function=service.files().list, retrieval_function=service.files().list,
list_key="files", list_key="files",
@ -189,7 +194,25 @@ def get_files_in_my_drive(
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query, q=query,
): ):
yield file update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(get_root_folder_id(service))
# Then get the files
query = "trashed = false and 'me' in owners"
query += _generate_time_range_filter(start, end)
fields = "files(id, name, mimeType, webViewLink, modifiedTime, createdTime)"
if not is_slim:
fields += ", files(permissions, permissionIds, owners)"
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
)
# Just in case we need to get the root folder id # Just in case we need to get the root folder id

View File

@ -6,7 +6,6 @@ from urllib.parse import urlparse
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from googleapiclient.discovery import build # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.app_configs import WEB_DOMAIN
@ -16,6 +15,8 @@ from danswer.configs.constants import KV_GMAIL_CRED_KEY
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_utils.resources import get_drive_service
from danswer.connectors.google_utils.resources import get_gmail_service
from danswer.connectors.google_utils.shared_constants import ( from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
) )
@ -45,8 +46,40 @@ from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
def _build_frontend_google_drive_redirect() -> str: def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback" if source == DocumentSource.GOOGLE_DRIVE:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
elif source == DocumentSource.GMAIL:
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
else:
raise ValueError(f"Unsupported source: {source}")
def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
drive_service = get_drive_service(creds)
user_info = (
drive_service.about()
.get(
fields="user(emailAddress)",
)
.execute()
)
email = user_info.get("user", {}).get("emailAddress")
elif source == DocumentSource.GMAIL:
gmail_service = get_gmail_service(creds)
user_info = (
gmail_service.users()
.getProfile(
userId="me",
fields="emailAddress",
)
.execute()
)
email = user_info.get("emailAddress")
else:
raise ValueError(f"Unsupported source: {source}")
return email
def verify_csrf(credential_id: int, state: str) -> None: def verify_csrf(credential_id: int, state: str) -> None:
@ -67,8 +100,8 @@ def update_credential_access_tokens(
app_credentials = get_google_app_cred(source) app_credentials = get_google_app_cred(source)
flow = InstalledAppFlow.from_client_config( flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(), app_credentials.model_dump(),
scopes=GOOGLE_SCOPES, scopes=GOOGLE_SCOPES[source],
redirect_uri=_build_frontend_google_drive_redirect(), redirect_uri=_build_frontend_google_drive_redirect(source),
) )
flow.fetch_token(code=auth_code) flow.fetch_token(code=auth_code)
creds = flow.credentials creds = flow.credentials
@ -77,15 +110,7 @@ def update_credential_access_tokens(
# Get user email from Google API so we know who # Get user email from Google API so we know who
# the primary admin is for this connector # the primary admin is for this connector
try: try:
admin_service = build("drive", "v3", credentials=creds) email = _get_current_oauth_user(creds, source)
user_info = (
admin_service.about()
.get(
fields="user(emailAddress)",
)
.execute()
)
email = user_info.get("user", {}).get("emailAddress")
except Exception as e: except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e): if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
@ -120,13 +145,18 @@ def build_service_account_creds(
) )
def get_auth_url(credential_id: int) -> str: def get_auth_url(credential_id: int, source: DocumentSource) -> str:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)) if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
credential_json = json.loads(creds_str) credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config( flow = InstalledAppFlow.from_client_config(
credential_json, credential_json,
scopes=GOOGLE_SCOPES, scopes=GOOGLE_SCOPES[source],
redirect_uri=_build_frontend_google_drive_redirect(), redirect_uri=_build_frontend_google_drive_redirect(source),
) )
auth_url, _ = flow.authorization_url(prompt="consent") auth_url, _ = flow.authorization_url(prompt="consent")

View File

@ -23,7 +23,7 @@ add_retries = retry_builder(tries=50, max_delay=30)
def _execute_with_retry(request: Any) -> Any: def _execute_with_retry(request: Any) -> Any:
max_attempts = 10 max_attempts = 10
attempt = 0 attempt = 1
while attempt < max_attempts: while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429 # Note for reasons unknown, the Google API will sometimes return a 429
@ -81,7 +81,8 @@ def _execute_with_retry(request: Any) -> Any:
def execute_paginated_retrieval( def execute_paginated_retrieval(
retrieval_function: Callable, retrieval_function: Callable,
list_key: str, list_key: str | None = None,
continue_on_404_or_403: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Iterator[GoogleDriveFileType]: ) -> Iterator[GoogleDriveFileType]:
"""Execute a paginated retrieval from Google Drive API """Execute a paginated retrieval from Google Drive API
@ -95,8 +96,30 @@ def execute_paginated_retrieval(
if next_page_token: if next_page_token:
request_kwargs["pageToken"] = next_page_token request_kwargs["pageToken"] = next_page_token
results = add_retries(lambda: retrieval_function(**request_kwargs).execute())() try:
results = retrieval_function(**request_kwargs).execute()
except HttpError as e:
if e.resp.status >= 500:
results = add_retries(
lambda: retrieval_function(**request_kwargs).execute()
)()
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.warning(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
else:
logger.exception("Error executing request:")
raise e
next_page_token = results.get("nextPageToken") next_page_token = results.get("nextPageToken")
for item in results.get(list_key, []): if list_key:
yield item for item in results.get(list_key, []):
yield item
else:
yield results

View File

@ -51,13 +51,13 @@ def get_drive_service(
def get_admin_service( def get_admin_service(
creds: ServiceAccountCredentials | OAuthCredentials, creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str, user_email: str | None = None,
) -> AdminService: ) -> AdminService:
return _get_google_service("admin", "directory_v1", creds, user_email) return _get_google_service("admin", "directory_v1", creds, user_email)
def get_gmail_service( def get_gmail_service(
creds: ServiceAccountCredentials | OAuthCredentials, creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str, user_email: str | None = None,
) -> GmailService: ) -> GmailService:
return _get_google_service("gmail", "v1", creds, user_email) return _get_google_service("gmail", "v1", creds, user_email)

View File

@ -31,7 +31,7 @@ MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requeste
# Documentation and error messages # Documentation and error messages
SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview" SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview"
ONYX_SCOPE_INSTRUCTIONS = ( ONYX_SCOPE_INSTRUCTIONS = (
"You have upgraded Danswer without updating the Google Drive scopes. " "You have upgraded Danswer without updating the Google Auth scopes. "
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}" f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
) )

View File

@ -10,7 +10,6 @@ from sqlalchemy.sql.expression import or_
from danswer.auth.schemas import UserRole from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_utils.shared_constants import ( from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
) )
@ -422,23 +421,15 @@ def cleanup_google_drive_credentials(db_session: Session) -> None:
db_session.commit() db_session.commit()
def delete_gmail_service_account_credentials( def delete_service_account_credentials(
user: User | None, db_session: Session user: User | None, db_session: Session, source: DocumentSource
) -> None: ) -> None:
credentials = fetch_credentials(db_session=db_session, user=user) credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials: for credential in credentials:
if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY): if (
db_session.delete(credential) credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
and credential.source == source
db_session.commit() ):
def delete_google_drive_service_account_credentials(
user: User | None, db_session: Session
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials:
if credential.credential_json.get(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY):
db_session.delete(credential) db_session.delete(credential)
db_session.commit() db_session.commit()

View File

@ -67,8 +67,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.credentials import cleanup_gmail_credentials from danswer.db.credentials import cleanup_gmail_credentials
from danswer.db.credentials import cleanup_google_drive_credentials from danswer.db.credentials import cleanup_google_drive_credentials
from danswer.db.credentials import create_credential from danswer.db.credentials import create_credential
from danswer.db.credentials import delete_gmail_service_account_credentials from danswer.db.credentials import delete_service_account_credentials
from danswer.db.credentials import delete_google_drive_service_account_credentials
from danswer.db.credentials import fetch_credential_by_id from danswer.db.credentials import fetch_credential_by_id
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_document_counts_for_cc_pairs from danswer.db.document import get_document_counts_for_cc_pairs
@ -309,13 +308,13 @@ def upsert_service_account_credential(
try: try:
credential_base = build_service_account_creds( credential_base = build_service_account_creds(
DocumentSource.GOOGLE_DRIVE, DocumentSource.GOOGLE_DRIVE,
primary_admin_email=service_account_credential_request.google_drive_primary_admin, primary_admin_email=service_account_credential_request.google_primary_admin,
) )
except KvKeyNotFoundError as e: except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
# first delete all existing service account credentials # first delete all existing service account credentials
delete_google_drive_service_account_credentials(user, db_session) delete_service_account_credentials(user, db_session, DocumentSource.GOOGLE_DRIVE)
# `user=None` since this credential is not a personal credential # `user=None` since this credential is not a personal credential
credential = create_credential( credential = create_credential(
credential_data=credential_base, user=user, db_session=db_session credential_data=credential_base, user=user, db_session=db_session
@ -335,13 +334,13 @@ def upsert_gmail_service_account_credential(
try: try:
credential_base = build_service_account_creds( credential_base = build_service_account_creds(
DocumentSource.GMAIL, DocumentSource.GMAIL,
primary_admin_email=service_account_credential_request.gmail_primary_admin, primary_admin_email=service_account_credential_request.google_primary_admin,
) )
except KvKeyNotFoundError as e: except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
# first delete all existing service account credentials # first delete all existing service account credentials
delete_gmail_service_account_credentials(user, db_session) delete_service_account_credentials(user, db_session, DocumentSource.GMAIL)
# `user=None` since this credential is not a personal credential # `user=None` since this credential is not a personal credential
credential = create_credential( credential = create_credential(
credential_data=credential_base, user=user, db_session=db_session credential_data=credential_base, user=user, db_session=db_session
@ -894,7 +893,7 @@ def gmail_auth(
httponly=True, httponly=True,
max_age=600, max_age=600,
) )
return AuthUrl(auth_url=get_auth_url(int(credential_id))) return AuthUrl(auth_url=get_auth_url(int(credential_id), DocumentSource.GMAIL))
@router.get("/connector/google-drive/authorize/{credential_id}") @router.get("/connector/google-drive/authorize/{credential_id}")
@ -908,7 +907,9 @@ def google_drive_auth(
httponly=True, httponly=True,
max_age=600, max_age=600,
) )
return AuthUrl(auth_url=get_auth_url(int(credential_id))) return AuthUrl(
auth_url=get_auth_url(int(credential_id), DocumentSource.GOOGLE_DRIVE)
)
@router.get("/connector/gmail/callback") @router.get("/connector/gmail/callback")
@ -925,12 +926,10 @@ def gmail_callback(
) )
credential_id = int(credential_id_cookie) credential_id = int(credential_id_cookie)
verify_csrf(credential_id, callback.state) verify_csrf(credential_id, callback.state)
if ( credentials: Credentials | None = update_credential_access_tokens(
update_credential_access_tokens( callback.code, credential_id, user, db_session, DocumentSource.GMAIL
callback.code, credential_id, user, db_session, DocumentSource.GMAIL )
) if credentials is None:
is None
):
raise HTTPException( raise HTTPException(
status_code=500, detail="Unable to fetch Gmail access tokens" status_code=500, detail="Unable to fetch Gmail access tokens"
) )

View File

@ -4,7 +4,6 @@ from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
from pydantic import Field from pydantic import Field
from pydantic import model_validator
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
@ -377,18 +376,7 @@ class GoogleServiceAccountKey(BaseModel):
class GoogleServiceAccountCredentialRequest(BaseModel): class GoogleServiceAccountCredentialRequest(BaseModel):
google_drive_primary_admin: str | None = None # email of user to impersonate google_primary_admin: str | None = None # email of user to impersonate
gmail_primary_admin: str | None = None # email of user to impersonate
@model_validator(mode="after")
def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest":
if (self.google_drive_primary_admin is None) == (
self.gmail_primary_admin is None
):
raise ValueError(
"Exactly one of google_drive_primary_admin or gmail_primary_admin must be set"
)
return self
class FileUploadResponse(BaseModel): class FileUploadResponse(BaseModel):

View File

@ -6,128 +6,99 @@ ALL_FILES = list(range(0, 60))
SHARED_DRIVE_FILES = list(range(20, 25)) SHARED_DRIVE_FILES = list(range(20, 25))
_ADMIN_FILE_IDS = list(range(0, 5)) ADMIN_FILE_IDS = list(range(0, 5))
_TEST_USER_1_FILE_IDS = list(range(5, 10)) ADMIN_FOLDER_3_FILE_IDS = list(range(65, 70))
_TEST_USER_2_FILE_IDS = list(range(10, 15)) TEST_USER_1_FILE_IDS = list(range(5, 10))
_TEST_USER_3_FILE_IDS = list(range(15, 20)) TEST_USER_2_FILE_IDS = list(range(10, 15))
_SHARED_DRIVE_1_FILE_IDS = list(range(20, 25)) TEST_USER_3_FILE_IDS = list(range(15, 20))
_FOLDER_1_FILE_IDS = list(range(25, 30)) SHARED_DRIVE_1_FILE_IDS = list(range(20, 25))
_FOLDER_1_1_FILE_IDS = list(range(30, 35)) FOLDER_1_FILE_IDS = list(range(25, 30))
_FOLDER_1_2_FILE_IDS = list(range(35, 40)) FOLDER_1_1_FILE_IDS = list(range(30, 35))
_SHARED_DRIVE_2_FILE_IDS = list(range(40, 45)) FOLDER_1_2_FILE_IDS = list(range(35, 40))
_FOLDER_2_FILE_IDS = list(range(45, 50)) SHARED_DRIVE_2_FILE_IDS = list(range(40, 45))
_FOLDER_2_1_FILE_IDS = list(range(50, 55)) FOLDER_2_FILE_IDS = list(range(45, 50))
_FOLDER_2_2_FILE_IDS = list(range(55, 60)) FOLDER_2_1_FILE_IDS = list(range(50, 55))
_SECTIONS_FILE_IDS = [61] FOLDER_2_2_FILE_IDS = list(range(55, 60))
SECTIONS_FILE_IDS = [61]
_PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_IDS PUBLIC_FOLDER_RANGE = FOLDER_1_2_FILE_IDS
_PUBLIC_FILE_IDS = list(range(55, 57)) PUBLIC_FILE_IDS = list(range(55, 57))
PUBLIC_RANGE = _PUBLIC_FOLDER_RANGE + _PUBLIC_FILE_IDS PUBLIC_RANGE = PUBLIC_FOLDER_RANGE + PUBLIC_FILE_IDS
_SHARED_DRIVE_1_URL = "https://drive.google.com/drive/folders/0AC_OJ4BkMd4kUk9PVA" SHARED_DRIVE_1_URL = "https://drive.google.com/drive/folders/0AC_OJ4BkMd4kUk9PVA"
# Group 1 is given access to this folder # Group 1 is given access to this folder
_FOLDER_1_URL = ( FOLDER_1_URL = (
"https://drive.google.com/drive/folders/1d3I7U3vUZMDziF1OQqYRkB8Jp2s_GWUn" "https://drive.google.com/drive/folders/1d3I7U3vUZMDziF1OQqYRkB8Jp2s_GWUn"
) )
_FOLDER_1_1_URL = ( FOLDER_1_1_URL = (
"https://drive.google.com/drive/folders/1aR33-zwzl_mnRAwH55GgtWTE-4A4yWWI" "https://drive.google.com/drive/folders/1aR33-zwzl_mnRAwH55GgtWTE-4A4yWWI"
) )
_FOLDER_1_2_URL = ( FOLDER_1_2_URL = (
"https://drive.google.com/drive/folders/1IO0X55VhvLXf4mdxzHxuKf4wxrDBB6jq" "https://drive.google.com/drive/folders/1IO0X55VhvLXf4mdxzHxuKf4wxrDBB6jq"
) )
_SHARED_DRIVE_2_URL = "https://drive.google.com/drive/folders/0ABKspIh7P4f4Uk9PVA" SHARED_DRIVE_2_URL = "https://drive.google.com/drive/folders/0ABKspIh7P4f4Uk9PVA"
_FOLDER_2_URL = ( FOLDER_2_URL = (
"https://drive.google.com/drive/folders/1lNpCJ1teu8Se0louwL0oOHK9nEalskof" "https://drive.google.com/drive/folders/1lNpCJ1teu8Se0louwL0oOHK9nEalskof"
) )
_FOLDER_2_1_URL = ( FOLDER_2_1_URL = (
"https://drive.google.com/drive/folders/1XeDOMWwxTDiVr9Ig2gKum3Zq_Wivv6zY" "https://drive.google.com/drive/folders/1XeDOMWwxTDiVr9Ig2gKum3Zq_Wivv6zY"
) )
_FOLDER_2_2_URL = ( FOLDER_2_2_URL = (
"https://drive.google.com/drive/folders/1RKlsexA8h7NHvBAWRbU27MJotic7KXe3" "https://drive.google.com/drive/folders/1RKlsexA8h7NHvBAWRbU27MJotic7KXe3"
) )
FOLDER_3_URL = (
"https://drive.google.com/drive/folders/1LHibIEXfpUmqZ-XjBea44SocA91Nkveu"
)
SECTIONS_FOLDER_URL = (
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
)
_ADMIN_EMAIL = "admin@onyx-test.com" ADMIN_EMAIL = "admin@onyx-test.com"
_TEST_USER_1_EMAIL = "test_user_1@onyx-test.com" TEST_USER_1_EMAIL = "test_user_1@onyx-test.com"
_TEST_USER_2_EMAIL = "test_user_2@onyx-test.com" TEST_USER_2_EMAIL = "test_user_2@onyx-test.com"
_TEST_USER_3_EMAIL = "test_user_3@onyx-test.com" TEST_USER_3_EMAIL = "test_user_3@onyx-test.com"
# Dictionary for ranges
DRIVE_ID_MAPPING: dict[str, list[int]] = {
"ADMIN": _ADMIN_FILE_IDS,
"TEST_USER_1": _TEST_USER_1_FILE_IDS,
"TEST_USER_2": _TEST_USER_2_FILE_IDS,
"TEST_USER_3": _TEST_USER_3_FILE_IDS,
"SHARED_DRIVE_1": _SHARED_DRIVE_1_FILE_IDS,
"FOLDER_1": _FOLDER_1_FILE_IDS,
"FOLDER_1_1": _FOLDER_1_1_FILE_IDS,
"FOLDER_1_2": _FOLDER_1_2_FILE_IDS,
"SHARED_DRIVE_2": _SHARED_DRIVE_2_FILE_IDS,
"FOLDER_2": _FOLDER_2_FILE_IDS,
"FOLDER_2_1": _FOLDER_2_1_FILE_IDS,
"FOLDER_2_2": _FOLDER_2_2_FILE_IDS,
"SECTIONS": _SECTIONS_FILE_IDS,
}
# Dictionary for emails
EMAIL_MAPPING: dict[str, str] = {
"ADMIN": _ADMIN_EMAIL,
"TEST_USER_1": _TEST_USER_1_EMAIL,
"TEST_USER_2": _TEST_USER_2_EMAIL,
"TEST_USER_3": _TEST_USER_3_EMAIL,
}
# Dictionary for URLs
URL_MAPPING: dict[str, str] = {
"SHARED_DRIVE_1": _SHARED_DRIVE_1_URL,
"FOLDER_1": _FOLDER_1_URL,
"FOLDER_1_1": _FOLDER_1_1_URL,
"FOLDER_1_2": _FOLDER_1_2_URL,
"SHARED_DRIVE_2": _SHARED_DRIVE_2_URL,
"FOLDER_2": _FOLDER_2_URL,
"FOLDER_2_1": _FOLDER_2_1_URL,
"FOLDER_2_2": _FOLDER_2_2_URL,
}
# Dictionary for access permissions # Dictionary for access permissions
# All users have access to their own My Drive as well as public files # All users have access to their own My Drive as well as public files
ACCESS_MAPPING: dict[str, list[int]] = { ACCESS_MAPPING: dict[str, list[int]] = {
# Admin has access to everything in shared # Admin has access to everything in shared
"ADMIN": ( ADMIN_EMAIL: (
_ADMIN_FILE_IDS ADMIN_FILE_IDS
+ _SHARED_DRIVE_1_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
+ _FOLDER_1_FILE_IDS + SHARED_DRIVE_1_FILE_IDS
+ _FOLDER_1_1_FILE_IDS + FOLDER_1_FILE_IDS
+ _FOLDER_1_2_FILE_IDS + FOLDER_1_1_FILE_IDS
+ _SHARED_DRIVE_2_FILE_IDS + FOLDER_1_2_FILE_IDS
+ _FOLDER_2_FILE_IDS + SHARED_DRIVE_2_FILE_IDS
+ _FOLDER_2_1_FILE_IDS + FOLDER_2_FILE_IDS
+ _FOLDER_2_2_FILE_IDS + FOLDER_2_1_FILE_IDS
+ _SECTIONS_FILE_IDS + FOLDER_2_2_FILE_IDS
+ SECTIONS_FILE_IDS
), ),
# This user has access to drive 1 # This user has access to drive 1
# This user has redundant access to folder 1 because of group access # This user has redundant access to folder 1 because of group access
# This user has been given individual access to files in Admin's My Drive # This user has been given individual access to files in Admin's My Drive
"TEST_USER_1": ( TEST_USER_1_EMAIL: (
_TEST_USER_1_FILE_IDS TEST_USER_1_FILE_IDS
+ _SHARED_DRIVE_1_FILE_IDS + SHARED_DRIVE_1_FILE_IDS
+ _FOLDER_1_FILE_IDS + FOLDER_1_FILE_IDS
+ _FOLDER_1_1_FILE_IDS + FOLDER_1_1_FILE_IDS
+ _FOLDER_1_2_FILE_IDS + FOLDER_1_2_FILE_IDS
+ list(range(0, 2)) + list(range(0, 2))
), ),
# Group 1 includes this user, giving access to folder 1 # Group 1 includes this user, giving access to folder 1
# This user has also been given access to folder 2-1 # This user has also been given access to folder 2-1
# This user has also been given individual access to files in folder 2 # This user has also been given individual access to files in folder 2
"TEST_USER_2": ( TEST_USER_2_EMAIL: (
_TEST_USER_2_FILE_IDS TEST_USER_2_FILE_IDS
+ _FOLDER_1_FILE_IDS + FOLDER_1_FILE_IDS
+ _FOLDER_1_1_FILE_IDS + FOLDER_1_1_FILE_IDS
+ _FOLDER_1_2_FILE_IDS + FOLDER_1_2_FILE_IDS
+ _FOLDER_2_1_FILE_IDS + FOLDER_2_1_FILE_IDS
+ list(range(45, 47)) + list(range(45, 47))
), ),
# This user can only see his own files and public files # This user can only see his own files and public files
"TEST_USER_3": _TEST_USER_3_FILE_IDS, TEST_USER_3_EMAIL: TEST_USER_3_FILE_IDS,
} }
SPECIAL_FILE_ID_TO_CONTENT_MAP: dict[int, str] = { SPECIAL_FILE_ID_TO_CONTENT_MAP: dict[int, str] = {

View File

@ -5,12 +5,29 @@ from unittest.mock import patch
from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.models import Document from danswer.connectors.models import Document
from tests.daily.connectors.google_drive.helpers import ( from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_retrieved_docs_match_expected, assert_retrieved_docs_match_expected,
) )
from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.helpers import URL_MAPPING from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_3_EMAIL
@patch( @patch(
@ -32,16 +49,17 @@ def test_include_all(
# Should get everything in shared and admin's My Drive with oauth # Should get everything in shared and admin's My Drive with oauth
expected_file_ids = ( expected_file_ids = (
DRIVE_ID_MAPPING["ADMIN"] ADMIN_FILE_IDS
+ DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + ADMIN_FOLDER_3_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1"] + SHARED_DRIVE_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1_1"] + FOLDER_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1_2"] + FOLDER_1_1_FILE_IDS
+ DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + FOLDER_1_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2"] + SHARED_DRIVE_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2_1"] + FOLDER_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2_2"] + FOLDER_2_1_FILE_IDS
+ DRIVE_ID_MAPPING["SECTIONS"] + FOLDER_2_2_FILE_IDS
+ SECTIONS_FILE_IDS
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
@ -68,15 +86,15 @@ def test_include_shared_drives_only(
# Should only get shared drives # Should only get shared drives
expected_file_ids = ( expected_file_ids = (
DRIVE_ID_MAPPING["SHARED_DRIVE_1"] SHARED_DRIVE_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1"] + FOLDER_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1_1"] + FOLDER_1_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1_2"] + FOLDER_1_2_FILE_IDS
+ DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + SHARED_DRIVE_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2"] + FOLDER_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2_1"] + FOLDER_2_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2_2"] + FOLDER_2_2_FILE_IDS
+ DRIVE_ID_MAPPING["SECTIONS"] + SECTIONS_FILE_IDS
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
@ -101,8 +119,8 @@ def test_include_my_drives_only(
for doc_batch in connector.poll_source(0, time.time()): for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch) retrieved_docs.extend(doc_batch)
# Should only get everyone's My Drives # Should only get primary_admins My Drive because we are impersonating them
expected_file_ids = list(range(0, 5)) # Admin's My Drive only expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids, expected_file_ids=expected_file_ids,
@ -118,9 +136,7 @@ def test_drive_one_only(
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None: ) -> None:
print("\n\nRunning test_drive_one_only") print("\n\nRunning test_drive_one_only")
drive_urls = [ drive_urls = [SHARED_DRIVE_1_URL]
URL_MAPPING["SHARED_DRIVE_1"],
]
connector = google_drive_oauth_connector_factory( connector = google_drive_oauth_connector_factory(
include_shared_drives=True, include_shared_drives=True,
include_my_drives=False, include_my_drives=False,
@ -130,8 +146,12 @@ def test_drive_one_only(
for doc_batch in connector.poll_source(0, time.time()): for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch) retrieved_docs.extend(doc_batch)
# We ignore shared_drive_urls if include_shared_drives is False expected_file_ids = (
expected_file_ids = list(range(20, 40)) # Shared Drive 1 and its folders SHARED_DRIVE_1_FILE_IDS
+ FOLDER_1_FILE_IDS
+ FOLDER_1_1_FILE_IDS
+ FOLDER_1_2_FILE_IDS
)
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids, expected_file_ids=expected_file_ids,
@ -147,8 +167,8 @@ def test_folder_and_shared_drive(
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None: ) -> None:
print("\n\nRunning test_folder_and_shared_drive") print("\n\nRunning test_folder_and_shared_drive")
drive_urls = [URL_MAPPING["SHARED_DRIVE_1"]] drive_urls = [SHARED_DRIVE_1_URL]
folder_urls = [URL_MAPPING["FOLDER_2"]] folder_urls = [FOLDER_2_URL]
connector = google_drive_oauth_connector_factory( connector = google_drive_oauth_connector_factory(
include_shared_drives=True, include_shared_drives=True,
include_my_drives=True, include_my_drives=True,
@ -159,11 +179,16 @@ def test_folder_and_shared_drive(
for doc_batch in connector.poll_source(0, time.time()): for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch) retrieved_docs.extend(doc_batch)
# Should
expected_file_ids = ( expected_file_ids = (
list(range(0, 5)) # Admin's My Drive ADMIN_FILE_IDS
+ list(range(20, 40)) # Shared Drive 1 and its folders + ADMIN_FOLDER_3_FILE_IDS
+ list(range(45, 60)) # Folder 2 and its subfolders + SHARED_DRIVE_1_FILE_IDS
+ FOLDER_1_FILE_IDS
+ FOLDER_1_1_FILE_IDS
+ FOLDER_1_2_FILE_IDS
+ FOLDER_2_FILE_IDS
+ FOLDER_2_1_FILE_IDS
+ FOLDER_2_2_FILE_IDS
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
@ -181,23 +206,32 @@ def test_folders_only(
) -> None: ) -> None:
print("\n\nRunning test_folders_only") print("\n\nRunning test_folders_only")
folder_urls = [ folder_urls = [
URL_MAPPING["FOLDER_1_1"], FOLDER_1_2_URL,
URL_MAPPING["FOLDER_1_2"], FOLDER_2_1_URL,
URL_MAPPING["FOLDER_2_1"], FOLDER_2_2_URL,
URL_MAPPING["FOLDER_2_2"], FOLDER_3_URL,
]
# This should get converted to a drive request and spit out a warning in the logs
shared_drive_urls = [
FOLDER_1_1_URL,
] ]
connector = google_drive_oauth_connector_factory( connector = google_drive_oauth_connector_factory(
include_shared_drives=False, include_shared_drives=False,
include_my_drives=False, include_my_drives=False,
shared_drive_urls=",".join([str(url) for url in shared_drive_urls]),
shared_folder_urls=",".join([str(url) for url in folder_urls]), shared_folder_urls=",".join([str(url) for url in folder_urls]),
) )
retrieved_docs: list[Document] = [] retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()): for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch) retrieved_docs.extend(doc_batch)
expected_file_ids = list(range(30, 40)) + list( # Folders 1_1 and 1_2 expected_file_ids = (
range(50, 60) FOLDER_1_1_FILE_IDS
) # Folders 2_1 and 2_2 + FOLDER_1_2_FILE_IDS
+ FOLDER_2_1_FILE_IDS
+ FOLDER_2_2_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
)
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids, expected_file_ids=expected_file_ids,
@ -214,8 +248,8 @@ def test_specific_emails(
) -> None: ) -> None:
print("\n\nRunning test_specific_emails") print("\n\nRunning test_specific_emails")
my_drive_emails = [ my_drive_emails = [
EMAIL_MAPPING["TEST_USER_1"], TEST_USER_1_EMAIL,
EMAIL_MAPPING["TEST_USER_3"], TEST_USER_3_EMAIL,
] ]
connector = google_drive_oauth_connector_factory( connector = google_drive_oauth_connector_factory(
include_shared_drives=False, include_shared_drives=False,
@ -228,7 +262,35 @@ def test_specific_emails(
# No matter who is specified, when using oauth, if include_my_drives is True, # No matter who is specified, when using oauth, if include_my_drives is True,
# we will get all the files from the admin's My Drive # we will get all the files from the admin's My Drive
expected_file_ids = DRIVE_ID_MAPPING["ADMIN"] expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_personal_folders_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_personal_folders_only")
folder_urls = [
FOLDER_3_URL,
]
connector = google_drive_oauth_connector_factory(
include_shared_drives=False,
include_my_drives=False,
shared_folder_urls=",".join([str(url) for url in folder_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids, expected_file_ids=expected_file_ids,

View File

@ -5,11 +5,7 @@ from unittest.mock import patch
from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.models import Document from danswer.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_URL
SECTIONS_FOLDER_URL = (
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
)
@patch( @patch(

View File

@ -5,12 +5,32 @@ from unittest.mock import patch
from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.models import Document from danswer.connectors.models import Document
from tests.daily.connectors.google_drive.helpers import ( from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_retrieved_docs_match_expected, assert_retrieved_docs_match_expected,
) )
from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.helpers import URL_MAPPING from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_3_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_3_FILE_IDS
@patch( @patch(
@ -32,19 +52,20 @@ def test_include_all(
# Should get everything # Should get everything
expected_file_ids = ( expected_file_ids = (
DRIVE_ID_MAPPING["ADMIN"] ADMIN_FILE_IDS
+ DRIVE_ID_MAPPING["TEST_USER_1"] + ADMIN_FOLDER_3_FILE_IDS
+ DRIVE_ID_MAPPING["TEST_USER_2"] + TEST_USER_1_FILE_IDS
+ DRIVE_ID_MAPPING["TEST_USER_3"] + TEST_USER_2_FILE_IDS
+ DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + TEST_USER_3_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1"] + SHARED_DRIVE_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1_1"] + FOLDER_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1_2"] + FOLDER_1_1_FILE_IDS
+ DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + FOLDER_1_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2"] + SHARED_DRIVE_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2_1"] + FOLDER_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2_2"] + FOLDER_2_1_FILE_IDS
+ DRIVE_ID_MAPPING["SECTIONS"] + FOLDER_2_2_FILE_IDS
+ SECTIONS_FILE_IDS
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
@ -71,15 +92,15 @@ def test_include_shared_drives_only(
# Should only get shared drives # Should only get shared drives
expected_file_ids = ( expected_file_ids = (
DRIVE_ID_MAPPING["SHARED_DRIVE_1"] SHARED_DRIVE_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1"] + FOLDER_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1_1"] + FOLDER_1_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_1_2"] + FOLDER_1_2_FILE_IDS
+ DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + SHARED_DRIVE_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2"] + FOLDER_2_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2_1"] + FOLDER_2_1_FILE_IDS
+ DRIVE_ID_MAPPING["FOLDER_2_2"] + FOLDER_2_2_FILE_IDS
+ DRIVE_ID_MAPPING["SECTIONS"] + SECTIONS_FILE_IDS
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
@ -105,7 +126,13 @@ def test_include_my_drives_only(
retrieved_docs.extend(doc_batch) retrieved_docs.extend(doc_batch)
# Should only get everyone's My Drives # Should only get everyone's My Drives
expected_file_ids = list(range(0, 20)) # All My Drives expected_file_ids = (
ADMIN_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
+ TEST_USER_1_FILE_IDS
+ TEST_USER_2_FILE_IDS
+ TEST_USER_3_FILE_IDS
)
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids, expected_file_ids=expected_file_ids,
@ -121,7 +148,7 @@ def test_drive_one_only(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None: ) -> None:
print("\n\nRunning test_drive_one_only") print("\n\nRunning test_drive_one_only")
urls = [URL_MAPPING["SHARED_DRIVE_1"]] urls = [SHARED_DRIVE_1_URL]
connector = google_drive_service_acct_connector_factory( connector = google_drive_service_acct_connector_factory(
include_shared_drives=True, include_shared_drives=True,
include_my_drives=False, include_my_drives=False,
@ -132,7 +159,12 @@ def test_drive_one_only(
retrieved_docs.extend(doc_batch) retrieved_docs.extend(doc_batch)
# We ignore shared_drive_urls if include_shared_drives is False # We ignore shared_drive_urls if include_shared_drives is False
expected_file_ids = list(range(20, 40)) # Shared Drive 1 and its folders expected_file_ids = (
SHARED_DRIVE_1_FILE_IDS
+ FOLDER_1_FILE_IDS
+ FOLDER_1_1_FILE_IDS
+ FOLDER_1_2_FILE_IDS
)
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids, expected_file_ids=expected_file_ids,
@ -148,10 +180,8 @@ def test_folder_and_shared_drive(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None: ) -> None:
print("\n\nRunning test_folder_and_shared_drive") print("\n\nRunning test_folder_and_shared_drive")
drive_urls = [ drive_urls = [SHARED_DRIVE_1_URL]
URL_MAPPING["SHARED_DRIVE_1"], folder_urls = [FOLDER_2_URL]
]
folder_urls = [URL_MAPPING["FOLDER_2"]]
connector = google_drive_service_acct_connector_factory( connector = google_drive_service_acct_connector_factory(
include_shared_drives=True, include_shared_drives=True,
include_my_drives=True, include_my_drives=True,
@ -162,11 +192,20 @@ def test_folder_and_shared_drive(
for doc_batch in connector.poll_source(0, time.time()): for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch) retrieved_docs.extend(doc_batch)
# Should # Should get everything except for the top level files in drive 2
expected_file_ids = ( expected_file_ids = (
list(range(0, 20)) # All My Drives ADMIN_FILE_IDS
+ list(range(20, 40)) # Shared Drive 1 and its folders + ADMIN_FOLDER_3_FILE_IDS
+ list(range(45, 60)) # Folder 2 and its subfolders + TEST_USER_1_FILE_IDS
+ TEST_USER_2_FILE_IDS
+ TEST_USER_3_FILE_IDS
+ SHARED_DRIVE_1_FILE_IDS
+ FOLDER_1_FILE_IDS
+ FOLDER_1_1_FILE_IDS
+ FOLDER_1_2_FILE_IDS
+ FOLDER_2_FILE_IDS
+ FOLDER_2_1_FILE_IDS
+ FOLDER_2_2_FILE_IDS
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
@ -184,23 +223,32 @@ def test_folders_only(
) -> None: ) -> None:
print("\n\nRunning test_folders_only") print("\n\nRunning test_folders_only")
folder_urls = [ folder_urls = [
URL_MAPPING["FOLDER_1_1"], FOLDER_1_2_URL,
URL_MAPPING["FOLDER_1_2"], FOLDER_2_1_URL,
URL_MAPPING["FOLDER_2_1"], FOLDER_2_2_URL,
URL_MAPPING["FOLDER_2_2"], FOLDER_3_URL,
]
# This should get converted to a drive request and spit out a warning in the logs
shared_drive_urls = [
FOLDER_1_1_URL,
] ]
connector = google_drive_service_acct_connector_factory( connector = google_drive_service_acct_connector_factory(
include_shared_drives=False, include_shared_drives=False,
include_my_drives=False, include_my_drives=False,
shared_drive_urls=",".join([str(url) for url in shared_drive_urls]),
shared_folder_urls=",".join([str(url) for url in folder_urls]), shared_folder_urls=",".join([str(url) for url in folder_urls]),
) )
retrieved_docs: list[Document] = [] retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()): for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch) retrieved_docs.extend(doc_batch)
expected_file_ids = list(range(30, 40)) + list( # Folders 1_1 and 1_2 expected_file_ids = (
range(50, 60) FOLDER_1_1_FILE_IDS
) # Folders 2_1 and 2_2 + FOLDER_1_2_FILE_IDS
+ FOLDER_2_1_FILE_IDS
+ FOLDER_2_2_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
)
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids, expected_file_ids=expected_file_ids,
@ -217,8 +265,8 @@ def test_specific_emails(
) -> None: ) -> None:
print("\n\nRunning test_specific_emails") print("\n\nRunning test_specific_emails")
my_drive_emails = [ my_drive_emails = [
EMAIL_MAPPING["TEST_USER_1"], TEST_USER_1_EMAIL,
EMAIL_MAPPING["TEST_USER_3"], TEST_USER_3_EMAIL,
] ]
connector = google_drive_service_acct_connector_factory( connector = google_drive_service_acct_connector_factory(
include_shared_drives=False, include_shared_drives=False,
@ -229,9 +277,64 @@ def test_specific_emails(
for doc_batch in connector.poll_source(0, time.time()): for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch) retrieved_docs.extend(doc_batch)
expected_file_ids = list(range(5, 10)) + list( expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS
range(15, 20) assert_retrieved_docs_match_expected(
) # TEST_USER_1 and TEST_USER_3 My Drives retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def get_specific_folders_in_my_drive(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning get_specific_folders_in_my_drive")
my_drive_emails = [
TEST_USER_1_EMAIL,
TEST_USER_3_EMAIL,
]
connector = google_drive_service_acct_connector_factory(
include_shared_drives=False,
include_my_drives=True,
my_drive_emails=",".join([str(email) for email in my_drive_emails]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_personal_folders_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_personal_folders_only")
folder_urls = [
FOLDER_3_URL,
]
connector = google_drive_oauth_connector_factory(
include_shared_drives=False,
include_my_drives=False,
shared_folder_urls=",".join([str(url) for url in folder_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids, expected_file_ids=expected_file_ids,

View File

@ -10,10 +10,28 @@ from danswer.connectors.google_utils.resources import get_admin_service
from ee.danswer.external_permissions.google_drive.doc_sync import ( from ee.danswer.external_permissions.google_drive.doc_sync import (
_get_permissions_from_slim_doc, _get_permissions_from_slim_doc,
) )
from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING from tests.daily.connectors.google_drive.consts_and_utils import ACCESS_MAPPING
from tests.daily.connectors.google_drive.helpers import file_name_template from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.helpers import print_discrepencies from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.helpers import PUBLIC_RANGE from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import file_name_template
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import print_discrepencies
from tests.daily.connectors.google_drive.consts_and_utils import PUBLIC_RANGE
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_2_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_3_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_3_FILE_IDS
def get_keys_available_to_user_from_access_map( def get_keys_available_to_user_from_access_map(
@ -113,72 +131,71 @@ def test_all_permissions(
) )
access_map: dict[str, ExternalAccess] = {} access_map: dict[str, ExternalAccess] = {}
found_file_names = set()
for slim_doc_batch in google_drive_connector.retrieve_all_slim_documents( for slim_doc_batch in google_drive_connector.retrieve_all_slim_documents(
0, time.time() 0, time.time()
): ):
for slim_doc in slim_doc_batch: for slim_doc in slim_doc_batch:
access_map[ name = (slim_doc.perm_sync_data or {})["name"]
(slim_doc.perm_sync_data or {})["name"] access_map[name] = _get_permissions_from_slim_doc(
] = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector, google_drive_connector=google_drive_connector,
slim_doc=slim_doc, slim_doc=slim_doc,
) )
found_file_names.add(name)
for file_name, external_access in access_map.items(): for file_name, external_access in access_map.items():
print(file_name, external_access) print(file_name, external_access)
expected_file_range = ( expected_file_range = (
list(range(0, 5)) # Admin's My Drive ADMIN_FILE_IDS # Admin's My Drive
+ list(range(5, 10)) # TEST_USER_1's My Drive + ADMIN_FOLDER_3_FILE_IDS # Admin's Folder 3
+ list(range(10, 15)) # TEST_USER_2's My Drive + TEST_USER_1_FILE_IDS # TEST_USER_1's My Drive
+ list(range(15, 20)) # TEST_USER_3's My Drive + TEST_USER_2_FILE_IDS # TEST_USER_2's My Drive
+ list(range(20, 25)) # Shared Drive 1 + TEST_USER_3_FILE_IDS # TEST_USER_3's My Drive
+ list(range(25, 30)) # Folder 1 + SHARED_DRIVE_1_FILE_IDS # Shared Drive 1
+ list(range(30, 35)) # Folder 1_1 + FOLDER_1_FILE_IDS # Folder 1
+ list(range(35, 40)) # Folder 1_2 + FOLDER_1_1_FILE_IDS # Folder 1_1
+ list(range(40, 45)) # Shared Drive 2 + FOLDER_1_2_FILE_IDS # Folder 1_2
+ list(range(45, 50)) # Folder 2 + SHARED_DRIVE_2_FILE_IDS # Shared Drive 2
+ list(range(50, 55)) # Folder 2_1 + FOLDER_2_FILE_IDS # Folder 2
+ list(range(55, 60)) # Folder 2_2 + FOLDER_2_1_FILE_IDS # Folder 2_1
+ [61] # Sections + FOLDER_2_2_FILE_IDS # Folder 2_2
+ SECTIONS_FILE_IDS # Sections
) )
expected_file_names = {
file_name_template.format(file_id) for file_id in expected_file_range
}
# Should get everything # Should get everything
assert len(access_map) == len(expected_file_range) print_discrepencies(expected_file_names, found_file_names)
assert expected_file_names == found_file_names
group_map = get_group_map(google_drive_connector) group_map = get_group_map(google_drive_connector)
print("groups:\n", group_map) print("groups:\n", group_map)
assert_correct_access_for_user( assert_correct_access_for_user(
user_email=EMAIL_MAPPING["ADMIN"], user_email=ADMIN_EMAIL,
expected_access_ids=list(range(0, 5)) # Admin's My Drive expected_access_ids=ACCESS_MAPPING[ADMIN_EMAIL],
+ list(range(20, 60)) # All shared drive content
+ [61], # Sections
group_map=group_map, group_map=group_map,
retrieved_access_map=access_map, retrieved_access_map=access_map,
) )
assert_correct_access_for_user( assert_correct_access_for_user(
user_email=EMAIL_MAPPING["TEST_USER_1"], user_email=TEST_USER_1_EMAIL,
expected_access_ids=list(range(5, 10)) # TEST_USER_1's My Drive expected_access_ids=ACCESS_MAPPING[TEST_USER_1_EMAIL],
+ list(range(20, 40)) # Shared Drive 1 and its folders
+ list(range(0, 2)), # Access to some of Admin's files
group_map=group_map, group_map=group_map,
retrieved_access_map=access_map, retrieved_access_map=access_map,
) )
assert_correct_access_for_user( assert_correct_access_for_user(
user_email=EMAIL_MAPPING["TEST_USER_2"], user_email=TEST_USER_2_EMAIL,
expected_access_ids=list(range(10, 15)) # TEST_USER_2's My Drive expected_access_ids=ACCESS_MAPPING[TEST_USER_2_EMAIL],
+ list(range(25, 40)) # Folder 1 and its subfolders
+ list(range(50, 55)) # Folder 2_1
+ list(range(45, 47)), # Some files in Folder 2
group_map=group_map, group_map=group_map,
retrieved_access_map=access_map, retrieved_access_map=access_map,
) )
assert_correct_access_for_user( assert_correct_access_for_user(
user_email=EMAIL_MAPPING["TEST_USER_3"], user_email=TEST_USER_3_EMAIL,
expected_access_ids=list(range(15, 20)), # TEST_USER_3's My Drive only expected_access_ids=ACCESS_MAPPING[TEST_USER_3_EMAIL],
group_map=group_map, group_map=group_map,
retrieved_access_map=access_map, retrieved_access_map=access_map,
) )

View File

@ -360,22 +360,12 @@ export const DriveAuthSection = ({
if (serviceAccountKeyData?.service_account_email) { if (serviceAccountKeyData?.service_account_email) {
return ( return (
<div> <div>
<p className="text-sm mb-6">
When using a Google Drive Service Account, you must specify the email
of the primary admin that you would like the service account to
impersonate.
<br />
<br />
Ideally, this account should be an owner/admin of the Google
Organization that owns the Google Drive(s) you want to index.
</p>
<Formik <Formik
initialValues={{ initialValues={{
google_drive_primary_admin: user?.email || "", google_primary_admin: user?.email || "",
}} }}
validationSchema={Yup.object().shape({ validationSchema={Yup.object().shape({
google_drive_primary_admin: Yup.string().required( google_primary_admin: Yup.string().required(
"User email is required" "User email is required"
), ),
})} })}
@ -389,7 +379,7 @@ export const DriveAuthSection = ({
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: JSON.stringify({ body: JSON.stringify({
google_drive_primary_admin: values.google_drive_primary_admin, google_primary_admin: values.google_primary_admin,
}), }),
} }
); );
@ -412,9 +402,9 @@ export const DriveAuthSection = ({
{({ isSubmitting }) => ( {({ isSubmitting }) => (
<Form> <Form>
<TextFormField <TextFormField
name="google_drive_primary_admin" name="google_primary_admin"
label="Primary Admin Email:" label="Primary Admin Email:"
subtext="Enter the email of the user whose Google Drive access you want to delegate to the service account." subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
/> />
<div className="flex"> <div className="flex">
<TremorButton type="submit" disabled={isSubmitting}> <TremorButton type="submit" disabled={isSubmitting}>

View File

@ -101,12 +101,14 @@ const GDriveMain = ({}: {}) => {
| Credential<GoogleDriveCredentialJson> | Credential<GoogleDriveCredentialJson>
| undefined = credentialsData.find( | undefined = credentialsData.find(
(credential) => (credential) =>
credential.credential_json?.google_drive_tokens && credential.admin_public credential.credential_json?.google_tokens &&
credential.admin_public &&
credential.source === "google_drive"
); );
const googleDriveServiceAccountCredential: const googleDriveServiceAccountCredential:
| Credential<GoogleDriveServiceAccountCredentialJson> | Credential<GoogleDriveServiceAccountCredentialJson>
| undefined = credentialsData.find( | undefined = credentialsData.find(
(credential) => credential.credential_json?.google_drive_service_account_key (credential) => credential.credential_json?.google_service_account_key
); );
const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus< const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus<

View File

@ -276,10 +276,12 @@ export const GmailJsonUploadSection = ({
> >
here here
</a>{" "} </a>{" "}
to setup a google OAuth App in your company workspace. to either (1) setup a google OAuth App in your company workspace or (2)
create a Service Account.
<br /> <br />
<br /> <br />
Download the credentials JSON and upload it here. Download the credentials JSON if choosing option (1) or the Service
Account key JSON if chooosing option (2), and upload it here.
</p> </p>
<DriveJsonUpload setPopup={setPopup} /> <DriveJsonUpload setPopup={setPopup} />
</div> </div>
@ -344,23 +346,13 @@ export const GmailAuthSection = ({
if (serviceAccountKeyData?.service_account_email) { if (serviceAccountKeyData?.service_account_email) {
return ( return (
<div> <div>
<p className="text-sm mb-2">
When using a Gmail Service Account, you must specify the email of the
primary admin that you would like the service account to impersonate.
<br />
<br />
For this connector to index all users Gmail, the primary admin email
should be an owner/admin of the Google Organization that being
indexed.
</p>
<CardSection> <CardSection>
<Formik <Formik
initialValues={{ initialValues={{
gmail_primary_admin: user?.email || "", google_primary_admin: user?.email || "",
}} }}
validationSchema={Yup.object().shape({ validationSchema={Yup.object().shape({
gmail_primary_admin: Yup.string().required(), google_primary_admin: Yup.string().required(),
})} })}
onSubmit={async (values, formikHelpers) => { onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true); formikHelpers.setSubmitting(true);
@ -373,7 +365,7 @@ export const GmailAuthSection = ({
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: JSON.stringify({ body: JSON.stringify({
gmail_primary_admin: values.gmail_primary_admin, google_primary_admin: values.google_primary_admin,
}), }),
} }
); );
@ -396,7 +388,7 @@ export const GmailAuthSection = ({
{({ isSubmitting }) => ( {({ isSubmitting }) => (
<Form> <Form>
<TextFormField <TextFormField
name="gmail_primary_admin" name="google_primary_admin"
label="Primary Admin Email:" label="Primary Admin Email:"
subtext="You must provide an admin/owner account to retrieve all org emails." subtext="You must provide an admin/owner account to retrieve all org emails."
/> />
@ -457,8 +449,8 @@ export const GmailAuthSection = ({
// case where no keys have been uploaded in step 1 // case where no keys have been uploaded in step 1
return ( return (
<p className="text-sm"> <p className="text-sm">
Please upload a OAuth Client Credential JSON in Step 1 before moving onto Please upload an OAuth or Service Account Credential JSON in Step 1 before
Step 2. moving onto Step 2.
</p> </p>
); );
}; };

View File

@ -109,12 +109,16 @@ export const GmailMain = () => {
const gmailPublicCredential: Credential<GmailCredentialJson> | undefined = const gmailPublicCredential: Credential<GmailCredentialJson> | undefined =
credentialsData.find( credentialsData.find(
(credential) => (credential) =>
credential.credential_json?.gmail_tokens && credential.admin_public (credential.credential_json?.google_service_account_key ||
credential.credential_json?.google_tokens) &&
credential.admin_public
); );
const gmailServiceAccountCredential: const gmailServiceAccountCredential:
| Credential<GmailServiceAccountCredentialJson> | Credential<GmailServiceAccountCredentialJson>
| undefined = credentialsData.find( | undefined = credentialsData.find(
(credential) => credential.credential_json?.gmail_service_account_key (credential) =>
credential.credential_json?.google_service_account_key &&
credential.source === "gmail"
); );
const gmailConnectorIndexingStatuses: ConnectorIndexingStatus< const gmailConnectorIndexingStatuses: ConnectorIndexingStatus<
GmailConfig, GmailConfig,

View File

@ -23,7 +23,7 @@ export const useGmailCredentials = (connector: string) => {
const gmailPublicCredential: Credential<GmailCredentialJson> | undefined = const gmailPublicCredential: Credential<GmailCredentialJson> | undefined =
credentialsData?.find( credentialsData?.find(
(credential) => (credential) =>
credential.credential_json?.google_service_account_key && credential.credential_json?.google_tokens &&
credential.admin_public && credential.admin_public &&
credential.source === connector credential.source === connector
); );
@ -31,7 +31,10 @@ export const useGmailCredentials = (connector: string) => {
const gmailServiceAccountCredential: const gmailServiceAccountCredential:
| Credential<GmailServiceAccountCredentialJson> | Credential<GmailServiceAccountCredentialJson>
| undefined = credentialsData?.find( | undefined = credentialsData?.find(
(credential) => credential.credential_json?.gmail_service_account_key (credential) =>
credential.credential_json?.google_service_account_key &&
credential.admin_public &&
credential.source === connector
); );
const liveGmailCredential = const liveGmailCredential =
@ -49,7 +52,7 @@ export const useGoogleDriveCredentials = (connector: string) => {
| Credential<GoogleDriveCredentialJson> | Credential<GoogleDriveCredentialJson>
| undefined = credentialsData?.find( | undefined = credentialsData?.find(
(credential) => (credential) =>
credential.credential_json?.google_service_account_key && credential.credential_json?.google_tokens &&
credential.admin_public && credential.admin_public &&
credential.source === connector credential.source === connector
); );
@ -57,7 +60,10 @@ export const useGoogleDriveCredentials = (connector: string) => {
const googleDriveServiceAccountCredential: const googleDriveServiceAccountCredential:
| Credential<GoogleDriveServiceAccountCredentialJson> | Credential<GoogleDriveServiceAccountCredentialJson>
| undefined = credentialsData?.find( | undefined = credentialsData?.find(
(credential) => credential.credential_json?.google_drive_service_account_key (credential) =>
credential.credential_json?.google_service_account_key &&
credential.admin_public &&
credential.source === connector
); );
const liveGDriveCredential = const liveGDriveCredential =

View File

@ -53,23 +53,23 @@ export interface SlackCredentialJson {
} }
export interface GmailCredentialJson { export interface GmailCredentialJson {
gmail_tokens: string; google_tokens: string;
gmail_primary_admin: string; google_primary_admin: string;
} }
export interface GoogleDriveCredentialJson { export interface GoogleDriveCredentialJson {
google_drive_tokens: string; google_tokens: string;
google_drive_primary_admin: string; google_primary_admin: string;
} }
export interface GmailServiceAccountCredentialJson { export interface GmailServiceAccountCredentialJson {
gmail_service_account_key: string; google_service_account_key: string;
gmail_primary_admin: string; google_primary_admin: string;
} }
export interface GoogleDriveServiceAccountCredentialJson { export interface GoogleDriveServiceAccountCredentialJson {
google_drive_service_account_key: string; google_service_account_key: string;
google_drive_primary_admin: string; google_primary_admin: string;
} }
export interface SlabCredentialJson { export interface SlabCredentialJson {
@ -301,8 +301,8 @@ export const credentialTemplates: Record<ValidSources, any> = {
ingestion_api: null, ingestion_api: null,
// NOTE: These are Special Cases // NOTE: These are Special Cases
google_drive: { google_drive_tokens: "" } as GoogleDriveCredentialJson, google_drive: { google_tokens: "" } as GoogleDriveCredentialJson,
gmail: { gmail_tokens: "" } as GmailCredentialJson, gmail: { google_tokens: "" } as GmailCredentialJson,
}; };
export const credentialDisplayNames: Record<string, string> = { export const credentialDisplayNames: Record<string, string> = {
@ -332,19 +332,10 @@ export const credentialDisplayNames: Record<string, string> = {
// Slack // Slack
slack_bot_token: "Slack Bot Token", slack_bot_token: "Slack Bot Token",
// Gmail // Gmail and Google Drive
gmail_tokens: "Gmail Tokens", google_tokens: "Google Oauth Tokens",
google_service_account_key: "Google Service Account Key",
// Google Drive google_primary_admin: "Primary Admin Email",
google_drive_tokens: "Google Drive Tokens",
// Gmail Service Account
gmail_service_account_key: "Gmail Service Account Key",
gmail_primary_admin: "Gmail Primary Admin",
// Google Drive Service Account
google_drive_service_account_key: "Google Drive Service Account Key",
google_drive_primary_admin: "Google Drive Delegated User",
// Slab // Slab
slab_bot_token: "Slab Bot Token", slab_bot_token: "Slab Bot Token",