diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index f387642d80..c7bece314b 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -5,11 +5,13 @@ from typing import cast from sqlalchemy.orm import Session +from onyx.chat.models import ContextualPruningConfig from onyx.chat.models import PromptConfig from onyx.chat.models import SectionRelevancePiece from onyx.chat.prune_and_merge import _merge_sections from onyx.chat.prune_and_merge import ChunkRange from onyx.chat.prune_and_merge import merge_chunk_intervals +from onyx.chat.prune_and_merge import prune_and_merge_sections from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.enums import QueryFlow @@ -61,6 +63,7 @@ class SearchPipeline: | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, prompt_config: PromptConfig | None = None, + contextual_pruning_config: ContextualPruningConfig | None = None, ): # NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None # and typically are None. The preprocessing will fetch default values to replace these empty overrides. @@ -77,6 +80,9 @@ class SearchPipeline: self.search_settings = get_current_search_settings(db_session) self.document_index = get_default_document_index(self.search_settings, None) self.prompt_config: PromptConfig | None = prompt_config + self.contextual_pruning_config: ContextualPruningConfig | None = ( + contextual_pruning_config + ) # Preprocessing steps generate this self._search_query: SearchQuery | None = None @@ -420,7 +426,26 @@ class SearchPipeline: if self._final_context_sections is not None: return self._final_context_sections - self._final_context_sections = _merge_sections(sections=self.reranked_sections) + if ( + self.contextual_pruning_config is not None + and self.prompt_config is not None + ): + self._final_context_sections = prune_and_merge_sections( + sections=self.reranked_sections, + section_relevance_list=None, + prompt_config=self.prompt_config, + llm_config=self.llm.config, + question=self.search_query.query, + contextual_pruning_config=self.contextual_pruning_config, + ) + + else: + logger.error( + "Contextual pruning or prompt config not set, using default merge" + ) + self._final_context_sections = _merge_sections( + sections=self.reranked_sections + ) return self._final_context_sections @property diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 379b2dd2a9..08751e2119 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -376,6 +376,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]): db_session=alternate_db_session or self.db_session, prompt_config=self.prompt_config, retrieved_sections_callback=retrieved_sections_callback, + contextual_pruning_config=self.contextual_pruning_config, ) search_query_info = SearchQueryInfo( @@ -447,6 +448,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]): db_session=self.db_session, bypass_acl=self.bypass_acl, prompt_config=self.prompt_config, + contextual_pruning_config=self.contextual_pruning_config, ) # Log what we're doing