diff --git a/backend/alembic/versions/98a5008d8711_agent_tracking.py b/backend/alembic/versions/98a5008d8711_agent_tracking.py index 2cc15aebb..3084bc47d 100644 --- a/backend/alembic/versions/98a5008d8711_agent_tracking.py +++ b/backend/alembic/versions/98a5008d8711_agent_tracking.py @@ -1,7 +1,7 @@ """agent_tracking Revision ID: 98a5008d8711 -Revises: 4d58345da04a +Revises: 33ea50e88f24 Create Date: 2025-01-29 17:00:00.000001 """ @@ -12,7 +12,7 @@ from sqlalchemy.dialects.postgresql import UUID # revision identifiers, used by Alembic. revision = "98a5008d8711" -down_revision = "4d58345da04a" +down_revision = "33ea50e88f24" branch_labels = None depends_on = None diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py index a7db34745..377ed32d0 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py @@ -68,12 +68,16 @@ def retrieve_documents( query_info = None if search_tool is None: raise ValueError("search_tool must be provided for agentic search") + + callback_container: list[list[InferenceSection]] = [] + # new db session to avoid concurrency issues with get_session_context_manager() as db_session: for tool_response in search_tool.run( query=query_to_retrieve, force_no_rerank=True, alternate_db_session=db_session, + retrieved_sections_callback=callback_container.append, ): # get retrieved docs to send to the rest of the graph if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: @@ -87,13 +91,9 @@ def retrieve_documents( break retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] - pre_rerank_docs = retrieved_docs - if search_tool.search_pipeline is not None: - pre_rerank_docs = ( - search_tool.search_pipeline._retrieved_sections or retrieved_docs - ) if AGENT_RETRIEVAL_STATS: + pre_rerank_docs = callback_container[0] fit_scores = get_fit_scores( pre_rerank_docs, retrieved_docs, diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index aaab9f9b2..8c1de70ca 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -982,9 +982,10 @@ def stream_chat_message_objects( # Saving Gen AI answer and responding with message info + basic_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1]) info = ( - info_by_subq[SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])] - if BASIC_KEY in info_by_subq + info_by_subq[basic_key] + if basic_key in info_by_subq else info_by_subq[ SubQuestionKey( level=AGENT_SEARCH_INITIAL_KEY[0], @@ -1018,7 +1019,7 @@ def stream_chat_message_objects( ), ) - # TODO: add answers for levels >= 1, where each level has the previous as its parent. Use + # add answers for levels >= 1, where each level has the previous as its parent. Use # the answer_by_level method in answer.py to get the answers for each level next_level = 1 prev_message = gen_ai_response_message diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index eace5b799..c3d5177cd 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -56,6 +56,8 @@ class SearchPipeline: retrieval_metrics_callback: ( Callable[[RetrievalMetricsContainer], None] | None ) = None, + retrieved_sections_callback: Callable[[list[InferenceSection]], None] + | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, prompt_config: PromptConfig | None = None, ): @@ -80,6 +82,8 @@ class SearchPipeline: self._retrieved_chunks: list[InferenceChunk] | None = None # Another call made to the document index to get surrounding sections self._retrieved_sections: list[InferenceSection] | None = None + + self.retrieved_sections_callback = retrieved_sections_callback # Reranking and LLM section selection can be run together # If only LLM selection is on, the reranked chunks are yielded immediatly self._reranked_sections: list[InferenceSection] | None = None @@ -328,9 +332,13 @@ class SearchPipeline: if self._reranked_sections is not None: return self._reranked_sections + retrieved_sections = self._get_sections() + if self.retrieved_sections_callback is not None: + self.retrieved_sections_callback(retrieved_sections) + self._postprocessing_generator = search_postprocessing( search_query=self.search_query, - retrieved_sections=self._get_sections(), + retrieved_sections=retrieved_sections, llm=self.fast_llm, rerank_metrics_callback=self.rerank_metrics_callback, ) diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index e4d698091..2e541f512 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -112,8 +112,6 @@ class SearchTool(Tool): self.fast_llm = fast_llm self.evaluation_type = evaluation_type - self.search_pipeline: SearchPipeline | None = None - self.selected_sections = selected_sections self.full_doc = full_doc @@ -282,6 +280,10 @@ class SearchTool(Tool): query = cast(str, kwargs["query"]) force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False)) alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None)) + retrieved_sections_callback = cast( + Callable[[list[InferenceSection]], None], + kwargs.get("retrieved_sections_callback"), + ) if self.selected_sections: yield from self._build_response_for_specified_sections(query) @@ -326,8 +328,8 @@ class SearchTool(Tool): bypass_acl=self.bypass_acl, db_session=alternate_db_session or self.db_session, prompt_config=self.prompt_config, + retrieved_sections_callback=retrieved_sections_callback, ) - self.search_pipeline = search_pipeline # used for agent_search metrics search_query_info = SearchQueryInfo( predicted_search=search_pipeline.search_query.search_type,