mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-29 15:15:03 +02:00
no more duplicate files during folder indexing (#4579)
* no more duplicate files during folder indexing * cleanup checkpoint after a shared folder has been finished * cleanup * lint
This commit is contained in:
@@ -89,7 +89,7 @@ def _clean_requested_drive_ids(
|
||||
requested_drive_ids: set[str],
|
||||
requested_folder_ids: set[str],
|
||||
all_drive_ids_available: set[str],
|
||||
) -> tuple[set[str], set[str]]:
|
||||
) -> tuple[list[str], list[str]]:
|
||||
invalid_requested_drive_ids = requested_drive_ids - all_drive_ids_available
|
||||
filtered_folder_ids = requested_folder_ids - all_drive_ids_available
|
||||
if invalid_requested_drive_ids:
|
||||
@@ -100,7 +100,7 @@ def _clean_requested_drive_ids(
|
||||
filtered_folder_ids.update(invalid_requested_drive_ids)
|
||||
|
||||
valid_requested_drive_ids = requested_drive_ids - invalid_requested_drive_ids
|
||||
return valid_requested_drive_ids, filtered_folder_ids
|
||||
return sorted(valid_requested_drive_ids), sorted(filtered_folder_ids)
|
||||
|
||||
|
||||
class CredentialedRetrievalMethod(Protocol):
|
||||
@@ -326,7 +326,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
return all_drive_ids
|
||||
|
||||
def make_drive_id_iterator(
|
||||
self, drive_ids: set[str], checkpoint: GoogleDriveCheckpoint
|
||||
self, drive_ids: list[str], checkpoint: GoogleDriveCheckpoint
|
||||
) -> Callable[[str], Iterator[str]]:
|
||||
cv = threading.Condition()
|
||||
drive_id_status = {
|
||||
@@ -490,12 +490,18 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
|
||||
resuming = False # we are starting the next stage for the first time
|
||||
|
||||
# In the folder files section of service account retrieval we take extra care
|
||||
# to not retrieve duplicate docs. In particular, we only add a folder to
|
||||
# retrieved_folder_and_drive_ids when all users are finished retrieving files
|
||||
# from that folder, and maintain a set of all file ids that have been retrieved
|
||||
# for each folder. This might get rather large; in practice we assume that the
|
||||
# specific folders users choose to index don't have too many files.
|
||||
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
|
||||
|
||||
def _yield_from_folder_crawl(
|
||||
folder_id: str, folder_start: SecondsSinceUnixEpoch | None
|
||||
) -> Iterator[RetrievedDriveFile]:
|
||||
yield from crawl_folders_for_files(
|
||||
for retrieved_file in crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
is_slim=is_slim,
|
||||
@@ -504,7 +510,23 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=folder_start,
|
||||
end=end,
|
||||
)
|
||||
):
|
||||
with checkpoint.processed_folder_file_ids.lock:
|
||||
should_yield = False
|
||||
completed_ids = checkpoint.processed_folder_file_ids._dict.get(
|
||||
folder_id, set()
|
||||
)
|
||||
if (
|
||||
"id" in retrieved_file.drive_file
|
||||
and retrieved_file.drive_file["id"] not in completed_ids
|
||||
):
|
||||
completed_ids.add(retrieved_file.drive_file["id"])
|
||||
should_yield = True
|
||||
checkpoint.processed_folder_file_ids._dict[folder_id] = (
|
||||
completed_ids
|
||||
)
|
||||
if should_yield:
|
||||
yield retrieved_file
|
||||
|
||||
# resume from a checkpoint
|
||||
last_processed_folder = None
|
||||
@@ -517,6 +539,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
last_processed_folder = folder_id
|
||||
|
||||
skipping_seen_folders = last_processed_folder is not None
|
||||
# NOTE:this assumes a small number of folders to crawl. If someone
|
||||
# really wants to specify a large number of folders, we should use
|
||||
# binary search to find the first unseen folder.
|
||||
for folder_id in sorted_filtered_folder_ids:
|
||||
if skipping_seen_folders:
|
||||
skipping_seen_folders = folder_id != last_processed_folder
|
||||
@@ -552,7 +577,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
raise ValueError("user emails not set")
|
||||
all_org_emails = checkpoint.user_emails
|
||||
|
||||
drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids(
|
||||
sorted_drive_ids, sorted_folder_ids = self._determine_retrieval_ids(
|
||||
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
|
||||
)
|
||||
|
||||
@@ -570,16 +595,12 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
logger.debug(f"Users: {all_org_emails}")
|
||||
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
|
||||
logger.debug(f"Drives: {drive_ids_to_retrieve}")
|
||||
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
|
||||
logger.debug(f"Folders: {folder_ids_to_retrieve}")
|
||||
logger.info(f"Found {len(sorted_drive_ids)} drives to retrieve")
|
||||
logger.debug(f"Drives: {sorted_drive_ids}")
|
||||
logger.info(f"Found {len(sorted_folder_ids)} folders to retrieve")
|
||||
logger.debug(f"Folders: {sorted_folder_ids}")
|
||||
|
||||
drive_id_iterator = self.make_drive_id_iterator(
|
||||
drive_ids_to_retrieve, checkpoint
|
||||
)
|
||||
|
||||
sorted_filtered_folder_ids = sorted(folder_ids_to_retrieve)
|
||||
drive_id_iterator = self.make_drive_id_iterator(sorted_drive_ids, checkpoint)
|
||||
|
||||
# only process emails that we haven't already completed retrieval for
|
||||
non_completed_org_emails = [
|
||||
@@ -606,7 +627,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
is_slim,
|
||||
checkpoint,
|
||||
drive_id_iterator,
|
||||
sorted_filtered_folder_ids,
|
||||
sorted_folder_ids,
|
||||
start,
|
||||
end,
|
||||
)
|
||||
@@ -619,7 +640,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
return
|
||||
|
||||
remaining_folders = (
|
||||
drive_ids_to_retrieve | folder_ids_to_retrieve
|
||||
set(sorted_drive_ids) | set(sorted_folder_ids)
|
||||
) - self._retrieved_ids
|
||||
if remaining_folders:
|
||||
logger.warning(
|
||||
@@ -637,26 +658,26 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
is_slim: bool,
|
||||
next_stage: DriveRetrievalStage,
|
||||
) -> tuple[set[str], set[str]]:
|
||||
) -> tuple[list[str], list[str]]:
|
||||
all_drive_ids = self.get_all_drive_ids()
|
||||
drive_ids_to_retrieve: set[str] = set()
|
||||
folder_ids_to_retrieve: set[str] = set()
|
||||
sorted_drive_ids: list[str] = []
|
||||
sorted_folder_ids: list[str] = []
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.DRIVE_IDS:
|
||||
if self._requested_shared_drive_ids or self._requested_folder_ids:
|
||||
(
|
||||
drive_ids_to_retrieve,
|
||||
folder_ids_to_retrieve,
|
||||
sorted_drive_ids,
|
||||
sorted_folder_ids,
|
||||
) = _clean_requested_drive_ids(
|
||||
requested_drive_ids=self._requested_shared_drive_ids,
|
||||
requested_folder_ids=self._requested_folder_ids,
|
||||
all_drive_ids_available=all_drive_ids,
|
||||
)
|
||||
elif self.include_shared_drives:
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
sorted_drive_ids = sorted(all_drive_ids)
|
||||
|
||||
if not is_slim:
|
||||
checkpoint.drive_ids_to_retrieve = list(drive_ids_to_retrieve)
|
||||
checkpoint.folder_ids_to_retrieve = list(folder_ids_to_retrieve)
|
||||
checkpoint.drive_ids_to_retrieve = sorted_drive_ids
|
||||
checkpoint.folder_ids_to_retrieve = sorted_folder_ids
|
||||
checkpoint.completion_stage = next_stage
|
||||
else:
|
||||
if checkpoint.drive_ids_to_retrieve is None:
|
||||
@@ -664,10 +685,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
if checkpoint.folder_ids_to_retrieve is None:
|
||||
raise ValueError("folder ids to retrieve not set in checkpoint")
|
||||
# When loading from a checkpoint, load the previously cached drive and folder ids
|
||||
drive_ids_to_retrieve = set(checkpoint.drive_ids_to_retrieve)
|
||||
folder_ids_to_retrieve = set(checkpoint.folder_ids_to_retrieve)
|
||||
sorted_drive_ids = checkpoint.drive_ids_to_retrieve
|
||||
sorted_folder_ids = checkpoint.folder_ids_to_retrieve
|
||||
|
||||
return drive_ids_to_retrieve, folder_ids_to_retrieve
|
||||
return sorted_drive_ids, sorted_folder_ids
|
||||
|
||||
def _oauth_retrieval_all_files(
|
||||
self,
|
||||
@@ -704,7 +725,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
self,
|
||||
is_slim: bool,
|
||||
drive_service: GoogleDriveService,
|
||||
drive_ids_to_retrieve: set[str],
|
||||
drive_ids_to_retrieve: list[str],
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -886,7 +907,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
checkpoint.completion_stage = DriveRetrievalStage.DONE
|
||||
return
|
||||
|
||||
drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids(
|
||||
sorted_drive_ids, sorted_folder_ids = self._determine_retrieval_ids(
|
||||
checkpoint, is_slim, DriveRetrievalStage.SHARED_DRIVE_FILES
|
||||
)
|
||||
|
||||
@@ -894,7 +915,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
yield from self._oauth_retrieval_drives(
|
||||
is_slim=is_slim,
|
||||
drive_service=drive_service,
|
||||
drive_ids_to_retrieve=drive_ids_to_retrieve,
|
||||
drive_ids_to_retrieve=sorted_drive_ids,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
@@ -906,8 +927,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
yield from self._oauth_retrieval_folders(
|
||||
is_slim=is_slim,
|
||||
drive_service=drive_service,
|
||||
drive_ids_to_retrieve=drive_ids_to_retrieve,
|
||||
folder_ids_to_retrieve=folder_ids_to_retrieve,
|
||||
drive_ids_to_retrieve=set(sorted_drive_ids),
|
||||
folder_ids_to_retrieve=set(sorted_folder_ids),
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
|
@@ -221,6 +221,9 @@ def get_files_in_shared_drive(
|
||||
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
||||
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
|
||||
# empty drives.
|
||||
# NOTE: ^^ the above is not actually true due to folder restrictions:
|
||||
# https://support.google.com/a/users/answer/12380484?hl=en
|
||||
# So we may have to change this logic for people who use folder restrictions.
|
||||
update_traversed_ids_func(drive_id)
|
||||
yield file
|
||||
|
||||
|
@@ -135,6 +135,9 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
# timestamp part is not used for folder crawling.
|
||||
completion_map: ThreadSafeDict[str, StageCompletion]
|
||||
|
||||
# only used for folder crawling. maps from parent folder id to seen file ids.
|
||||
processed_folder_file_ids: ThreadSafeDict[str, set[str]] = ThreadSafeDict()
|
||||
|
||||
# cached version of the drive and folder ids to retrieve
|
||||
drive_ids_to_retrieve: list[str] | None = None
|
||||
folder_ids_to_retrieve: list[str] | None = None
|
||||
@@ -152,5 +155,18 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
|
||||
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
|
||||
return ThreadSafeDict(
|
||||
{k: StageCompletion.model_validate(v) for k, v in v.items()}
|
||||
{k: StageCompletion.model_validate(val) for k, val in v.items()}
|
||||
)
|
||||
|
||||
@field_serializer("processed_folder_file_ids")
|
||||
def serialize_processed_folder_file_ids(
|
||||
self, processed_folder_file_ids: ThreadSafeDict[str, set[str]], _info: Any
|
||||
) -> dict[str, set[str]]:
|
||||
return processed_folder_file_ids._dict
|
||||
|
||||
@field_validator("processed_folder_file_ids", mode="before")
|
||||
def validate_processed_folder_file_ids(
|
||||
cls, v: Any
|
||||
) -> ThreadSafeDict[str, set[str]]:
|
||||
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
|
||||
return ThreadSafeDict({k: set(val) for k, val in v.items()})
|
||||
|
@@ -24,6 +24,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
|
||||
@@ -477,3 +478,33 @@ def test_specific_user_emails_restricted_folder(
|
||||
)
|
||||
test_docs = load_all_docs(test_connector)
|
||||
assert len(test_docs) == 0
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_shared_drive_folder(
|
||||
mock_get_api_key: MagicMock,
|
||||
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_shared_drive_folder")
|
||||
connector = google_drive_oauth_uploaded_connector_factory(
|
||||
primary_admin_email=TEST_USER_1_EMAIL,
|
||||
include_files_shared_with_me=False,
|
||||
include_shared_drives=False,
|
||||
include_my_drives=True,
|
||||
shared_folder_urls=FOLDER_1_URL,
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS
|
||||
|
||||
# test for deduping
|
||||
assert len(expected_file_ids) == len(retrieved_docs)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
Reference in New Issue
Block a user