From 7ec04484d43238e45e83a39e4b66c2486b744e7d Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Wed, 2 Apr 2025 11:08:40 -0700 Subject: [PATCH] Another fix for Salesforce perm sync (#4432) * Another fix for Salesforce perm sync * typing --- .../post_query_censoring.py | 18 ++++++++++++++---- backend/onyx/context/search/pipeline.py | 17 ++++++++++------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/backend/ee/onyx/external_permissions/post_query_censoring.py b/backend/ee/onyx/external_permissions/post_query_censoring.py index 4d25643eb7..7d8f7a86fa 100644 --- a/backend/ee/onyx/external_permissions/post_query_censoring.py +++ b/backend/ee/onyx/external_permissions/post_query_censoring.py @@ -55,7 +55,7 @@ def _post_query_chunk_censoring( # if user is None, permissions are not enforced return chunks - chunks_to_keep = [] + final_chunk_dict: dict[str, InferenceChunk] = {} chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {} sources_to_censor = _get_all_censoring_enabled_sources() @@ -64,7 +64,7 @@ def _post_query_chunk_censoring( if chunk.source_type in sources_to_censor: chunks_to_process.setdefault(chunk.source_type, []).append(chunk) else: - chunks_to_keep.append(chunk) + final_chunk_dict[chunk.unique_id] = chunk # For each source, filter out the chunks using the permission # check function for that source @@ -79,6 +79,16 @@ def _post_query_chunk_censoring( f" chunks for this source and continuing: {e}" ) continue - chunks_to_keep.extend(censored_chunks) - return chunks_to_keep + for censored_chunk in censored_chunks: + final_chunk_dict[censored_chunk.unique_id] = censored_chunk + + # IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in + final_chunk_list: list[InferenceChunk] = [] + for chunk in chunks: + # only if the chunk is in the final censored chunks, add it to the final list + # if it is missing, that means it was intentionally left out + if chunk.unique_id in final_chunk_dict: + final_chunk_list.append(final_chunk_dict[chunk.unique_id]) + + return final_chunk_list diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index c7bece314b..f1b1cf81db 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -227,13 +227,16 @@ class SearchPipeline: # If ee is enabled, censor the chunk sections based on user access # Otherwise, return the retrieved chunks - censored_chunks = fetch_ee_implementation_or_noop( - "onyx.external_permissions.post_query_censoring", - "_post_query_chunk_censoring", - retrieved_chunks, - )( - chunks=retrieved_chunks, - user=self.user, + censored_chunks = cast( + list[InferenceChunk], + fetch_ee_implementation_or_noop( + "onyx.external_permissions.post_query_censoring", + "_post_query_chunk_censoring", + retrieved_chunks, + )( + chunks=retrieved_chunks, + user=self.user, + ), ) above = self.search_query.chunks_above