Index all Google Drive file types (#373)

This commit is contained in:
Yuhong Sun 2023-08-31 19:20:32 -07:00 committed by GitHub
parent 6bae93ad3c
commit e1fbffd141
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,9 @@
import datetime
import io
import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Sequence
from enum import Enum
from itertools import chain
from typing import Any
from typing import cast
@ -44,14 +45,17 @@ logger = setup_logger()
# allow 10 minutes for modifiedTime to get propagated
DRIVE_START_TIME_OFFSET = 60 * 10
SUPPORTED_DRIVE_DOC_TYPES = [
"application/vnd.google-apps.document",
"application/vnd.google-apps.spreadsheet",
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
]
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
GoogleDriveFileType = dict[str, Any]
@ -63,7 +67,7 @@ def _run_drive_file_query(
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]:
) -> Iterator[GoogleDriveFileType]:
next_page_token = ""
while next_page_token is not None:
logger.debug(f"Running Google Drive fetch with query: {query}")
@ -148,7 +152,7 @@ def _get_folders(
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]:
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
if follow_shortcuts:
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
@ -181,9 +185,8 @@ def _get_files(
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
supported_drive_doc_types: list[str] = SUPPORTED_DRIVE_DOC_TYPES,
batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]:
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
if time_range_start is not None:
time_start = (
@ -205,9 +208,8 @@ def _get_files(
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
for file in files:
if file["mimeType"] in supported_drive_doc_types:
yield file
return files
def get_all_files_batched(
@ -223,11 +225,11 @@ def get_all_files_batched(
# Only applies if folder_id is specified.
traverse_subfolders: bool = True,
folder_ids_traversed: list[str] | None = None,
) -> Generator[list[GoogleDriveFileType], None, None]:
) -> Iterator[list[GoogleDriveFileType]]:
"""Gets all files matching the criteria specified by the args from Google Drive
in batches of size `batch_size`.
"""
valid_files = _get_files(
found_files = _get_files(
service=service,
continue_on_failure=continue_on_failure,
time_range_start=time_range_start,
@ -238,7 +240,7 @@ def get_all_files_batched(
batch_size=batch_size,
)
yield from batch_generator(
items=valid_files,
items=found_files,
batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.info(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
@ -279,32 +281,32 @@ def get_all_files_batched(
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"]
if mime_type == "application/vnd.google-apps.document":
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type == GDriveMimeType.DOC.value:
return (
service.files()
.export(fileId=file["id"], mimeType="text/plain")
.execute()
.decode("utf-8")
)
elif mime_type == "application/vnd.google-apps.spreadsheet":
elif mime_type == GDriveMimeType.SPREADSHEET.value:
return (
service.files()
.export(fileId=file["id"], mimeType="text/csv")
.execute()
.decode("utf-8")
)
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
elif mime_type == GDriveMimeType.WORD_DOC.value:
response = service.files().get_media(fileId=file["id"]).execute()
word_stream = io.BytesIO(response)
with tempfile.NamedTemporaryFile(delete=False) as temp:
temp.write(word_stream.getvalue())
temp_path = temp.name
return docx2txt.process(temp_path)
# Default download to PDF since most types can be exported as a PDF
else:
elif mime_type == GDriveMimeType.PDF.value:
response = service.files().get_media(fileId=file["id"]).execute()
pdf_stream = io.BytesIO(response)
pdf_reader = PdfReader(pdf_stream)
@ -317,6 +319,8 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
return "\n".join(page.extract_text() for page in pdf_reader.pages)
return UNSUPPORTED_FILE_TYPE_CONTENT
class GoogleDriveConnector(LoadConnector, PollConnector):
def __init__(
@ -450,7 +454,10 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
for file in files_batch:
try:
text_contents = extract_text(file, service)
full_context = file["name"] + " - " + text_contents
if text_contents:
full_context = file["name"] + " - " + text_contents
else:
full_context = file["name"]
doc_batch.append(
Document(