initial variable renaming

This commit is contained in:
joachim-danswer 2025-01-31 15:28:13 -08:00 committed by Evan Lohn
parent d5661baf98
commit 8342168658
21 changed files with 109 additions and 147 deletions

View File

@ -5,7 +5,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
) )
from onyx.agents.agent_search.deep_search.main.operations import logger from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
DecompAnswersUpdate, SubQuestionResultsUpdate,
) )
from onyx.agents.agent_search.shared_graph_utils.operators import ( from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections, dedup_inference_sections,
@ -14,7 +14,7 @@ from onyx.agents.agent_search.shared_graph_utils.operators import (
def format_initial_sub_answers( def format_initial_sub_answers(
state: AnswerQuestionOutput, state: AnswerQuestionOutput,
) -> DecompAnswersUpdate: ) -> SubQuestionResultsUpdate:
now_start = datetime.now() now_start = datetime.now()
logger.info(f"--------{now_start}--------INGEST ANSWERS---") logger.info(f"--------{now_start}--------INGEST ANSWERS---")
@ -32,7 +32,7 @@ def format_initial_sub_answers(
f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---" f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---"
) )
return DecompAnswersUpdate( return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph # Deduping is done by the documents operator for the main graph
# so we might not need to dedup here # so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []), verified_reranked_documents=dedup_inference_sections(documents, []),

View File

@ -19,7 +19,7 @@ def parallelize_initial_sub_question_answering(
state: SearchSQState, state: SearchSQState,
) -> list[Send | Hashable]: ) -> list[Send | Hashable]:
edge_start_time = datetime.now() edge_start_time = datetime.now()
if len(state.initial_decomp_questions) > 0: if len(state.initial_sub_questions) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0: # if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence: # if state["config"].use_persistence:
@ -41,7 +41,7 @@ def parallelize_initial_sub_question_answering(
], ],
), ),
) )
for question_nr, question in enumerate(state.initial_decomp_questions) for question_nr, question in enumerate(state.initial_sub_questions)
] ]
else: else:

View File

@ -68,11 +68,13 @@ def generate_initial_answer(
prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config) prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config)
sub_questions_cited_documents = state.cited_documents sub_questions_cited_documents = state.cited_documents
all_original_question_documents = state.all_original_question_documents orig_question_retrieval_documents = state.orig_question_retrieval_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
counter = 0 counter = 0
for original_doc_number, original_doc in enumerate(all_original_question_documents): for original_doc_number, original_doc in enumerate(
orig_question_retrieval_documents
):
if original_doc_number not in sub_questions_cited_documents: if original_doc_number not in sub_questions_cited_documents:
if ( if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS counter <= AGENT_MIN_ORIG_QUESTION_DOCS
@ -89,7 +91,7 @@ def generate_initial_answer(
decomp_questions = [] decomp_questions = []
# Use the query info from the base document retrieval # Use the query info from the base document retrieval
query_info = get_query_info(state.original_question_retrieval_results) query_info = get_query_info(state.orig_question_query_retrieval_results)
if agent_search_config.search_tool is None: if agent_search_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
@ -229,7 +231,7 @@ def generate_initial_answer(
answer = cast(str, response) answer = cast(str, response)
initial_agent_stats = calculate_initial_agent_stats( initial_agent_stats = calculate_initial_agent_stats(
state.sub_question_results, state.original_question_retrieval_stats state.sub_question_results, state.orig_question_retrieval_stats
) )
logger.debug( logger.debug(
@ -250,8 +252,8 @@ def generate_initial_answer(
agent_base_metrics = AgentBaseMetrics( agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs), num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state.original_question_retrieval_stats.verified_count, num_verified_documents_core=state.orig_question_retrieval_stats.verified_count,
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores, verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get( num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents" "num_verified_documents"
), ),

View File

@ -1,42 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search.main.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def ingest_retrieved_documents(
state: BaseRawSearchOutput,
) -> ExpandedRetrievalUpdate:
node_start_time = datetime.now()
sub_question_retrieval_stats = (
state.base_expanded_retrieval_result.sub_question_retrieval_stats
)
# if sub_question_retrieval_stats is None:
# sub_question_retrieval_stats = AgentChunkStats()
# else:
# sub_question_retrieval_stats = sub_question_retrieval_stats
sub_question_retrieval_stats = sub_question_retrieval_stats or AgentChunkStats()
return ExpandedRetrievalUpdate(
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
all_original_question_documents=state.base_expanded_retrieval_result.context_documents,
original_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="ingest retrieved documents",
node_start_time=node_start_time,
result="",
)
],
)

View File

@ -3,13 +3,6 @@ from typing import Annotated
from typing import TypedDict from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import BaseDecompUpdate
from onyx.agents.agent_search.deep_search.main.states import (
DecompAnswersUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate, ExploratorySearchUpdate,
) )
@ -19,6 +12,15 @@ from onyx.agents.agent_search.deep_search.main.states import (
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate, InitialAnswerUpdate,
) )
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import ( from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
ExpandedRetrievalResult, ExpandedRetrievalResult,
) )
@ -36,10 +38,10 @@ class SearchSQInput(CoreState):
class SearchSQState( class SearchSQState(
# This includes the core state # This includes the core state
SearchSQInput, SearchSQInput,
BaseDecompUpdate, InitialQuestionDecompositionUpdate,
InitialAnswerUpdate, InitialAnswerUpdate,
DecompAnswersUpdate, SubQuestionResultsUpdate,
ExpandedRetrievalUpdate, OrigQuestionRetrievalUpdate,
InitialAnswerQualityUpdate, InitialAnswerQualityUpdate,
ExploratorySearchUpdate, ExploratorySearchUpdate,
): ):

