mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-08 05:04:11 +02:00
drive file deduping (#4648)
* drive file deduping * switched to version that does not require thread safety * thanks greptile * CW comments
This commit is contained in:
@@ -217,7 +217,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
self._retrieved_ids: set[str] = set()
|
||||
# ids of folders and shared drives that have been traversed
|
||||
self._retrieved_folder_and_drive_ids: set[str] = set()
|
||||
|
||||
self.allow_images = False
|
||||
|
||||
self.size_threshold = GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
|
||||
@@ -270,7 +272,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
return new_creds_dict
|
||||
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
self._retrieved_ids.add(folder_id)
|
||||
self._retrieved_folder_and_drive_ids.add(folder_id)
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
if self._specific_user_emails:
|
||||
@@ -329,21 +331,26 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
self, drive_ids: list[str], checkpoint: GoogleDriveCheckpoint
|
||||
) -> Callable[[str], Iterator[str]]:
|
||||
cv = threading.Condition()
|
||||
drive_id_status = {
|
||||
drive_id: (
|
||||
DriveIdStatus.FINISHED
|
||||
if drive_id in self._retrieved_ids
|
||||
else DriveIdStatus.AVAILABLE
|
||||
)
|
||||
for drive_id in drive_ids
|
||||
}
|
||||
|
||||
def _get_available_drive_id(
|
||||
processed_ids: set[str], thread_id: str
|
||||
) -> tuple[str | None, bool]:
|
||||
in_progress_drive_ids = {
|
||||
completion.current_folder_or_drive_id: user_email
|
||||
for user_email, completion in checkpoint.completion_map.items()
|
||||
if completion.stage == DriveRetrievalStage.SHARED_DRIVE_FILES
|
||||
and completion.current_folder_or_drive_id is not None
|
||||
}
|
||||
drive_id_status: dict[str, DriveIdStatus] = {}
|
||||
for drive_id in drive_ids:
|
||||
if drive_id in self._retrieved_folder_and_drive_ids:
|
||||
drive_id_status[drive_id] = DriveIdStatus.FINISHED
|
||||
elif drive_id in in_progress_drive_ids:
|
||||
drive_id_status[drive_id] = DriveIdStatus.IN_PROGRESS
|
||||
else:
|
||||
drive_id_status[drive_id] = DriveIdStatus.AVAILABLE
|
||||
|
||||
def _get_available_drive_id(processed_ids: set[str]) -> tuple[str | None, bool]:
|
||||
found_future_work = False
|
||||
for drive_id, status in drive_id_status.items():
|
||||
if drive_id in self._retrieved_ids:
|
||||
if drive_id in self._retrieved_folder_and_drive_ids:
|
||||
drive_id_status[drive_id] = DriveIdStatus.FINISHED
|
||||
continue
|
||||
if drive_id in processed_ids:
|
||||
@@ -357,19 +364,38 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
|
||||
def drive_id_iterator(thread_id: str) -> Iterator[str]:
|
||||
completion = checkpoint.completion_map[thread_id]
|
||||
|
||||
def record_drive_processing(drive_id: str) -> None:
|
||||
with cv:
|
||||
completion.processed_drive_ids.add(drive_id)
|
||||
drive_id_status[drive_id] = (
|
||||
DriveIdStatus.FINISHED
|
||||
if drive_id in self._retrieved_folder_and_drive_ids
|
||||
else DriveIdStatus.AVAILABLE
|
||||
)
|
||||
# wake up other threads waiting for work
|
||||
cv.notify_all()
|
||||
|
||||
# when entering the iterator with a previous id in the checkpoint, the user
|
||||
# just finished that drive from a previous run.
|
||||
if (
|
||||
completion.stage == DriveRetrievalStage.MY_DRIVE_FILES
|
||||
and completion.current_folder_or_drive_id is not None
|
||||
):
|
||||
record_drive_processing(completion.current_folder_or_drive_id)
|
||||
# continue iterating until this thread has no more work to do
|
||||
while True:
|
||||
# this locks operations on _retrieved_ids and drive_id_status
|
||||
with cv:
|
||||
available_drive_id, found_future_work = _get_available_drive_id(
|
||||
completion.processed_drive_ids, thread_id
|
||||
completion.processed_drive_ids
|
||||
)
|
||||
|
||||
# wait while there is no work currently available but still drives that may need processing
|
||||
while available_drive_id is None and found_future_work:
|
||||
cv.wait()
|
||||
available_drive_id, found_future_work = _get_available_drive_id(
|
||||
completion.processed_drive_ids, thread_id
|
||||
completion.processed_drive_ids
|
||||
)
|
||||
|
||||
# if there is no work available and no future work, we are done
|
||||
@@ -379,15 +405,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
drive_id_status[available_drive_id] = DriveIdStatus.IN_PROGRESS
|
||||
|
||||
yield available_drive_id
|
||||
with cv:
|
||||
completion.processed_drive_ids.add(available_drive_id)
|
||||
drive_id_status[available_drive_id] = (
|
||||
DriveIdStatus.FINISHED
|
||||
if available_drive_id in self._retrieved_ids
|
||||
else DriveIdStatus.AVAILABLE
|
||||
)
|
||||
# wake up other threads waiting for work
|
||||
cv.notify_all()
|
||||
record_drive_processing(available_drive_id)
|
||||
|
||||
return drive_id_iterator
|
||||
|
||||
@@ -472,7 +490,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
|
||||
# resume from a checkpoint
|
||||
if resuming:
|
||||
drive_id = curr_stage.completed_until_parent_id
|
||||
drive_id = curr_stage.current_folder_or_drive_id
|
||||
if drive_id is None:
|
||||
raise ValueError("drive id not set in checkpoint")
|
||||
resume_start = curr_stage.completed_until
|
||||
@@ -485,7 +503,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
f"Getting files in shared drive '{drive_id}' as '{user_email}'"
|
||||
)
|
||||
curr_stage.completed_until = 0
|
||||
curr_stage.completed_until_parent_id = drive_id
|
||||
curr_stage.current_folder_or_drive_id = drive_id
|
||||
yield from _yield_from_drive(drive_id, start)
|
||||
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
|
||||
resuming = False # we are starting the next stage for the first time
|
||||
@@ -506,32 +524,17 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
parent_id=folder_id,
|
||||
is_slim=is_slim,
|
||||
user_email=user_email,
|
||||
traversed_parent_ids=self._retrieved_ids,
|
||||
traversed_parent_ids=self._retrieved_folder_and_drive_ids,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=folder_start,
|
||||
end=end,
|
||||
):
|
||||
with checkpoint.processed_folder_file_ids.lock:
|
||||
should_yield = False
|
||||
completed_ids = checkpoint.processed_folder_file_ids._dict.get(
|
||||
folder_id, set()
|
||||
)
|
||||
if (
|
||||
"id" in retrieved_file.drive_file
|
||||
and retrieved_file.drive_file["id"] not in completed_ids
|
||||
):
|
||||
completed_ids.add(retrieved_file.drive_file["id"])
|
||||
should_yield = True
|
||||
checkpoint.processed_folder_file_ids._dict[folder_id] = (
|
||||
completed_ids
|
||||
)
|
||||
if should_yield:
|
||||
yield retrieved_file
|
||||
yield retrieved_file
|
||||
|
||||
# resume from a checkpoint
|
||||
last_processed_folder = None
|
||||
if resuming:
|
||||
folder_id = curr_stage.completed_until_parent_id
|
||||
folder_id = curr_stage.current_folder_or_drive_id
|
||||
if folder_id is None:
|
||||
raise ValueError("folder id not set in checkpoint")
|
||||
resume_start = curr_stage.completed_until
|
||||
@@ -547,11 +550,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
skipping_seen_folders = folder_id != last_processed_folder
|
||||
continue
|
||||
|
||||
if folder_id in self._retrieved_ids:
|
||||
if folder_id in self._retrieved_folder_and_drive_ids:
|
||||
continue
|
||||
|
||||
curr_stage.completed_until = 0
|
||||
curr_stage.completed_until_parent_id = folder_id
|
||||
curr_stage.current_folder_or_drive_id = folder_id
|
||||
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
|
||||
yield from _yield_from_folder_crawl(folder_id, start)
|
||||
|
||||
@@ -641,7 +644,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
|
||||
remaining_folders = (
|
||||
set(sorted_drive_ids) | set(sorted_folder_ids)
|
||||
) - self._retrieved_ids
|
||||
) - self._retrieved_folder_and_drive_ids
|
||||
if remaining_folders:
|
||||
logger.warning(
|
||||
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
|
||||
@@ -754,7 +757,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
):
|
||||
drive_id = checkpoint.completion_map[
|
||||
self.primary_admin_email
|
||||
].completed_until_parent_id
|
||||
].current_folder_or_drive_id
|
||||
if drive_id is None:
|
||||
raise ValueError("drive id not set in checkpoint")
|
||||
resume_start = checkpoint.completion_map[
|
||||
@@ -763,7 +766,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
yield from _yield_from_drive(drive_id, resume_start)
|
||||
|
||||
for drive_id in drive_ids_to_retrieve:
|
||||
if drive_id in self._retrieved_ids:
|
||||
if drive_id in self._retrieved_folder_and_drive_ids:
|
||||
logger.info(
|
||||
f"Skipping drive '{drive_id}' as it has already been retrieved"
|
||||
)
|
||||
@@ -790,7 +793,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
"""
|
||||
# Even if no folders were requested, we still check if any drives were requested
|
||||
# that could be folders.
|
||||
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
|
||||
remaining_folders = (
|
||||
folder_ids_to_retrieve - self._retrieved_folder_and_drive_ids
|
||||
)
|
||||
|
||||
def _yield_from_folder_crawl(
|
||||
folder_id: str, folder_start: SecondsSinceUnixEpoch | None
|
||||
@@ -800,7 +805,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
parent_id=folder_id,
|
||||
is_slim=is_slim,
|
||||
user_email=self.primary_admin_email,
|
||||
traversed_parent_ids=self._retrieved_ids,
|
||||
traversed_parent_ids=self._retrieved_folder_and_drive_ids,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=folder_start,
|
||||
end=end,
|
||||
@@ -813,7 +818,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
):
|
||||
folder_id = checkpoint.completion_map[
|
||||
self.primary_admin_email
|
||||
].completed_until_parent_id
|
||||
].current_folder_or_drive_id
|
||||
if folder_id is None:
|
||||
raise ValueError("folder id not set in checkpoint")
|
||||
resume_start = checkpoint.completion_map[
|
||||
@@ -831,7 +836,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
|
||||
remaining_folders = (
|
||||
drive_ids_to_retrieve | folder_ids_to_retrieve
|
||||
) - self._retrieved_ids
|
||||
) - self._retrieved_folder_and_drive_ids
|
||||
if remaining_folders:
|
||||
logger.warning(
|
||||
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
|
||||
@@ -861,9 +866,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
completed_until=datetime.fromisoformat(
|
||||
file.drive_file[GoogleFields.MODIFIED_TIME.value]
|
||||
).timestamp(),
|
||||
completed_until_parent_id=file.parent_id,
|
||||
current_folder_or_drive_id=file.parent_id,
|
||||
)
|
||||
yield file
|
||||
if file.drive_file["id"] not in checkpoint.all_retrieved_file_ids:
|
||||
checkpoint.all_retrieved_file_ids.add(file.drive_file["id"])
|
||||
yield file
|
||||
|
||||
def _manage_oauth_retrieval(
|
||||
self,
|
||||
@@ -877,7 +884,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
checkpoint.completion_map[self.primary_admin_email] = StageCompletion(
|
||||
stage=DriveRetrievalStage.START,
|
||||
completed_until=0,
|
||||
completed_until_parent_id=None,
|
||||
current_folder_or_drive_id=None,
|
||||
)
|
||||
|
||||
drive_service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
@@ -1037,7 +1044,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
files_batch = []
|
||||
|
||||
if batches_complete > BATCHES_PER_CHECKPOINT:
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
|
||||
checkpoint.retrieved_folder_and_drive_ids = (
|
||||
self._retrieved_folder_and_drive_ids
|
||||
)
|
||||
return # create a new checkpoint
|
||||
|
||||
# Process any remaining files
|
||||
@@ -1062,14 +1071,14 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
)
|
||||
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids
|
||||
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive(checkpoint, start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_folder_and_drive_ids
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.DONE:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
@@ -1172,6 +1181,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
retrieved_folder_and_drive_ids=set(),
|
||||
completion_stage=DriveRetrievalStage.START,
|
||||
completion_map=ThreadSafeDict(),
|
||||
all_retrieved_file_ids=set(),
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
|
@@ -76,7 +76,7 @@ class StageCompletion(BaseModel):
|
||||
|
||||
stage: DriveRetrievalStage
|
||||
completed_until: SecondsSinceUnixEpoch
|
||||
completed_until_parent_id: str | None = None
|
||||
current_folder_or_drive_id: str | None = None
|
||||
|
||||
# only used for shared drives
|
||||
processed_drive_ids: set[str] = set()
|
||||
@@ -85,11 +85,11 @@ class StageCompletion(BaseModel):
|
||||
self,
|
||||
stage: DriveRetrievalStage,
|
||||
completed_until: SecondsSinceUnixEpoch,
|
||||
completed_until_parent_id: str | None = None,
|
||||
current_folder_or_drive_id: str | None = None,
|
||||
) -> None:
|
||||
self.stage = stage
|
||||
self.completed_until = completed_until
|
||||
self.completed_until_parent_id = completed_until_parent_id
|
||||
self.current_folder_or_drive_id = current_folder_or_drive_id
|
||||
|
||||
|
||||
class RetrievedDriveFile(BaseModel):
|
||||
@@ -135,8 +135,8 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
# timestamp part is not used for folder crawling.
|
||||
completion_map: ThreadSafeDict[str, StageCompletion]
|
||||
|
||||
# only used for folder crawling. maps from parent folder id to seen file ids.
|
||||
processed_folder_file_ids: ThreadSafeDict[str, set[str]] = ThreadSafeDict()
|
||||
# all file ids that have been retrieved
|
||||
all_retrieved_file_ids: set[str] = set()
|
||||
|
||||
# cached version of the drive and folder ids to retrieve
|
||||
drive_ids_to_retrieve: list[str] | None = None
|
||||
@@ -157,16 +157,3 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
return ThreadSafeDict(
|
||||
{k: StageCompletion.model_validate(val) for k, val in v.items()}
|
||||
)
|
||||
|
||||
@field_serializer("processed_folder_file_ids")
|
||||
def serialize_processed_folder_file_ids(
|
||||
self, processed_folder_file_ids: ThreadSafeDict[str, set[str]], _info: Any
|
||||
) -> dict[str, set[str]]:
|
||||
return processed_folder_file_ids._dict
|
||||
|
||||
@field_validator("processed_folder_file_ids", mode="before")
|
||||
def validate_processed_folder_file_ids(
|
||||
cls, v: Any
|
||||
) -> ThreadSafeDict[str, set[str]]:
|
||||
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
|
||||
return ThreadSafeDict({k: set(val) for k, val in v.items()})
|
||||
|
Reference in New Issue
Block a user