Refactored Google Drive Connector + Permission Syncing (#2945)

* refactoring changes

* everything working for service account

* works with service account

* combined scopes

* copy change

* oauth prep

* Works for oauth and service account credentials

* mypy

* merge fixes

* Refactor Google Drive connector

* finished backend

* auth changes

* if its stupid but it works, its not stupid

* npm run dev fixes

* addressed change requests

* string fix

* minor fixes and cleanup

* spacing cleanup

* Update connector.py

* everything done

* testing!

* Delete backend/tests/daily/connectors/google_drive/file_generator.py

* cleaned up

---------

Co-authored-by: Chris Weaver <25087905+Weves@users.noreply.github.com>
This commit is contained in:
hagen-danswer
2024-10-31 19:25:00 -07:00
committed by GitHub
parent b34f5862d7
commit 71d4fb98d3
35 changed files with 2036 additions and 975 deletions

View File

@@ -18,6 +18,9 @@ env:
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
jobs:
connectors-check:

View File

@@ -251,9 +251,6 @@ ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
# for some connectors
ENABLE_EXPENSIVE_EXPERT_CALLS = False
GOOGLE_DRIVE_INCLUDE_SHARED = False
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
# TODO these should be available for frontend configuration, via advanced options expandable
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(

View File

@@ -17,6 +17,7 @@ from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
@@ -249,7 +250,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")

View File

@@ -1,556 +1,305 @@
import io
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from enum import Enum
from itertools import chain
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.connector_auth import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.connectors.google_drive.constants import SCOPE_DOC_URL
from danswer.connectors.google_drive.constants import SLIM_BATCH_SIZE
from danswer.connectors.google_drive.constants import USER_FIELDS
from danswer.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
)
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.batching import batch_generator
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
if not string:
return []
return [s.strip() for s in string.split(",") if s.strip()]
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
return [url.split("/")[-1] for url in urls]
GoogleDriveFileType = dict[str, Any]
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def _run_drive_file_query(
service: discovery.Resource,
query: str,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
next_page_token = ""
while next_page_token is not None:
logger.debug(f"Running Google Drive fetch with query: {query}")
results = add_retries(
lambda: (
service.files()
.list(
corpora="allDrives"
if include_shared
else "user", # needed to search through shared drives
pageSize=batch_size,
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
fields=(
"nextPageToken, files(mimeType, id, name, permissions, "
"modifiedTime, webViewLink, shortcutDetails)"
),
pageToken=next_page_token,
q=query,
)
.execute()
)
)()
next_page_token = results.get("nextPageToken")
files = results["files"]
for file in files:
if follow_shortcuts and "shortcutDetails" in file:
try:
file_shortcut_points_to = add_retries(
lambda: (
service.files()
.get(
fileId=file["shortcutDetails"]["targetId"],
supportsAllDrives=include_shared,
fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails",
)
.execute()
)
)()
yield file_shortcut_points_to
except HttpError:
logger.error(
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
)
if continue_on_failure:
continue
raise
else:
yield file
def _get_folder_id(
service: discovery.Resource,
parent_id: str,
folder_name: str,
include_shared: bool,
follow_shortcuts: bool,
) -> str | None:
"""
Get the ID of a folder given its name and the ID of its parent folder.
"""
query = f"'{parent_id}' in parents and name='{folder_name}' and "
if follow_shortcuts:
query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')"
else:
query += f"mimeType='{DRIVE_FOLDER_TYPE}'"
# TODO: support specifying folder path in shared drive rather than just `My Drive`
results = add_retries(
lambda: (
service.files()
.list(
q=query,
spaces="drive",
fields="nextPageToken, files(id, name, shortcutDetails)",
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
)
.execute()
)
)()
items = results.get("files", [])
folder_id = None
if items:
if follow_shortcuts and "shortcutDetails" in items[0]:
folder_id = items[0]["shortcutDetails"]["targetId"]
else:
folder_id = items[0]["id"]
return folder_id
def _get_folders(
service: discovery.Resource,
continue_on_failure: bool,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
if follow_shortcuts:
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
for file in _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
):
# Need to check this since file may have been a target of a shortcut
# and not necessarily a folder
if file["mimeType"] == DRIVE_FOLDER_TYPE:
yield file
else:
pass
def _get_files(
service: discovery.Resource,
continue_on_failure: bool,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
if time_range_start is not None:
time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
query += f"and modifiedTime >= '{time_start}' "
if time_range_end is not None:
time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
query += f"and modifiedTime <= '{time_stop}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
files = _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
return files
def get_all_files_batched(
service: discovery.Resource,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
# if True, will fetch files in sub-folders of the specified folder ID.
# Only applies if folder_id is specified.
traverse_subfolders: bool = True,
folder_ids_traversed: list[str] | None = None,
) -> Iterator[list[GoogleDriveFileType]]:
"""Gets all files matching the criteria specified by the args from Google Drive
in batches of size `batch_size`.
"""
found_files = _get_files(
service=service,
continue_on_failure=continue_on_failure,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=folder_id,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
yield from batch_generator(
items=found_files,
batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.debug(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
),
)
if traverse_subfolders and folder_id is not None:
folder_ids_traversed = folder_ids_traversed or []
subfolders = _get_folders(
service=service,
folder_id=folder_id,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
for subfolder in subfolders:
if subfolder["id"] not in folder_ids_traversed:
logger.info("Fetching all files in subfolder: " + subfolder["name"])
folder_ids_traversed.append(subfolder["id"])
yield from get_all_files_batched(
service=service,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=subfolder["id"],
traverse_subfolders=traverse_subfolders,
folder_ids_traversed=folder_ids_traversed,
)
else:
logger.debug(
"Skipping subfolder since already traversed: " + subfolder["name"]
)
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
)
if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT
class GoogleDriveConnector(LoadConnector, PollConnector):
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
# optional list of folder paths e.g. "[My Folder/My Subfolder]"
# if specified, will only index files in these folders
folder_paths: list[str] | None = None,
include_shared_drives: bool = True,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
# OLD PARAMETERS
folder_paths: list[str] | None = None,
include_shared: bool | None = None,
follow_shortcuts: bool | None = None,
only_org_public: bool | None = None,
continue_on_failure: bool | None = None,
) -> None:
self.folder_paths = folder_paths or []
# Check for old input parameters
if (
folder_paths is not None
or include_shared is not None
or follow_shortcuts is not None
or only_org_public is not None
or continue_on_failure is not None
):
logger.exception(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
raise ValueError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
if (
not include_shared_drives
and not include_my_drives
and not shared_folder_urls
):
raise ValueError(
"At least one of include_shared_drives, include_my_drives,"
" or shared_folder_urls must be true"
)
self.batch_size = batch_size
self.include_shared = include_shared
self.follow_shortcuts = follow_shortcuts
self.only_org_public = only_org_public
self.continue_on_failure = continue_on_failure
self.include_shared_drives = include_shared_drives
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self.shared_drive_ids = _extract_ids_from_urls(shared_drive_url_list)
self.include_my_drives = include_my_drives
self.my_drive_emails = _extract_str_list_from_comma_str(my_drive_emails)
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list)
self.primary_admin_email: str | None = None
self.google_domain: str | None = None
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
@staticmethod
def _process_folder_paths(
service: discovery.Resource,
folder_paths: list[str],
include_shared: bool,
follow_shortcuts: bool,
) -> list[str]:
"""['Folder/Sub Folder'] -> ['<FOLDER_ID>']"""
folder_ids: list[str] = []
for path in folder_paths:
folder_names = path.split("/")
parent_id = "root"
for folder_name in folder_names:
found_parent_id = _get_folder_id(
service=service,
parent_id=parent_id,
folder_name=folder_name,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
)
if found_parent_id is None:
raise ValueError(
(
f"Folder '{folder_name}' in path '{path}' "
"not found in Google Drive"
)
)
parent_id = found_parent_id
folder_ids.append(parent_id)
self._TRAVERSED_PARENT_IDS: set[str] = set()
return folder_ids
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._TRAVERSED_PARENT_IDS.add(folder_id)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds, new_creds_dict = get_google_drive_creds(credentials)
self.creds = creds
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self.google_domain = primary_admin_email.split("@")[1]
self.primary_admin_email = primary_admin_email
self.creds, new_creds_dict = get_google_drive_creds(credentials)
return new_creds_dict
def _fetch_docs_from_drive(
def get_google_resource(
self,
service_name: str = "drive",
service_version: str = "v3",
user_email: str | None = None,
) -> Resource:
if isinstance(self.creds, ServiceAccountCredentials):
creds = self.creds.with_subject(user_email or self.primary_admin_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(self.creds, OAuthCredentials):
service = build(service_name, service_version, credentials=self.creds)
else:
raise PermissionError("No credentials found")
return service
def _get_all_user_emails(self) -> list[str]:
admin_service = self.get_google_resource("admin", "directory_v1")
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
primary_drive_service = self.get_google_resource()
if self.include_shared_drives:
shared_drive_urls = self.shared_drive_ids
if not shared_drive_urls:
# if no parent ids are specified, get all shared drives using the admin account
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=True,
fields="drives(id)",
):
shared_drive_urls.append(drive["id"])
# For each shared drive, retrieve all files
for shared_drive_id in shared_drive_urls:
for file in get_files_in_shared_drive(
service=primary_drive_service,
drive_id=shared_drive_id,
is_slim=is_slim,
cache_folders=bool(self.shared_folder_ids),
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
):
yield file
if self.shared_folder_ids:
# Crawl all the shared parent ids for files
for folder_id in self.shared_folder_ids:
yield from crawl_folders_for_files(
service=primary_drive_service,
parent_id=folder_id,
personal_drive=False,
traversed_parent_ids=self._TRAVERSED_PARENT_IDS,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
all_user_emails = []
# get all personal docs from each users' personal drive
if self.include_my_drives:
if isinstance(self.creds, ServiceAccountCredentials):
all_user_emails = self.my_drive_emails or []
# If using service account and no emails specified, fetch all users
if not all_user_emails:
all_user_emails = self._get_all_user_emails()
elif self.primary_admin_email:
# If using OAuth, only fetch the primary admin email
all_user_emails = [self.primary_admin_email]
for email in all_user_emails:
logger.info(f"Fetching personal files for user: {email}")
user_drive_service = self.get_google_resource(user_email=email)
yield from get_files_in_my_drive(
service=user_drive_service,
email=email,
is_slim=is_slim,
start=start,
end=end,
)
def _extract_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
if self.creds is None:
raise PermissionError("Not logged into Google Drive")
service = discovery.build("drive", "v3", credentials=self.creds)
folder_ids: Sequence[str | None] = self._process_folder_paths(
service, self.folder_paths, self.include_shared, self.follow_shortcuts
)
if not folder_ids:
folder_ids = [None]
file_batches = chain(
*[
get_all_files_batched(
service=service,
continue_on_failure=self.continue_on_failure,
include_shared=self.include_shared,
follow_shortcuts=self.follow_shortcuts,
batch_size=self.batch_size,
time_range_start=start,
time_range_end=end,
folder_id=folder_id,
traverse_subfolders=True,
)
for folder_id in folder_ids
]
)
for files_batch in file_batches:
doc_batch = []
for file in files_batch:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
continue
if self.only_org_public:
if "permissions" not in file:
continue
if not any(
permission["type"] == "domain"
for permission in file["permissions"]
for file in self._fetch_drive_items(
is_slim=False,
start=start,
end=end,
):
continue
try:
text_contents = extract_text(file, service) or ""
except HttpError as e:
reason = (
e.error_details[0]["reason"]
if e.error_details
else e.reason
)
message = (
e.error_details[0]["message"]
if e.error_details
else e.reason
)
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
continue
raise
doc_batch.append(
Document(
id=file["webViewLink"],
sections=[
Section(link=file["webViewLink"], text=text_contents)
],
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(
file["modifiedTime"]
).astimezone(timezone.utc),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
)
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception(
"Ran into exception when pulling a file from Google Drive"
)
user_email = file.get("owners", [{}])[0].get("emailAddress")
service = self.get_google_resource(user_email=user_email)
if doc := convert_drive_item_to_document(
file=file,
service=service,
):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
yield from self._fetch_docs_from_drive()
try:
yield from self._extract_docs_from_google_drive()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
# need to subtract 10 minutes from start time to account for modifiedTime
# propogation if a document is modified, it takes some time for the API to
# reflect these changes if we do not have an offset, then we may "miss" the
# update when polling
yield from self._fetch_docs_from_drive(start, end)
try:
yield from self._extract_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
if __name__ == "__main__":
import json
import os
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
if not service_account_json_path:
raise ValueError(
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
def _extract_slim_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_batch = []
for file in self._fetch_drive_items(
is_slim=True,
start=start,
end=end,
):
slim_batch.append(
SlimDocument(
id=file["webViewLink"],
perm_sync_data={
"doc_id": file.get("id"),
"permissions": file.get("permissions", []),
"permission_ids": file.get("permissionIds", []),
"name": file.get("name"),
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
},
)
with open(service_account_json_path) as f:
creds = json.load(f)
)
if len(slim_batch) >= SLIM_BATCH_SIZE:
yield slim_batch
slim_batch = []
yield slim_batch
credentials_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds),
}
delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER")
if delegated_user:
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True)
connector.load_credentials(credentials_dict)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print(document_batch)
break
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._extract_slim_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e

View File

@@ -8,24 +8,16 @@ from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from googleapiclient.discovery import build # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import BASE_SCOPES
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
@@ -36,15 +28,14 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def build_gdrive_scopes() -> list[str]:
base_scopes: list[str] = BASE_SCOPES
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
if ENTERPRISE_EDITION_ENABLED:
return base_scopes + permissions_scopes + groups_scopes
return base_scopes + permissions_scopes
GOOGLE_DRIVE_SCOPES = [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
]
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_drive_primary_admin"
def _build_frontend_google_drive_redirect() -> str:
@@ -52,7 +43,7 @@ def _build_frontend_google_drive_redirect() -> str:
def get_google_drive_creds_for_authorized_user(
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
token_json_str: str, scopes: list[str]
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
@@ -72,21 +63,15 @@ def get_google_drive_creds_for_authorized_user(
return None
def _get_google_drive_creds_for_service_account(
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=scopes
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def get_google_drive_creds(
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
credentials: dict[str, str], scopes: list[str] = GOOGLE_DRIVE_SCOPES
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
@@ -100,26 +85,27 @@ def get_google_drive_creds(
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
],
}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_creds = _get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str,
scopes=scopes,
elif KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY]
service_account_key = json.loads(service_account_key_json_str)
service_creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=scopes
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
service_creds = (
service_creds.with_subject(delegated_user_email)
if service_creds
else None
if not service_creds.valid or not service_creds.expired:
service_creds.refresh(Request())
if not service_creds.valid:
raise PermissionError(
"Unable to access Google Drive - service account credentials are invalid."
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
@@ -146,7 +132,7 @@ def get_auth_url(credential_id: int) -> str:
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=build_gdrive_scopes(),
scopes=GOOGLE_DRIVE_SCOPES,
redirect_uri=_build_frontend_google_drive_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
@@ -169,13 +155,34 @@ def update_credential_access_tokens(
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=build_gdrive_scopes(),
scopes=GOOGLE_DRIVE_SCOPES,
redirect_uri=_build_frontend_google_drive_redirect(),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
# Get user email from Google API so we know who
# the primary admin is for this connector
try:
admin_service = build("drive", "v3", credentials=creds)
user_info = (
admin_service.about()
.get(
fields="user(emailAddress)",
)
.execute()
)
email = user_info.get("user", {}).get("emailAddress")
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
@@ -184,15 +191,15 @@ def update_credential_access_tokens(
def build_service_account_creds(
source: DocumentSource,
delegated_user_email: str | None = None,
primary_admin_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key()
credential_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if delegated_user_email:
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
if primary_admin_email:
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
return CredentialBase(
credential_json=credential_dict,

View File

@@ -1,7 +1,36 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]
FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
"shortcutDetails, owners(emailAddress))"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
"permissionIds, webViewLink, owners(emailAddress))"
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
# Documentation and error messages
SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview"
ONYX_SCOPE_INSTRUCTIONS = (
"You have upgraded Danswer without updating the Google Drive scopes. "
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
)
# Batch sizes
SLIM_BATCH_SIZE = 500

View File

@@ -0,0 +1,115 @@
import io
from datetime import datetime
from datetime import timezone
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.constants import ERRORS_TO_CONTINUE_ON
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from danswer.connectors.google_drive.models import GDriveMimeType
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _extract_text(file: dict[str, str], service: Resource) -> str:
mime_type = file["mimeType"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
)
if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT
def convert_drive_item_to_document(
file: GoogleDriveFileType, service: Resource
) -> Document | None:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
return None
try:
text_contents = _extract_text(file, service) or ""
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None
raise
return Document(
id=file["webViewLink"],
sections=[Section(link=file["webViewLink"], text=text_contents)],
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
except Exception as e:
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise e
logger.exception("Ran into exception when pulling a file from Google Drive")
return None

View File

@@ -0,0 +1,192 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from googleapiclient.discovery import Resource # type: ignore
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.constants import FILE_FIELDS
from danswer.connectors.google_drive.constants import FOLDER_FIELDS
from danswer.connectors.google_drive.constants import SLIM_FILE_FIELDS
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _generate_time_range_filter(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
time_range_filter += f" and modifiedTime >= '{time_start}'"
if end is not None:
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
time_range_filter += f" and modifiedTime <= '{time_stop}'"
return time_range_filter
def _get_folders_in_parent(
service: Resource,
parent_id: str | None = None,
personal_drive: bool = False,
) -> Iterator[GoogleDriveFileType]:
# Follow shortcuts to folders
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
query += " and trashed = false"
if parent_id:
query += f" and '{parent_id}' in parents"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user" if personal_drive else "allDrives",
supportsAllDrives=not personal_drive,
includeItemsFromAllDrives=not personal_drive,
fields=FOLDER_FIELDS,
q=query,
):
yield file
def _get_files_in_parent(
service: Resource,
parent_id: str,
personal_drive: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
is_slim: bool = False,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user" if personal_drive else "allDrives",
supportsAllDrives=not personal_drive,
includeItemsFromAllDrives=not personal_drive,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
yield file
def crawl_folders_for_files(
service: Resource,
parent_id: str,
personal_drive: bool,
traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
"""
This function starts crawling from any folder. It is slower though.
"""
if parent_id in traversed_parent_ids:
print(f"Skipping subfolder since already traversed: {parent_id}")
return
update_traversed_ids_func(parent_id)
yield from _get_files_in_parent(
service=service,
personal_drive=personal_drive,
start=start,
end=end,
parent_id=parent_id,
)
for subfolder in _get_folders_in_parent(
service=service,
parent_id=parent_id,
personal_drive=personal_drive,
):
logger.info("Fetching all files in subfolder: " + subfolder["name"])
yield from crawl_folders_for_files(
service=service,
parent_id=subfolder["id"],
personal_drive=personal_drive,
traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func,
start=start,
end=end,
)
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
is_slim: bool = False,
cache_folders: bool = True,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
if cache_folders:
# Get all folders being queried and add them to the traversed set
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=query,
):
update_traversed_ids_func(file["id"])
# Get all files in the shared drive
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
yield file
def get_files_in_my_drive(
service: Resource,
email: str,
is_slim: bool = False,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{email}' in owners"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
yield file
# Just in case we need to get the root folder id
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields="id").execute()["id"]

View File

@@ -0,0 +1,35 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.utils.retry_wrapper import retry_builder
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType]:
"""Execute a paginated retrieval from Google Drive API
Args:
retrieval_function: The specific list function to call (e.g., service.files().list)
**kwargs: Arguments to pass to the list function
"""
next_page_token = ""
while next_page_token is not None:
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs["pageToken"] = next_page_token
results = add_retries(lambda: retrieval_function(**request_kwargs).execute())()
next_page_token = results.get("nextPageToken")
for item in results.get(list_key, []):
yield item

View File

@@ -0,0 +1,18 @@
from enum import Enum
from typing import Any
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
GoogleDriveFileType = dict[str, Any]

View File

@@ -56,7 +56,11 @@ class PollConnector(BaseConnector):
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError

View File

@@ -251,7 +251,11 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_metadata_list: list[SlimDocument] = []

View File

@@ -391,7 +391,11 @@ class SlackPollConnector(PollConnector, SlimConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

View File

@@ -10,12 +10,10 @@ from sqlalchemy.sql.expression import or_
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import Credential__UserGroup
@@ -442,7 +440,7 @@ def delete_google_drive_service_account_credentials(
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials:
if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY):
if credential.credential_json.get(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY):
db_session.delete(credential)
db_session.commit()

View File

@@ -9,6 +9,7 @@ from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
from google.oauth2.credentials import Credentials # type: ignore
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -35,6 +36,7 @@ from danswer.connectors.gmail.connector_auth import (
)
from danswer.connectors.gmail.connector_auth import upsert_google_app_gmail_cred
from danswer.connectors.google_drive.connector_auth import build_service_account_creds
from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.connector_auth import delete_google_app_cred
from danswer.connectors.google_drive.connector_auth import delete_service_account_key
from danswer.connectors.google_drive.connector_auth import get_auth_url
@@ -43,13 +45,13 @@ from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds_for_authorized_user,
)
from danswer.connectors.google_drive.connector_auth import get_service_account_key
from danswer.connectors.google_drive.connector_auth import GOOGLE_DRIVE_SCOPES
from danswer.connectors.google_drive.connector_auth import (
update_credential_access_tokens,
)
from danswer.connectors.google_drive.connector_auth import upsert_google_app_cred
from danswer.connectors.google_drive.connector_auth import upsert_service_account_key
from danswer.connectors.google_drive.connector_auth import verify_csrf
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.db.connector import create_connector
from danswer.db.connector import delete_connector
from danswer.db.connector import fetch_connector_by_id
@@ -294,7 +296,7 @@ def upsert_service_account_credential(
try:
credential_base = build_service_account_creds(
DocumentSource.GOOGLE_DRIVE,
delegated_user_email=service_account_credential_request.google_drive_delegated_user,
primary_admin_email=service_account_credential_request.google_drive_primary_admin,
)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -320,7 +322,7 @@ def upsert_gmail_service_account_credential(
try:
credential_base = build_service_account_creds(
DocumentSource.GMAIL,
delegated_user_email=service_account_credential_request.gmail_delegated_user,
primary_admin_email=service_account_credential_request.gmail_delegated_user,
)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -348,27 +350,14 @@ def check_drive_tokens(
return AuthStatus(authenticated=False)
token_json_str = str(db_credentials.credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
google_drive_creds = get_google_drive_creds_for_authorized_user(
token_json_str=token_json_str
token_json_str=token_json_str,
scopes=GOOGLE_DRIVE_SCOPES,
)
if google_drive_creds is None:
return AuthStatus(authenticated=False)
return AuthStatus(authenticated=True)
@router.get("/admin/connector/google-drive/authorize/{credential_id}")
def admin_google_drive_auth(
response: Response, credential_id: str, _: User = Depends(current_admin_user)
) -> AuthUrl:
# set a cookie that we can read in the callback (used for `verify_csrf`)
response.set_cookie(
key=_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME,
value=credential_id,
httponly=True,
max_age=600,
)
return AuthUrl(auth_url=get_auth_url(credential_id=int(credential_id)))
@router.post("/admin/connector/file/upload")
def upload_files(
files: list[UploadFile],
@@ -951,10 +940,11 @@ def google_drive_callback(
)
credential_id = int(credential_id_cookie)
verify_csrf(credential_id, callback.state)
if (
update_credential_access_tokens(callback.code, credential_id, user, db_session)
is None
):
credentials: Credentials | None = update_credential_access_tokens(
callback.code, credential_id, user, db_session
)
if credentials is None:
raise HTTPException(
status_code=500, detail="Unable to fetch Google Drive access tokens"
)

View File

@@ -377,16 +377,16 @@ class GoogleServiceAccountKey(BaseModel):
class GoogleServiceAccountCredentialRequest(BaseModel):
google_drive_delegated_user: str | None = None # email of user to impersonate
google_drive_primary_admin: str | None = None # email of user to impersonate
gmail_delegated_user: str | None = None # email of user to impersonate
@model_validator(mode="after")
def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest":
if (self.google_drive_delegated_user is None) == (
if (self.google_drive_primary_admin is None) == (
self.gmail_delegated_user is None
):
raise ValueError(
"Exactly one of google_drive_delegated_user or gmail_delegated_user must be set"
"Exactly one of google_drive_primary_admin or gmail_delegated_user must be set"
)
return self

View File

@@ -13,12 +13,12 @@ tasks_to_schedule = [
{
"name": "sync-external-doc-permissions",
"task": "check_sync_external_doc_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
"schedule": timedelta(seconds=30), # TODO: optimize this
},
{
"name": "sync-external-group-permissions",
"task": "check_sync_external_group_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
"schedule": timedelta(seconds=60), # TODO: optimize this
},
{
"name": "autogenerate_usage_report",

View File

@@ -1,144 +1,119 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds,
)
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import InputType
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.models import SlimDocument
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
from ee.danswer.db.document import upsert_document_external_perms__no_commit
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=5, delay=5, max_delay=30)
logger = setup_logger()
_PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
def _get_docs_with_additional_info(
db_session: Session,
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
) -> dict[str, Any]:
# Get all document ids that need their permissions updated
runnable_connector = instantiate_connector(
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.POLL,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
assert isinstance(runnable_connector, PollConnector)
google_drive_connector: GoogleDriveConnector,
) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc)
start_time = (
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
if cc_pair.last_time_perm_sync
else 0.0
)
cc_pair.last_time_perm_sync = current_time
doc_batch_generator = runnable_connector.poll_source(
return google_drive_connector.retrieve_all_slim_documents(
start=start_time, end=current_time.timestamp()
)
docs_with_additional_info = {
doc.id: doc.additional_info
for doc_batch in doc_batch_generator
for doc in doc_batch
}
return docs_with_additional_info
def _fetch_permissions_for_permission_ids(
google_drive_connector: GoogleDriveConnector,
permission_ids: list[str],
permission_info: dict[str, Any],
) -> list[dict[str, Any]]:
doc_id = permission_info.get("doc_id")
if not permission_info or not doc_id:
return []
# Check cache first for all permission IDs
permissions = [
_PERMISSION_ID_PERMISSION_MAP[pid]
for pid in permission_ids
if pid in _PERMISSION_ID_PERMISSION_MAP
]
def _fetch_permissions_paginated(
drive_service: Any, drive_file_id: str
) -> Iterator[dict[str, Any]]:
next_token = None
# If we found all permissions in cache, return them
if len(permissions) == len(permission_ids):
return permissions
# Get paginated permissions for the file id
while True:
try:
permissions_resp: dict[str, Any] = add_retries(
lambda: (
drive_service.permissions()
.list(
fileId=drive_file_id,
fields="permissions(emailAddress, type, domain)",
owner_email = permission_info.get("owner_email")
drive_service = google_drive_connector.get_google_resource(user_email=owner_email)
# Otherwise, fetch all permissions and update cache
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain)",
supportsAllDrives=True,
pageToken=next_token,
)
.execute()
)
)()
except HttpError as e:
if e.resp.status == 404:
logger.warning(f"Document with id {drive_file_id} not found: {e}")
break
elif e.resp.status == 403:
logger.warning(
f"Access denied for retrieving document permissions: {e}"
)
break
else:
logger.error(f"Failed to fetch permissions: {e}")
raise
for permission in permissions_resp.get("permissions", []):
yield permission
permissions_for_doc_id = []
# Update cache and return all permissions
for permission in fetched_permissions:
permissions_for_doc_id.append(permission)
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
next_token = permissions_resp.get("nextPageToken")
if not next_token:
break
return permissions_for_doc_id
def _fetch_google_permissions_for_document_id(
db_session: Session,
drive_file_id: str,
credentials_json: dict[str, str],
company_google_domains: list[str],
def _get_permissions_from_slim_doc(
google_drive_connector: GoogleDriveConnector,
slim_doc: SlimDocument,
) -> ExternalAccess:
# Authenticate and construct service
google_drive_creds, _ = get_google_drive_creds(
credentials_json,
permission_info = slim_doc.perm_sync_data or {}
permissions_list = permission_info.get("permissions", [])
if not permissions_list:
if permission_ids := permission_info.get("permission_ids"):
permissions_list = _fetch_permissions_for_permission_ids(
google_drive_connector=google_drive_connector,
permission_ids=permission_ids,
permission_info=permission_info,
)
if not permissions_list:
logger.warning(f"No permissions found for document {slim_doc.id}")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
)
if not google_drive_creds.valid:
raise ValueError("Invalid Google Drive credentials")
drive_service = build("drive", "v3", credentials=google_drive_creds)
company_domain = google_drive_connector.google_domain
user_emails: set[str] = set()
group_emails: set[str] = set()
public = False
for permission in _fetch_permissions_paginated(drive_service, drive_file_id):
for permission in permissions_list:
permission_type = permission["type"]
if permission_type == "user":
user_emails.add(permission["emailAddress"])
elif permission_type == "group":
group_emails.add(permission["emailAddress"])
elif permission_type == "domain":
if permission["domain"] in company_google_domains:
elif permission_type == "domain" and company_domain:
if permission["domain"] == company_domain:
public = True
elif permission_type == "anyone":
public = True
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_emails,
@@ -156,32 +131,26 @@ def gdrive_doc_sync(
it in postgres so that when it gets created later, the permissions are
already populated
"""
sync_details = cc_pair.auto_sync_options
if sync_details is None:
logger.error("Sync details not found for Google Drive")
raise ValueError("Sync details not found for Google Drive")
# Here we run the connector to grab all the ids
# this may grab ids before they are indexed but that is fine because
# we create a document in postgres to hold the permissions info
# until the indexing job has a chance to run
docs_with_additional_info = _get_docs_with_additional_info(
db_session=db_session,
cc_pair=cc_pair,
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
for doc_id, doc_additional_info in docs_with_additional_info.items():
ext_access = _fetch_google_permissions_for_document_id(
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,
)
batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session,
drive_file_id=doc_additional_info,
credentials_json=cc_pair.credential.credential_json,
company_google_domains=[
cast(dict[str, str], sync_details)["company_domain"]
],
emails=list(ext_access.external_user_emails),
)
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc_id,
doc_id=slim_doc.id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)

View File

@@ -1,136 +1,48 @@
from collections.abc import Iterator
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from sqlalchemy.orm import Session
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds,
)
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
logger = setup_logger()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=5, delay=5, max_delay=30)
def _fetch_groups_paginated(
google_drive_creds: ServiceAccountCredentials | OAuthCredentials,
identity_source: str | None = None,
customer_id: str | None = None,
) -> Iterator[dict[str, Any]]:
# Note that Google Drive does not use of update the user_cache as the user email
# comes directly with the call to fetch the groups, therefore this is not a valid
# place to save on requests
if identity_source is None and customer_id is None:
raise ValueError(
"Either identity_source or customer_id must be provided to fetch groups"
)
cloud_identity_service = build(
"cloudidentity", "v1", credentials=google_drive_creds
)
parent = (
f"identitysources/{identity_source}"
if identity_source
else f"customers/{customer_id}"
)
while True:
try:
groups_resp: dict[str, Any] = add_retries(
lambda: (cloud_identity_service.groups().list(parent=parent).execute())
)()
for group in groups_resp.get("groups", []):
yield group
next_token = groups_resp.get("nextPageToken")
if not next_token:
break
except HttpError as e:
if e.resp.status == 404 or e.resp.status == 403:
break
logger.error(f"Error fetching groups: {e}")
raise
def _fetch_group_members_paginated(
google_drive_creds: ServiceAccountCredentials | OAuthCredentials,
group_name: str,
) -> Iterator[dict[str, Any]]:
cloud_identity_service = build(
"cloudidentity", "v1", credentials=google_drive_creds
)
next_token = None
while True:
try:
membership_info = add_retries(
lambda: (
cloud_identity_service.groups()
.memberships()
.searchTransitiveMemberships(
parent=group_name, pageToken=next_token
)
.execute()
)
)()
for member in membership_info.get("memberships", []):
yield member
next_token = membership_info.get("nextPageToken")
if not next_token:
break
except HttpError as e:
if e.resp.status == 404 or e.resp.status == 403:
break
logger.error(f"Error fetching group members: {e}")
raise
def gdrive_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> None:
sync_details = cc_pair.auto_sync_options
if sync_details is None:
logger.error("Sync details not found for Google Drive")
raise ValueError("Sync details not found for Google Drive")
google_drive_creds, _ = get_google_drive_creds(
cc_pair.credential.credential_json,
scopes=FETCH_GROUPS_SCOPES,
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
admin_service = google_drive_connector.get_google_resource("admin", "directory_v1")
danswer_groups: list[ExternalUserGroup] = []
for group in _fetch_groups_paginated(
google_drive_creds,
identity_source=sync_details.get("identity_source"),
customer_id=sync_details.get("customer_id"),
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_drive_connector.google_domain,
fields="groups(email)",
):
# The id is the group email
group_email = group["groupKey"]["id"]
group_email = group["email"]
# Gather group member emails
group_member_emails: list[str] = []
for member in _fetch_group_members_paginated(google_drive_creds, group["name"]):
member_keys = member["preferredMemberKey"]
member_emails = [member_key["id"] for member_key in member_keys]
for member_email in member_emails:
group_member_emails.append(member_email)
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.append(member["email"])
# Add group members to DB and get their IDs
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=group_member_emails
)

View File

@@ -59,6 +59,7 @@ def run_external_doc_permission_sync(
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
last_time_perm_sync = cc_pair.last_time_perm_sync
if doc_sync_func is None:
raise ValueError(
@@ -110,4 +111,5 @@ def run_external_doc_permission_sync(
logger.info(f"Successfully synced docs for {source_type}")
except Exception:
logger.exception("Error Syncing Document Permissions")
cc_pair.last_time_perm_sync = last_time_perm_sync
db_session.rollback()

View File

@@ -0,0 +1,98 @@
import json
import os
from collections.abc import Callable
import pytest
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.connector_auth import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
# Load environment variables at the module level
load_env_vars()
@pytest.fixture
def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]:
def _connector_factory(
primary_admin_email: str = "admin@onyx-test.com",
include_shared_drives: bool = True,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
) -> GoogleDriveConnector:
connector = GoogleDriveConnector(
include_shared_drives=include_shared_drives,
shared_drive_urls=shared_drive_urls,
include_my_drives=include_my_drives,
my_drive_emails=my_drive_emails,
shared_folder_urls=shared_folder_urls,
)
json_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
refried_json_string = json.loads(json_string)
credentials_json = {
DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email,
}
connector.load_credentials(credentials_json)
return connector
return _connector_factory
@pytest.fixture
def google_drive_service_acct_connector_factory() -> (
Callable[..., GoogleDriveConnector]
):
def _connector_factory(
primary_admin_email: str = "admin@onyx-test.com",
include_shared_drives: bool = True,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
) -> GoogleDriveConnector:
print("Creating GoogleDriveConnector with service account credentials")
connector = GoogleDriveConnector(
include_shared_drives=include_shared_drives,
shared_drive_urls=shared_drive_urls,
include_my_drives=include_my_drives,
my_drive_emails=my_drive_emails,
shared_folder_urls=shared_folder_urls,
)
json_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
refried_json_string = json.loads(json_string)
# Load Service Account Credentials
connector.load_credentials(
{
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: refried_json_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email,
}
)
return connector
return _connector_factory

View File

@@ -0,0 +1,164 @@
from collections.abc import Sequence
from danswer.connectors.models import Document
ALL_FILES = list(range(0, 60))
SHARED_DRIVE_FILES = list(range(20, 25))
_ADMIN_FILE_IDS = list(range(0, 5))
_TEST_USER_1_FILE_IDS = list(range(5, 10))
_TEST_USER_2_FILE_IDS = list(range(10, 15))
_TEST_USER_3_FILE_IDS = list(range(15, 20))
_SHARED_DRIVE_1_FILE_IDS = list(range(20, 25))
_FOLDER_1_FILE_IDS = list(range(25, 30))
_FOLDER_1_1_FILE_IDS = list(range(30, 35))
_FOLDER_1_2_FILE_IDS = list(range(35, 40))
_SHARED_DRIVE_2_FILE_IDS = list(range(40, 45))
_FOLDER_2_FILE_IDS = list(range(45, 50))
_FOLDER_2_1_FILE_IDS = list(range(50, 55))
_FOLDER_2_2_FILE_IDS = list(range(55, 60))
_PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_IDS
_PUBLIC_FILE_IDS = list(range(55, 57))
PUBLIC_RANGE = _PUBLIC_FOLDER_RANGE + _PUBLIC_FILE_IDS
_SHARED_DRIVE_1_URL = "https://drive.google.com/drive/folders/0AC_OJ4BkMd4kUk9PVA"
# Group 1 is given access to this folder
_FOLDER_1_URL = (
"https://drive.google.com/drive/folders/1d3I7U3vUZMDziF1OQqYRkB8Jp2s_GWUn"
)
_FOLDER_1_1_URL = (
"https://drive.google.com/drive/folders/1aR33-zwzl_mnRAwH55GgtWTE-4A4yWWI"
)
_FOLDER_1_2_URL = (
"https://drive.google.com/drive/folders/1IO0X55VhvLXf4mdxzHxuKf4wxrDBB6jq"
)
_SHARED_DRIVE_2_URL = "https://drive.google.com/drive/folders/0ABKspIh7P4f4Uk9PVA"
_FOLDER_2_URL = (
"https://drive.google.com/drive/folders/1lNpCJ1teu8Se0louwL0oOHK9nEalskof"
)
_FOLDER_2_1_URL = (
"https://drive.google.com/drive/folders/1XeDOMWwxTDiVr9Ig2gKum3Zq_Wivv6zY"
)
_FOLDER_2_2_URL = (
"https://drive.google.com/drive/folders/1RKlsexA8h7NHvBAWRbU27MJotic7KXe3"
)
_ADMIN_EMAIL = "admin@onyx-test.com"
_TEST_USER_1_EMAIL = "test_user_1@onyx-test.com"
_TEST_USER_2_EMAIL = "test_user_2@onyx-test.com"
_TEST_USER_3_EMAIL = "test_user_3@onyx-test.com"
# Dictionary for ranges
DRIVE_ID_MAPPING: dict[str, list[int]] = {
"ADMIN": _ADMIN_FILE_IDS,
"TEST_USER_1": _TEST_USER_1_FILE_IDS,
"TEST_USER_2": _TEST_USER_2_FILE_IDS,
"TEST_USER_3": _TEST_USER_3_FILE_IDS,
"SHARED_DRIVE_1": _SHARED_DRIVE_1_FILE_IDS,
"FOLDER_1": _FOLDER_1_FILE_IDS,
"FOLDER_1_1": _FOLDER_1_1_FILE_IDS,
"FOLDER_1_2": _FOLDER_1_2_FILE_IDS,
"SHARED_DRIVE_2": _SHARED_DRIVE_2_FILE_IDS,
"FOLDER_2": _FOLDER_2_FILE_IDS,
"FOLDER_2_1": _FOLDER_2_1_FILE_IDS,
"FOLDER_2_2": _FOLDER_2_2_FILE_IDS,
}
# Dictionary for emails
EMAIL_MAPPING: dict[str, str] = {
"ADMIN": _ADMIN_EMAIL,
"TEST_USER_1": _TEST_USER_1_EMAIL,
"TEST_USER_2": _TEST_USER_2_EMAIL,
"TEST_USER_3": _TEST_USER_3_EMAIL,
}
# Dictionary for URLs
URL_MAPPING: dict[str, str] = {
"SHARED_DRIVE_1": _SHARED_DRIVE_1_URL,
"FOLDER_1": _FOLDER_1_URL,
"FOLDER_1_1": _FOLDER_1_1_URL,
"FOLDER_1_2": _FOLDER_1_2_URL,
"SHARED_DRIVE_2": _SHARED_DRIVE_2_URL,
"FOLDER_2": _FOLDER_2_URL,
"FOLDER_2_1": _FOLDER_2_1_URL,
"FOLDER_2_2": _FOLDER_2_2_URL,
}
# Dictionary for access permissions
# All users have access to their own My Drive as well as public files
ACCESS_MAPPING: dict[str, list[int]] = {
# Admin has access to everything in shared
"ADMIN": (
_ADMIN_FILE_IDS
+ _SHARED_DRIVE_1_FILE_IDS
+ _FOLDER_1_FILE_IDS
+ _FOLDER_1_1_FILE_IDS
+ _FOLDER_1_2_FILE_IDS
+ _SHARED_DRIVE_2_FILE_IDS
+ _FOLDER_2_FILE_IDS
+ _FOLDER_2_1_FILE_IDS
+ _FOLDER_2_2_FILE_IDS
),
# This user has access to drive 1
# This user has redundant access to folder 1 because of group access
# This user has been given individual access to files in Admin's My Drive
"TEST_USER_1": (
_TEST_USER_1_FILE_IDS
+ _SHARED_DRIVE_1_FILE_IDS
+ _FOLDER_1_FILE_IDS
+ _FOLDER_1_1_FILE_IDS
+ _FOLDER_1_2_FILE_IDS
+ list(range(0, 2))
),
# Group 1 includes this user, giving access to folder 1
# This user has also been given access to folder 2-1
# This user has also been given individual access to files in folder 2
"TEST_USER_2": (
_TEST_USER_2_FILE_IDS
+ _FOLDER_1_FILE_IDS
+ _FOLDER_1_1_FILE_IDS
+ _FOLDER_1_2_FILE_IDS
+ _FOLDER_2_1_FILE_IDS
+ list(range(45, 47))
),
# This user can only see his own files and public files
"TEST_USER_3": _TEST_USER_3_FILE_IDS,
}
file_name_template = "file_{}.txt"
file_text_template = "This is file {}"
def print_discrepencies(expected: set[str], retrieved: set[str]) -> None:
if expected != retrieved:
print(expected)
print(retrieved)
print("Extra:")
print(retrieved - expected)
print("Missing:")
print(expected - retrieved)
def assert_retrieved_docs_match_expected(
retrieved_docs: list[Document], expected_file_ids: Sequence[int]
) -> None:
expected_file_names = {
file_name_template.format(file_id) for file_id in expected_file_ids
}
expected_file_texts = {
file_text_template.format(file_id) for file_id in expected_file_ids
}
retrieved_file_names = set([doc.semantic_identifier for doc in retrieved_docs])
retrieved_texts = set([doc.sections[0].text for doc in retrieved_docs])
# Check file names
print_discrepencies(expected_file_names, retrieved_file_names)
assert expected_file_names == retrieved_file_names
# Check file texts
print_discrepencies(expected_file_texts, retrieved_texts)
assert expected_file_texts == retrieved_texts

View File

@@ -0,0 +1,246 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.models import Document
from tests.daily.connectors.google_drive.helpers import (
assert_retrieved_docs_match_expected,
)
from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING
from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING
from tests.daily.connectors.google_drive.helpers import URL_MAPPING
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_include_all(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_all")
connector = google_drive_oauth_connector_factory(
include_shared_drives=True,
include_my_drives=True,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# Should get everything in shared and admin's My Drive with oauth
expected_file_ids = (
DRIVE_ID_MAPPING["ADMIN"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_1"]
+ DRIVE_ID_MAPPING["FOLDER_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_2"]
+ DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_include_shared_drives_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_shared_drives_only")
connector = google_drive_oauth_connector_factory(
include_shared_drives=True,
include_my_drives=False,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# Should only get shared drives
expected_file_ids = (
DRIVE_ID_MAPPING["SHARED_DRIVE_1"]
+ DRIVE_ID_MAPPING["FOLDER_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_2"]
+ DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_include_my_drives_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_my_drives_only")
connector = google_drive_oauth_connector_factory(
include_shared_drives=False,
include_my_drives=True,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# Should only get everyone's My Drives
expected_file_ids = DRIVE_ID_MAPPING["ADMIN"]
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_drive_one_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_drive_one_only")
drive_urls = [
URL_MAPPING["SHARED_DRIVE_1"],
]
connector = google_drive_oauth_connector_factory(
include_shared_drives=True,
include_my_drives=False,
shared_drive_urls=",".join([str(url) for url in drive_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# We ignore shared_drive_urls if include_shared_drives is False
expected_file_ids = (
DRIVE_ID_MAPPING["SHARED_DRIVE_1"]
+ DRIVE_ID_MAPPING["FOLDER_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_folder_and_shared_drive(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_folder_and_shared_drive")
drive_urls = [URL_MAPPING["SHARED_DRIVE_1"]]
folder_urls = [URL_MAPPING["FOLDER_2"]]
connector = google_drive_oauth_connector_factory(
include_shared_drives=True,
include_my_drives=True,
shared_drive_urls=",".join([str(url) for url in drive_urls]),
shared_folder_urls=",".join([str(url) for url in folder_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# Should
expected_file_ids = (
DRIVE_ID_MAPPING["ADMIN"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_1"]
+ DRIVE_ID_MAPPING["FOLDER_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
+ DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_folders_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_folders_only")
folder_urls = [
URL_MAPPING["FOLDER_1_1"],
URL_MAPPING["FOLDER_1_2"],
URL_MAPPING["FOLDER_2_1"],
URL_MAPPING["FOLDER_2_2"],
]
connector = google_drive_oauth_connector_factory(
include_shared_drives=False,
include_my_drives=False,
shared_folder_urls=",".join([str(url) for url in folder_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
expected_file_ids = (
DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_specific_emails(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_specific_emails")
my_drive_emails = [
EMAIL_MAPPING["TEST_USER_1"],
EMAIL_MAPPING["TEST_USER_3"],
]
connector = google_drive_oauth_connector_factory(
include_shared_drives=False,
include_my_drives=True,
my_drive_emails=",".join([str(email) for email in my_drive_emails]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# No matter who is specified, when using oauth, if include_my_drives is True,
# we will get all the files from the admin's My Drive
expected_file_ids = DRIVE_ID_MAPPING["ADMIN"]
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)

View File

@@ -0,0 +1,257 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.models import Document
from tests.daily.connectors.google_drive.helpers import (
assert_retrieved_docs_match_expected,
)
from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING
from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING
from tests.daily.connectors.google_drive.helpers import URL_MAPPING
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_include_all(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_all")
connector = google_drive_service_acct_connector_factory(
include_shared_drives=True,
include_my_drives=True,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# Should get everything
expected_file_ids = (
DRIVE_ID_MAPPING["ADMIN"]
+ DRIVE_ID_MAPPING["TEST_USER_1"]
+ DRIVE_ID_MAPPING["TEST_USER_2"]
+ DRIVE_ID_MAPPING["TEST_USER_3"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_1"]
+ DRIVE_ID_MAPPING["FOLDER_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_2"]
+ DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_include_shared_drives_only(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_shared_drives_only")
connector = google_drive_service_acct_connector_factory(
include_shared_drives=True,
include_my_drives=False,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# Should only get shared drives
expected_file_ids = (
DRIVE_ID_MAPPING["SHARED_DRIVE_1"]
+ DRIVE_ID_MAPPING["FOLDER_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_2"]
+ DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_include_my_drives_only(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_my_drives_only")
connector = google_drive_service_acct_connector_factory(
include_shared_drives=False,
include_my_drives=True,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# Should only get everyone's My Drives
expected_file_ids = (
DRIVE_ID_MAPPING["ADMIN"]
+ DRIVE_ID_MAPPING["TEST_USER_1"]
+ DRIVE_ID_MAPPING["TEST_USER_2"]
+ DRIVE_ID_MAPPING["TEST_USER_3"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_drive_one_only(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_drive_one_only")
urls = [URL_MAPPING["SHARED_DRIVE_1"]]
connector = google_drive_service_acct_connector_factory(
include_shared_drives=True,
include_my_drives=False,
shared_drive_urls=",".join([str(url) for url in urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# We ignore shared_drive_urls if include_shared_drives is False
expected_file_ids = (
DRIVE_ID_MAPPING["SHARED_DRIVE_1"]
+ DRIVE_ID_MAPPING["FOLDER_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_folder_and_shared_drive(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_folder_and_shared_drive")
drive_urls = [
URL_MAPPING["SHARED_DRIVE_1"],
]
folder_urls = [URL_MAPPING["FOLDER_2"]]
connector = google_drive_service_acct_connector_factory(
include_shared_drives=True,
include_my_drives=True,
shared_drive_urls=",".join([str(url) for url in drive_urls]),
shared_folder_urls=",".join([str(url) for url in folder_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# Should
expected_file_ids = (
DRIVE_ID_MAPPING["ADMIN"]
+ DRIVE_ID_MAPPING["TEST_USER_1"]
+ DRIVE_ID_MAPPING["TEST_USER_2"]
+ DRIVE_ID_MAPPING["TEST_USER_3"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_1"]
+ DRIVE_ID_MAPPING["FOLDER_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
+ DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_folders_only(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_folders_only")
folder_urls = [
URL_MAPPING["FOLDER_1_1"],
URL_MAPPING["FOLDER_1_2"],
URL_MAPPING["FOLDER_2_1"],
URL_MAPPING["FOLDER_2_2"],
]
connector = google_drive_service_acct_connector_factory(
include_shared_drives=False,
include_my_drives=False,
shared_folder_urls=",".join([str(url) for url in folder_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
expected_file_ids = (
DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_specific_emails(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_specific_emails")
my_drive_emails = [
EMAIL_MAPPING["TEST_USER_1"],
EMAIL_MAPPING["TEST_USER_3"],
]
connector = google_drive_service_acct_connector_factory(
include_shared_drives=False,
include_my_drives=True,
my_drive_emails=",".join([str(email) for email in my_drive_emails]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
expected_file_ids = (
DRIVE_ID_MAPPING["TEST_USER_1"] + DRIVE_ID_MAPPING["TEST_USER_3"]
)
assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)

View File

@@ -0,0 +1,174 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from danswer.access.models import ExternalAccess
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from ee.danswer.external_permissions.google_drive.doc_sync import (
_get_permissions_from_slim_doc,
)
from tests.daily.connectors.google_drive.helpers import ACCESS_MAPPING
from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING
from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING
from tests.daily.connectors.google_drive.helpers import file_name_template
from tests.daily.connectors.google_drive.helpers import print_discrepencies
from tests.daily.connectors.google_drive.helpers import PUBLIC_RANGE
def get_keys_available_to_user_from_access_map(
user_email: str,
group_map: dict[str, list[str]],
access_map: dict[str, ExternalAccess],
) -> list[str]:
"""
Extracts the names of the files available to the user from the access map
through their own email or group memberships or public access
"""
group_emails_for_user = []
for group_email, user_in_group_email_list in group_map.items():
if user_email in user_in_group_email_list:
group_emails_for_user.append(group_email)
accessible_file_names_for_user = []
for file_name, external_access in access_map.items():
if external_access.is_public:
accessible_file_names_for_user.append(file_name)
elif user_email in external_access.external_user_emails:
accessible_file_names_for_user.append(file_name)
elif any(
group_email in external_access.external_user_group_ids
for group_email in group_emails_for_user
):
accessible_file_names_for_user.append(file_name)
return accessible_file_names_for_user
def assert_correct_access_for_user(
user_email: str,
expected_access_ids: list[int],
group_map: dict[str, list[str]],
retrieved_access_map: dict[str, ExternalAccess],
) -> None:
"""
compares the expected access range of the user to the keys available to the user
retrieved from the source
"""
retrieved_keys_available_to_user = get_keys_available_to_user_from_access_map(
user_email, group_map, retrieved_access_map
)
retrieved_file_names = set(retrieved_keys_available_to_user)
# Combine public and user-specific access IDs
all_accessible_ids = expected_access_ids + PUBLIC_RANGE
expected_file_names = {file_name_template.format(i) for i in all_accessible_ids}
print_discrepencies(expected_file_names, retrieved_file_names)
assert expected_file_names == retrieved_file_names
# This function is supposed to map to the group_sync.py file for the google drive connector
# TODO: Call it directly
def get_group_map(google_drive_connector: GoogleDriveConnector) -> dict[str, list[str]]:
admin_service = google_drive_connector.get_google_resource("admin", "directory_v1")
group_map: dict[str, list[str]] = {}
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_drive_connector.google_domain,
fields="groups(email)",
):
# The id is the group email
group_email = group["email"]
# Gather group member emails
group_member_emails: list[str] = []
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.append(member["email"])
group_map[group_email] = group_member_emails
return group_map
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_all_permissions(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
google_drive_connector = google_drive_service_acct_connector_factory(
include_shared_drives=True,
include_my_drives=True,
)
access_map: dict[str, ExternalAccess] = {}
for slim_doc_batch in google_drive_connector.retrieve_all_slim_documents(
0, time.time()
):
for slim_doc in slim_doc_batch:
access_map[
(slim_doc.perm_sync_data or {})["name"]
] = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,
)
for file_name, external_access in access_map.items():
print(file_name, external_access)
expected_file_range = (
DRIVE_ID_MAPPING["ADMIN"]
+ DRIVE_ID_MAPPING["TEST_USER_1"]
+ DRIVE_ID_MAPPING["TEST_USER_2"]
+ DRIVE_ID_MAPPING["TEST_USER_3"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_1"]
+ DRIVE_ID_MAPPING["FOLDER_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_1"]
+ DRIVE_ID_MAPPING["FOLDER_1_2"]
+ DRIVE_ID_MAPPING["SHARED_DRIVE_2"]
+ DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
)
# Should get everything
assert len(access_map) == len(expected_file_range)
group_map = get_group_map(google_drive_connector)
print("groups:\n", group_map)
assert_correct_access_for_user(
user_email=EMAIL_MAPPING["ADMIN"],
expected_access_ids=ACCESS_MAPPING["ADMIN"],
group_map=group_map,
retrieved_access_map=access_map,
)
assert_correct_access_for_user(
user_email=EMAIL_MAPPING["TEST_USER_1"],
expected_access_ids=ACCESS_MAPPING["TEST_USER_1"],
group_map=group_map,
retrieved_access_map=access_map,
)
assert_correct_access_for_user(
user_email=EMAIL_MAPPING["TEST_USER_2"],
expected_access_ids=ACCESS_MAPPING["TEST_USER_2"],
group_map=group_map,
retrieved_access_map=access_map,
)
assert_correct_access_for_user(
user_email=EMAIL_MAPPING["TEST_USER_3"],
expected_access_ids=ACCESS_MAPPING["TEST_USER_3"],
group_map=group_map,
retrieved_access_map=access_map,
)

View File

@@ -431,6 +431,12 @@ export default function AddConnector({
setSelectedFiles={setSelectedFiles}
selectedFiles={selectedFiles}
connector={connector}
currentCredential={
currentCredential ||
liveGDriveCredential ||
liveGmailCredential ||
null
}
/>
</Card>
)}

View File

@@ -4,18 +4,22 @@ import { TextArrayField } from "@/components/admin/connectors/Field";
import { useFormikContext } from "formik";
interface ListInputProps {
field: ListOption;
name: string;
label: string | ((credential: any) => string);
description: string | ((credential: any) => string);
}
const ListInput: React.FC<ListInputProps> = ({ field }) => {
const ListInput: React.FC<ListInputProps> = ({ name, label, description }) => {
const { values } = useFormikContext<any>();
return (
<TextArrayField
name={field.name}
label={field.label}
name={name}
label={typeof label === "function" ? label(null) : label}
values={values}
subtext={field.description}
placeholder={`Enter ${field.label.toLowerCase()}`}
subtext={
typeof description === "function" ? description(null) : description
}
placeholder={`Enter ${typeof label === "function" ? label(null) : label.toLowerCase()}`}
/>
);
};

View File

@@ -13,6 +13,8 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { AccessTypeForm } from "@/components/admin/connectors/AccessTypeForm";
import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTypeGroupSelector";
import { ConfigurableSources } from "@/lib/types";
import { Credential } from "@/lib/connectors/credentials";
import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection";
export interface DynamicConnectionFormProps {
config: ConnectionConfiguration;
@@ -20,19 +22,44 @@ export interface DynamicConnectionFormProps {
setSelectedFiles: Dispatch<SetStateAction<File[]>>;
values: any;
connector: ConfigurableSources;
currentCredential: Credential<any> | null;
}
const DynamicConnectionForm: FC<DynamicConnectionFormProps> = ({
config,
interface RenderFieldProps {
field: any;
values: any;
selectedFiles: File[];
setSelectedFiles: Dispatch<SetStateAction<File[]>>;
connector: ConfigurableSources;
currentCredential: Credential<any> | null;
}
const RenderField: FC<RenderFieldProps> = ({
field,
values,
selectedFiles,
setSelectedFiles,
values,
connector,
currentCredential,
}) => {
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
if (
field.visibleCondition &&
!field.visibleCondition(values, currentCredential)
) {
return null;
}
const renderField = (field: any) => (
<div key={field.name}>
const label =
typeof field.label === "function"
? field.label(currentCredential)
: field.label;
const description =
typeof field.description === "function"
? field.description(currentCredential)
: field.description;
const fieldContent = (
<>
{field.type === "file" ? (
<FileUpload
name={field.name}
@@ -42,46 +69,71 @@ const DynamicConnectionForm: FC<DynamicConnectionFormProps> = ({
) : field.type === "zip" ? (
<FileInput
name={field.name}
label={field.label}
label={label}
optional={field.optional}
description={field.description}
description={description}
/>
) : field.type === "list" ? (
<ListInput field={field} />
<ListInput name={field.name} label={label} description={description} />
) : field.type === "select" ? (
<SelectInput
name={field.name}
optional={field.optional}
description={field.description}
description={description}
options={field.options || []}
label={field.label}
label={label}
/>
) : field.type === "number" ? (
<NumberInput
label={field.label}
label={label}
optional={field.optional}
description={field.description}
description={description}
name={field.name}
/>
) : field.type === "checkbox" ? (
<AdminBooleanFormField
checked={values[field.name]}
subtext={field.description}
subtext={description}
name={field.name}
label={field.label}
label={label}
/>
) : (
<TextFormField
subtext={field.description}
subtext={description}
optional={field.optional}
type={field.type}
label={field.label}
label={label}
name={field.name}
isTextArea={true}
/>
)}
</div>
</>
);
if (
field.visibleCondition &&
field.visibleCondition(values, currentCredential)
) {
return (
<CollapsibleSection prompt={label} key={field.name}>
{fieldContent}
</CollapsibleSection>
);
} else {
return <div key={field.name}>{fieldContent}</div>;
}
};
const DynamicConnectionForm: FC<DynamicConnectionFormProps> = ({
config,
selectedFiles,
setSelectedFiles,
values,
connector,
currentCredential,
}) => {
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
return (
<>
<h2 className="text-2xl font-bold text-text-800">{config.description}</h2>
@@ -97,7 +149,20 @@ const DynamicConnectionForm: FC<DynamicConnectionFormProps> = ({
name={"name"}
/>
{config.values.map((field) => !field.hidden && renderField(field))}
{config.values.map(
(field) =>
!field.hidden && (
<RenderField
key={field.name}
field={field}
values={values}
selectedFiles={selectedFiles}
setSelectedFiles={setSelectedFiles}
connector={connector}
currentCredential={currentCredential}
/>
)
)}
<AccessTypeForm connector={connector} />
<AccessTypeGroupSelector />
@@ -108,7 +173,18 @@ const DynamicConnectionForm: FC<DynamicConnectionFormProps> = ({
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
{showAdvancedOptions && config.advanced_values.map(renderField)}
{showAdvancedOptions &&
config.advanced_values.map((field) => (
<RenderField
key={field.name}
field={field}
values={values}
selectedFiles={selectedFiles}
setSelectedFiles={setSelectedFiles}
connector={connector}
currentCredential={currentCredential}
/>
))}
</>
)}
</>

View File

@@ -10,6 +10,7 @@ import { GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants";
import Cookies from "js-cookie";
import { TextFormField } from "@/components/admin/connectors/Field";
import { Form, Formik } from "formik";
import { User } from "@/lib/types";
import { Button as TremorButton } from "@tremor/react";
import {
Credential,
@@ -157,6 +158,7 @@ export const DriveJsonUploadSection = ({
isAdmin,
}: DriveJsonUploadSectionProps) => {
const { mutate } = useSWRConfig();
const router = useRouter();
if (serviceAccountCredentialData?.service_account_email) {
return (
@@ -190,6 +192,7 @@ export const DriveJsonUploadSection = ({
message: "Successfully deleted service account key",
type: "success",
});
router.refresh();
} else {
const errorMsg = await response.text();
setPopup({
@@ -307,9 +310,10 @@ interface DriveCredentialSectionProps {
setPopup: (popupSpec: PopupSpec | null) => void;
refreshCredentials: () => void;
connectorExists: boolean;
user: User | null;
}
export const DriveOAuthSection = ({
export const DriveAuthSection = ({
googleDrivePublicCredential,
googleDriveServiceAccountCredential,
serviceAccountKeyData,
@@ -317,6 +321,7 @@ export const DriveOAuthSection = ({
setPopup,
refreshCredentials,
connectorExists,
user,
}: DriveCredentialSectionProps) => {
const router = useRouter();
@@ -356,23 +361,23 @@ export const DriveOAuthSection = ({
return (
<div>
<p className="text-sm mb-6">
When using a Google Drive Service Account, you can either have Danswer
act as the service account itself OR you can specify an account for
the service account to impersonate.
When using a Google Drive Service Account, you must specify the email
of the primary admin that you would like the service account to
impersonate.
<br />
<br />
If you want to use the service account itself, leave the{" "}
<b>&apos;User email to impersonate&apos;</b> field blank when
submitting. If you do choose this option, make sure you have shared
the documents you want to index with the service account.
Ideally, this account should be an owner/admin of the Google
Organization that owns the Google Drive(s) you want to index.
</p>
<Formik
initialValues={{
google_drive_delegated_user: "",
google_drive_primary_admin: user?.email || "",
}}
validationSchema={Yup.object().shape({
google_drive_delegated_user: Yup.string().optional(),
google_drive_primary_admin: Yup.string().required(
"User email is required"
),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
@@ -384,8 +389,7 @@ export const DriveOAuthSection = ({
"Content-Type": "application/json",
},
body: JSON.stringify({
google_drive_delegated_user:
values.google_drive_delegated_user,
google_drive_primary_admin: values.google_drive_primary_admin,
}),
}
);
@@ -408,9 +412,9 @@ export const DriveOAuthSection = ({
{({ isSubmitting }) => (
<Form>
<TextFormField
name="google_drive_delegated_user"
label="[Optional] User email to impersonate:"
subtext="If left blank, Danswer will use the service account itself."
name="google_drive_primary_admin"
label="User email to impersonate:"
subtext="Enter the email of the user whose Google Drive access you want to delegate to the service account."
/>
<div className="flex">
<TremorButton type="submit" disabled={isSubmitting}>

View File

@@ -12,7 +12,7 @@ import {
useConnectorCredentialIndexingStatus,
} from "@/lib/hooks";
import { Title } from "@tremor/react";
import { DriveJsonUploadSection, DriveOAuthSection } from "./Credential";
import { DriveJsonUploadSection, DriveAuthSection } from "./Credential";
import {
Credential,
GoogleDriveCredentialJson,
@@ -22,7 +22,7 @@ import { GoogleDriveConfig } from "@/lib/connectors/connectors";
import { useUser } from "@/components/user/UserProvider";
const GDriveMain = ({}: {}) => {
const { isLoadingUser, isAdmin } = useUser();
const { isLoadingUser, isAdmin, user } = useUser();
const {
data: appCredentialData,
@@ -135,7 +135,7 @@ const GDriveMain = ({}: {}) => {
<Title className="mb-2 mt-6 ml-auto mr-auto">
Step 2: Authenticate with Danswer
</Title>
<DriveOAuthSection
<DriveAuthSection
setPopup={setPopup}
refreshCredentials={refreshCredentials}
googleDrivePublicCredential={googleDrivePublicCredential}
@@ -145,6 +145,7 @@ const GDriveMain = ({}: {}) => {
appCredentialData={appCredentialData}
serviceAccountKeyData={serviceAccountKeyData}
connectorExists={googleDriveConnectorIndexingStatuses.length > 0}
user={user}
/>
</>
)}

View File

@@ -9,6 +9,7 @@ import { useUser } from "@/components/user/UserProvider";
import { useField } from "formik";
import { AutoSyncOptions } from "./AutoSyncOptions";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { useEffect } from "react";
function isValidAutoSyncSource(
value: ConfigurableSources
@@ -28,6 +29,21 @@ export function AccessTypeForm({
const isAutoSyncSupported = isValidAutoSyncSource(connector);
const { isLoadingUser, isAdmin } = useUser();
useEffect(() => {
if (!isPaidEnterpriseEnabled) {
access_type_helpers.setValue("public");
} else if (isAutoSyncSupported) {
access_type_helpers.setValue("sync");
} else {
access_type_helpers.setValue("private");
}
}, [
isAutoSyncSupported,
isAdmin,
isPaidEnterpriseEnabled,
access_type_helpers,
]);
const options = [
{
name: "Private",
@@ -46,9 +62,9 @@ export function AccessTypeForm({
});
}
if (isAutoSyncSupported && isAdmin) {
if (isAutoSyncSupported && isAdmin && isPaidEnterpriseEnabled) {
options.push({
name: "Auto Sync",
name: "Auto Sync Permissions",
value: "sync",
description:
"We will automatically sync permissions from the source. A document will be searchable in Danswer if and only if the user performing the search has permission to access the document in the source.",
@@ -59,12 +75,13 @@ export function AccessTypeForm({
<>
{isPaidEnterpriseEnabled && isAdmin && (
<>
<div className="flex gap-x-2 items-center">
<div>
<label className="text-text-950 font-medium">Document Access</label>
</div>
<p className="text-sm text-text-500 mb-2">
<p className="text-sm text-text-500">
Control who has access to the documents indexed by this connector.
</p>
</div>
<DefaultDropdown
options={options}
selected={access_type.value}
@@ -75,11 +92,9 @@ export function AccessTypeForm({
/>
{access_type.value === "sync" && isAutoSyncSupported && (
<div>
<AutoSyncOptions
connectorType={connector as ValidAutoSyncSources}
/>
</div>
)}
</>
)}

View File

@@ -64,21 +64,6 @@ export const ConnectorTitle = ({
"Jira Project URL",
typedConnector.connector_specific_config.jira_project_url
);
} else if (connector.source === "google_drive") {
const typedConnector = connector as Connector<GoogleDriveConfig>;
if (
typedConnector.connector_specific_config?.folder_paths &&
typedConnector.connector_specific_config?.folder_paths.length > 0
) {
additionalMetadata.set(
"Folders",
typedConnector.connector_specific_config.folder_paths.join(", ")
);
}
if (!isPublic && owner) {
additionalMetadata.set("Owner", owner);
}
} else if (connector.source === "slack") {
const typedConnector = connector as Connector<SlackConfig>;
if (

View File

@@ -12,37 +12,6 @@ export const autoSyncConfigBySource: Record<
>
> = {
confluence: {},
google_drive: {
customer_id: {
label: "Google Workspace Customer ID",
subtext: (
<>
The unique identifier for your Google Workspace account. To find this,
checkout the{" "}
<a
href="https://support.google.com/cloudidentity/answer/10070793"
target="_blank"
className="text-link"
>
guide from Google
</a>
.
</>
),
},
company_domain: {
label: "Google Workspace Company Domain",
subtext: (
<>
The email domain for your Google Workspace account.
<br />
<br />
For example, if your email provided through Google Workspace looks
something like chris@danswer.ai, then your company domain is{" "}
<b>danswer.ai</b>
</>
),
},
},
google_drive: {},
slack: {},
};

View File

@@ -2,6 +2,7 @@ import * as Yup from "yup";
import { IsPublicGroupSelectorFormType } from "@/components/IsPublicGroupSelector";
import { ConfigurableSources, ValidInputTypes, ValidSources } from "../types";
import { AccessTypeGroupSelectorFormType } from "@/components/admin/connectors/AccessTypeGroupSelector";
import { Credential } from "@/lib/connectors/credentials"; // Import Credential type
export function isLoadState(connector_name: string): boolean {
// TODO: centralize connector metadata like this somewhere instead of hardcoding it here
@@ -29,12 +30,18 @@ export type StringWithDescription = {
};
export interface Option {
label: string;
label: string | ((currentCredential: Credential<any> | null) => string);
name: string;
description?: string;
description?:
| string
| ((currentCredential: Credential<any> | null) => string);
query?: string;
optional?: boolean;
hidden?: boolean;
visibleCondition?: (
values: any,
currentCredential: Credential<any> | null
) => boolean;
}
export interface SelectOption extends Option {
@@ -204,38 +211,59 @@ export const connectorConfigs: Record<
description: "Configure Google Drive connector",
values: [
{
type: "list",
query: "Enter folder paths:",
label: "Folder Paths",
name: "folder_paths",
type: "checkbox",
label: "Include shared drives?",
description:
"This will allow Danswer to index everything in your shared drives.",
name: "include_shared_drives",
optional: true,
default: true,
},
{
type: "text",
description:
"Enter a comma separated list of the URLs of the shared drives to index. Leave blank to index all shared drives.",
label: "Shared Drive URLs",
name: "shared_drive_urls",
visibleCondition: (values) => values.include_shared_drives,
optional: true,
},
{
type: "checkbox",
query: "Include shared files?",
label: "Include Shared",
name: "include_shared",
optional: false,
default: false,
label: (currentCredential) =>
currentCredential?.credential_json?.google_drive_tokens
? "Include My Drive?"
: "Include Everyone's My Drive?",
description: (currentCredential) =>
currentCredential?.credential_json?.google_drive_tokens
? "This will allow Danswer to index everything in your My Drive."
: "This will allow Danswer to index everything in everyone's My Drives.",
name: "include_my_drives",
optional: true,
default: true,
},
{
type: "checkbox",
query: "Follow shortcuts?",
label: "Follow Shortcuts",
name: "follow_shortcuts",
optional: false,
default: false,
},
{
type: "checkbox",
query: "Only include organization public files?",
label: "Only Org Public",
name: "only_org_public",
optional: false,
default: false,
type: "text",
description:
"Enter a comma separated list of the emails of the users whose MyDrive you want to index. Leave blank to index all MyDrives.",
label: "My Drive Emails",
name: "my_drive_emails",
visibleCondition: (values, currentCredential) =>
values.include_my_drives &&
!currentCredential?.credential_json?.google_drive_tokens,
optional: true,
},
],
advanced_values: [
{
type: "text",
description:
"Enter a comma separated list of the URLs of the folders located in Shared Drives to index. The files located in these folders (and all subfolders) will be indexed. Note: This will be in addition to the 'Include Shared Drives' and 'Shared Drive URLs' settings, so leave those blank if you only want to index the folders specified here.",
label: "Folder URLs",
name: "shared_folder_urls",
optional: true,
},
],
advanced_values: [],
},
gmail: {
description: "Configure Gmail connector",
@@ -1025,7 +1053,7 @@ export interface GitlabConfig {
}
export interface GoogleDriveConfig {
folder_paths?: string[];
parent_urls?: string[];
include_shared?: boolean;
follow_shortcuts?: boolean;
only_org_public?: boolean;

View File

@@ -58,6 +58,7 @@ export interface GmailCredentialJson {
export interface GoogleDriveCredentialJson {
google_drive_tokens: string;
google_drive_primary_admin: string;
}
export interface GmailServiceAccountCredentialJson {
@@ -67,7 +68,7 @@ export interface GmailServiceAccountCredentialJson {
export interface GoogleDriveServiceAccountCredentialJson {
google_drive_service_account_key: string;
google_drive_delegated_user: string;
google_drive_primary_admin: string;
}
export interface SlabCredentialJson {
@@ -331,7 +332,7 @@ export const credentialDisplayNames: Record<string, string> = {
// Google Drive Service Account
google_drive_service_account_key: "Google Drive Service Account Key",
google_drive_delegated_user: "Google Drive Delegated User",
google_drive_primary_admin: "Google Drive Delegated User",
// Slab
slab_bot_token: "Slab Bot Token",