View File

@ -19,7 +19,7 @@ def parallelize_initial_sub_question_answering(
state: SearchSQState, state: SearchSQState,
) -> list[Send | Hashable]: ) -> list[Send | Hashable]:
edge_start_time = datetime.now() edge_start_time = datetime.now()
if len(state.initial_decomp_questions) > 0: if len(state.initial_sub_questions) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0: # if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence: # if state["config"].use_persistence:
@ -41,7 +41,7 @@ def parallelize_initial_sub_question_answering(
], ],
), ),
) )
for question_nr, question in enumerate(state.initial_decomp_questions) for question_nr, question in enumerate(state.initial_sub_questions)
] ]
else: else:

View File

@ -15,7 +15,9 @@ from onyx.agents.agent_search.deep_search.main.models import (
from onyx.agents.agent_search.deep_search.main.operations import ( from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion, dispatch_subquestion,
) )
from onyx.agents.agent_search.deep_search.main.states import BaseDecompUpdate from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt, build_history_prompt,
@ -39,7 +41,7 @@ from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
def decompose_orig_question( def decompose_orig_question(
state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BaseDecompUpdate: ) -> InitialQuestionDecompositionUpdate:
node_start_time = datetime.now() node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
@ -123,8 +125,8 @@ def decompose_orig_question(
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""] decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
return BaseDecompUpdate( return InitialQuestionDecompositionUpdate(
initial_decomp_questions=decomp_list, initial_sub_questions=decomp_list,
agent_start_time=agent_start_time, agent_start_time=agent_start_time,
agent_refined_start_time=None, agent_refined_start_time=None,
agent_refined_end_time=None, agent_refined_end_time=None,

View File

@ -4,7 +4,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
AnswerQuestionOutput, AnswerQuestionOutput,
) )
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
DecompAnswersUpdate, SubQuestionResultsUpdate,
) )
from onyx.agents.agent_search.shared_graph_utils.operators import ( from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections, dedup_inference_sections,
@ -16,7 +16,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
def format_initial_sub_answers( def format_initial_sub_answers(
state: AnswerQuestionOutput, state: AnswerQuestionOutput,
) -> DecompAnswersUpdate: ) -> SubQuestionResultsUpdate:
node_start_time = datetime.now() node_start_time = datetime.now()
documents = [] documents = []
@ -28,7 +28,7 @@ def format_initial_sub_answers(
context_documents.extend(answer_result.context_documents) context_documents.extend(answer_result.context_documents)
cited_documents.extend(answer_result.cited_documents) cited_documents.extend(answer_result.cited_documents)
return DecompAnswersUpdate( return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph # Deduping is done by the documents operator for the main graph
# so we might not need to dedup here # so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []), verified_reranked_documents=dedup_inference_sections(documents, []),

View File

@ -1,13 +1,15 @@
from typing import TypedDict from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import BaseDecompUpdate
from onyx.agents.agent_search.deep_search.main.states import (
DecompAnswersUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate, InitialAnswerUpdate,
) )
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
### States ### ### States ###
@ -22,9 +24,9 @@ class SQInput(CoreState):
class SQState( class SQState(
# This includes the core state # This includes the core state
SQInput, SQInput,
BaseDecompUpdate, InitialQuestionDecompositionUpdate,
InitialAnswerUpdate, InitialAnswerUpdate,
DecompAnswersUpdate, SubQuestionResultsUpdate,
): ):
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add] # expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
pass pass

View File

@ -1,4 +1,4 @@
from onyx.agents.agent_search.deep_search.main.states import ExpandedRetrievalUpdate from onyx.agents.agent_search.deep_search.main.states import OrigQuestionRetrievalUpdate
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import ( from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput, ExpandedRetrievalOutput,
) )
@ -10,7 +10,7 @@ logger = setup_logger()
def format_orig_question_search_output( def format_orig_question_search_output(
state: ExpandedRetrievalOutput, state: ExpandedRetrievalOutput,
) -> ExpandedRetrievalUpdate: ) -> OrigQuestionRetrievalUpdate:
# return BaseRawSearchOutput( # return BaseRawSearchOutput(
# base_expanded_retrieval_result=state.expanded_retrieval_result, # base_expanded_retrieval_result=state.expanded_retrieval_result,
# # base_retrieval_results=[state.expanded_retrieval_result], # # base_retrieval_results=[state.expanded_retrieval_result],
@ -25,9 +25,9 @@ def format_orig_question_search_output(
else: else:
sub_question_retrieval_stats = sub_question_retrieval_stats sub_question_retrieval_stats = sub_question_retrieval_stats
return ExpandedRetrievalUpdate( return OrigQuestionRetrievalUpdate(
original_question_retrieval_results=state.expanded_retrieval_result.expanded_queries_results, orig_question_query_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
all_original_question_documents=state.expanded_retrieval_result.context_documents, orig_question_retrieval_documents=state.expanded_retrieval_result.context_documents,
original_question_retrieval_stats=sub_question_retrieval_stats, orig_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[], log_messages=[],
) )

View File

@ -1,7 +1,7 @@
from pydantic import BaseModel from pydantic import BaseModel
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
ExpandedRetrievalUpdate, OrigQuestionRetrievalUpdate,
) )
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import ( from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
ExpandedRetrievalResult, ExpandedRetrievalResult,
@ -39,6 +39,6 @@ class BaseRawSearchOutput(BaseModel):
class BaseRawSearchState( class BaseRawSearchState(
BaseRawSearchInput, BaseRawSearchOutput, ExpandedRetrievalUpdate BaseRawSearchInput, BaseRawSearchOutput, OrigQuestionRetrievalUpdate
): ):
pass pass

View File

@ -14,7 +14,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
) )
from onyx.agents.agent_search.deep_search.main.states import MainState from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinedAnswerUpdate, RequireRefinementUpdate,
) )
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
@ -44,7 +44,7 @@ def parallelize_initial_sub_question_answering(
state: MainState, state: MainState,
) -> list[Send | Hashable]: ) -> list[Send | Hashable]:
edge_start_time = datetime.now() edge_start_time = datetime.now()
if len(state.initial_decomp_questions) > 0: if len(state.initial_sub_questions) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0: # if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence: # if state["config"].use_persistence:
@ -66,7 +66,7 @@ def parallelize_initial_sub_question_answering(
], ],
), ),
) )
for question_nr, question in enumerate(state.initial_decomp_questions) for question_nr, question in enumerate(state.initial_sub_questions)
] ]
else: else:
@ -82,7 +82,7 @@ def parallelize_initial_sub_question_answering(
# Define the function that determines whether to continue or not # Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end( def continue_to_refined_answer_or_end(
state: RequireRefinedAnswerUpdate, state: RequireRefinementUpdate,
) -> Literal["create_refined_sub_questions", "logging_node"]: ) -> Literal["create_refined_sub_questions", "logging_node"]:
if state.require_refined_answer_eval: if state.require_refined_answer_eval:
return "create_refined_sub_questions" return "create_refined_sub_questions"

