variable renaming

This commit is contained in:
joachim-danswer
2025-01-29 09:04:46 -08:00
committed by Evan Lohn
parent ff4df6f3bf
commit 4e17fc06ff
11 changed files with 21 additions and 20 deletions

View File

@@ -16,7 +16,7 @@ def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate
return RetrievalIngestionUpdate( return RetrievalIngestionUpdate(
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results, expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
documents=state.expanded_retrieval_result.all_documents, documents=state.expanded_retrieval_result.reranked_documents,
context_documents=state.expanded_retrieval_result.context_documents, context_documents=state.expanded_retrieval_result.context_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats, sub_question_retrieval_stats=sub_question_retrieval_stats,
) )

View File

@@ -35,7 +35,7 @@ def initial_answer_quality_check(state: SearchSQState) -> InitialAnswerQualityUp
) )
return InitialAnswerQualityUpdate( return InitialAnswerQualityUpdate(
initial_answer_quality=verdict, initial_answer_quality_eval=verdict,
log_messages=[ log_messages=[
f"{now_start} -- Main - Initial answer quality check, Time taken: {now_end - now_start}" f"{now_start} -- Main - Initial answer quality check, Time taken: {now_end - now_start}"
], ],

View File

@@ -84,7 +84,7 @@ def parallelize_initial_sub_question_answering(
def continue_to_refined_answer_or_end( def continue_to_refined_answer_or_end(
state: RequireRefinedAnswerUpdate, state: RequireRefinedAnswerUpdate,
) -> Literal["refined_sub_question_creation", "logging_node"]: ) -> Literal["refined_sub_question_creation", "logging_node"]:
if state.require_refined_answer: if state.require_refined_answer_eval:
return "refined_sub_question_creation" return "refined_sub_question_creation"
else: else:
return "logging_node" return "logging_node"

View File

@@ -29,7 +29,7 @@ def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDeci
) )
return RoutingDecision( return RoutingDecision(
# Decide which route to take # Decide which route to take
routing=routing, routing_decision=routing,
log_messages=[ log_messages=[
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}" f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
], ],

View File

@@ -10,7 +10,7 @@ def agent_path_routing(
state: MainState, state: MainState,
) -> Command[Literal["agent_search_start", "LLM"]]: ) -> Command[Literal["agent_search_start", "LLM"]]:
now_start = datetime.now() now_start = datetime.now()
routing = state.routing if hasattr(state, "routing") else "agent_search" routing = state.routing_decision if hasattr(state, "routing") else "agent_search"
if routing == "agent_search": if routing == "agent_search":
agent_path = "agent_search_start" agent_path = "agent_search_start"

View File

@@ -52,7 +52,7 @@ def agent_search_start(
return ExploratorySearchUpdate( return ExploratorySearchUpdate(
exploratory_search_results=exploratory_search_results, exploratory_search_results=exploratory_search_results,
previous_history=history, previous_history_summary=history,
log_messages=[ log_messages=[
f"{now_start} -- Main - Exploratory Search, Time taken: {now_end - now_start}" f"{now_start} -- Main - Exploratory Search, Time taken: {now_end - now_start}"
], ],

View File

@@ -53,7 +53,7 @@ def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerCompari
) )
return AnswerComparison( return AnswerComparison(
refined_answer_improvement=refined_answer_improvement, refined_answer_improvement_eval=refined_answer_improvement,
log_messages=[ log_messages=[
f"{now_start} -- Answer comparison: {refined_answer_improvement}, Time taken: {now_end - now_start}" f"{now_start} -- Answer comparison: {refined_answer_improvement}, Time taken: {now_end - now_start}"
], ],

View File

@@ -36,12 +36,12 @@ def refined_answer_decision(
] ]
if agent_a_config.allow_refinement: if agent_a_config.allow_refinement:
return RequireRefinedAnswerUpdate( return RequireRefinedAnswerUpdate(
require_refined_answer=decision, require_refined_answer_eval=decision,
log_messages=log_messages, log_messages=log_messages,
) )
else: else:
return RequireRefinedAnswerUpdate( return RequireRefinedAnswerUpdate(
require_refined_answer=False, require_refined_answer_eval=False,
log_messages=log_messages, log_messages=log_messages,
) )

View File

@@ -63,17 +63,18 @@ class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
class ExploratorySearchUpdate(LoggerUpdate): class ExploratorySearchUpdate(LoggerUpdate):
exploratory_search_results: list[InferenceSection] = [] exploratory_search_results: list[InferenceSection] = []
previous_history: str = "" previous_history_summary: str = ""
class AnswerComparison(LoggerUpdate): class AnswerComparison(LoggerUpdate):
refined_answer_improvement: bool = False refined_answer_improvement_eval: bool = False
class RoutingDecision(LoggerUpdate): class RoutingDecision(LoggerUpdate):
routing: str = "" routing_decision: str = ""
# Not used in current graph
class InitialAnswerBASEUpdate(BaseModel): class InitialAnswerBASEUpdate(BaseModel):
initial_base_answer: str = "" initial_base_answer: str = ""
@@ -93,11 +94,11 @@ class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
class InitialAnswerQualityUpdate(LoggerUpdate): class InitialAnswerQualityUpdate(LoggerUpdate):
initial_answer_quality: bool = False initial_answer_quality_eval: bool = False
class RequireRefinedAnswerUpdate(LoggerUpdate): class RequireRefinedAnswerUpdate(LoggerUpdate):
require_refined_answer: bool = True require_refined_answer_eval: bool = True
class DecompAnswersUpdate(LoggerUpdate): class DecompAnswersUpdate(LoggerUpdate):

View File

@@ -7,6 +7,6 @@ from onyx.context.search.models import InferenceSection
class ExpandedRetrievalResult(BaseModel): class ExpandedRetrievalResult(BaseModel):
expanded_queries_results: list[QueryResult] = [] expanded_queries_results: list[QueryResult] = []
all_documents: list[InferenceSection] = [] reranked_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = [] context_documents: list[InferenceSection] = []
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats() sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()

View File

@@ -37,20 +37,20 @@ def format_results(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
# main question docs will be sent later after aggregation and deduping with sub-question docs # main question docs will be sent later after aggregation and deduping with sub-question docs
stream_documents = state.reranked_documents reranked_documents = state.reranked_documents
if not (level == 0 and question_nr == 0): if not (level == 0 and question_nr == 0):
if len(stream_documents) == 0: if len(reranked_documents) == 0:
# The sub-question is used as the last query. If no verified documents are found, stream # The sub-question is used as the last query. If no verified documents are found, stream
# the top 3 for that one. We may want to revisit this. # the top 3 for that one. We may want to revisit this.
stream_documents = state.expanded_retrieval_results[-1].search_results[:3] reranked_documents = state.expanded_retrieval_results[-1].search_results[:3]
if agent_a_config.search_tool is None: if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
for tool_response in yield_search_responses( for tool_response in yield_search_responses(
query=state.question, query=state.question,
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.) reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
final_context_sections=stream_documents, final_context_sections=reranked_documents,
search_query_info=query_infos[0], # TODO: handle differing query infos? search_query_info=query_infos[0], # TODO: handle differing query infos?
get_section_relevance=lambda: None, # TODO: add relevance get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool, search_tool=agent_a_config.search_tool,
@@ -77,7 +77,7 @@ def format_results(
return ExpandedRetrievalUpdate( return ExpandedRetrievalUpdate(
expanded_retrieval_result=ExpandedRetrievalResult( expanded_retrieval_result=ExpandedRetrievalResult(
expanded_queries_results=state.expanded_retrieval_results, expanded_queries_results=state.expanded_retrieval_results,
all_documents=stream_documents, reranked_documents=reranked_documents,
context_documents=state.reranked_documents, context_documents=state.reranked_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats, sub_question_retrieval_stats=sub_question_retrieval_stats,
), ),