Simpler approach (#4414)

This commit is contained in:
joachim-danswer 2025-04-01 16:52:59 -07:00 committed by GitHub
parent b7ece296e0
commit daef985b02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 1 deletions

View File

@ -5,11 +5,13 @@ from typing import cast
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece from onyx.chat.models import SectionRelevancePiece
from onyx.chat.prune_and_merge import _merge_sections from onyx.chat.prune_and_merge import _merge_sections
from onyx.chat.prune_and_merge import ChunkRange from onyx.chat.prune_and_merge import ChunkRange
from onyx.chat.prune_and_merge import merge_chunk_intervals from onyx.chat.prune_and_merge import merge_chunk_intervals
from onyx.chat.prune_and_merge import prune_and_merge_sections
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import QueryFlow from onyx.context.search.enums import QueryFlow
@ -61,6 +63,7 @@ class SearchPipeline:
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None, prompt_config: PromptConfig | None = None,
contextual_pruning_config: ContextualPruningConfig | None = None,
): ):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None # NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides. # and typically are None. The preprocessing will fetch default values to replace these empty overrides.
@ -77,6 +80,9 @@ class SearchPipeline:
self.search_settings = get_current_search_settings(db_session) self.search_settings = get_current_search_settings(db_session)
self.document_index = get_default_document_index(self.search_settings, None) self.document_index = get_default_document_index(self.search_settings, None)
self.prompt_config: PromptConfig | None = prompt_config self.prompt_config: PromptConfig | None = prompt_config
self.contextual_pruning_config: ContextualPruningConfig | None = (
contextual_pruning_config
)
# Preprocessing steps generate this # Preprocessing steps generate this
self._search_query: SearchQuery | None = None self._search_query: SearchQuery | None = None
@ -420,7 +426,26 @@ class SearchPipeline:
if self._final_context_sections is not None: if self._final_context_sections is not None:
return self._final_context_sections return self._final_context_sections
self._final_context_sections = _merge_sections(sections=self.reranked_sections) if (
self.contextual_pruning_config is not None
and self.prompt_config is not None
):
self._final_context_sections = prune_and_merge_sections(
sections=self.reranked_sections,
section_relevance_list=None,
prompt_config=self.prompt_config,
llm_config=self.llm.config,
question=self.search_query.query,
contextual_pruning_config=self.contextual_pruning_config,
)
else:
logger.error(
"Contextual pruning or prompt config not set, using default merge"
)
self._final_context_sections = _merge_sections(
sections=self.reranked_sections
)
return self._final_context_sections return self._final_context_sections
@property @property

View File

@ -376,6 +376,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
db_session=alternate_db_session or self.db_session, db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config, prompt_config=self.prompt_config,
retrieved_sections_callback=retrieved_sections_callback, retrieved_sections_callback=retrieved_sections_callback,
contextual_pruning_config=self.contextual_pruning_config,
) )
search_query_info = SearchQueryInfo( search_query_info = SearchQueryInfo(
@ -447,6 +448,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
db_session=self.db_session, db_session=self.db_session,
bypass_acl=self.bypass_acl, bypass_acl=self.bypass_acl,
prompt_config=self.prompt_config, prompt_config=self.prompt_config,
contextual_pruning_config=self.contextual_pruning_config,
) )
# Log what we're doing # Log what we're doing