From 83421686588f77cb1b8086cc359ac2a427e97f83 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Fri, 31 Jan 2025 15:28:13 -0800 Subject: [PATCH] initial variable renaming --- .../nodes/format_initial_sub_answers.py | 6 +-- .../initial/generate_initial_answer/edges.py | 4 +- .../nodes/generate_initial_answer.py | 14 +++--- .../nodes/ingest_retrieved_documents.py | 42 ------------------ .../initial/generate_initial_answer/states.py | 22 +++++----- .../initial/generate_sub_answers/edges.py | 4 +- .../nodes/decompose_orig_question.py | 10 +++-- .../nodes/format_initial_sub_answers.py | 6 +-- .../initial/generate_sub_answers/states.py | 14 +++--- .../format_orig_question_search_output.py | 12 ++--- .../retrieve_orig_question_docs/states.py | 4 +- .../agent_search/deep_search/main/edges.py | 8 ++-- .../deep_search/main/nodes/compare_answers.py | 8 ++-- .../nodes/create_refined_sub_questions.py | 10 ++--- .../main/nodes/decide_refinement_need.py | 8 ++-- .../main/nodes/generate_refined_answer.py | 20 ++++----- .../main/nodes/ingest_refined_answers.py | 6 +-- .../agent_search/deep_search/main/states.py | 44 +++++++++---------- .../nodes/format_results.py | 8 ++-- .../nodes/retrieve_documents.py | 4 +- .../shared/expanded_retrieval/states.py | 2 +- 21 files changed, 109 insertions(+), 147 deletions(-) delete mode 100644 backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/ingest_retrieved_documents.py diff --git a/backend/onyx/agents/agent_search/deep_search/initial/general_sub_answers/nodes/format_initial_sub_answers.py b/backend/onyx/agents/agent_search/deep_search/initial/general_sub_answers/nodes/format_initial_sub_answers.py index 4867217c5..dfc7e2d2a 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/general_sub_answers/nodes/format_initial_sub_answers.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/general_sub_answers/nodes/format_initial_sub_answers.py @@ -5,7 +5,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer ) from onyx.agents.agent_search.deep_search.main.operations import logger from onyx.agents.agent_search.deep_search.main.states import ( - DecompAnswersUpdate, + SubQuestionResultsUpdate, ) from onyx.agents.agent_search.shared_graph_utils.operators import ( dedup_inference_sections, @@ -14,7 +14,7 @@ from onyx.agents.agent_search.shared_graph_utils.operators import ( def format_initial_sub_answers( state: AnswerQuestionOutput, -) -> DecompAnswersUpdate: +) -> SubQuestionResultsUpdate: now_start = datetime.now() logger.info(f"--------{now_start}--------INGEST ANSWERS---") @@ -32,7 +32,7 @@ def format_initial_sub_answers( f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---" ) - return DecompAnswersUpdate( + return SubQuestionResultsUpdate( # Deduping is done by the documents operator for the main graph # so we might not need to dedup here verified_reranked_documents=dedup_inference_sections(documents, []), diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/edges.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/edges.py index 953f1cf28..55b1fe385 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/edges.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/edges.py @@ -19,7 +19,7 @@ def parallelize_initial_sub_question_answering( state: SearchSQState, ) -> list[Send | Hashable]: edge_start_time = datetime.now() - if len(state.initial_decomp_questions) > 0: + if len(state.initial_sub_questions) > 0: # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] # if len(state["sub_question_records"]) == 0: # if state["config"].use_persistence: @@ -41,7 +41,7 @@ def parallelize_initial_sub_question_answering( ], ), ) - for question_nr, question in enumerate(state.initial_decomp_questions) + for question_nr, question in enumerate(state.initial_sub_questions) ] else: diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py index 3bbe93d13..d64a7fb0a 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py @@ -68,11 +68,13 @@ def generate_initial_answer( prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config) sub_questions_cited_documents = state.cited_documents - all_original_question_documents = state.all_original_question_documents + orig_question_retrieval_documents = state.orig_question_retrieval_documents consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents counter = 0 - for original_doc_number, original_doc in enumerate(all_original_question_documents): + for original_doc_number, original_doc in enumerate( + orig_question_retrieval_documents + ): if original_doc_number not in sub_questions_cited_documents: if ( counter <= AGENT_MIN_ORIG_QUESTION_DOCS @@ -89,7 +91,7 @@ def generate_initial_answer( decomp_questions = [] # Use the query info from the base document retrieval - query_info = get_query_info(state.original_question_retrieval_results) + query_info = get_query_info(state.orig_question_query_retrieval_results) if agent_search_config.search_tool is None: raise ValueError("search_tool must be provided for agentic search") @@ -229,7 +231,7 @@ def generate_initial_answer( answer = cast(str, response) initial_agent_stats = calculate_initial_agent_stats( - state.sub_question_results, state.original_question_retrieval_stats + state.sub_question_results, state.orig_question_retrieval_stats ) logger.debug( @@ -250,8 +252,8 @@ def generate_initial_answer( agent_base_metrics = AgentBaseMetrics( num_verified_documents_total=len(relevant_docs), - num_verified_documents_core=state.original_question_retrieval_stats.verified_count, - verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores, + num_verified_documents_core=state.orig_question_retrieval_stats.verified_count, + verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores, num_verified_documents_base=initial_agent_stats.sub_questions.get( "num_verified_documents" ), diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/ingest_retrieved_documents.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/ingest_retrieved_documents.py deleted file mode 100644 index 5af42108e..000000000 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/ingest_retrieved_documents.py +++ /dev/null @@ -1,42 +0,0 @@ -from datetime import datetime - -from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import ( - BaseRawSearchOutput, -) -from onyx.agents.agent_search.deep_search.main.states import ( - ExpandedRetrievalUpdate, -) -from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats -from onyx.agents.agent_search.shared_graph_utils.utils import ( - get_langgraph_node_log_string, -) - - -def ingest_retrieved_documents( - state: BaseRawSearchOutput, -) -> ExpandedRetrievalUpdate: - node_start_time = datetime.now() - - sub_question_retrieval_stats = ( - state.base_expanded_retrieval_result.sub_question_retrieval_stats - ) - # if sub_question_retrieval_stats is None: - # sub_question_retrieval_stats = AgentChunkStats() - # else: - # sub_question_retrieval_stats = sub_question_retrieval_stats - - sub_question_retrieval_stats = sub_question_retrieval_stats or AgentChunkStats() - - return ExpandedRetrievalUpdate( - original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results, - all_original_question_documents=state.base_expanded_retrieval_result.context_documents, - original_question_retrieval_stats=sub_question_retrieval_stats, - log_messages=[ - get_langgraph_node_log_string( - graph_component="initial - generate initial answer", - node_name="ingest retrieved documents", - node_start_time=node_start_time, - result="", - ) - ], - ) diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/states.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/states.py index 8e2ccaa24..51a9447eb 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/states.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/states.py @@ -3,13 +3,6 @@ from typing import Annotated from typing import TypedDict from onyx.agents.agent_search.core_state import CoreState -from onyx.agents.agent_search.deep_search.main.states import BaseDecompUpdate -from onyx.agents.agent_search.deep_search.main.states import ( - DecompAnswersUpdate, -) -from onyx.agents.agent_search.deep_search.main.states import ( - ExpandedRetrievalUpdate, -) from onyx.agents.agent_search.deep_search.main.states import ( ExploratorySearchUpdate, ) @@ -19,6 +12,15 @@ from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import ( InitialAnswerUpdate, ) +from onyx.agents.agent_search.deep_search.main.states import ( + InitialQuestionDecompositionUpdate, +) +from onyx.agents.agent_search.deep_search.main.states import ( + OrigQuestionRetrievalUpdate, +) +from onyx.agents.agent_search.deep_search.main.states import ( + SubQuestionResultsUpdate, +) from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import ( ExpandedRetrievalResult, ) @@ -36,10 +38,10 @@ class SearchSQInput(CoreState): class SearchSQState( # This includes the core state SearchSQInput, - BaseDecompUpdate, + InitialQuestionDecompositionUpdate, InitialAnswerUpdate, - DecompAnswersUpdate, - ExpandedRetrievalUpdate, + SubQuestionResultsUpdate, + OrigQuestionRetrievalUpdate, InitialAnswerQualityUpdate, ExploratorySearchUpdate, ): diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/edges.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/edges.py index b18101833..aa564b4a0 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/edges.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/edges.py @@ -19,7 +19,7 @@ def parallelize_initial_sub_question_answering( state: SearchSQState, ) -> list[Send | Hashable]: edge_start_time = datetime.now() - if len(state.initial_decomp_questions) > 0: + if len(state.initial_sub_questions) > 0: # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] # if len(state["sub_question_records"]) == 0: # if state["config"].use_persistence: @@ -41,7 +41,7 @@ def parallelize_initial_sub_question_answering( ], ), ) - for question_nr, question in enumerate(state.initial_decomp_questions) + for question_nr, question in enumerate(state.initial_sub_questions) ] else: diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/decompose_orig_question.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/decompose_orig_question.py index 8f80c11a1..ff965232b 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/decompose_orig_question.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/decompose_orig_question.py @@ -15,7 +15,9 @@ from onyx.agents.agent_search.deep_search.main.models import ( from onyx.agents.agent_search.deep_search.main.operations import ( dispatch_subquestion, ) -from onyx.agents.agent_search.deep_search.main.states import BaseDecompUpdate +from onyx.agents.agent_search.deep_search.main.states import ( + InitialQuestionDecompositionUpdate, +) from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( build_history_prompt, @@ -39,7 +41,7 @@ from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION def decompose_orig_question( state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None -) -> BaseDecompUpdate: +) -> InitialQuestionDecompositionUpdate: node_start_time = datetime.now() agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) @@ -123,8 +125,8 @@ def decompose_orig_question( decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""] - return BaseDecompUpdate( - initial_decomp_questions=decomp_list, + return InitialQuestionDecompositionUpdate( + initial_sub_questions=decomp_list, agent_start_time=agent_start_time, agent_refined_start_time=None, agent_refined_end_time=None, diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/format_initial_sub_answers.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/format_initial_sub_answers.py index 3afb7ed64..2d845be84 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/format_initial_sub_answers.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/format_initial_sub_answers.py @@ -4,7 +4,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer AnswerQuestionOutput, ) from onyx.agents.agent_search.deep_search.main.states import ( - DecompAnswersUpdate, + SubQuestionResultsUpdate, ) from onyx.agents.agent_search.shared_graph_utils.operators import ( dedup_inference_sections, @@ -16,7 +16,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( def format_initial_sub_answers( state: AnswerQuestionOutput, -) -> DecompAnswersUpdate: +) -> SubQuestionResultsUpdate: node_start_time = datetime.now() documents = [] @@ -28,7 +28,7 @@ def format_initial_sub_answers( context_documents.extend(answer_result.context_documents) cited_documents.extend(answer_result.cited_documents) - return DecompAnswersUpdate( + return SubQuestionResultsUpdate( # Deduping is done by the documents operator for the main graph # so we might not need to dedup here verified_reranked_documents=dedup_inference_sections(documents, []), diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/states.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/states.py index 5ea661c3b..87567b0ea 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/states.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/states.py @@ -1,13 +1,15 @@ from typing import TypedDict from onyx.agents.agent_search.core_state import CoreState -from onyx.agents.agent_search.deep_search.main.states import BaseDecompUpdate -from onyx.agents.agent_search.deep_search.main.states import ( - DecompAnswersUpdate, -) from onyx.agents.agent_search.deep_search.main.states import ( InitialAnswerUpdate, ) +from onyx.agents.agent_search.deep_search.main.states import ( + InitialQuestionDecompositionUpdate, +) +from onyx.agents.agent_search.deep_search.main.states import ( + SubQuestionResultsUpdate, +) ### States ### @@ -22,9 +24,9 @@ class SQInput(CoreState): class SQState( # This includes the core state SQInput, - BaseDecompUpdate, + InitialQuestionDecompositionUpdate, InitialAnswerUpdate, - DecompAnswersUpdate, + SubQuestionResultsUpdate, ): # expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add] pass diff --git a/backend/onyx/agents/agent_search/deep_search/initial/retrieve_orig_question_docs/nodes/format_orig_question_search_output.py b/backend/onyx/agents/agent_search/deep_search/initial/retrieve_orig_question_docs/nodes/format_orig_question_search_output.py index 48efcd2ac..8206e4dda 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/retrieve_orig_question_docs/nodes/format_orig_question_search_output.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/retrieve_orig_question_docs/nodes/format_orig_question_search_output.py @@ -1,4 +1,4 @@ -from onyx.agents.agent_search.deep_search.main.states import ExpandedRetrievalUpdate +from onyx.agents.agent_search.deep_search.main.states import OrigQuestionRetrievalUpdate from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import ( ExpandedRetrievalOutput, ) @@ -10,7 +10,7 @@ logger = setup_logger() def format_orig_question_search_output( state: ExpandedRetrievalOutput, -) -> ExpandedRetrievalUpdate: +) -> OrigQuestionRetrievalUpdate: # return BaseRawSearchOutput( # base_expanded_retrieval_result=state.expanded_retrieval_result, # # base_retrieval_results=[state.expanded_retrieval_result], @@ -25,9 +25,9 @@ def format_orig_question_search_output( else: sub_question_retrieval_stats = sub_question_retrieval_stats - return ExpandedRetrievalUpdate( - original_question_retrieval_results=state.expanded_retrieval_result.expanded_queries_results, - all_original_question_documents=state.expanded_retrieval_result.context_documents, - original_question_retrieval_stats=sub_question_retrieval_stats, + return OrigQuestionRetrievalUpdate( + orig_question_query_retrieval_results=state.expanded_retrieval_result.expanded_queries_results, + orig_question_retrieval_documents=state.expanded_retrieval_result.context_documents, + orig_question_retrieval_stats=sub_question_retrieval_stats, log_messages=[], ) diff --git a/backend/onyx/agents/agent_search/deep_search/initial/retrieve_orig_question_docs/states.py b/backend/onyx/agents/agent_search/deep_search/initial/retrieve_orig_question_docs/states.py index a54ea0889..744081e76 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/retrieve_orig_question_docs/states.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/retrieve_orig_question_docs/states.py @@ -1,7 +1,7 @@ from pydantic import BaseModel from onyx.agents.agent_search.deep_search.main.states import ( - ExpandedRetrievalUpdate, + OrigQuestionRetrievalUpdate, ) from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import ( ExpandedRetrievalResult, @@ -39,6 +39,6 @@ class BaseRawSearchOutput(BaseModel): class BaseRawSearchState( - BaseRawSearchInput, BaseRawSearchOutput, ExpandedRetrievalUpdate + BaseRawSearchInput, BaseRawSearchOutput, OrigQuestionRetrievalUpdate ): pass diff --git a/backend/onyx/agents/agent_search/deep_search/main/edges.py b/backend/onyx/agents/agent_search/deep_search/main/edges.py index b03f5d570..d5a28bace 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/edges.py +++ b/backend/onyx/agents/agent_search/deep_search/main/edges.py @@ -14,7 +14,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer ) from onyx.agents.agent_search.deep_search.main.states import MainState from onyx.agents.agent_search.deep_search.main.states import ( - RequireRefinedAnswerUpdate, + RequireRefinementUpdate, ) from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id @@ -44,7 +44,7 @@ def parallelize_initial_sub_question_answering( state: MainState, ) -> list[Send | Hashable]: edge_start_time = datetime.now() - if len(state.initial_decomp_questions) > 0: + if len(state.initial_sub_questions) > 0: # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] # if len(state["sub_question_records"]) == 0: # if state["config"].use_persistence: @@ -66,7 +66,7 @@ def parallelize_initial_sub_question_answering( ], ), ) - for question_nr, question in enumerate(state.initial_decomp_questions) + for question_nr, question in enumerate(state.initial_sub_questions) ] else: @@ -82,7 +82,7 @@ def parallelize_initial_sub_question_answering( # Define the function that determines whether to continue or not def continue_to_refined_answer_or_end( - state: RequireRefinedAnswerUpdate, + state: RequireRefinementUpdate, ) -> Literal["create_refined_sub_questions", "logging_node"]: if state.require_refined_answer_eval: return "create_refined_sub_questions" diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/compare_answers.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/compare_answers.py index 58ae33fb5..77c8015b7 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/compare_answers.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/compare_answers.py @@ -5,7 +5,9 @@ from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig from langgraph.types import StreamWriter -from onyx.agents.agent_search.deep_search.main.states import AnswerComparison +from onyx.agents.agent_search.deep_search.main.states import ( + InitialVRefinedAnswerComparisonUpdate, +) from onyx.agents.agent_search.deep_search.main.states import MainState from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT @@ -18,7 +20,7 @@ from onyx.chat.models import RefinedAnswerImprovement def compare_answers( state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None -) -> AnswerComparison: +) -> InitialVRefinedAnswerComparisonUpdate: node_start_time = datetime.now() agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) @@ -50,7 +52,7 @@ def compare_answers( writer, ) - return AnswerComparison( + return InitialVRefinedAnswerComparisonUpdate( refined_answer_improvement_eval=refined_answer_improvement, log_messages=[ get_langgraph_node_log_string( diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/create_refined_sub_questions.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/create_refined_sub_questions.py index 092549846..6411ef5f9 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/create_refined_sub_questions.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/create_refined_sub_questions.py @@ -12,10 +12,10 @@ from onyx.agents.agent_search.deep_search.main.models import ( from onyx.agents.agent_search.deep_search.main.operations import ( dispatch_subquestion, ) -from onyx.agents.agent_search.deep_search.main.states import ( - FollowUpSubQuestionsUpdate, -) from onyx.agents.agent_search.deep_search.main.states import MainState +from onyx.agents.agent_search.deep_search.main.states import ( + RefinedQuestionDecompositionUpdate, +) from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( build_history_prompt, @@ -37,7 +37,7 @@ from onyx.tools.models import ToolCallKickoff def create_refined_sub_questions( state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None -) -> FollowUpSubQuestionsUpdate: +) -> RefinedQuestionDecompositionUpdate: """ """ agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) write_custom_event( @@ -114,7 +114,7 @@ def create_refined_sub_questions( refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question - return FollowUpSubQuestionsUpdate( + return RefinedQuestionDecompositionUpdate( refined_sub_questions=refined_sub_question_dict, agent_refined_start_time=agent_refined_start_time, log_messages=[ diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/decide_refinement_need.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/decide_refinement_need.py index 2f781900d..fcb8bf0e3 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/decide_refinement_need.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/decide_refinement_need.py @@ -5,7 +5,7 @@ from langchain_core.runnables import RunnableConfig from onyx.agents.agent_search.deep_search.main.states import MainState from onyx.agents.agent_search.deep_search.main.states import ( - RequireRefinedAnswerUpdate, + RequireRefinementUpdate, ) from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.utils import ( @@ -15,7 +15,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( def decide_refinement_need( state: MainState, config: RunnableConfig -) -> RequireRefinedAnswerUpdate: +) -> RequireRefinementUpdate: node_start_time = datetime.now() agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) @@ -32,12 +32,12 @@ def decide_refinement_need( ] if agent_search_config.allow_refinement: - return RequireRefinedAnswerUpdate( + return RequireRefinementUpdate( require_refined_answer_eval=decision, log_messages=log_messages, ) else: - return RequireRefinedAnswerUpdate( + return RequireRefinementUpdate( require_refined_answer_eval=False, log_messages=log_messages, ) diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_refined_answer.py index 866091ab5..ef2e75439 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_refined_answer.py @@ -69,10 +69,9 @@ def generate_refined_answer( prompt_enrichment_components.persona_prompts.contextualized_prompt ) - initial_documents = state.verified_reranked_documents - refined_documents = state.refined_documents + verified_reranked_documents = state.verified_reranked_documents sub_questions_cited_documents = state.cited_documents - all_original_question_documents = state.all_original_question_documents + all_original_question_documents = state.orig_question_retrieval_documents consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents @@ -93,7 +92,7 @@ def generate_refined_answer( consolidated_context_docs, consolidated_context_docs ) - query_info = get_query_info(state.original_question_retrieval_results) + query_info = get_query_info(state.orig_question_query_retrieval_results) if agent_search_config.search_tool is None: raise ValueError("search_tool must be provided for agentic search") # stream refined answer docs @@ -117,15 +116,14 @@ def generate_refined_answer( writer, ) - if len(initial_documents) > 0: - revision_doc_effectiveness = len(relevant_docs) / len(initial_documents) - elif len(refined_documents) == 0: - revision_doc_effectiveness = 0.0 + if len(verified_reranked_documents) > 0: + refined_doc_effectiveness = len(relevant_docs) / len( + verified_reranked_documents + ) else: - revision_doc_effectiveness = 10.0 + refined_doc_effectiveness = 10.0 decomp_answer_results = state.sub_question_results - # revised_answer_results = state.refined_decomp_answer_results answered_qa_list: list[str] = [] decomp_questions = [] @@ -261,7 +259,7 @@ def generate_refined_answer( # ) refined_agent_stats = RefinedAgentStats( - revision_doc_efficiency=revision_doc_effectiveness, + revision_doc_efficiency=refined_doc_effectiveness, revision_question_efficiency=revision_question_efficiency, ) diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/ingest_refined_answers.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/ingest_refined_answers.py index ca816ebb5..6b92defd2 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/ingest_refined_answers.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/ingest_refined_answers.py @@ -4,7 +4,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer AnswerQuestionOutput, ) from onyx.agents.agent_search.deep_search.main.states import ( - DecompAnswersUpdate, + SubQuestionResultsUpdate, ) from onyx.agents.agent_search.shared_graph_utils.operators import ( dedup_inference_sections, @@ -16,7 +16,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( def ingest_refined_answers( state: AnswerQuestionOutput, -) -> DecompAnswersUpdate: +) -> SubQuestionResultsUpdate: node_start_time = datetime.now() documents = [] @@ -24,7 +24,7 @@ def ingest_refined_answers( for answer_result in answer_results: documents.extend(answer_result.verified_reranked_documents) - return DecompAnswersUpdate( + return SubQuestionResultsUpdate( # Deduping is done by the documents operator for the main graph # so we might not need to dedup here verified_reranked_documents=dedup_inference_sections(documents, []), diff --git a/backend/onyx/agents/agent_search/deep_search/main/states.py b/backend/onyx/agents/agent_search/deep_search/main/states.py index 81723932a..a43edfac2 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/states.py +++ b/backend/onyx/agents/agent_search/deep_search/main/states.py @@ -55,10 +55,12 @@ class RefinedAgentEndStats(BaseModel): agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics() -class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate): +class InitialQuestionDecompositionUpdate( + RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate +): agent_start_time: datetime | None = None previous_history: str | None = None - initial_decomp_questions: list[str] = [] + initial_sub_questions: list[str] = [] class ExploratorySearchUpdate(LoggerUpdate): @@ -66,11 +68,11 @@ class ExploratorySearchUpdate(LoggerUpdate): previous_history_summary: str | None = None -class AnswerComparison(LoggerUpdate): +class InitialVRefinedAnswerComparisonUpdate(LoggerUpdate): refined_answer_improvement_eval: bool = False -class RoutingDecision(LoggerUpdate): +class RoutingDecisionUpdate(LoggerUpdate): routing_decision: str | None = None @@ -97,11 +99,11 @@ class InitialAnswerQualityUpdate(LoggerUpdate): initial_answer_quality_eval: bool = False -class RequireRefinedAnswerUpdate(LoggerUpdate): +class RequireRefinementUpdate(LoggerUpdate): require_refined_answer_eval: bool = True -class DecompAnswersUpdate(LoggerUpdate): +class SubQuestionResultsUpdate(LoggerUpdate): verified_reranked_documents: Annotated[ list[InferenceSection], dedup_inference_sections ] = [] @@ -114,17 +116,12 @@ class DecompAnswersUpdate(LoggerUpdate): ] = [] -class FollowUpDecompAnswersUpdate(LoggerUpdate): - refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = [] - refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = [] - - -class ExpandedRetrievalUpdate(LoggerUpdate): - all_original_question_documents: Annotated[ +class OrigQuestionRetrievalUpdate(LoggerUpdate): + orig_question_retrieval_documents: Annotated[ list[InferenceSection], dedup_inference_sections ] - original_question_retrieval_results: list[QueryResult] = [] - original_question_retrieval_stats: AgentChunkStats = AgentChunkStats() + orig_question_query_retrieval_results: list[QueryResult] = [] + orig_question_retrieval_stats: AgentChunkStats = AgentChunkStats() class EntityTermExtractionUpdate(LoggerUpdate): @@ -133,7 +130,7 @@ class EntityTermExtractionUpdate(LoggerUpdate): ) -class FollowUpSubQuestionsUpdate(RefinedAgentStartStats, LoggerUpdate): +class RefinedQuestionDecompositionUpdate(RefinedAgentStartStats, LoggerUpdate): refined_sub_questions: dict[int, FollowUpSubQuestion] = {} @@ -154,21 +151,20 @@ class MainState( ToolChoiceInput, ToolCallUpdate, ToolChoiceUpdate, - BaseDecompUpdate, + InitialQuestionDecompositionUpdate, InitialAnswerUpdate, InitialAnswerBASEUpdate, - DecompAnswersUpdate, - ExpandedRetrievalUpdate, + SubQuestionResultsUpdate, + OrigQuestionRetrievalUpdate, EntityTermExtractionUpdate, InitialAnswerQualityUpdate, - RequireRefinedAnswerUpdate, - FollowUpSubQuestionsUpdate, - FollowUpDecompAnswersUpdate, + RequireRefinementUpdate, + RefinedQuestionDecompositionUpdate, RefinedAnswerUpdate, RefinedAgentStartStats, RefinedAgentEndStats, - RoutingDecision, - AnswerComparison, + RoutingDecisionUpdate, + InitialVRefinedAnswerComparisonUpdate, ExploratorySearchUpdate, ): # expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add] diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py index da819849b..c7f7550d4 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py @@ -31,7 +31,7 @@ def format_results( writer: StreamWriter = lambda _: None, ) -> ExpandedRetrievalUpdate: level, question_nr = parse_question_id(state.sub_question_id or "0_0") - query_info = get_query_info(state.expanded_retrieval_results) + query_info = get_query_info(state.query_retrieval_results) agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) # main question docs will be sent later after aggregation and deduping with sub-question docs @@ -42,7 +42,7 @@ def format_results( if len(reranked_documents) == 0: # The sub-question is used as the last query. If no verified documents are found, stream # the top 3 for that one. We may want to revisit this. - reranked_documents = state.expanded_retrieval_results[-1].search_results[:3] + reranked_documents = state.query_retrieval_results[-1].search_results[:3] if agent_search_config.search_tool is None: raise ValueError("search_tool must be provided for agentic search") @@ -68,7 +68,7 @@ def format_results( ) sub_question_retrieval_stats = calculate_sub_question_retrieval_stats( verified_documents=state.verified_documents, - expanded_retrieval_results=state.expanded_retrieval_results, + expanded_retrieval_results=state.query_retrieval_results, ) if sub_question_retrieval_stats is None: @@ -78,7 +78,7 @@ def format_results( return ExpandedRetrievalUpdate( expanded_retrieval_result=ExpandedRetrievalResult( - expanded_queries_results=state.expanded_retrieval_results, + expanded_queries_results=state.query_retrieval_results, verified_reranked_documents=reranked_documents, context_documents=state.reranked_documents, sub_question_retrieval_stats=sub_question_retrieval_stats, diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py index 46acd6228..d6dfe0249 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py @@ -53,7 +53,7 @@ def retrieve_documents( logger.warning("Empty query, skipping retrieval") return DocRetrievalUpdate( - expanded_retrieval_results=[], + query_retrieval_results=[], retrieved_documents=[], log_messages=[ get_langgraph_node_log_string( @@ -109,7 +109,7 @@ def retrieve_documents( ) return DocRetrievalUpdate( - expanded_retrieval_results=[expanded_retrieval_result], + query_retrieval_results=[expanded_retrieval_result], retrieved_documents=retrieved_docs, log_messages=[ get_langgraph_node_log_string( diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/states.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/states.py index fba7c657d..d486c4754 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/states.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/states.py @@ -39,7 +39,7 @@ class DocVerificationUpdate(BaseModel): class DocRetrievalUpdate(LoggerUpdate, BaseModel): - expanded_retrieval_results: Annotated[list[QueryResult], add] = [] + query_retrieval_results: Annotated[list[QueryResult], add] = [] retrieved_documents: Annotated[ list[InferenceSection], dedup_inference_sections ] = []