mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 21:09:51 +02:00
initial variable renaming
This commit is contained in:
parent
d5661baf98
commit
8342168658
@ -5,7 +5,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
DecompAnswersUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
@ -14,7 +14,7 @@ from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
|
||||
def format_initial_sub_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
) -> SubQuestionResultsUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------INGEST ANSWERS---")
|
||||
@ -32,7 +32,7 @@ def format_initial_sub_answers(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---"
|
||||
)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
return SubQuestionResultsUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
verified_reranked_documents=dedup_inference_sections(documents, []),
|
||||
|
@ -19,7 +19,7 @@ def parallelize_initial_sub_question_answering(
|
||||
state: SearchSQState,
|
||||
) -> list[Send | Hashable]:
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
@ -41,7 +41,7 @@ def parallelize_initial_sub_question_answering(
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
for question_nr, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
|
@ -68,11 +68,13 @@ def generate_initial_answer(
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config)
|
||||
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
orig_question_retrieval_documents = state.orig_question_retrieval_documents
|
||||
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(all_original_question_documents):
|
||||
for original_doc_number, original_doc in enumerate(
|
||||
orig_question_retrieval_documents
|
||||
):
|
||||
if original_doc_number not in sub_questions_cited_documents:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
@ -89,7 +91,7 @@ def generate_initial_answer(
|
||||
decomp_questions = []
|
||||
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
query_info = get_query_info(state.orig_question_query_retrieval_results)
|
||||
|
||||
if agent_search_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
@ -229,7 +231,7 @@ def generate_initial_answer(
|
||||
answer = cast(str, response)
|
||||
|
||||
initial_agent_stats = calculate_initial_agent_stats(
|
||||
state.sub_question_results, state.original_question_retrieval_stats
|
||||
state.sub_question_results, state.orig_question_retrieval_stats
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@ -250,8 +252,8 @@ def generate_initial_answer(
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_core=state.orig_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents"
|
||||
),
|
||||
|
@ -1,42 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
|
||||
|
||||
def ingest_retrieved_documents(
|
||||
state: BaseRawSearchOutput,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_question_retrieval_stats = (
|
||||
state.base_expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
# if sub_question_retrieval_stats is None:
|
||||
# sub_question_retrieval_stats = AgentChunkStats()
|
||||
# else:
|
||||
# sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats or AgentChunkStats()
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
|
||||
all_original_question_documents=state.base_expanded_retrieval_result.context_documents,
|
||||
original_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="ingest retrieved documents",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
@ -3,13 +3,6 @@ from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import BaseDecompUpdate
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
DecompAnswersUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
ExploratorySearchUpdate,
|
||||
)
|
||||
@ -19,6 +12,15 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
OrigQuestionRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
@ -36,10 +38,10 @@ class SearchSQInput(CoreState):
|
||||
class SearchSQState(
|
||||
# This includes the core state
|
||||
SearchSQInput,
|
||||
BaseDecompUpdate,
|
||||
InitialQuestionDecompositionUpdate,
|
||||
InitialAnswerUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
OrigQuestionRetrievalUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
ExploratorySearchUpdate,
|
||||
):
|
||||
|
@ -19,7 +19,7 @@ def parallelize_initial_sub_question_answering(
|
||||
state: SearchSQState,
|
||||
) -> list[Send | Hashable]:
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
@ -41,7 +41,7 @@ def parallelize_initial_sub_question_answering(
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
for question_nr, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
|
@ -15,7 +15,9 @@ from onyx.agents.agent_search.deep_search.main.models import (
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
dispatch_subquestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import BaseDecompUpdate
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
@ -39,7 +41,7 @@ from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
|
||||
def decompose_orig_question(
|
||||
state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BaseDecompUpdate:
|
||||
) -> InitialQuestionDecompositionUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
@ -123,8 +125,8 @@ def decompose_orig_question(
|
||||
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
|
||||
return BaseDecompUpdate(
|
||||
initial_decomp_questions=decomp_list,
|
||||
return InitialQuestionDecompositionUpdate(
|
||||
initial_sub_questions=decomp_list,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
|
@ -4,7 +4,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
DecompAnswersUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
@ -16,7 +16,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
|
||||
def format_initial_sub_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
) -> SubQuestionResultsUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
documents = []
|
||||
@ -28,7 +28,7 @@ def format_initial_sub_answers(
|
||||
context_documents.extend(answer_result.context_documents)
|
||||
cited_documents.extend(answer_result.cited_documents)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
return SubQuestionResultsUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
verified_reranked_documents=dedup_inference_sections(documents, []),
|
||||
|
@ -1,13 +1,15 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import BaseDecompUpdate
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
DecompAnswersUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
|
||||
### States ###
|
||||
|
||||
@ -22,9 +24,9 @@ class SQInput(CoreState):
|
||||
class SQState(
|
||||
# This includes the core state
|
||||
SQInput,
|
||||
BaseDecompUpdate,
|
||||
InitialQuestionDecompositionUpdate,
|
||||
InitialAnswerUpdate,
|
||||
DecompAnswersUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
pass
|
||||
|
@ -1,4 +1,4 @@
|
||||
from onyx.agents.agent_search.deep_search.main.states import ExpandedRetrievalUpdate
|
||||
from onyx.agents.agent_search.deep_search.main.states import OrigQuestionRetrievalUpdate
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
@ -10,7 +10,7 @@ logger = setup_logger()
|
||||
|
||||
def format_orig_question_search_output(
|
||||
state: ExpandedRetrievalOutput,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
) -> OrigQuestionRetrievalUpdate:
|
||||
# return BaseRawSearchOutput(
|
||||
# base_expanded_retrieval_result=state.expanded_retrieval_result,
|
||||
# # base_retrieval_results=[state.expanded_retrieval_result],
|
||||
@ -25,9 +25,9 @@ def format_orig_question_search_output(
|
||||
else:
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
original_question_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
|
||||
all_original_question_documents=state.expanded_retrieval_result.context_documents,
|
||||
original_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
return OrigQuestionRetrievalUpdate(
|
||||
orig_question_query_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
|
||||
orig_question_retrieval_documents=state.expanded_retrieval_result.context_documents,
|
||||
orig_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
log_messages=[],
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
OrigQuestionRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
@ -39,6 +39,6 @@ class BaseRawSearchOutput(BaseModel):
|
||||
|
||||
|
||||
class BaseRawSearchState(
|
||||
BaseRawSearchInput, BaseRawSearchOutput, ExpandedRetrievalUpdate
|
||||
BaseRawSearchInput, BaseRawSearchOutput, OrigQuestionRetrievalUpdate
|
||||
):
|
||||
pass
|
||||
|
@ -14,7 +14,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
RequireRefinementUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
@ -44,7 +44,7 @@ def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
@ -66,7 +66,7 @@ def parallelize_initial_sub_question_answering(
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
for question_nr, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
@ -82,7 +82,7 @@ def parallelize_initial_sub_question_answering(
|
||||
|
||||
# Define the function that determines whether to continue or not
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinedAnswerUpdate,
|
||||
state: RequireRefinementUpdate,
|
||||
) -> Literal["create_refined_sub_questions", "logging_node"]:
|
||||
if state.require_refined_answer_eval:
|
||||
return "create_refined_sub_questions"
|
||||
|
@ -5,7 +5,9 @@ from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.states import AnswerComparison
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialVRefinedAnswerComparisonUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
||||
@ -18,7 +20,7 @@ from onyx.chat.models import RefinedAnswerImprovement
|
||||
|
||||
def compare_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> AnswerComparison:
|
||||
) -> InitialVRefinedAnswerComparisonUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
@ -50,7 +52,7 @@ def compare_answers(
|
||||
writer,
|
||||
)
|
||||
|
||||
return AnswerComparison(
|
||||
return InitialVRefinedAnswerComparisonUpdate(
|
||||
refined_answer_improvement_eval=refined_answer_improvement,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
|
@ -12,10 +12,10 @@ from onyx.agents.agent_search.deep_search.main.models import (
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
dispatch_subquestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
FollowUpSubQuestionsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RefinedQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
@ -37,7 +37,7 @@ from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
def create_refined_sub_questions(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> FollowUpSubQuestionsUpdate:
|
||||
) -> RefinedQuestionDecompositionUpdate:
|
||||
""" """
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
write_custom_event(
|
||||
@ -114,7 +114,7 @@ def create_refined_sub_questions(
|
||||
|
||||
refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question
|
||||
|
||||
return FollowUpSubQuestionsUpdate(
|
||||
return RefinedQuestionDecompositionUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
agent_refined_start_time=agent_refined_start_time,
|
||||
log_messages=[
|
||||
|
@ -5,7 +5,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
RequireRefinementUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -15,7 +15,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
|
||||
def decide_refinement_need(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RequireRefinedAnswerUpdate:
|
||||
) -> RequireRefinementUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
@ -32,12 +32,12 @@ def decide_refinement_need(
|
||||
]
|
||||
|
||||
if agent_search_config.allow_refinement:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
return RequireRefinementUpdate(
|
||||
require_refined_answer_eval=decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
else:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
return RequireRefinementUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
|
@ -69,10 +69,9 @@ def generate_refined_answer(
|
||||
prompt_enrichment_components.persona_prompts.contextualized_prompt
|
||||
)
|
||||
|
||||
initial_documents = state.verified_reranked_documents
|
||||
refined_documents = state.refined_documents
|
||||
verified_reranked_documents = state.verified_reranked_documents
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
all_original_question_documents = state.orig_question_retrieval_documents
|
||||
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||
|
||||
@ -93,7 +92,7 @@ def generate_refined_answer(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
query_info = get_query_info(state.orig_question_query_retrieval_results)
|
||||
if agent_search_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
# stream refined answer docs
|
||||
@ -117,15 +116,14 @@ def generate_refined_answer(
|
||||
writer,
|
||||
)
|
||||
|
||||
if len(initial_documents) > 0:
|
||||
revision_doc_effectiveness = len(relevant_docs) / len(initial_documents)
|
||||
elif len(refined_documents) == 0:
|
||||
revision_doc_effectiveness = 0.0
|
||||
if len(verified_reranked_documents) > 0:
|
||||
refined_doc_effectiveness = len(relevant_docs) / len(
|
||||
verified_reranked_documents
|
||||
)
|
||||
else:
|
||||
revision_doc_effectiveness = 10.0
|
||||
refined_doc_effectiveness = 10.0
|
||||
|
||||
decomp_answer_results = state.sub_question_results
|
||||
# revised_answer_results = state.refined_decomp_answer_results
|
||||
|
||||
answered_qa_list: list[str] = []
|
||||
decomp_questions = []
|
||||
@ -261,7 +259,7 @@ def generate_refined_answer(
|
||||
# )
|
||||
|
||||
refined_agent_stats = RefinedAgentStats(
|
||||
revision_doc_efficiency=revision_doc_effectiveness,
|
||||
revision_doc_efficiency=refined_doc_effectiveness,
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
|
@ -4,7 +4,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
DecompAnswersUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
@ -16,7 +16,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
|
||||
def ingest_refined_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
) -> SubQuestionResultsUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
documents = []
|
||||
@ -24,7 +24,7 @@ def ingest_refined_answers(
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.verified_reranked_documents)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
return SubQuestionResultsUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
verified_reranked_documents=dedup_inference_sections(documents, []),
|
||||
|
@ -55,10 +55,12 @@ class RefinedAgentEndStats(BaseModel):
|
||||
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
|
||||
|
||||
|
||||
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate):
|
||||
class InitialQuestionDecompositionUpdate(
|
||||
RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate
|
||||
):
|
||||
agent_start_time: datetime | None = None
|
||||
previous_history: str | None = None
|
||||
initial_decomp_questions: list[str] = []
|
||||
initial_sub_questions: list[str] = []
|
||||
|
||||
|
||||
class ExploratorySearchUpdate(LoggerUpdate):
|
||||
@ -66,11 +68,11 @@ class ExploratorySearchUpdate(LoggerUpdate):
|
||||
previous_history_summary: str | None = None
|
||||
|
||||
|
||||
class AnswerComparison(LoggerUpdate):
|
||||
class InitialVRefinedAnswerComparisonUpdate(LoggerUpdate):
|
||||
refined_answer_improvement_eval: bool = False
|
||||
|
||||
|
||||
class RoutingDecision(LoggerUpdate):
|
||||
class RoutingDecisionUpdate(LoggerUpdate):
|
||||
routing_decision: str | None = None
|
||||
|
||||
|
||||
@ -97,11 +99,11 @@ class InitialAnswerQualityUpdate(LoggerUpdate):
|
||||
initial_answer_quality_eval: bool = False
|
||||
|
||||
|
||||
class RequireRefinedAnswerUpdate(LoggerUpdate):
|
||||
class RequireRefinementUpdate(LoggerUpdate):
|
||||
require_refined_answer_eval: bool = True
|
||||
|
||||
|
||||
class DecompAnswersUpdate(LoggerUpdate):
|
||||
class SubQuestionResultsUpdate(LoggerUpdate):
|
||||
verified_reranked_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
@ -114,17 +116,12 @@ class DecompAnswersUpdate(LoggerUpdate):
|
||||
] = []
|
||||
|
||||
|
||||
class FollowUpDecompAnswersUpdate(LoggerUpdate):
|
||||
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(LoggerUpdate):
|
||||
all_original_question_documents: Annotated[
|
||||
class OrigQuestionRetrievalUpdate(LoggerUpdate):
|
||||
orig_question_retrieval_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
original_question_retrieval_results: list[QueryResult] = []
|
||||
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
orig_question_query_retrieval_results: list[QueryResult] = []
|
||||
orig_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
class EntityTermExtractionUpdate(LoggerUpdate):
|
||||
@ -133,7 +130,7 @@ class EntityTermExtractionUpdate(LoggerUpdate):
|
||||
)
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdate(RefinedAgentStartStats, LoggerUpdate):
|
||||
class RefinedQuestionDecompositionUpdate(RefinedAgentStartStats, LoggerUpdate):
|
||||
refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
|
||||
|
||||
|
||||
@ -154,21 +151,20 @@ class MainState(
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
BaseDecompUpdate,
|
||||
InitialQuestionDecompositionUpdate,
|
||||
InitialAnswerUpdate,
|
||||
InitialAnswerBASEUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
OrigQuestionRetrievalUpdate,
|
||||
EntityTermExtractionUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
RequireRefinedAnswerUpdate,
|
||||
FollowUpSubQuestionsUpdate,
|
||||
FollowUpDecompAnswersUpdate,
|
||||
RequireRefinementUpdate,
|
||||
RefinedQuestionDecompositionUpdate,
|
||||
RefinedAnswerUpdate,
|
||||
RefinedAgentStartStats,
|
||||
RefinedAgentEndStats,
|
||||
RoutingDecision,
|
||||
AnswerComparison,
|
||||
RoutingDecisionUpdate,
|
||||
InitialVRefinedAnswerComparisonUpdate,
|
||||
ExploratorySearchUpdate,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
|
@ -31,7 +31,7 @@ def format_results(
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||
query_info = get_query_info(state.expanded_retrieval_results)
|
||||
query_info = get_query_info(state.query_retrieval_results)
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||
@ -42,7 +42,7 @@ def format_results(
|
||||
if len(reranked_documents) == 0:
|
||||
# The sub-question is used as the last query. If no verified documents are found, stream
|
||||
# the top 3 for that one. We may want to revisit this.
|
||||
reranked_documents = state.expanded_retrieval_results[-1].search_results[:3]
|
||||
reranked_documents = state.query_retrieval_results[-1].search_results[:3]
|
||||
|
||||
if agent_search_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
@ -68,7 +68,7 @@ def format_results(
|
||||
)
|
||||
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state.verified_documents,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
expanded_retrieval_results=state.query_retrieval_results,
|
||||
)
|
||||
|
||||
if sub_question_retrieval_stats is None:
|
||||
@ -78,7 +78,7 @@ def format_results(
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
expanded_retrieval_result=ExpandedRetrievalResult(
|
||||
expanded_queries_results=state.expanded_retrieval_results,
|
||||
expanded_queries_results=state.query_retrieval_results,
|
||||
verified_reranked_documents=reranked_documents,
|
||||
context_documents=state.reranked_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
|
@ -53,7 +53,7 @@ def retrieve_documents(
|
||||
logger.warning("Empty query, skipping retrieval")
|
||||
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[],
|
||||
query_retrieval_results=[],
|
||||
retrieved_documents=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
@ -109,7 +109,7 @@ def retrieve_documents(
|
||||
)
|
||||
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[expanded_retrieval_result],
|
||||
query_retrieval_results=[expanded_retrieval_result],
|
||||
retrieved_documents=retrieved_docs,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
|
@ -39,7 +39,7 @@ class DocVerificationUpdate(BaseModel):
|
||||
|
||||
|
||||
class DocRetrievalUpdate(LoggerUpdate, BaseModel):
|
||||
expanded_retrieval_results: Annotated[list[QueryResult], add] = []
|
||||
query_retrieval_results: Annotated[list[QueryResult], add] = []
|
||||
retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user