no more duplicate files during folder indexing (#4579)

* no more duplicate files during folder indexing

* cleanup checkpoint after a shared folder has been finished

* cleanup

* lint
This commit is contained in:
Evan Lohn
2025-04-27 18:01:20 -07:00
committed by GitHub
parent ea0664e203
commit 5db676967f
4 changed files with 105 additions and 34 deletions

View File

@@ -89,7 +89,7 @@ def _clean_requested_drive_ids(
requested_drive_ids: set[str],
requested_folder_ids: set[str],
all_drive_ids_available: set[str],
) -> tuple[set[str], set[str]]:
) -> tuple[list[str], list[str]]:
invalid_requested_drive_ids = requested_drive_ids - all_drive_ids_available
filtered_folder_ids = requested_folder_ids - all_drive_ids_available
if invalid_requested_drive_ids:
@@ -100,7 +100,7 @@ def _clean_requested_drive_ids(
filtered_folder_ids.update(invalid_requested_drive_ids)
valid_requested_drive_ids = requested_drive_ids - invalid_requested_drive_ids
return valid_requested_drive_ids, filtered_folder_ids
return sorted(valid_requested_drive_ids), sorted(filtered_folder_ids)
class CredentialedRetrievalMethod(Protocol):
@@ -326,7 +326,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
return all_drive_ids
def make_drive_id_iterator(
self, drive_ids: set[str], checkpoint: GoogleDriveCheckpoint
self, drive_ids: list[str], checkpoint: GoogleDriveCheckpoint
) -> Callable[[str], Iterator[str]]:
cv = threading.Condition()
drive_id_status = {
@@ -490,12 +490,18 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
resuming = False # we are starting the next stage for the first time
# In the folder files section of service account retrieval we take extra care
# to not retrieve duplicate docs. In particular, we only add a folder to
# retrieved_folder_and_drive_ids when all users are finished retrieving files
# from that folder, and maintain a set of all file ids that have been retrieved
# for each folder. This might get rather large; in practice we assume that the
# specific folders users choose to index don't have too many files.
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
def _yield_from_folder_crawl(
folder_id: str, folder_start: SecondsSinceUnixEpoch | None
) -> Iterator[RetrievedDriveFile]:
yield from crawl_folders_for_files(
for retrieved_file in crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
is_slim=is_slim,
@@ -504,7 +510,23 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
update_traversed_ids_func=self._update_traversed_parent_ids,
start=folder_start,
end=end,
)
):
with checkpoint.processed_folder_file_ids.lock:
should_yield = False
completed_ids = checkpoint.processed_folder_file_ids._dict.get(
folder_id, set()
)
if (
"id" in retrieved_file.drive_file
and retrieved_file.drive_file["id"] not in completed_ids
):
completed_ids.add(retrieved_file.drive_file["id"])
should_yield = True
checkpoint.processed_folder_file_ids._dict[folder_id] = (
completed_ids
)
if should_yield:
yield retrieved_file
# resume from a checkpoint
last_processed_folder = None
@@ -517,6 +539,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
last_processed_folder = folder_id
skipping_seen_folders = last_processed_folder is not None
# NOTE:this assumes a small number of folders to crawl. If someone
# really wants to specify a large number of folders, we should use
# binary search to find the first unseen folder.
for folder_id in sorted_filtered_folder_ids:
if skipping_seen_folders:
skipping_seen_folders = folder_id != last_processed_folder
@@ -552,7 +577,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
raise ValueError("user emails not set")
all_org_emails = checkpoint.user_emails
drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids(
sorted_drive_ids, sorted_folder_ids = self._determine_retrieval_ids(
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
)
@@ -570,16 +595,12 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
# fetching stuff
logger.info(f"Found {len(all_org_emails)} users to impersonate")
logger.debug(f"Users: {all_org_emails}")
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
logger.debug(f"Drives: {drive_ids_to_retrieve}")
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
logger.debug(f"Folders: {folder_ids_to_retrieve}")
logger.info(f"Found {len(sorted_drive_ids)} drives to retrieve")
logger.debug(f"Drives: {sorted_drive_ids}")
logger.info(f"Found {len(sorted_folder_ids)} folders to retrieve")
logger.debug(f"Folders: {sorted_folder_ids}")
drive_id_iterator = self.make_drive_id_iterator(
drive_ids_to_retrieve, checkpoint
)
sorted_filtered_folder_ids = sorted(folder_ids_to_retrieve)
drive_id_iterator = self.make_drive_id_iterator(sorted_drive_ids, checkpoint)
# only process emails that we haven't already completed retrieval for
non_completed_org_emails = [
@@ -606,7 +627,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
is_slim,
checkpoint,
drive_id_iterator,
sorted_filtered_folder_ids,
sorted_folder_ids,
start,
end,
)
@@ -619,7 +640,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
return
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
set(sorted_drive_ids) | set(sorted_folder_ids)
) - self._retrieved_ids
if remaining_folders:
logger.warning(
@@ -637,26 +658,26 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
checkpoint: GoogleDriveCheckpoint,
is_slim: bool,
next_stage: DriveRetrievalStage,
) -> tuple[set[str], set[str]]:
) -> tuple[list[str], list[str]]:
all_drive_ids = self.get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
sorted_drive_ids: list[str] = []
sorted_folder_ids: list[str] = []
if checkpoint.completion_stage == DriveRetrievalStage.DRIVE_IDS:
if self._requested_shared_drive_ids or self._requested_folder_ids:
(
drive_ids_to_retrieve,
folder_ids_to_retrieve,
sorted_drive_ids,
sorted_folder_ids,
) = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
sorted_drive_ids = sorted(all_drive_ids)
if not is_slim:
checkpoint.drive_ids_to_retrieve = list(drive_ids_to_retrieve)
checkpoint.folder_ids_to_retrieve = list(folder_ids_to_retrieve)
checkpoint.drive_ids_to_retrieve = sorted_drive_ids
checkpoint.folder_ids_to_retrieve = sorted_folder_ids
checkpoint.completion_stage = next_stage
else:
if checkpoint.drive_ids_to_retrieve is None:
@@ -664,10 +685,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
if checkpoint.folder_ids_to_retrieve is None:
raise ValueError("folder ids to retrieve not set in checkpoint")
# When loading from a checkpoint, load the previously cached drive and folder ids
drive_ids_to_retrieve = set(checkpoint.drive_ids_to_retrieve)
folder_ids_to_retrieve = set(checkpoint.folder_ids_to_retrieve)
sorted_drive_ids = checkpoint.drive_ids_to_retrieve
sorted_folder_ids = checkpoint.folder_ids_to_retrieve
return drive_ids_to_retrieve, folder_ids_to_retrieve
return sorted_drive_ids, sorted_folder_ids
def _oauth_retrieval_all_files(
self,
@@ -704,7 +725,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
self,
is_slim: bool,
drive_service: GoogleDriveService,
drive_ids_to_retrieve: set[str],
drive_ids_to_retrieve: list[str],
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
@@ -886,7 +907,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
checkpoint.completion_stage = DriveRetrievalStage.DONE
return
drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids(
sorted_drive_ids, sorted_folder_ids = self._determine_retrieval_ids(
checkpoint, is_slim, DriveRetrievalStage.SHARED_DRIVE_FILES
)
@@ -894,7 +915,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
yield from self._oauth_retrieval_drives(
is_slim=is_slim,
drive_service=drive_service,
drive_ids_to_retrieve=drive_ids_to_retrieve,
drive_ids_to_retrieve=sorted_drive_ids,
checkpoint=checkpoint,
start=start,
end=end,
@@ -906,8 +927,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
yield from self._oauth_retrieval_folders(
is_slim=is_slim,
drive_service=drive_service,
drive_ids_to_retrieve=drive_ids_to_retrieve,
folder_ids_to_retrieve=folder_ids_to_retrieve,
drive_ids_to_retrieve=set(sorted_drive_ids),
folder_ids_to_retrieve=set(sorted_folder_ids),
checkpoint=checkpoint,
start=start,
end=end,

View File

@@ -221,6 +221,9 @@ def get_files_in_shared_drive(
# If we found any files, mark this drive as traversed. When a user has access to a drive,
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
# empty drives.
# NOTE: ^^ the above is not actually true due to folder restrictions:
# https://support.google.com/a/users/answer/12380484?hl=en
# So we may have to change this logic for people who use folder restrictions.
update_traversed_ids_func(drive_id)
yield file

View File

@@ -135,6 +135,9 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
# timestamp part is not used for folder crawling.
completion_map: ThreadSafeDict[str, StageCompletion]
# only used for folder crawling. maps from parent folder id to seen file ids.
processed_folder_file_ids: ThreadSafeDict[str, set[str]] = ThreadSafeDict()
# cached version of the drive and folder ids to retrieve
drive_ids_to_retrieve: list[str] | None = None
folder_ids_to_retrieve: list[str] | None = None
@@ -152,5 +155,18 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
return ThreadSafeDict(
{k: StageCompletion.model_validate(v) for k, v in v.items()}
{k: StageCompletion.model_validate(val) for k, val in v.items()}
)
@field_serializer("processed_folder_file_ids")
def serialize_processed_folder_file_ids(
self, processed_folder_file_ids: ThreadSafeDict[str, set[str]], _info: Any
) -> dict[str, set[str]]:
return processed_folder_file_ids._dict
@field_validator("processed_folder_file_ids", mode="before")
def validate_processed_folder_file_ids(
cls, v: Any
) -> ThreadSafeDict[str, set[str]]:
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
return ThreadSafeDict({k: set(val) for k, val in v.items()})

View File

@@ -24,6 +24,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
@@ -477,3 +478,33 @@ def test_specific_user_emails_restricted_folder(
)
test_docs = load_all_docs(test_connector)
assert len(test_docs) == 0
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_shared_drive_folder(
mock_get_api_key: MagicMock,
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_drive_folder")
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=TEST_USER_1_EMAIL,
include_files_shared_with_me=False,
include_shared_drives=False,
include_my_drives=True,
shared_folder_urls=FOLDER_1_URL,
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS
# test for deduping
assert len(expected_file_ids) == len(retrieved_docs)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)