Another fix for Salesforce perm sync (#4432)

* Another fix for Salesforce perm sync

* typing
This commit is contained in:
Chris Weaver 2025-04-02 11:08:40 -07:00 committed by GitHub
parent 1cf966ecc1
commit 7ec04484d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 11 deletions

View File

@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
# if user is None, permissions are not enforced
return chunks
chunks_to_keep = []
final_chunk_dict: dict[str, InferenceChunk] = {}
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
chunks_to_keep.append(chunk)
final_chunk_dict[chunk.unique_id] = chunk
# For each source, filter out the chunks using the permission
# check function for that source
@ -79,6 +79,16 @@ def _post_query_chunk_censoring(
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
return chunks_to_keep
for censored_chunk in censored_chunks:
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
final_chunk_list: list[InferenceChunk] = []
for chunk in chunks:
# only if the chunk is in the final censored chunks, add it to the final list
# if it is missing, that means it was intentionally left out
if chunk.unique_id in final_chunk_dict:
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
return final_chunk_list

View File

@ -227,13 +227,16 @@ class SearchPipeline:
# If ee is enabled, censor the chunk sections based on user access
# Otherwise, return the retrieved chunks
censored_chunks = fetch_ee_implementation_or_noop(
censored_chunks = cast(
list[InferenceChunk],
fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring",
"_post_query_chunk_censoring",
retrieved_chunks,
)(
chunks=retrieved_chunks,
user=self.user,
),
)
above = self.search_query.chunks_above