mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 05:05:34 +02:00
main nodes renaming
This commit is contained in:
@@ -53,8 +53,8 @@ def consolidate_sub_answers_graph_builder() -> StateGraph:
|
|||||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||||
|
|
||||||
# graph.add_edge(
|
# graph.add_edge(
|
||||||
# start_key="agent_search_start",
|
# start_key="start_agent_search",
|
||||||
# end_key="entity_term_extraction_llm",
|
# end_key="extract_entity_term",
|
||||||
# )
|
# )
|
||||||
|
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
|
@@ -25,7 +25,7 @@ logger = setup_logger()
|
|||||||
|
|
||||||
def route_initial_tool_choice(
|
def route_initial_tool_choice(
|
||||||
state: MainState, config: RunnableConfig
|
state: MainState, config: RunnableConfig
|
||||||
) -> Literal["tool_call", "agent_search_start", "logging_node"]:
|
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||||
if state.tool_choice is not None:
|
if state.tool_choice is not None:
|
||||||
if (
|
if (
|
||||||
@@ -33,7 +33,7 @@ def route_initial_tool_choice(
|
|||||||
and agent_config.search_tool is not None
|
and agent_config.search_tool is not None
|
||||||
and state.tool_choice.tool.name == agent_config.search_tool.name
|
and state.tool_choice.tool.name == agent_config.search_tool.name
|
||||||
):
|
):
|
||||||
return "agent_search_start"
|
return "start_agent_search"
|
||||||
else:
|
else:
|
||||||
return "tool_call"
|
return "tool_call"
|
||||||
else:
|
else:
|
||||||
@@ -83,9 +83,9 @@ def parallelize_initial_sub_question_answering(
|
|||||||
# Define the function that determines whether to continue or not
|
# Define the function that determines whether to continue or not
|
||||||
def continue_to_refined_answer_or_end(
|
def continue_to_refined_answer_or_end(
|
||||||
state: RequireRefinedAnswerUpdate,
|
state: RequireRefinedAnswerUpdate,
|
||||||
) -> Literal["refined_sub_question_creation", "logging_node"]:
|
) -> Literal["create_refined_sub_questions", "logging_node"]:
|
||||||
if state.require_refined_answer_eval:
|
if state.require_refined_answer_eval:
|
||||||
return "refined_sub_question_creation"
|
return "create_refined_sub_questions"
|
||||||
else:
|
else:
|
||||||
return "logging_node"
|
return "logging_node"
|
||||||
|
|
||||||
|
@@ -14,17 +14,14 @@ from onyx.agents.agent_search.deep_search_a.main.edges import (
|
|||||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||||
route_initial_tool_choice,
|
route_initial_tool_choice,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import (
|
from onyx.agents.agent_search.deep_search_a.main.nodes.compare_answers import (
|
||||||
agent_logging,
|
compare_answers,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import (
|
from onyx.agents.agent_search.deep_search_a.main.nodes.create_refined_sub_questions import (
|
||||||
agent_search_start,
|
create_refined_sub_questions,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search_a.main.nodes.answer_comparison import (
|
from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entity_term import (
|
||||||
answer_comparison,
|
extract_entity_term,
|
||||||
)
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
|
|
||||||
entity_term_extraction_llm,
|
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
|
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
|
||||||
generate_refined_answer,
|
generate_refined_answer,
|
||||||
@@ -32,11 +29,14 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer i
|
|||||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import (
|
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import (
|
||||||
ingest_refined_answers,
|
ingest_refined_answers,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision import (
|
from onyx.agents.agent_search.deep_search_a.main.nodes.persist_agent_results import (
|
||||||
refined_answer_decision,
|
persist_agent_results,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
|
from onyx.agents.agent_search.deep_search_a.main.nodes.start_agent_search import (
|
||||||
refined_sub_question_creation,
|
start_agent_search,
|
||||||
|
)
|
||||||
|
from onyx.agents.agent_search.deep_search_a.main.nodes.validate_refined_answer import (
|
||||||
|
validate_refined_answer,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
|
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
|
||||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||||
@@ -65,20 +65,6 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
|||||||
input=MainInput,
|
input=MainInput,
|
||||||
)
|
)
|
||||||
|
|
||||||
# graph.add_node(
|
|
||||||
# node="agent_path_decision",
|
|
||||||
# action=agent_path_decision,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_node(
|
|
||||||
# node="agent_path_routing",
|
|
||||||
# action=agent_path_routing,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_node(
|
|
||||||
# node="LLM",
|
|
||||||
# action=direct_llm_handling,
|
|
||||||
# )
|
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="prepare_tool_input",
|
node="prepare_tool_input",
|
||||||
action=prepare_tool_input,
|
action=prepare_tool_input,
|
||||||
@@ -97,42 +83,19 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
|||||||
action=basic_use_tool_response,
|
action=basic_use_tool_response,
|
||||||
)
|
)
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="agent_search_start",
|
node="start_agent_search",
|
||||||
action=agent_search_start,
|
action=start_agent_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
# graph.add_node(
|
|
||||||
# node="initial_sub_question_creation",
|
|
||||||
# action=initial_sub_question_creation,
|
|
||||||
# )
|
|
||||||
|
|
||||||
generate_initial_answer_subgraph = generate_initial_answer_graph_builder().compile()
|
generate_initial_answer_subgraph = generate_initial_answer_graph_builder().compile()
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="generate_initial_answer_subgraph",
|
node="generate_initial_answer_subgraph",
|
||||||
action=generate_initial_answer_subgraph,
|
action=generate_initial_answer_subgraph,
|
||||||
)
|
)
|
||||||
|
|
||||||
# answer_query_subgraph = answer_query_graph_builder().compile()
|
|
||||||
# graph.add_node(
|
|
||||||
# node="answer_query_subgraph",
|
|
||||||
# action=answer_query_subgraph,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
|
||||||
# graph.add_node(
|
|
||||||
# node="base_raw_search_subgraph",
|
|
||||||
# action=base_raw_search_subgraph,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# refined_answer_subgraph = refined_answers_graph_builder().compile()
|
|
||||||
# graph.add_node(
|
|
||||||
# node="refined_answer_subgraph",
|
|
||||||
# action=refined_answer_subgraph,
|
|
||||||
# )
|
|
||||||
|
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="refined_sub_question_creation",
|
node="create_refined_sub_questions",
|
||||||
action=refined_sub_question_creation,
|
action=create_refined_sub_questions,
|
||||||
)
|
)
|
||||||
|
|
||||||
answer_refined_question = answer_refined_query_graph_builder().compile()
|
answer_refined_question = answer_refined_query_graph_builder().compile()
|
||||||
@@ -151,70 +114,25 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
|||||||
action=generate_refined_answer,
|
action=generate_refined_answer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# graph.add_node(
|
|
||||||
# node="check_refined_answer",
|
|
||||||
# action=check_refined_answer,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_node(
|
|
||||||
# node="ingest_initial_retrieval",
|
|
||||||
# action=ingest_initial_base_retrieval,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_node(
|
|
||||||
# node="retrieval_consolidation",
|
|
||||||
# action=retrieval_consolidation,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_node(
|
|
||||||
# node="ingest_initial_sub_question_answers",
|
|
||||||
# action=ingest_initial_sub_question_answers,
|
|
||||||
# )
|
|
||||||
# graph.add_node(
|
|
||||||
# node="generate_initial_answer",
|
|
||||||
# action=generate_initial_answer,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_node(
|
|
||||||
# node="initial_answer_quality_check",
|
|
||||||
# action=initial_answer_quality_check,
|
|
||||||
# )
|
|
||||||
|
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="entity_term_extraction_llm",
|
node="extract_entity_term",
|
||||||
action=entity_term_extraction_llm,
|
action=extract_entity_term,
|
||||||
)
|
)
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="refined_answer_decision",
|
node="validate_refined_answer",
|
||||||
action=refined_answer_decision,
|
action=validate_refined_answer,
|
||||||
)
|
)
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="answer_comparison",
|
node="compare_answers",
|
||||||
action=answer_comparison,
|
action=compare_answers,
|
||||||
)
|
)
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="logging_node",
|
node="logging_node",
|
||||||
action=agent_logging,
|
action=persist_agent_results,
|
||||||
)
|
)
|
||||||
# if test_mode:
|
|
||||||
# graph.add_node(
|
|
||||||
# node="generate_initial_base_answer",
|
|
||||||
# action=generate_initial_base_answer,
|
|
||||||
# )
|
|
||||||
|
|
||||||
### Add edges ###
|
### Add edges ###
|
||||||
|
|
||||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key=START,
|
|
||||||
# end_key="agent_path_decision",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="agent_path_decision",
|
|
||||||
# end_key="agent_path_routing",
|
|
||||||
# )
|
|
||||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||||
|
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
@@ -225,7 +143,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
|||||||
graph.add_conditional_edges(
|
graph.add_conditional_edges(
|
||||||
"initial_tool_choice",
|
"initial_tool_choice",
|
||||||
route_initial_tool_choice,
|
route_initial_tool_choice,
|
||||||
["tool_call", "agent_search_start", "logging_node"],
|
["tool_call", "start_agent_search", "logging_node"],
|
||||||
)
|
)
|
||||||
|
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
@@ -238,96 +156,38 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
|
<<<<<<< HEAD
|
||||||
start_key="agent_search_start",
|
start_key="agent_search_start",
|
||||||
end_key="generate_initial_answer_subgraph",
|
end_key="generate_initial_answer_subgraph",
|
||||||
)
|
=======
|
||||||
# graph.add_edge(
|
start_key="start_agent_search",
|
||||||
# start_key="agent_search_start",
|
end_key="initial_search_sq_subgraph",
|
||||||
# end_key="base_raw_search_subgraph",
|
>>>>>>> ab2510c4d (main nodes renaming)
|
||||||
# )
|
|
||||||
|
|
||||||
graph.add_edge(
|
|
||||||
start_key="agent_search_start",
|
|
||||||
end_key="entity_term_extraction_llm",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# graph.add_edge(
|
graph.add_edge(
|
||||||
# start_key="agent_search_start",
|
start_key="start_agent_search",
|
||||||
# end_key="initial_sub_question_creation",
|
end_key="extract_entity_term",
|
||||||
# )
|
)
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="base_raw_search_subgraph",
|
|
||||||
# end_key="ingest_initial_retrieval",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
|
|
||||||
# end_key="retrieval_consolidation",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="retrieval_consolidation",
|
|
||||||
# end_key="generate_initial_answer",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="LLM",
|
|
||||||
# end_key=END,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key=START,
|
|
||||||
# end_key="initial_sub_question_creation",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_conditional_edges(
|
|
||||||
# source="initial_sub_question_creation",
|
|
||||||
# path=parallelize_initial_sub_question_answering,
|
|
||||||
# path_map=["answer_query_subgraph"],
|
|
||||||
# )
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="answer_query_subgraph",
|
|
||||||
# end_key="ingest_initial_sub_question_answers",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="retrieval_consolidation",
|
|
||||||
# end_key="generate_initial_answer",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="generate_initial_answer",
|
|
||||||
# end_key="entity_term_extraction_llm",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="generate_initial_answer",
|
|
||||||
# end_key="initial_answer_quality_check",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
|
|
||||||
# end_key="refined_answer_decision",
|
|
||||||
# )
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="initial_answer_quality_check",
|
|
||||||
# end_key="refined_answer_decision",
|
|
||||||
# )
|
|
||||||
|
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
|
<<<<<<< HEAD
|
||||||
start_key=["generate_initial_answer_subgraph", "entity_term_extraction_llm"],
|
start_key=["generate_initial_answer_subgraph", "entity_term_extraction_llm"],
|
||||||
end_key="refined_answer_decision",
|
end_key="refined_answer_decision",
|
||||||
|
=======
|
||||||
|
start_key=["initial_search_sq_subgraph", "extract_entity_term"],
|
||||||
|
end_key="validate_refined_answer",
|
||||||
|
>>>>>>> ab2510c4d (main nodes renaming)
|
||||||
)
|
)
|
||||||
|
|
||||||
graph.add_conditional_edges(
|
graph.add_conditional_edges(
|
||||||
source="refined_answer_decision",
|
source="validate_refined_answer",
|
||||||
path=continue_to_refined_answer_or_end,
|
path=continue_to_refined_answer_or_end,
|
||||||
path_map=["refined_sub_question_creation", "logging_node"],
|
path_map=["create_refined_sub_questions", "logging_node"],
|
||||||
)
|
)
|
||||||
|
|
||||||
graph.add_conditional_edges(
|
graph.add_conditional_edges(
|
||||||
source="refined_sub_question_creation", # DONE
|
source="create_refined_sub_questions", # DONE
|
||||||
path=parallelize_refined_sub_question_answering,
|
path=parallelize_refined_sub_question_answering,
|
||||||
path_map=["answer_refined_question"],
|
path_map=["answer_refined_question"],
|
||||||
)
|
)
|
||||||
@@ -341,23 +201,12 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
|||||||
end_key="generate_refined_answer",
|
end_key="generate_refined_answer",
|
||||||
)
|
)
|
||||||
|
|
||||||
# graph.add_conditional_edges(
|
|
||||||
# source="refined_answer_decision",
|
|
||||||
# path=continue_to_refined_answer_or_end,
|
|
||||||
# path_map=["refined_answer_subgraph", END],
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="refined_answer_subgraph",
|
|
||||||
# end_key="generate_refined_answer",
|
|
||||||
# )
|
|
||||||
|
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
start_key="generate_refined_answer",
|
start_key="generate_refined_answer",
|
||||||
end_key="answer_comparison",
|
end_key="compare_answers",
|
||||||
)
|
)
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
start_key="answer_comparison",
|
start_key="compare_answers",
|
||||||
end_key="logging_node",
|
end_key="logging_node",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -366,16 +215,6 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
|||||||
end_key=END,
|
end_key=END,
|
||||||
)
|
)
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="generate_refined_answer",
|
|
||||||
# end_key="check_refined_answer",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# graph.add_edge(
|
|
||||||
# start_key="check_refined_answer",
|
|
||||||
# end_key=END,
|
|
||||||
# )
|
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,36 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.states import RoutingDecision
|
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
|
||||||
|
|
||||||
|
|
||||||
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
|
|
||||||
now_start = datetime.now()
|
|
||||||
|
|
||||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
|
||||||
|
|
||||||
# perform_initial_search_path_decision = (
|
|
||||||
# agent_a_config.perform_initial_search_path_decision
|
|
||||||
# )
|
|
||||||
|
|
||||||
logger.info(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
|
|
||||||
|
|
||||||
routing = "agent_search"
|
|
||||||
|
|
||||||
now_end = datetime.now()
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
|
|
||||||
)
|
|
||||||
return RoutingDecision(
|
|
||||||
# Decide which route to take
|
|
||||||
routing_decision=routing,
|
|
||||||
log_messages=[
|
|
||||||
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
|
|
||||||
],
|
|
||||||
)
|
|
@@ -1,31 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from langgraph.types import Command
|
|
||||||
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
|
||||||
|
|
||||||
|
|
||||||
def agent_path_routing(
|
|
||||||
state: MainState,
|
|
||||||
) -> Command[Literal["agent_search_start", "LLM"]]:
|
|
||||||
now_start = datetime.now()
|
|
||||||
routing = state.routing_decision if hasattr(state, "routing") else "agent_search"
|
|
||||||
|
|
||||||
if routing == "agent_search":
|
|
||||||
agent_path = "agent_search_start"
|
|
||||||
else:
|
|
||||||
agent_path = "LLM"
|
|
||||||
|
|
||||||
now_end = datetime.now()
|
|
||||||
|
|
||||||
return Command(
|
|
||||||
# state update
|
|
||||||
update={
|
|
||||||
"log_messages": [
|
|
||||||
f"{now_start} -- Main - Path routing: {agent_path}, Time taken: {now_end - now_start}"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
# control flow
|
|
||||||
goto=agent_path,
|
|
||||||
)
|
|
@@ -13,7 +13,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISO
|
|||||||
from onyx.chat.models import RefinedAnswerImprovement
|
from onyx.chat.models import RefinedAnswerImprovement
|
||||||
|
|
||||||
|
|
||||||
def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerComparison:
|
def compare_answers(state: MainState, config: RunnableConfig) -> AnswerComparison:
|
||||||
now_start = datetime.now()
|
now_start = datetime.now()
|
||||||
|
|
||||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||||
@@ -23,11 +23,11 @@ def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerCompari
|
|||||||
|
|
||||||
logger.info(f"--------{now_start}--------ANSWER COMPARISON STARTED--")
|
logger.info(f"--------{now_start}--------ANSWER COMPARISON STARTED--")
|
||||||
|
|
||||||
answer_comparison_prompt = ANSWER_COMPARISON_PROMPT.format(
|
compare_answers_prompt = ANSWER_COMPARISON_PROMPT.format(
|
||||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = [HumanMessage(content=answer_comparison_prompt)]
|
msg = [HumanMessage(content=compare_answers_prompt)]
|
||||||
|
|
||||||
# Get the rewritten queries in a defined format
|
# Get the rewritten queries in a defined format
|
||||||
model = agent_a_config.fast_llm
|
model = agent_a_config.fast_llm
|
@@ -32,7 +32,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
|||||||
from onyx.tools.models import ToolCallKickoff
|
from onyx.tools.models import ToolCallKickoff
|
||||||
|
|
||||||
|
|
||||||
def refined_sub_question_creation(
|
def create_refined_sub_questions(
|
||||||
state: MainState, config: RunnableConfig
|
state: MainState, config: RunnableConfig
|
||||||
) -> FollowUpSubQuestionsUpdate:
|
) -> FollowUpSubQuestionsUpdate:
|
||||||
""" """
|
""" """
|
@@ -1,83 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langchain_core.messages import merge_content
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
|
||||||
InitialAnswerUpdate,
|
|
||||||
)
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
|
||||||
get_persona_agent_prompt_expressions,
|
|
||||||
)
|
|
||||||
from onyx.chat.models import AgentAnswerPiece
|
|
||||||
|
|
||||||
|
|
||||||
def direct_llm_handling(
|
|
||||||
state: MainState, config: RunnableConfig
|
|
||||||
) -> InitialAnswerUpdate:
|
|
||||||
now_start = datetime.now()
|
|
||||||
|
|
||||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
|
||||||
question = agent_a_config.search_request.query
|
|
||||||
persona_contextualialized_prompt = get_persona_agent_prompt_expressions(
|
|
||||||
agent_a_config.search_request.persona
|
|
||||||
).contextualized_prompt
|
|
||||||
|
|
||||||
logger.info(f"--------{now_start}--------LLM HANDLING START---")
|
|
||||||
|
|
||||||
model = agent_a_config.fast_llm
|
|
||||||
|
|
||||||
msg = [
|
|
||||||
HumanMessage(
|
|
||||||
content=DIRECT_LLM_PROMPT.format(
|
|
||||||
persona_specification=persona_contextualialized_prompt,
|
|
||||||
question=question,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
|
||||||
|
|
||||||
for message in model.stream(msg):
|
|
||||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
|
||||||
content = message.content
|
|
||||||
if not isinstance(content, str):
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected content to be a string, but got {type(content)}"
|
|
||||||
)
|
|
||||||
dispatch_custom_event(
|
|
||||||
"initial_agent_answer",
|
|
||||||
AgentAnswerPiece(
|
|
||||||
answer_piece=content,
|
|
||||||
level=0,
|
|
||||||
level_question_nr=0,
|
|
||||||
answer_type="agent_level_answer",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
streamed_tokens.append(content)
|
|
||||||
|
|
||||||
response = merge_content(*streamed_tokens)
|
|
||||||
answer = cast(str, response)
|
|
||||||
|
|
||||||
now_end = datetime.now()
|
|
||||||
|
|
||||||
logger.info(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---")
|
|
||||||
|
|
||||||
return InitialAnswerUpdate(
|
|
||||||
initial_answer=answer,
|
|
||||||
initial_agent_stats=None,
|
|
||||||
generated_sub_questions=[],
|
|
||||||
agent_base_end_time=now_end,
|
|
||||||
agent_base_metrics=None,
|
|
||||||
log_messages=[
|
|
||||||
f"{now_start} -- Main - LLM handling: {answer}, Time taken: {now_end - now_start}"
|
|
||||||
],
|
|
||||||
)
|
|
@@ -25,7 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROM
|
|||||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||||
|
|
||||||
|
|
||||||
def entity_term_extraction_llm(
|
def extract_entity_term(
|
||||||
state: MainState, config: RunnableConfig
|
state: MainState, config: RunnableConfig
|
||||||
) -> EntityTermExtractionUpdate:
|
) -> EntityTermExtractionUpdate:
|
||||||
now_start = datetime.now()
|
now_start = datetime.now()
|
@@ -1,58 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
|
||||||
InitialAnswerBASEUpdate,
|
|
||||||
)
|
|
||||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
|
||||||
trim_prompt_piece,
|
|
||||||
)
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
|
||||||
|
|
||||||
|
|
||||||
def generate_initial_base_search_only_answer(
|
|
||||||
state: MainState,
|
|
||||||
config: RunnableConfig,
|
|
||||||
) -> InitialAnswerBASEUpdate:
|
|
||||||
now_start = datetime.now()
|
|
||||||
|
|
||||||
logger.info(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---")
|
|
||||||
|
|
||||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
|
||||||
question = agent_a_config.search_request.query
|
|
||||||
original_question_docs = state.all_original_question_documents
|
|
||||||
|
|
||||||
model = agent_a_config.fast_llm
|
|
||||||
|
|
||||||
doc_context = format_docs(original_question_docs)
|
|
||||||
doc_context = trim_prompt_piece(
|
|
||||||
model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = [
|
|
||||||
HumanMessage(
|
|
||||||
content=INITIAL_RAG_BASE_PROMPT.format(
|
|
||||||
question=question,
|
|
||||||
context=doc_context,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Grader
|
|
||||||
response = model.invoke(msg)
|
|
||||||
answer = response.pretty_repr()
|
|
||||||
|
|
||||||
now_end = datetime.now()
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
return InitialAnswerBASEUpdate(initial_base_answer=answer)
|
|
@@ -16,7 +16,7 @@ from onyx.db.chat import log_agent_metrics
|
|||||||
from onyx.db.chat import log_agent_sub_question_results
|
from onyx.db.chat import log_agent_sub_question_results
|
||||||
|
|
||||||
|
|
||||||
def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||||
now_start = datetime.now()
|
now_start = datetime.now()
|
||||||
|
|
||||||
logger.info(f"--------{now_start}--------LOGGING NODE---")
|
logger.info(f"--------{now_start}--------LOGGING NODE---")
|
||||||
@@ -94,16 +94,6 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
|||||||
sub_question_answer_results=sub_question_answer_results,
|
sub_question_answer_results=sub_question_answer_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None:
|
|
||||||
# create_sub_answer(
|
|
||||||
# db_session=db_session,
|
|
||||||
# chat_session_id=chat_session_id,
|
|
||||||
# primary_message_id=primary_message_id,
|
|
||||||
# sub_question_id=sub_question_id,
|
|
||||||
# answer=answer_str,
|
|
||||||
# # )
|
|
||||||
# pass
|
|
||||||
|
|
||||||
now_end = datetime.now()
|
now_end = datetime.now()
|
||||||
main_output = MainOutput(
|
main_output = MainOutput(
|
||||||
log_messages=[
|
log_messages=[
|
@@ -17,7 +17,7 @@ from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
|
|||||||
from onyx.context.search.models import InferenceSection
|
from onyx.context.search.models import InferenceSection
|
||||||
|
|
||||||
|
|
||||||
def agent_search_start(
|
def start_agent_search(
|
||||||
state: MainState, config: RunnableConfig
|
state: MainState, config: RunnableConfig
|
||||||
) -> ExploratorySearchUpdate:
|
) -> ExploratorySearchUpdate:
|
||||||
now_start = datetime.now()
|
now_start = datetime.now()
|
@@ -11,7 +11,7 @@ from onyx.agents.agent_search.deep_search_a.main.states import (
|
|||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||||
|
|
||||||
|
|
||||||
def refined_answer_decision(
|
def validate_refined_answer(
|
||||||
state: MainState, config: RunnableConfig
|
state: MainState, config: RunnableConfig
|
||||||
) -> RequireRefinedAnswerUpdate:
|
) -> RequireRefinedAnswerUpdate:
|
||||||
now_start = datetime.now()
|
now_start = datetime.now()
|
||||||
@@ -19,12 +19,8 @@ def refined_answer_decision(
|
|||||||
logger.info(f"--------{now_start}--------REFINED ANSWER DECISION---")
|
logger.info(f"--------{now_start}--------REFINED ANSWER DECISION---")
|
||||||
|
|
||||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||||
if "?" in agent_a_config.search_request.query:
|
|
||||||
decision = False
|
|
||||||
else:
|
|
||||||
decision = True
|
|
||||||
|
|
||||||
decision = True
|
decision = True # TODO: just for current testing purposes
|
||||||
|
|
||||||
now_end = datetime.now()
|
now_end = datetime.now()
|
||||||
|
|
@@ -14,6 +14,5 @@ class QuestionAnswerResults(BaseModel):
|
|||||||
question: str
|
question: str
|
||||||
answer: str
|
answer: str
|
||||||
quality: str
|
quality: str
|
||||||
# expanded_retrieval_results: list[QueryResult]
|
|
||||||
documents: list[InferenceSection]
|
documents: list[InferenceSection]
|
||||||
sub_question_retrieval_stats: AgentChunkStats
|
sub_question_retrieval_stats: AgentChunkStats
|
||||||
|
Reference in New Issue
Block a user