mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-05 04:01:31 +02:00
Add drive sections (#3040)
* ADd header support for drive * Fix mypy * Comment change * Improve * Cleanup * Add comment
This commit is contained in:
@ -3,8 +3,6 @@ from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
@ -24,6 +22,9 @@ from danswer.connectors.google_drive.file_retrieval import get_files_in_my_drive
|
||||
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.connectors.google_drive.resources import get_admin_service
|
||||
from danswer.connectors.google_drive.resources import get_drive_service
|
||||
from danswer.connectors.google_drive.resources import get_google_docs_service
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@ -103,42 +104,49 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
|
||||
self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list)
|
||||
|
||||
self.primary_admin_email: str | None = None
|
||||
self._primary_admin_email: str | None = None
|
||||
self.google_domain: str | None = None
|
||||
|
||||
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
self._TRAVERSED_PARENT_IDS: set[str] = set()
|
||||
|
||||
@property
|
||||
def primary_admin_email(self) -> str:
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email
|
||||
|
||||
@property
|
||||
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||
if self._creds is None:
|
||||
raise RuntimeError(
|
||||
"Creds missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._creds
|
||||
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
self._TRAVERSED_PARENT_IDS.add(folder_id)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
self.google_domain = primary_admin_email.split("@")[1]
|
||||
self.primary_admin_email = primary_admin_email
|
||||
self._primary_admin_email = primary_admin_email
|
||||
|
||||
self.creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
self._creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
return new_creds_dict
|
||||
|
||||
def get_google_resource(
|
||||
self,
|
||||
service_name: str = "drive",
|
||||
service_version: str = "v3",
|
||||
user_email: str | None = None,
|
||||
) -> Resource:
|
||||
if isinstance(self.creds, ServiceAccountCredentials):
|
||||
creds = self.creds.with_subject(user_email or self.primary_admin_email)
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
elif isinstance(self.creds, OAuthCredentials):
|
||||
service = build(service_name, service_version, credentials=self.creds)
|
||||
else:
|
||||
raise PermissionError("No credentials found")
|
||||
|
||||
return service
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
admin_service = self.get_google_resource("admin", "directory_v1")
|
||||
admin_service = get_admin_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
@ -156,7 +164,10 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
primary_drive_service = self.get_google_resource()
|
||||
primary_drive_service = get_drive_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
|
||||
if self.include_shared_drives:
|
||||
shared_drive_urls = self.shared_drive_ids
|
||||
@ -212,7 +223,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
for email in all_user_emails:
|
||||
logger.info(f"Fetching personal files for user: {email}")
|
||||
user_drive_service = self.get_google_resource(user_email=email)
|
||||
user_drive_service = get_drive_service(self.creds, user_email=email)
|
||||
|
||||
yield from get_files_in_my_drive(
|
||||
service=user_drive_service,
|
||||
@ -233,11 +244,16 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress")
|
||||
service = self.get_google_resource(user_email=user_email)
|
||||
user_email = (
|
||||
file.get("owners", [{}])[0].get("emailAddress")
|
||||
or self.primary_admin_email
|
||||
)
|
||||
user_drive_service = get_drive_service(self.creds, user_email=user_email)
|
||||
docs_service = get_google_docs_service(self.creds, user_email=user_email)
|
||||
if doc := convert_drive_item_to_document(
|
||||
file=file,
|
||||
service=service,
|
||||
drive_service=user_drive_service,
|
||||
docs_service=docs_service,
|
||||
):
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
|
@ -28,6 +28,8 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
|
||||
# this is counted under `/auth/drive.readonly`
|
||||
GOOGLE_DRIVE_SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
|
@ -2,7 +2,6 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
@ -13,6 +12,9 @@ from danswer.connectors.google_drive.constants import ERRORS_TO_CONTINUE_ON
|
||||
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
from danswer.connectors.google_drive.models import GDriveMimeType
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.connectors.google_drive.resources import GoogleDocsService
|
||||
from danswer.connectors.google_drive.resources import GoogleDriveService
|
||||
from danswer.connectors.google_drive.section_extraction import get_document_sections
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
@ -25,13 +27,17 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_text(file: dict[str, str], service: Resource) -> str:
|
||||
def _extract_sections_basic(
|
||||
file: dict[str, str], service: GoogleDriveService
|
||||
) -> list[Section]:
|
||||
mime_type = file["mimeType"]
|
||||
link = file["webViewLink"]
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
try:
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
@ -42,17 +48,26 @@ def _extract_text(file: dict[str, str], service: Resource) -> str:
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
return (
|
||||
text = (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
return [
|
||||
Section(
|
||||
link=link,
|
||||
text=service.files()
|
||||
.get_media(fileId=file["id"])
|
||||
.execute()
|
||||
.decode("utf-8"),
|
||||
)
|
||||
]
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
@ -60,31 +75,64 @@ def _extract_text(file: dict[str, str], service: Resource) -> str:
|
||||
]:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
if get_unstructured_api_key():
|
||||
return unstructured_to_text(
|
||||
file=io.BytesIO(response), file_name=file.get("name", file["id"])
|
||||
return [
|
||||
Section(
|
||||
link=link,
|
||||
text=unstructured_to_text(
|
||||
file=io.BytesIO(response),
|
||||
file_name=file.get("name", file["id"]),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
return [
|
||||
Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
return [Section(link=link, text=text)]
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
return [
|
||||
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
except Exception:
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType, service: Resource
|
||||
file: GoogleDriveFileType,
|
||||
drive_service: GoogleDriveService,
|
||||
docs_service: GoogleDocsService,
|
||||
) -> Document | None:
|
||||
try:
|
||||
# Skip files that are shortcuts
|
||||
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
|
||||
logger.info("Ignoring Drive Shortcut Filetype")
|
||||
return None
|
||||
|
||||
sections: list[Section] = []
|
||||
|
||||
# Special handling for Google Docs to preserve structure, link
|
||||
# to headers
|
||||
if file.get("mimeType") == GDriveMimeType.DOC.value:
|
||||
try:
|
||||
text_contents = _extract_text(file, service) or ""
|
||||
sections = get_document_sections(docs_service, file["id"])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
)
|
||||
|
||||
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
|
||||
if not sections:
|
||||
try:
|
||||
# For all other file types just extract the text
|
||||
sections = _extract_sections_basic(file, drive_service)
|
||||
|
||||
except HttpError as e:
|
||||
reason = e.error_details[0]["reason"] if e.error_details else e.reason
|
||||
message = e.error_details[0]["message"] if e.error_details else e.reason
|
||||
@ -96,15 +144,20 @@ def convert_drive_item_to_document(
|
||||
|
||||
raise
|
||||
|
||||
if not sections:
|
||||
return None
|
||||
|
||||
return Document(
|
||||
id=file["webViewLink"],
|
||||
sections=[Section(link=file["webViewLink"], text=text_contents)],
|
||||
sections=sections,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
|
||||
timezone.utc
|
||||
),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
metadata={}
|
||||
if any(section.text for section in sections)
|
||||
else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -28,7 +28,7 @@ def execute_paginated_retrieval(
|
||||
if next_page_token:
|
||||
request_kwargs["pageToken"] = next_page_token
|
||||
|
||||
results = add_retries(lambda: retrieval_function(**request_kwargs).execute())()
|
||||
results = (lambda: retrieval_function(**request_kwargs).execute())()
|
||||
|
||||
next_page_token = results.get("nextPageToken")
|
||||
for item in results.get(list_key, []):
|
||||
|
52
backend/danswer/connectors/google_drive/resources.py
Normal file
52
backend/danswer/connectors/google_drive/resources.py
Normal file
@ -0,0 +1,52 @@
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
|
||||
class GoogleDriveService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDocsService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class AdminService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
def _get_google_service(
|
||||
service_name: str,
|
||||
service_version: str,
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService:
|
||||
if isinstance(creds, ServiceAccountCredentials):
|
||||
creds = creds.with_subject(user_email)
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
elif isinstance(creds, OAuthCredentials):
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
|
||||
return service
|
||||
|
||||
|
||||
def get_google_docs_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDocsService:
|
||||
return _get_google_service("docs", "v1", creds, user_email)
|
||||
|
||||
|
||||
def get_drive_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService:
|
||||
return _get_google_service("drive", "v3", creds, user_email)
|
||||
|
||||
|
||||
def get_admin_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str,
|
||||
) -> AdminService:
|
||||
return _get_google_service("admin", "directory_v1", creds, user_email)
|
105
backend/danswer/connectors/google_drive/section_extraction.py
Normal file
105
backend/danswer/connectors/google_drive/section_extraction.py
Normal file
@ -0,0 +1,105 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.connectors.google_drive.resources import GoogleDocsService
|
||||
from danswer.connectors.models import Section
|
||||
|
||||
|
||||
class CurrentHeading(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
|
||||
|
||||
def _build_gdoc_section_link(doc_id: str, heading_id: str) -> str:
|
||||
"""Builds a Google Doc link that jumps to a specific heading"""
|
||||
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
|
||||
# @Chris
|
||||
return (
|
||||
f"https://docs.google.com/document/d/{doc_id}/edit?tab=t.0#heading={heading_id}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the id from a heading paragraph element"""
|
||||
return paragraph["paragraphStyle"]["headingId"]
|
||||
|
||||
|
||||
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the text content from a paragraph element"""
|
||||
text_elements = []
|
||||
for element in paragraph.get("elements", []):
|
||||
if "textRun" in element:
|
||||
text_elements.append(element["textRun"].get("content", ""))
|
||||
return "".join(text_elements)
|
||||
|
||||
|
||||
def get_document_sections(
|
||||
docs_service: GoogleDocsService,
|
||||
doc_id: str,
|
||||
) -> list[Section]:
|
||||
"""Extracts sections from a Google Doc, including their headings and content"""
|
||||
# Fetch the document structure
|
||||
doc = docs_service.documents().get(documentId=doc_id).execute()
|
||||
|
||||
# Get the content
|
||||
content = doc.get("body", {}).get("content", [])
|
||||
|
||||
sections: list[Section] = []
|
||||
current_section: list[str] = []
|
||||
current_heading: CurrentHeading | None = None
|
||||
|
||||
for element in content:
|
||||
if "paragraph" not in element:
|
||||
continue
|
||||
|
||||
paragraph = element["paragraph"]
|
||||
|
||||
# Check if this is a heading
|
||||
if (
|
||||
"paragraphStyle" in paragraph
|
||||
and "namedStyleType" in paragraph["paragraphStyle"]
|
||||
):
|
||||
style = paragraph["paragraphStyle"]["namedStyleType"]
|
||||
is_heading = style.startswith("HEADING_")
|
||||
is_title = style.startswith("TITLE")
|
||||
|
||||
if is_heading or is_title:
|
||||
# If we were building a previous section, add it to sections list
|
||||
if current_heading is not None and current_section:
|
||||
heading_text = current_heading.text
|
||||
section_text = f"{heading_text}\n" + "\n".join(current_section)
|
||||
sections.append(
|
||||
Section(
|
||||
text=section_text.strip(),
|
||||
link=_build_gdoc_section_link(doc_id, current_heading.id),
|
||||
)
|
||||
)
|
||||
current_section = []
|
||||
|
||||
# Start new heading
|
||||
heading_id = _extract_id_from_heading(paragraph)
|
||||
heading_text = _extract_text_from_paragraph(paragraph)
|
||||
current_heading = CurrentHeading(
|
||||
id=heading_id,
|
||||
text=heading_text,
|
||||
)
|
||||
continue
|
||||
|
||||
# Add content to current section
|
||||
if current_heading is not None:
|
||||
text = _extract_text_from_paragraph(paragraph)
|
||||
if text.strip():
|
||||
current_section.append(text)
|
||||
|
||||
# Don't forget to add the last section
|
||||
if current_heading is not None and current_section:
|
||||
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
|
||||
sections.append(
|
||||
Section(
|
||||
text=section_text.strip(),
|
||||
link=_build_gdoc_section_link(doc_id, current_heading.id),
|
||||
)
|
||||
)
|
||||
|
||||
return sections
|
@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_drive.resources import get_drive_service
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@ -56,7 +57,10 @@ def _fetch_permissions_for_permission_ids(
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
drive_service = google_drive_connector.get_google_resource(user_email=owner_email)
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# Otherwise, fetch all permissions and update cache
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
|
@ -2,6 +2,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_drive.resources import get_admin_service
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -19,8 +20,9 @@ def gdrive_group_sync(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
admin_service = google_drive_connector.get_google_resource("admin", "directory_v1")
|
||||
admin_service = get_admin_service(
|
||||
google_drive_connector.creds, google_drive_connector.primary_admin_email
|
||||
)
|
||||
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
for group in execute_paginated_retrieval(
|
||||
|
@ -31,6 +31,30 @@ def load_env_vars(env_file: str = ".env") -> None:
|
||||
load_env_vars()
|
||||
|
||||
|
||||
def parse_credentials(env_str: str) -> dict:
|
||||
"""
|
||||
Parse a double-escaped JSON string from environment variables into a Python dictionary.
|
||||
|
||||
Args:
|
||||
env_str (str): The double-escaped JSON string from environment variables
|
||||
|
||||
Returns:
|
||||
dict: Parsed OAuth credentials
|
||||
"""
|
||||
# first try normally
|
||||
try:
|
||||
return json.loads(env_str)
|
||||
except Exception:
|
||||
# First, try remove extra escaping backslashes
|
||||
unescaped = env_str.replace('\\"', '"')
|
||||
|
||||
# remove leading / trailing quotes
|
||||
unescaped = unescaped.strip('"')
|
||||
|
||||
# Now parse the JSON
|
||||
return json.loads(unescaped)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]:
|
||||
def _connector_factory(
|
||||
@ -50,7 +74,7 @@ def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector
|
||||
)
|
||||
|
||||
json_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
|
||||
refried_json_string = json.loads(json_string)
|
||||
refried_json_string = json.dumps(parse_credentials(json_string))
|
||||
|
||||
credentials_json = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string,
|
||||
@ -84,7 +108,7 @@ def google_drive_service_acct_connector_factory() -> (
|
||||
)
|
||||
|
||||
json_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
|
||||
refried_json_string = json.loads(json_string)
|
||||
refried_json_string = json.dumps(parse_credentials(json_string))
|
||||
|
||||
# Load Service Account Credentials
|
||||
connector.load_credentials(
|
||||
|
@ -18,6 +18,7 @@ _SHARED_DRIVE_2_FILE_IDS = list(range(40, 45))
|
||||
_FOLDER_2_FILE_IDS = list(range(45, 50))
|
||||
_FOLDER_2_1_FILE_IDS = list(range(50, 55))
|
||||
_FOLDER_2_2_FILE_IDS = list(range(55, 60))
|
||||
_SECTIONS_FILE_IDS = [61]
|
||||
|
||||
_PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_IDS
|
||||
_PUBLIC_FILE_IDS = list(range(55, 57))
|
||||
@ -64,6 +65,7 @@ DRIVE_ID_MAPPING: dict[str, list[int]] = {
|
||||
"FOLDER_2": _FOLDER_2_FILE_IDS,
|
||||
"FOLDER_2_1": _FOLDER_2_1_FILE_IDS,
|
||||
"FOLDER_2_2": _FOLDER_2_2_FILE_IDS,
|
||||
"SECTIONS": _SECTIONS_FILE_IDS,
|
||||
}
|
||||
|
||||
# Dictionary for emails
|
||||
@ -100,6 +102,7 @@ ACCESS_MAPPING: dict[str, list[int]] = {
|
||||
+ _FOLDER_2_FILE_IDS
|
||||
+ _FOLDER_2_1_FILE_IDS
|
||||
+ _FOLDER_2_2_FILE_IDS
|
||||
+ _SECTIONS_FILE_IDS
|
||||
),
|
||||
# This user has access to drive 1
|
||||
# This user has redundant access to folder 1 because of group access
|
||||
@ -127,6 +130,21 @@ ACCESS_MAPPING: dict[str, list[int]] = {
|
||||
"TEST_USER_3": _TEST_USER_3_FILE_IDS,
|
||||
}
|
||||
|
||||
SPECIAL_FILE_ID_TO_CONTENT_MAP: dict[int, str] = {
|
||||
61: (
|
||||
"Title\n\n"
|
||||
"This is a Google Doc with sections - "
|
||||
"Section 1\n\n"
|
||||
"Section 1 content - "
|
||||
"Sub-Section 1-1\n\n"
|
||||
"Sub-Section 1-1 content - "
|
||||
"Sub-Section 1-2\n\n"
|
||||
"Sub-Section 1-2 content - "
|
||||
"Section 2\n\n"
|
||||
"Section 2 content"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
file_name_template = "file_{}.txt"
|
||||
file_text_template = "This is file {}"
|
||||
@ -142,18 +160,28 @@ def print_discrepencies(expected: set[str], retrieved: set[str]) -> None:
|
||||
print(expected - retrieved)
|
||||
|
||||
|
||||
def get_file_content(file_id: int) -> str:
|
||||
if file_id in SPECIAL_FILE_ID_TO_CONTENT_MAP:
|
||||
return SPECIAL_FILE_ID_TO_CONTENT_MAP[file_id]
|
||||
|
||||
return file_text_template.format(file_id)
|
||||
|
||||
|
||||
def assert_retrieved_docs_match_expected(
|
||||
retrieved_docs: list[Document], expected_file_ids: Sequence[int]
|
||||
) -> None:
|
||||
expected_file_names = {
|
||||
file_name_template.format(file_id) for file_id in expected_file_ids
|
||||
}
|
||||
expected_file_texts = {
|
||||
file_text_template.format(file_id) for file_id in expected_file_ids
|
||||
}
|
||||
expected_file_texts = {get_file_content(file_id) for file_id in expected_file_ids}
|
||||
|
||||
retrieved_file_names = set([doc.semantic_identifier for doc in retrieved_docs])
|
||||
retrieved_texts = set([doc.sections[0].text for doc in retrieved_docs])
|
||||
retrieved_texts = set(
|
||||
[
|
||||
" - ".join([section.text for section in doc.sections])
|
||||
for doc in retrieved_docs
|
||||
]
|
||||
)
|
||||
|
||||
# Check file names
|
||||
print_discrepencies(expected_file_names, retrieved_file_names)
|
||||
|
@ -41,6 +41,7 @@ def test_include_all(
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
|
||||
+ DRIVE_ID_MAPPING["SECTIONS"]
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
@ -75,6 +76,7 @@ def test_include_shared_drives_only(
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
|
||||
+ DRIVE_ID_MAPPING["SECTIONS"]
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
|
@ -0,0 +1,71 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.models import Document
|
||||
|
||||
|
||||
SECTIONS_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"danswer.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_google_drive_sections(
|
||||
mock_get_api_key: MagicMock,
|
||||
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
oauth_connector = google_drive_oauth_connector_factory(
|
||||
include_shared_drives=False,
|
||||
include_my_drives=False,
|
||||
shared_folder_urls=SECTIONS_FOLDER_URL,
|
||||
)
|
||||
service_acct_connector = google_drive_service_acct_connector_factory(
|
||||
include_shared_drives=False,
|
||||
include_my_drives=False,
|
||||
shared_folder_urls=SECTIONS_FOLDER_URL,
|
||||
)
|
||||
for connector in [oauth_connector, service_acct_connector]:
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# Verify we got the 1 doc with sections
|
||||
assert len(retrieved_docs) == 1
|
||||
|
||||
# Verify each section has the expected structure
|
||||
doc = retrieved_docs[0]
|
||||
assert len(doc.sections) == 5
|
||||
|
||||
header_section = doc.sections[0]
|
||||
assert header_section.text == "Title\n\nThis is a Google Doc with sections"
|
||||
assert header_section.link is not None
|
||||
assert header_section.link.endswith(
|
||||
"?tab=t.0#heading=h.hfjc17k6qwzt"
|
||||
) or header_section.link.endswith("?tab=t.0#heading=h.hfjc17k6qwzt")
|
||||
|
||||
section_1 = doc.sections[1]
|
||||
assert section_1.text == "Section 1\n\nSection 1 content"
|
||||
assert section_1.link is not None
|
||||
assert section_1.link.endswith("?tab=t.0#heading=h.8slfx752a3g5")
|
||||
|
||||
section_2 = doc.sections[2]
|
||||
assert section_2.text == "Sub-Section 1-1\n\nSub-Section 1-1 content"
|
||||
assert section_2.link is not None
|
||||
assert section_2.link.endswith("?tab=t.0#heading=h.4kj3ayade1bp")
|
||||
|
||||
section_3 = doc.sections[3]
|
||||
assert section_3.text == "Sub-Section 1-2\n\nSub-Section 1-2 content"
|
||||
assert section_3.link is not None
|
||||
assert section_3.link.endswith("?tab=t.0#heading=h.pm6wrpzgk69l")
|
||||
|
||||
section_4 = doc.sections[4]
|
||||
assert section_4.text == "Section 2\n\nSection 2 content"
|
||||
assert section_4.link is not None
|
||||
assert section_4.link.endswith("?tab=t.0#heading=h.2m0s9youe2k9")
|
@ -44,6 +44,7 @@ def test_include_all(
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
|
||||
+ DRIVE_ID_MAPPING["SECTIONS"]
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
@ -78,6 +79,7 @@ def test_include_shared_drives_only(
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
|
||||
+ DRIVE_ID_MAPPING["SECTIONS"]
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
|
@ -6,6 +6,7 @@ from unittest.mock import patch
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_drive.resources import get_admin_service
|
||||
from ee.danswer.external_permissions.google_drive.doc_sync import (
|
||||
_get_permissions_from_slim_doc,
|
||||
)
|
||||
@ -72,7 +73,10 @@ def assert_correct_access_for_user(
|
||||
# This function is supposed to map to the group_sync.py file for the google drive connector
|
||||
# TODO: Call it directly
|
||||
def get_group_map(google_drive_connector: GoogleDriveConnector) -> dict[str, list[str]]:
|
||||
admin_service = google_drive_connector.get_google_resource("admin", "directory_v1")
|
||||
admin_service = get_admin_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=google_drive_connector.primary_admin_email,
|
||||
)
|
||||
|
||||
group_map: dict[str, list[str]] = {}
|
||||
for group in execute_paginated_retrieval(
|
||||
@ -138,6 +142,7 @@ def test_all_permissions(
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_1"]
|
||||
+ DRIVE_ID_MAPPING["FOLDER_2_2"]
|
||||
+ DRIVE_ID_MAPPING["SECTIONS"]
|
||||
)
|
||||
|
||||
# Should get everything
|
||||
|
Reference in New Issue
Block a user