View File

@ -5,7 +5,9 @@ from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.states import AnswerComparison from onyx.agents.agent_search.deep_search.main.states import (
InitialVRefinedAnswerComparisonUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
@ -18,7 +20,7 @@ from onyx.chat.models import RefinedAnswerImprovement
def compare_answers( def compare_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> AnswerComparison: ) -> InitialVRefinedAnswerComparisonUpdate:
node_start_time = datetime.now() node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
@ -50,7 +52,7 @@ def compare_answers(
writer, writer,
) )
return AnswerComparison( return InitialVRefinedAnswerComparisonUpdate(
refined_answer_improvement_eval=refined_answer_improvement, refined_answer_improvement_eval=refined_answer_improvement,
log_messages=[ log_messages=[
get_langgraph_node_log_string( get_langgraph_node_log_string(

View File

@ -12,10 +12,10 @@ from onyx.agents.agent_search.deep_search.main.models import (
from onyx.agents.agent_search.deep_search.main.operations import ( from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion, dispatch_subquestion,
) )
from onyx.agents.agent_search.deep_search.main.states import (
FollowUpSubQuestionsUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt, build_history_prompt,
@ -37,7 +37,7 @@ from onyx.tools.models import ToolCallKickoff
def create_refined_sub_questions( def create_refined_sub_questions(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FollowUpSubQuestionsUpdate: ) -> RefinedQuestionDecompositionUpdate:
""" """ """ """
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
write_custom_event( write_custom_event(
@ -114,7 +114,7 @@ def create_refined_sub_questions(
refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question
return FollowUpSubQuestionsUpdate( return RefinedQuestionDecompositionUpdate(
refined_sub_questions=refined_sub_question_dict, refined_sub_questions=refined_sub_question_dict,
agent_refined_start_time=agent_refined_start_time, agent_refined_start_time=agent_refined_start_time,
log_messages=[ log_messages=[

View File

@ -5,7 +5,7 @@ from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.states import MainState from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinedAnswerUpdate, RequireRefinementUpdate,
) )
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -15,7 +15,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
def decide_refinement_need( def decide_refinement_need(
state: MainState, config: RunnableConfig state: MainState, config: RunnableConfig
) -> RequireRefinedAnswerUpdate: ) -> RequireRefinementUpdate:
node_start_time = datetime.now() node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
@ -32,12 +32,12 @@ def decide_refinement_need(
] ]
if agent_search_config.allow_refinement: if agent_search_config.allow_refinement:
return RequireRefinedAnswerUpdate( return RequireRefinementUpdate(
require_refined_answer_eval=decision, require_refined_answer_eval=decision,
log_messages=log_messages, log_messages=log_messages,
) )
else: else:
return RequireRefinedAnswerUpdate( return RequireRefinementUpdate(
require_refined_answer_eval=False, require_refined_answer_eval=False,
log_messages=log_messages, log_messages=log_messages,
) )

View File

@ -69,10 +69,9 @@ def generate_refined_answer(
prompt_enrichment_components.persona_prompts.contextualized_prompt prompt_enrichment_components.persona_prompts.contextualized_prompt
) )
initial_documents = state.verified_reranked_documents verified_reranked_documents = state.verified_reranked_documents
refined_documents = state.refined_documents
sub_questions_cited_documents = state.cited_documents sub_questions_cited_documents = state.cited_documents
all_original_question_documents = state.all_original_question_documents all_original_question_documents = state.orig_question_retrieval_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
@ -93,7 +92,7 @@ def generate_refined_answer(
consolidated_context_docs, consolidated_context_docs consolidated_context_docs, consolidated_context_docs
) )
query_info = get_query_info(state.original_question_retrieval_results) query_info = get_query_info(state.orig_question_query_retrieval_results)
if agent_search_config.search_tool is None: if agent_search_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
# stream refined answer docs # stream refined answer docs
@ -117,15 +116,14 @@ def generate_refined_answer(
writer, writer,
) )
if len(initial_documents) > 0: if len(verified_reranked_documents) > 0:
revision_doc_effectiveness = len(relevant_docs) / len(initial_documents) refined_doc_effectiveness = len(relevant_docs) / len(
elif len(refined_documents) == 0: verified_reranked_documents
revision_doc_effectiveness = 0.0 )
else: else:
revision_doc_effectiveness = 10.0 refined_doc_effectiveness = 10.0
decomp_answer_results = state.sub_question_results decomp_answer_results = state.sub_question_results
# revised_answer_results = state.refined_decomp_answer_results
answered_qa_list: list[str] = [] answered_qa_list: list[str] = []
decomp_questions = [] decomp_questions = []
@ -261,7 +259,7 @@ def generate_refined_answer(
# ) # )
refined_agent_stats = RefinedAgentStats( refined_agent_stats = RefinedAgentStats(
revision_doc_efficiency=revision_doc_effectiveness, revision_doc_efficiency=refined_doc_effectiveness,
revision_question_efficiency=revision_question_efficiency, revision_question_efficiency=revision_question_efficiency,
) )

View File

@ -4,7 +4,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
AnswerQuestionOutput, AnswerQuestionOutput,
) )
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
DecompAnswersUpdate, SubQuestionResultsUpdate,
) )
from onyx.agents.agent_search.shared_graph_utils.operators import ( from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections, dedup_inference_sections,
@ -16,7 +16,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
def ingest_refined_answers( def ingest_refined_answers(
state: AnswerQuestionOutput, state: AnswerQuestionOutput,
) -> DecompAnswersUpdate: ) -> SubQuestionResultsUpdate:
node_start_time = datetime.now() node_start_time = datetime.now()
documents = [] documents = []
@ -24,7 +24,7 @@ def ingest_refined_answers(
for answer_result in answer_results: for answer_result in answer_results:
documents.extend(answer_result.verified_reranked_documents) documents.extend(answer_result.verified_reranked_documents)
return DecompAnswersUpdate( return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph # Deduping is done by the documents operator for the main graph
# so we might not need to dedup here # so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []), verified_reranked_documents=dedup_inference_sections(documents, []),

View File

@ -55,10 +55,12 @@ class RefinedAgentEndStats(BaseModel):
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics() agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate): class InitialQuestionDecompositionUpdate(
RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate
):
agent_start_time: datetime | None = None agent_start_time: datetime | None = None
previous_history: str | None = None previous_history: str | None = None
initial_decomp_questions: list[str] = [] initial_sub_questions: list[str] = []
class ExploratorySearchUpdate(LoggerUpdate): class ExploratorySearchUpdate(LoggerUpdate):
@ -66,11 +68,11 @@ class ExploratorySearchUpdate(LoggerUpdate):
previous_history_summary: str | None = None previous_history_summary: str | None = None
class AnswerComparison(LoggerUpdate): class InitialVRefinedAnswerComparisonUpdate(LoggerUpdate):
refined_answer_improvement_eval: bool = False refined_answer_improvement_eval: bool = False
class RoutingDecision(LoggerUpdate): class RoutingDecisionUpdate(LoggerUpdate):
routing_decision: str | None = None routing_decision: str | None = None
@ -97,11 +99,11 @@ class InitialAnswerQualityUpdate(LoggerUpdate):
initial_answer_quality_eval: bool = False initial_answer_quality_eval: bool = False
class RequireRefinedAnswerUpdate(LoggerUpdate): class RequireRefinementUpdate(LoggerUpdate):
require_refined_answer_eval: bool = True require_refined_answer_eval: bool = True
class DecompAnswersUpdate(LoggerUpdate): class SubQuestionResultsUpdate(LoggerUpdate):
verified_reranked_documents: Annotated[ verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections list[InferenceSection], dedup_inference_sections
] = [] ] = []
@ -114,17 +116,12 @@ class DecompAnswersUpdate(LoggerUpdate):
] = [] ] = []
class FollowUpDecompAnswersUpdate(LoggerUpdate): class OrigQuestionRetrievalUpdate(LoggerUpdate):
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = [] orig_question_retrieval_documents: Annotated[
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
class ExpandedRetrievalUpdate(LoggerUpdate):
all_original_question_documents: Annotated[
list[InferenceSection], dedup_inference_sections list[InferenceSection], dedup_inference_sections
] ]
original_question_retrieval_results: list[QueryResult] = [] orig_question_query_retrieval_results: list[QueryResult] = []
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats() orig_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
class EntityTermExtractionUpdate(LoggerUpdate): class EntityTermExtractionUpdate(LoggerUpdate):
@ -133,7 +130,7 @@ class EntityTermExtractionUpdate(LoggerUpdate):
) )
class FollowUpSubQuestionsUpdate(RefinedAgentStartStats, LoggerUpdate): class RefinedQuestionDecompositionUpdate(RefinedAgentStartStats, LoggerUpdate):
refined_sub_questions: dict[int, FollowUpSubQuestion] = {} refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
@ -154,21 +151,20 @@ class MainState(
ToolChoiceInput, ToolChoiceInput,
ToolCallUpdate, ToolCallUpdate,
ToolChoiceUpdate, ToolChoiceUpdate,
BaseDecompUpdate, InitialQuestionDecompositionUpdate,
InitialAnswerUpdate, InitialAnswerUpdate,
InitialAnswerBASEUpdate, InitialAnswerBASEUpdate,
DecompAnswersUpdate, SubQuestionResultsUpdate,
ExpandedRetrievalUpdate, OrigQuestionRetrievalUpdate,
EntityTermExtractionUpdate, EntityTermExtractionUpdate,
InitialAnswerQualityUpdate, InitialAnswerQualityUpdate,
RequireRefinedAnswerUpdate, RequireRefinementUpdate,
FollowUpSubQuestionsUpdate, RefinedQuestionDecompositionUpdate,
FollowUpDecompAnswersUpdate,
RefinedAnswerUpdate, RefinedAnswerUpdate,
RefinedAgentStartStats, RefinedAgentStartStats,
RefinedAgentEndStats, RefinedAgentEndStats,
RoutingDecision, RoutingDecisionUpdate,
AnswerComparison, InitialVRefinedAnswerComparisonUpdate,
ExploratorySearchUpdate, ExploratorySearchUpdate,
): ):
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add] # expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]

View File

@ -31,7 +31,7 @@ def format_results(
writer: StreamWriter = lambda _: None, writer: StreamWriter = lambda _: None,
) -> ExpandedRetrievalUpdate: ) -> ExpandedRetrievalUpdate:
level, question_nr = parse_question_id(state.sub_question_id or "0_0") level, question_nr = parse_question_id(state.sub_question_id or "0_0")
query_info = get_query_info(state.expanded_retrieval_results) query_info = get_query_info(state.query_retrieval_results)
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
# main question docs will be sent later after aggregation and deduping with sub-question docs # main question docs will be sent later after aggregation and deduping with sub-question docs
@ -42,7 +42,7 @@ def format_results(
if len(reranked_documents) == 0: if len(reranked_documents) == 0:
# The sub-question is used as the last query. If no verified documents are found, stream # The sub-question is used as the last query. If no verified documents are found, stream
# the top 3 for that one. We may want to revisit this. # the top 3 for that one. We may want to revisit this.
reranked_documents = state.expanded_retrieval_results[-1].search_results[:3] reranked_documents = state.query_retrieval_results[-1].search_results[:3]
if agent_search_config.search_tool is None: if agent_search_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
@ -68,7 +68,7 @@ def format_results(
) )
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats( sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
verified_documents=state.verified_documents, verified_documents=state.verified_documents,
expanded_retrieval_results=state.expanded_retrieval_results, expanded_retrieval_results=state.query_retrieval_results,
) )
if sub_question_retrieval_stats is None: if sub_question_retrieval_stats is None:
@ -78,7 +78,7 @@ def format_results(
return ExpandedRetrievalUpdate( return ExpandedRetrievalUpdate(
expanded_retrieval_result=ExpandedRetrievalResult( expanded_retrieval_result=ExpandedRetrievalResult(
expanded_queries_results=state.expanded_retrieval_results, expanded_queries_results=state.query_retrieval_results,
verified_reranked_documents=reranked_documents, verified_reranked_documents=reranked_documents,
context_documents=state.reranked_documents, context_documents=state.reranked_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats, sub_question_retrieval_stats=sub_question_retrieval_stats,

View File

@ -53,7 +53,7 @@ def retrieve_documents(
logger.warning("Empty query, skipping retrieval") logger.warning("Empty query, skipping retrieval")
return DocRetrievalUpdate( return DocRetrievalUpdate(
expanded_retrieval_results=[], query_retrieval_results=[],
retrieved_documents=[], retrieved_documents=[],
log_messages=[ log_messages=[
get_langgraph_node_log_string( get_langgraph_node_log_string(
@ -109,7 +109,7 @@ def retrieve_documents(
) )
return DocRetrievalUpdate( return DocRetrievalUpdate(
expanded_retrieval_results=[expanded_retrieval_result], query_retrieval_results=[expanded_retrieval_result],
retrieved_documents=retrieved_docs, retrieved_documents=retrieved_docs,
log_messages=[ log_messages=[
get_langgraph_node_log_string( get_langgraph_node_log_string(

View File

@ -39,7 +39,7 @@ class DocVerificationUpdate(BaseModel):
class DocRetrievalUpdate(LoggerUpdate, BaseModel): class DocRetrievalUpdate(LoggerUpdate, BaseModel):
expanded_retrieval_results: Annotated[list[QueryResult], add] = [] query_retrieval_results: Annotated[list[QueryResult], add] = []
retrieved_documents: Annotated[ retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections list[InferenceSection], dedup_inference_sections
] = [] ] = []