Index all Google Drive file types (#373)

This commit is contained in:
Yuhong Sun
2023-08-31 19:20:32 -07:00
committed by GitHub
parent 6bae93ad3c
commit e1fbffd141

View File

@@ -1,8 +1,9 @@
import datetime import datetime
import io import io
import tempfile import tempfile
from collections.abc import Generator from collections.abc import Iterator
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum
from itertools import chain from itertools import chain
from typing import Any from typing import Any
from typing import cast from typing import cast
@@ -44,14 +45,17 @@ logger = setup_logger()
# allow 10 minutes for modifiedTime to get propagated # allow 10 minutes for modifiedTime to get propagated
DRIVE_START_TIME_OFFSET = 60 * 10 DRIVE_START_TIME_OFFSET = 60 * 10
SUPPORTED_DRIVE_DOC_TYPES = [
"application/vnd.google-apps.document",
"application/vnd.google-apps.spreadsheet",
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
]
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder" DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
GoogleDriveFileType = dict[str, Any] GoogleDriveFileType = dict[str, Any]
@@ -63,7 +67,7 @@ def _run_drive_file_query(
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]: ) -> Iterator[GoogleDriveFileType]:
next_page_token = "" next_page_token = ""
while next_page_token is not None: while next_page_token is not None:
logger.debug(f"Running Google Drive fetch with query: {query}") logger.debug(f"Running Google Drive fetch with query: {query}")
@@ -148,7 +152,7 @@ def _get_folders(
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]: ) -> Iterator[GoogleDriveFileType]:
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' " query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
if follow_shortcuts: if follow_shortcuts:
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") " query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
@@ -181,9 +185,8 @@ def _get_files(
folder_id: str | None = None, # if specified, only fetches files within this folder folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
supported_drive_doc_types: list[str] = SUPPORTED_DRIVE_DOC_TYPES,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]: ) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' " query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
if time_range_start is not None: if time_range_start is not None:
time_start = ( time_start = (
@@ -205,9 +208,8 @@ def _get_files(
follow_shortcuts=follow_shortcuts, follow_shortcuts=follow_shortcuts,
batch_size=batch_size, batch_size=batch_size,
) )
for file in files:
if file["mimeType"] in supported_drive_doc_types: return files
yield file
def get_all_files_batched( def get_all_files_batched(
@@ -223,11 +225,11 @@ def get_all_files_batched(
# Only applies if folder_id is specified. # Only applies if folder_id is specified.
traverse_subfolders: bool = True, traverse_subfolders: bool = True,
folder_ids_traversed: list[str] | None = None, folder_ids_traversed: list[str] | None = None,
) -> Generator[list[GoogleDriveFileType], None, None]: ) -> Iterator[list[GoogleDriveFileType]]:
"""Gets all files matching the criteria specified by the args from Google Drive """Gets all files matching the criteria specified by the args from Google Drive
in batches of size `batch_size`. in batches of size `batch_size`.
""" """
valid_files = _get_files( found_files = _get_files(
service=service, service=service,
continue_on_failure=continue_on_failure, continue_on_failure=continue_on_failure,
time_range_start=time_range_start, time_range_start=time_range_start,
@@ -238,7 +240,7 @@ def get_all_files_batched(
batch_size=batch_size, batch_size=batch_size,
) )
yield from batch_generator( yield from batch_generator(
items=valid_files, items=found_files,
batch_size=batch_size, batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.info( pre_batch_yield=lambda batch_files: logger.info(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}" f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
@@ -279,32 +281,32 @@ def get_all_files_batched(
def extract_text(file: dict[str, str], service: discovery.Resource) -> str: def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"] mime_type = file["mimeType"]
if mime_type == "application/vnd.google-apps.document": if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type == GDriveMimeType.DOC.value:
return ( return (
service.files() service.files()
.export(fileId=file["id"], mimeType="text/plain") .export(fileId=file["id"], mimeType="text/plain")
.execute() .execute()
.decode("utf-8") .decode("utf-8")
) )
elif mime_type == "application/vnd.google-apps.spreadsheet": elif mime_type == GDriveMimeType.SPREADSHEET.value:
return ( return (
service.files() service.files()
.export(fileId=file["id"], mimeType="text/csv") .export(fileId=file["id"], mimeType="text/csv")
.execute() .execute()
.decode("utf-8") .decode("utf-8")
) )
elif ( elif mime_type == GDriveMimeType.WORD_DOC.value:
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
response = service.files().get_media(fileId=file["id"]).execute() response = service.files().get_media(fileId=file["id"]).execute()
word_stream = io.BytesIO(response) word_stream = io.BytesIO(response)
with tempfile.NamedTemporaryFile(delete=False) as temp: with tempfile.NamedTemporaryFile(delete=False) as temp:
temp.write(word_stream.getvalue()) temp.write(word_stream.getvalue())
temp_path = temp.name temp_path = temp.name
return docx2txt.process(temp_path) return docx2txt.process(temp_path)
# Default download to PDF since most types can be exported as a PDF elif mime_type == GDriveMimeType.PDF.value:
else:
response = service.files().get_media(fileId=file["id"]).execute() response = service.files().get_media(fileId=file["id"]).execute()
pdf_stream = io.BytesIO(response) pdf_stream = io.BytesIO(response)
pdf_reader = PdfReader(pdf_stream) pdf_reader = PdfReader(pdf_stream)
@@ -317,6 +319,8 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
return "\n".join(page.extract_text() for page in pdf_reader.pages) return "\n".join(page.extract_text() for page in pdf_reader.pages)
return UNSUPPORTED_FILE_TYPE_CONTENT
class GoogleDriveConnector(LoadConnector, PollConnector): class GoogleDriveConnector(LoadConnector, PollConnector):
def __init__( def __init__(
@@ -450,7 +454,10 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
for file in files_batch: for file in files_batch:
try: try:
text_contents = extract_text(file, service) text_contents = extract_text(file, service)
full_context = file["name"] + " - " + text_contents if text_contents:
full_context = file["name"] + " - " + text_contents
else:
full_context = file["name"]
doc_batch.append( doc_batch.append(
Document( Document(