Add drive sections (#3040)

* ADd header support for drive

* Fix mypy

* Comment change

* Improve

* Cleanup

* Add comment
This commit is contained in:
Chris Weaver
2024-11-03 14:10:45 -08:00
committed by GitHub
parent 56c3a5ff5b
commit c2d04f591d
14 changed files with 460 additions and 94 deletions

View File

@ -3,8 +3,6 @@ from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.connectors.google_drive.connector_auth import ( from danswer.connectors.google_drive.connector_auth import (
@ -24,6 +22,9 @@ from danswer.connectors.google_drive.file_retrieval import get_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.google_drive.models import GoogleDriveFileType from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_drive.resources import get_admin_service
from danswer.connectors.google_drive.resources import get_drive_service
from danswer.connectors.google_drive.resources import get_google_docs_service
from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import LoadConnector
@ -103,42 +104,49 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls) shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list) self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list)
self.primary_admin_email: str | None = None self._primary_admin_email: str | None = None
self.google_domain: str | None = None self.google_domain: str | None = None
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._TRAVERSED_PARENT_IDS: set[str] = set() self._TRAVERSED_PARENT_IDS: set[str] = set()
@property
def primary_admin_email(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email
@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
return self._creds
def _update_traversed_parent_ids(self, folder_id: str) -> None: def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._TRAVERSED_PARENT_IDS.add(folder_id) self._TRAVERSED_PARENT_IDS.add(folder_id)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self.google_domain = primary_admin_email.split("@")[1] self.google_domain = primary_admin_email.split("@")[1]
self.primary_admin_email = primary_admin_email self._primary_admin_email = primary_admin_email
self.creds, new_creds_dict = get_google_drive_creds(credentials) self._creds, new_creds_dict = get_google_drive_creds(credentials)
return new_creds_dict return new_creds_dict
def get_google_resource(
self,
service_name: str = "drive",
service_version: str = "v3",
user_email: str | None = None,
) -> Resource:
if isinstance(self.creds, ServiceAccountCredentials):
creds = self.creds.with_subject(user_email or self.primary_admin_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(self.creds, OAuthCredentials):
service = build(service_name, service_version, credentials=self.creds)
else:
raise PermissionError("No credentials found")
return service
def _get_all_user_emails(self) -> list[str]: def _get_all_user_emails(self) -> list[str]:
admin_service = self.get_google_resource("admin", "directory_v1") admin_service = get_admin_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
emails = [] emails = []
for user in execute_paginated_retrieval( for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list, retrieval_function=admin_service.users().list,
@ -156,7 +164,10 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]: ) -> Iterator[GoogleDriveFileType]:
primary_drive_service = self.get_google_resource() primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
if self.include_shared_drives: if self.include_shared_drives:
shared_drive_urls = self.shared_drive_ids shared_drive_urls = self.shared_drive_ids
@ -212,7 +223,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
for email in all_user_emails: for email in all_user_emails:
logger.info(f"Fetching personal files for user: {email}") logger.info(f"Fetching personal files for user: {email}")
user_drive_service = self.get_google_resource(user_email=email) user_drive_service = get_drive_service(self.creds, user_email=email)
yield from get_files_in_my_drive( yield from get_files_in_my_drive(
service=user_drive_service, service=user_drive_service,
@ -233,11 +244,16 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
start=start, start=start,
end=end, end=end,
): ):
user_email = file.get("owners", [{}])[0].get("emailAddress") user_email = (
service = self.get_google_resource(user_email=user_email) file.get("owners", [{}])[0].get("emailAddress")
or self.primary_admin_email
)
user_drive_service = get_drive_service(self.creds, user_email=user_email)
docs_service = get_google_docs_service(self.creds, user_email=user_email)
if doc := convert_drive_item_to_document( if doc := convert_drive_item_to_document(
file=file, file=file,
service=service, drive_service=user_drive_service,
docs_service=docs_service,
): ):
doc_batch.append(doc) doc_batch.append(doc)
if len(doc_batch) >= self.batch_size: if len(doc_batch) >= self.batch_size:

View File

@ -28,6 +28,8 @@ from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
# this is counted under `/auth/drive.readonly`
GOOGLE_DRIVE_SCOPES = [ GOOGLE_DRIVE_SCOPES = [
"https://www.googleapis.com/auth/drive.readonly", "https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly", "https://www.googleapis.com/auth/drive.metadata.readonly",

View File

@ -2,7 +2,6 @@ import io
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@ -13,6 +12,9 @@ from danswer.connectors.google_drive.constants import ERRORS_TO_CONTINUE_ON
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from danswer.connectors.google_drive.models import GDriveMimeType from danswer.connectors.google_drive.models import GDriveMimeType
from danswer.connectors.google_drive.models import GoogleDriveFileType from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_drive.resources import GoogleDocsService
from danswer.connectors.google_drive.resources import GoogleDriveService
from danswer.connectors.google_drive.section_extraction import get_document_sections
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text from danswer.file_processing.extract_file_text import docx_to_text
@ -25,86 +27,137 @@ from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
def _extract_text(file: dict[str, str], service: Resource) -> str: def _extract_sections_basic(
file: dict[str, str], service: GoogleDriveService
) -> list[Section]:
mime_type = file["mimeType"] mime_type = file["mimeType"]
link = file["webViewLink"]
if mime_type not in set(item.value for item in GDriveMimeType): if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful # Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
if mime_type in [ try:
GDriveMimeType.DOC.value, if mime_type in [
GDriveMimeType.PPT.value, GDriveMimeType.DOC.value,
GDriveMimeType.SPREADSHEET.value, GDriveMimeType.PPT.value,
]: GDriveMimeType.SPREADSHEET.value,
export_mime_type = ( ]:
"text/plain" export_mime_type = (
if mime_type != GDriveMimeType.SPREADSHEET.value "text/plain"
else "text/csv" if mime_type != GDriveMimeType.SPREADSHEET.value
) else "text/csv"
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
) )
text = (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return [Section(link=link, text=text)]
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return [
Section(
link=link,
text=service.files()
.get_media(fileId=file["id"])
.execute()
.decode("utf-8"),
)
]
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return [
Section(
link=link,
text=unstructured_to_text(
file=io.BytesIO(response),
file_name=file.get("name", file["id"]),
),
)
]
if mime_type == GDriveMimeType.WORD_DOC.value: if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response)) return [
elif mime_type == GDriveMimeType.PDF.value: Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
text, _ = read_pdf_file(file=io.BytesIO(response)) ]
return text elif mime_type == GDriveMimeType.PDF.value:
elif mime_type == GDriveMimeType.POWERPOINT.value: text, _ = read_pdf_file(file=io.BytesIO(response))
return pptx_to_text(file=io.BytesIO(response)) return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
return [
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]
return UNSUPPORTED_FILE_TYPE_CONTENT return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
except Exception:
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
def convert_drive_item_to_document( def convert_drive_item_to_document(
file: GoogleDriveFileType, service: Resource file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
) -> Document | None: ) -> Document | None:
try: try:
# Skip files that are shortcuts # Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype") logger.info("Ignoring Drive Shortcut Filetype")
return None return None
try:
text_contents = _extract_text(file, service) or ""
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None
raise sections: list[Section] = []
# Special handling for Google Docs to preserve structure, link
# to headers
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
" Falling back to basic extraction."
)
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
if not sections:
try:
# For all other file types just extract the text
sections = _extract_sections_basic(file, drive_service)
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None
raise
if not sections:
return None
return Document( return Document(
id=file["webViewLink"], id=file["webViewLink"],
sections=[Section(link=file["webViewLink"], text=text_contents)], sections=sections,
source=DocumentSource.GOOGLE_DRIVE, source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"], semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone( doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc timezone.utc
), ),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, metadata={}
if any(section.text for section in sections)
else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"), additional_info=file.get("id"),
) )
except Exception as e: except Exception as e:

View File

@ -28,7 +28,7 @@ def execute_paginated_retrieval(
if next_page_token: if next_page_token:
request_kwargs["pageToken"] = next_page_token request_kwargs["pageToken"] = next_page_token
results = add_retries(lambda: retrieval_function(**request_kwargs).execute())() results = (lambda: retrieval_function(**request_kwargs).execute())()
next_page_token = results.get("nextPageToken") next_page_token = results.get("nextPageToken")
for item in results.get(list_key, []): for item in results.get(list_key, []):

View File

@ -0,0 +1,52 @@
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
class GoogleDriveService(Resource):
pass
class GoogleDocsService(Resource):
pass
class AdminService(Resource):
pass
def _get_google_service(
service_name: str,
service_version: str,
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService:
if isinstance(creds, ServiceAccountCredentials):
creds = creds.with_subject(user_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(creds, OAuthCredentials):
service = build(service_name, service_version, credentials=creds)
return service
def get_google_docs_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDocsService:
return _get_google_service("docs", "v1", creds, user_email)
def get_drive_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService:
return _get_google_service("drive", "v3", creds, user_email)
def get_admin_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str,
) -> AdminService:
return _get_google_service("admin", "directory_v1", creds, user_email)

View File

@ -0,0 +1,105 @@
from typing import Any
from pydantic import BaseModel
from danswer.connectors.google_drive.resources import GoogleDocsService
from danswer.connectors.models import Section
class CurrentHeading(BaseModel):
id: str
text: str
def _build_gdoc_section_link(doc_id: str, heading_id: str) -> str:
"""Builds a Google Doc link that jumps to a specific heading"""
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
# @Chris
return (
f"https://docs.google.com/document/d/{doc_id}/edit?tab=t.0#heading={heading_id}"
)
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
"""Extracts the id from a heading paragraph element"""
return paragraph["paragraphStyle"]["headingId"]
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
"""Extracts the text content from a paragraph element"""
text_elements = []
for element in paragraph.get("elements", []):
if "textRun" in element:
text_elements.append(element["textRun"].get("content", ""))
return "".join(text_elements)
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[Section]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
doc = docs_service.documents().get(documentId=doc_id).execute()
# Get the content
content = doc.get("body", {}).get("content", [])
sections: list[Section] = []
current_section: list[str] = []
current_heading: CurrentHeading | None = None
for element in content:
if "paragraph" not in element:
continue
paragraph = element["paragraph"]
# Check if this is a heading
if (
"paragraphStyle" in paragraph
and "namedStyleType" in paragraph["paragraphStyle"]
):
style = paragraph["paragraphStyle"]["namedStyleType"]
is_heading = style.startswith("HEADING_")
is_title = style.startswith("TITLE")
if is_heading or is_title:
# If we were building a previous section, add it to sections list
if current_heading is not None and current_section:
heading_text = current_heading.text
section_text = f"{heading_text}\n" + "\n".join(current_section)
sections.append(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
)
current_section = []
# Start new heading
heading_id = _extract_id_from_heading(paragraph)
heading_text = _extract_text_from_paragraph(paragraph)
current_heading = CurrentHeading(
id=heading_id,
text=heading_text,
)
continue
# Add content to current section
if current_heading is not None:
text = _extract_text_from_paragraph(paragraph)
if text.strip():
current_section.append(text)
# Don't forget to add the last section
if current_heading is not None and current_section:
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
sections.append(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
)
return sections

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess from danswer.access.models import ExternalAccess
from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.google_drive.resources import get_drive_service
from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.models import SlimDocument from danswer.connectors.models import SlimDocument
from danswer.db.models import ConnectorCredentialPair from danswer.db.models import ConnectorCredentialPair
@ -56,7 +57,10 @@ def _fetch_permissions_for_permission_ids(
return permissions return permissions
owner_email = permission_info.get("owner_email") owner_email = permission_info.get("owner_email")
drive_service = google_drive_connector.get_google_resource(user_email=owner_email) drive_service = get_drive_service(
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# Otherwise, fetch all permissions and update cache # Otherwise, fetch all permissions and update cache
fetched_permissions = execute_paginated_retrieval( fetched_permissions = execute_paginated_retrieval(

View File

@ -2,6 +2,7 @@ from sqlalchemy.orm import Session
from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.google_drive.resources import get_admin_service
from danswer.db.models import ConnectorCredentialPair from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@ -19,8 +20,9 @@ def gdrive_group_sync(
**cc_pair.connector.connector_specific_config **cc_pair.connector.connector_specific_config
) )
google_drive_connector.load_credentials(cc_pair.credential.credential_json) google_drive_connector.load_credentials(cc_pair.credential.credential_json)
admin_service = get_admin_service(
admin_service = google_drive_connector.get_google_resource("admin", "directory_v1") google_drive_connector.creds, google_drive_connector.primary_admin_email
)
danswer_groups: list[ExternalUserGroup] = [] danswer_groups: list[ExternalUserGroup] = []
for group in execute_paginated_retrieval( for group in execute_paginated_retrieval(

View File

@ -31,6 +31,30 @@ def load_env_vars(env_file: str = ".env") -> None:
load_env_vars() load_env_vars()
def parse_credentials(env_str: str) -> dict:
"""
Parse a double-escaped JSON string from environment variables into a Python dictionary.
Args:
env_str (str): The double-escaped JSON string from environment variables
Returns:
dict: Parsed OAuth credentials
"""
# first try normally
try:
return json.loads(env_str)
except Exception:
# First, try remove extra escaping backslashes
unescaped = env_str.replace('\\"', '"')
# remove leading / trailing quotes
unescaped = unescaped.strip('"')
# Now parse the JSON
return json.loads(unescaped)
@pytest.fixture @pytest.fixture
def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]: def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]:
def _connector_factory( def _connector_factory(
@ -50,7 +74,7 @@ def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector
) )
json_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] json_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
refried_json_string = json.loads(json_string) refried_json_string = json.dumps(parse_credentials(json_string))
credentials_json = { credentials_json = {
DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string, DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string,
@ -84,7 +108,7 @@ def google_drive_service_acct_connector_factory() -> (
) )
json_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] json_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
refried_json_string = json.loads(json_string) refried_json_string = json.dumps(parse_credentials(json_string))
# Load Service Account Credentials # Load Service Account Credentials
connector.load_credentials( connector.load_credentials(

View File

@ -18,6 +18,7 @@ _SHARED_DRIVE_2_FILE_IDS = list(range(40, 45))
_FOLDER_2_FILE_IDS = list(range(45, 50)) _FOLDER_2_FILE_IDS = list(range(45, 50))
_FOLDER_2_1_FILE_IDS = list(range(50, 55)) _FOLDER_2_1_FILE_IDS = list(range(50, 55))
_FOLDER_2_2_FILE_IDS = list(range(55, 60)) _FOLDER_2_2_FILE_IDS = list(range(55, 60))
_SECTIONS_FILE_IDS = [61]
_PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_IDS _PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_IDS
_PUBLIC_FILE_IDS = list(range(55, 57)) _PUBLIC_FILE_IDS = list(range(55, 57))
@ -64,6 +65,7 @@ DRIVE_ID_MAPPING: dict[str, list[int]] = {
"FOLDER_2": _FOLDER_2_FILE_IDS, "FOLDER_2": _FOLDER_2_FILE_IDS,
"FOLDER_2_1": _FOLDER_2_1_FILE_IDS, "FOLDER_2_1": _FOLDER_2_1_FILE_IDS,
"FOLDER_2_2": _FOLDER_2_2_FILE_IDS, "FOLDER_2_2": _FOLDER_2_2_FILE_IDS,
"SECTIONS": _SECTIONS_FILE_IDS,
} }
# Dictionary for emails # Dictionary for emails
@ -100,6 +102,7 @@ ACCESS_MAPPING: dict[str, list[int]] = {
+ _FOLDER_2_FILE_IDS + _FOLDER_2_FILE_IDS
+ _FOLDER_2_1_FILE_IDS + _FOLDER_2_1_FILE_IDS
+ _FOLDER_2_2_FILE_IDS + _FOLDER_2_2_FILE_IDS
+ _SECTIONS_FILE_IDS
), ),
# This user has access to drive 1 # This user has access to drive 1
# This user has redundant access to folder 1 because of group access # This user has redundant access to folder 1 because of group access
@ -127,6 +130,21 @@ ACCESS_MAPPING: dict[str, list[int]] = {
"TEST_USER_3": _TEST_USER_3_FILE_IDS, "TEST_USER_3": _TEST_USER_3_FILE_IDS,
} }
SPECIAL_FILE_ID_TO_CONTENT_MAP: dict[int, str] = {
61: (
"Title\n\n"
"This is a Google Doc with sections - "
"Section 1\n\n"
"Section 1 content - "
"Sub-Section 1-1\n\n"
"Sub-Section 1-1 content - "
"Sub-Section 1-2\n\n"
"Sub-Section 1-2 content - "
"Section 2\n\n"
"Section 2 content"
),
}
file_name_template = "file_{}.txt" file_name_template = "file_{}.txt"
file_text_template = "This is file {}" file_text_template = "This is file {}"
@ -142,18 +160,28 @@ def print_discrepencies(expected: set[str], retrieved: set[str]) -> None:
print(expected - retrieved) print(expected - retrieved)
def get_file_content(file_id: int) -> str:
if file_id in SPECIAL_FILE_ID_TO_CONTENT_MAP:
return SPECIAL_FILE_ID_TO_CONTENT_MAP[file_id]
return file_text_template.format(file_id)
def assert_retrieved_docs_match_expected( def assert_retrieved_docs_match_expected(
retrieved_docs: list[Document], expected_file_ids: Sequence[int] retrieved_docs: list[Document], expected_file_ids: Sequence[int]
) -> None: ) -> None:
expected_file_names = { expected_file_names = {
file_name_template.format(file_id) for file_id in expected_file_ids file_name_template.format(file_id) for file_id in expected_file_ids
} }
expected_file_texts = { expected_file_texts = {get_file_content(file_id) for file_id in expected_file_ids}
file_text_template.format(file_id) for file_id in expected_file_ids
}
retrieved_file_names = set([doc.semantic_identifier for doc in retrieved_docs]) retrieved_file_names = set([doc.semantic_identifier for doc in retrieved_docs])
retrieved_texts = set([doc.sections[0].text for doc in retrieved_docs]) retrieved_texts = set(
[
" - ".join([section.text for section in doc.sections])
for doc in retrieved_docs
]
)
# Check file names # Check file names
print_discrepencies(expected_file_names, retrieved_file_names) print_discrepencies(expected_file_names, retrieved_file_names)

View File

@ -41,6 +41,7 @@ def test_include_all(
+ DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"] + DRIVE_ID_MAPPING["FOLDER_2_2"]
+ DRIVE_ID_MAPPING["SECTIONS"]
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
@ -75,6 +76,7 @@ def test_include_shared_drives_only(
+ DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"] + DRIVE_ID_MAPPING["FOLDER_2_2"]
+ DRIVE_ID_MAPPING["SECTIONS"]
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,

View File

@ -0,0 +1,71 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.models import Document
SECTIONS_FOLDER_URL = (
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
)
@patch(
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_google_drive_sections(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
oauth_connector = google_drive_oauth_connector_factory(
include_shared_drives=False,
include_my_drives=False,
shared_folder_urls=SECTIONS_FOLDER_URL,
)
service_acct_connector = google_drive_service_acct_connector_factory(
include_shared_drives=False,
include_my_drives=False,
shared_folder_urls=SECTIONS_FOLDER_URL,
)
for connector in [oauth_connector, service_acct_connector]:
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
# Verify we got the 1 doc with sections
assert len(retrieved_docs) == 1
# Verify each section has the expected structure
doc = retrieved_docs[0]
assert len(doc.sections) == 5
header_section = doc.sections[0]
assert header_section.text == "Title\n\nThis is a Google Doc with sections"
assert header_section.link is not None
assert header_section.link.endswith(
"?tab=t.0#heading=h.hfjc17k6qwzt"
) or header_section.link.endswith("?tab=t.0#heading=h.hfjc17k6qwzt")
section_1 = doc.sections[1]
assert section_1.text == "Section 1\n\nSection 1 content"
assert section_1.link is not None
assert section_1.link.endswith("?tab=t.0#heading=h.8slfx752a3g5")
section_2 = doc.sections[2]
assert section_2.text == "Sub-Section 1-1\n\nSub-Section 1-1 content"
assert section_2.link is not None
assert section_2.link.endswith("?tab=t.0#heading=h.4kj3ayade1bp")
section_3 = doc.sections[3]
assert section_3.text == "Sub-Section 1-2\n\nSub-Section 1-2 content"
assert section_3.link is not None
assert section_3.link.endswith("?tab=t.0#heading=h.pm6wrpzgk69l")
section_4 = doc.sections[4]
assert section_4.text == "Section 2\n\nSection 2 content"
assert section_4.link is not None
assert section_4.link.endswith("?tab=t.0#heading=h.2m0s9youe2k9")

View File

@ -44,6 +44,7 @@ def test_include_all(
+ DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"] + DRIVE_ID_MAPPING["FOLDER_2_2"]
+ DRIVE_ID_MAPPING["SECTIONS"]
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,
@ -78,6 +79,7 @@ def test_include_shared_drives_only(
+ DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"] + DRIVE_ID_MAPPING["FOLDER_2_2"]
+ DRIVE_ID_MAPPING["SECTIONS"]
) )
assert_retrieved_docs_match_expected( assert_retrieved_docs_match_expected(
retrieved_docs=retrieved_docs, retrieved_docs=retrieved_docs,

View File

@ -6,6 +6,7 @@ from unittest.mock import patch
from danswer.access.models import ExternalAccess from danswer.access.models import ExternalAccess
from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.google_drive.resources import get_admin_service
from ee.danswer.external_permissions.google_drive.doc_sync import ( from ee.danswer.external_permissions.google_drive.doc_sync import (
_get_permissions_from_slim_doc, _get_permissions_from_slim_doc,
) )
@ -72,7 +73,10 @@ def assert_correct_access_for_user(
# This function is supposed to map to the group_sync.py file for the google drive connector # This function is supposed to map to the group_sync.py file for the google drive connector
# TODO: Call it directly # TODO: Call it directly
def get_group_map(google_drive_connector: GoogleDriveConnector) -> dict[str, list[str]]: def get_group_map(google_drive_connector: GoogleDriveConnector) -> dict[str, list[str]]:
admin_service = google_drive_connector.get_google_resource("admin", "directory_v1") admin_service = get_admin_service(
creds=google_drive_connector.creds,
user_email=google_drive_connector.primary_admin_email,
)
group_map: dict[str, list[str]] = {} group_map: dict[str, list[str]] = {}
for group in execute_paginated_retrieval( for group in execute_paginated_retrieval(
@ -138,6 +142,7 @@ def test_all_permissions(
+ DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2"]
+ DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_1"]
+ DRIVE_ID_MAPPING["FOLDER_2_2"] + DRIVE_ID_MAPPING["FOLDER_2_2"]
+ DRIVE_ID_MAPPING["SECTIONS"]
) )
# Should get everything # Should get everything