diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index ae9a84f57d1..70728251ea1 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -278,17 +278,17 @@ def search_postprocessing( _log_top_section_links(search_query.search_type.value, reranked_sections) yield reranked_sections - llm_section_selection = cast( - list[str] | None, - post_processing_results.get(str(llm_filter_task_id)) - if llm_filter_task_id - else None, - ) - if llm_section_selection is not None: - yield [ - index - for index, section in enumerate(reranked_sections or retrieved_sections) - if section.center_chunk.unique_id in llm_section_selection + llm_selected_section_ids = ( + [ + section.center_chunk.unique_id + for section in post_processing_results.get(str(llm_filter_task_id), []) ] - else: - yield cast(list[int], []) + if llm_filter_task_id + else [] + ) + + yield [ + index + for index, section in enumerate(reranked_sections or retrieved_sections) + if section.center_chunk.unique_id in llm_selected_section_ids + ]