drive file deduping (#4648)

* drive file deduping

* switched to version that does not require thread safety

* thanks greptile

* CW comments
This commit is contained in:
Evan Lohn
2025-05-02 10:58:16 -07:00
committed by GitHub
parent 75fa10cead
commit 6d9693dc51
2 changed files with 75 additions and 78 deletions

View File

@@ -217,7 +217,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._retrieved_ids: set[str] = set()
# ids of folders and shared drives that have been traversed
self._retrieved_folder_and_drive_ids: set[str] = set()
self.allow_images = False
self.size_threshold = GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
@@ -270,7 +272,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
return new_creds_dict
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
self._retrieved_folder_and_drive_ids.add(folder_id)
def _get_all_user_emails(self) -> list[str]:
if self._specific_user_emails:
@@ -329,21 +331,26 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
self, drive_ids: list[str], checkpoint: GoogleDriveCheckpoint
) -> Callable[[str], Iterator[str]]:
cv = threading.Condition()
drive_id_status = {
drive_id: (
DriveIdStatus.FINISHED
if drive_id in self._retrieved_ids
else DriveIdStatus.AVAILABLE
)
for drive_id in drive_ids
}
def _get_available_drive_id(
processed_ids: set[str], thread_id: str
) -> tuple[str | None, bool]:
in_progress_drive_ids = {
completion.current_folder_or_drive_id: user_email
for user_email, completion in checkpoint.completion_map.items()
if completion.stage == DriveRetrievalStage.SHARED_DRIVE_FILES
and completion.current_folder_or_drive_id is not None
}
drive_id_status: dict[str, DriveIdStatus] = {}
for drive_id in drive_ids:
if drive_id in self._retrieved_folder_and_drive_ids:
drive_id_status[drive_id] = DriveIdStatus.FINISHED
elif drive_id in in_progress_drive_ids:
drive_id_status[drive_id] = DriveIdStatus.IN_PROGRESS
else:
drive_id_status[drive_id] = DriveIdStatus.AVAILABLE
def _get_available_drive_id(processed_ids: set[str]) -> tuple[str | None, bool]:
found_future_work = False
for drive_id, status in drive_id_status.items():
if drive_id in self._retrieved_ids:
if drive_id in self._retrieved_folder_and_drive_ids:
drive_id_status[drive_id] = DriveIdStatus.FINISHED
continue
if drive_id in processed_ids:
@@ -357,19 +364,38 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
def drive_id_iterator(thread_id: str) -> Iterator[str]:
completion = checkpoint.completion_map[thread_id]
def record_drive_processing(drive_id: str) -> None:
with cv:
completion.processed_drive_ids.add(drive_id)
drive_id_status[drive_id] = (
DriveIdStatus.FINISHED
if drive_id in self._retrieved_folder_and_drive_ids
else DriveIdStatus.AVAILABLE
)
# wake up other threads waiting for work
cv.notify_all()
# when entering the iterator with a previous id in the checkpoint, the user
# just finished that drive from a previous run.
if (
completion.stage == DriveRetrievalStage.MY_DRIVE_FILES
and completion.current_folder_or_drive_id is not None
):
record_drive_processing(completion.current_folder_or_drive_id)
# continue iterating until this thread has no more work to do
while True:
# this locks operations on _retrieved_ids and drive_id_status
with cv:
available_drive_id, found_future_work = _get_available_drive_id(
completion.processed_drive_ids, thread_id
completion.processed_drive_ids
)
# wait while there is no work currently available but still drives that may need processing
while available_drive_id is None and found_future_work:
cv.wait()
available_drive_id, found_future_work = _get_available_drive_id(
completion.processed_drive_ids, thread_id
completion.processed_drive_ids
)
# if there is no work available and no future work, we are done
@@ -379,15 +405,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
drive_id_status[available_drive_id] = DriveIdStatus.IN_PROGRESS
yield available_drive_id
with cv:
completion.processed_drive_ids.add(available_drive_id)
drive_id_status[available_drive_id] = (
DriveIdStatus.FINISHED
if available_drive_id in self._retrieved_ids
else DriveIdStatus.AVAILABLE
)
# wake up other threads waiting for work
cv.notify_all()
record_drive_processing(available_drive_id)
return drive_id_iterator
@@ -472,7 +490,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
# resume from a checkpoint
if resuming:
drive_id = curr_stage.completed_until_parent_id
drive_id = curr_stage.current_folder_or_drive_id
if drive_id is None:
raise ValueError("drive id not set in checkpoint")
resume_start = curr_stage.completed_until
@@ -485,7 +503,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
f"Getting files in shared drive '{drive_id}' as '{user_email}'"
)
curr_stage.completed_until = 0
curr_stage.completed_until_parent_id = drive_id
curr_stage.current_folder_or_drive_id = drive_id
yield from _yield_from_drive(drive_id, start)
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
resuming = False # we are starting the next stage for the first time
@@ -506,32 +524,17 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
parent_id=folder_id,
is_slim=is_slim,
user_email=user_email,
traversed_parent_ids=self._retrieved_ids,
traversed_parent_ids=self._retrieved_folder_and_drive_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=folder_start,
end=end,
):
with checkpoint.processed_folder_file_ids.lock:
should_yield = False
completed_ids = checkpoint.processed_folder_file_ids._dict.get(
folder_id, set()
)
if (
"id" in retrieved_file.drive_file
and retrieved_file.drive_file["id"] not in completed_ids
):
completed_ids.add(retrieved_file.drive_file["id"])
should_yield = True
checkpoint.processed_folder_file_ids._dict[folder_id] = (
completed_ids
)
if should_yield:
yield retrieved_file
yield retrieved_file
# resume from a checkpoint
last_processed_folder = None
if resuming:
folder_id = curr_stage.completed_until_parent_id
folder_id = curr_stage.current_folder_or_drive_id
if folder_id is None:
raise ValueError("folder id not set in checkpoint")
resume_start = curr_stage.completed_until
@@ -547,11 +550,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
skipping_seen_folders = folder_id != last_processed_folder
continue
if folder_id in self._retrieved_ids:
if folder_id in self._retrieved_folder_and_drive_ids:
continue
curr_stage.completed_until = 0
curr_stage.completed_until_parent_id = folder_id
curr_stage.current_folder_or_drive_id = folder_id
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
yield from _yield_from_folder_crawl(folder_id, start)
@@ -641,7 +644,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
remaining_folders = (
set(sorted_drive_ids) | set(sorted_folder_ids)
) - self._retrieved_ids
) - self._retrieved_folder_and_drive_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
@@ -754,7 +757,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
):
drive_id = checkpoint.completion_map[
self.primary_admin_email
].completed_until_parent_id
].current_folder_or_drive_id
if drive_id is None:
raise ValueError("drive id not set in checkpoint")
resume_start = checkpoint.completion_map[
@@ -763,7 +766,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
yield from _yield_from_drive(drive_id, resume_start)
for drive_id in drive_ids_to_retrieve:
if drive_id in self._retrieved_ids:
if drive_id in self._retrieved_folder_and_drive_ids:
logger.info(
f"Skipping drive '{drive_id}' as it has already been retrieved"
)
@@ -790,7 +793,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
"""
# Even if no folders were requested, we still check if any drives were requested
# that could be folders.
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
remaining_folders = (
folder_ids_to_retrieve - self._retrieved_folder_and_drive_ids
)
def _yield_from_folder_crawl(
folder_id: str, folder_start: SecondsSinceUnixEpoch | None
@@ -800,7 +805,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
parent_id=folder_id,
is_slim=is_slim,
user_email=self.primary_admin_email,
traversed_parent_ids=self._retrieved_ids,
traversed_parent_ids=self._retrieved_folder_and_drive_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=folder_start,
end=end,
@@ -813,7 +818,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
):
folder_id = checkpoint.completion_map[
self.primary_admin_email
].completed_until_parent_id
].current_folder_or_drive_id
if folder_id is None:
raise ValueError("folder id not set in checkpoint")
resume_start = checkpoint.completion_map[
@@ -831,7 +836,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids
) - self._retrieved_folder_and_drive_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
@@ -861,9 +866,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
completed_until=datetime.fromisoformat(
file.drive_file[GoogleFields.MODIFIED_TIME.value]
).timestamp(),
completed_until_parent_id=file.parent_id,
current_folder_or_drive_id=file.parent_id,
)
yield file
if file.drive_file["id"] not in checkpoint.all_retrieved_file_ids:
checkpoint.all_retrieved_file_ids.add(file.drive_file["id"])
yield file
def _manage_oauth_retrieval(
self,
@@ -877,7 +884,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
checkpoint.completion_map[self.primary_admin_email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
completed_until_parent_id=None,
current_folder_or_drive_id=None,
)
drive_service = get_drive_service(self.creds, self.primary_admin_email)
@@ -1037,7 +1044,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
files_batch = []
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
checkpoint.retrieved_folder_and_drive_ids = (
self._retrieved_folder_and_drive_ids
)
return # create a new checkpoint
# Process any remaining files
@@ -1062,14 +1071,14 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
)
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
try:
yield from self._extract_docs_from_google_drive(checkpoint, start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_folder_and_drive_ids
if checkpoint.completion_stage == DriveRetrievalStage.DONE:
checkpoint.has_more = False
return checkpoint
@@ -1172,6 +1181,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
retrieved_folder_and_drive_ids=set(),
completion_stage=DriveRetrievalStage.START,
completion_map=ThreadSafeDict(),
all_retrieved_file_ids=set(),
has_more=True,
)

View File

@@ -76,7 +76,7 @@ class StageCompletion(BaseModel):
stage: DriveRetrievalStage
completed_until: SecondsSinceUnixEpoch
completed_until_parent_id: str | None = None
current_folder_or_drive_id: str | None = None
# only used for shared drives
processed_drive_ids: set[str] = set()
@@ -85,11 +85,11 @@ class StageCompletion(BaseModel):
self,
stage: DriveRetrievalStage,
completed_until: SecondsSinceUnixEpoch,
completed_until_parent_id: str | None = None,
current_folder_or_drive_id: str | None = None,
) -> None:
self.stage = stage
self.completed_until = completed_until
self.completed_until_parent_id = completed_until_parent_id
self.current_folder_or_drive_id = current_folder_or_drive_id
class RetrievedDriveFile(BaseModel):
@@ -135,8 +135,8 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
# timestamp part is not used for folder crawling.
completion_map: ThreadSafeDict[str, StageCompletion]
# only used for folder crawling. maps from parent folder id to seen file ids.
processed_folder_file_ids: ThreadSafeDict[str, set[str]] = ThreadSafeDict()
# all file ids that have been retrieved
all_retrieved_file_ids: set[str] = set()
# cached version of the drive and folder ids to retrieve
drive_ids_to_retrieve: list[str] | None = None
@@ -157,16 +157,3 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
return ThreadSafeDict(
{k: StageCompletion.model_validate(val) for k, val in v.items()}
)
@field_serializer("processed_folder_file_ids")
def serialize_processed_folder_file_ids(
self, processed_folder_file_ids: ThreadSafeDict[str, set[str]], _info: Any
) -> dict[str, set[str]]:
return processed_folder_file_ids._dict
@field_validator("processed_folder_file_ids", mode="before")
def validate_processed_folder_file_ids(
cls, v: Any
) -> ThreadSafeDict[str, set[str]]:
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
return ThreadSafeDict({k: set(val) for k, val in v.items()})