main nodes renaming

This commit is contained in:
Evan Lohn
2025-01-30 13:10:57 -08:00
parent b0c3098693
commit 2b8cd63b34
14 changed files with 59 additions and 443 deletions

View File

@@ -53,8 +53,8 @@ def consolidate_sub_answers_graph_builder() -> StateGraph:
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
# graph.add_edge(
# start_key="agent_search_start",
# end_key="entity_term_extraction_llm",
# start_key="start_agent_search",
# end_key="extract_entity_term",
# )
graph.add_edge(

View File

@@ -25,7 +25,7 @@ logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["tool_call", "agent_search_start", "logging_node"]:
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
if state.tool_choice is not None:
if (
@@ -33,7 +33,7 @@ def route_initial_tool_choice(
and agent_config.search_tool is not None
and state.tool_choice.tool.name == agent_config.search_tool.name
):
return "agent_search_start"
return "start_agent_search"
else:
return "tool_call"
else:
@@ -83,9 +83,9 @@ def parallelize_initial_sub_question_answering(
# Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end(
state: RequireRefinedAnswerUpdate,
) -> Literal["refined_sub_question_creation", "logging_node"]:
) -> Literal["create_refined_sub_questions", "logging_node"]:
if state.require_refined_answer_eval:
return "refined_sub_question_creation"
return "create_refined_sub_questions"
else:
return "logging_node"

View File

@@ -14,17 +14,14 @@ from onyx.agents.agent_search.deep_search_a.main.edges import (
from onyx.agents.agent_search.deep_search_a.main.edges import (
route_initial_tool_choice,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import (
agent_logging,
from onyx.agents.agent_search.deep_search_a.main.nodes.compare_answers import (
compare_answers,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import (
agent_search_start,
from onyx.agents.agent_search.deep_search_a.main.nodes.create_refined_sub_questions import (
create_refined_sub_questions,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.answer_comparison import (
answer_comparison,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
entity_term_extraction_llm,
from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entity_term import (
extract_entity_term,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
generate_refined_answer,
@@ -32,11 +29,14 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer i
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import (
ingest_refined_answers,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision import (
refined_answer_decision,
from onyx.agents.agent_search.deep_search_a.main.nodes.persist_agent_results import (
persist_agent_results,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
refined_sub_question_creation,
from onyx.agents.agent_search.deep_search_a.main.nodes.start_agent_search import (
start_agent_search,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.validate_refined_answer import (
validate_refined_answer,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
from onyx.agents.agent_search.deep_search_a.main.states import MainState
@@ -65,20 +65,6 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
input=MainInput,
)
# graph.add_node(
# node="agent_path_decision",
# action=agent_path_decision,
# )
# graph.add_node(
# node="agent_path_routing",
# action=agent_path_routing,
# )
# graph.add_node(
# node="LLM",
# action=direct_llm_handling,
# )
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
@@ -97,42 +83,19 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
action=basic_use_tool_response,
)
graph.add_node(
node="agent_search_start",
action=agent_search_start,
node="start_agent_search",
action=start_agent_search,
)
# graph.add_node(
# node="initial_sub_question_creation",
# action=initial_sub_question_creation,
# )
generate_initial_answer_subgraph = generate_initial_answer_graph_builder().compile()
graph.add_node(
node="generate_initial_answer_subgraph",
action=generate_initial_answer_subgraph,
)
# answer_query_subgraph = answer_query_graph_builder().compile()
# graph.add_node(
# node="answer_query_subgraph",
# action=answer_query_subgraph,
# )
# base_raw_search_subgraph = base_raw_search_graph_builder().compile()
# graph.add_node(
# node="base_raw_search_subgraph",
# action=base_raw_search_subgraph,
# )
# refined_answer_subgraph = refined_answers_graph_builder().compile()
# graph.add_node(
# node="refined_answer_subgraph",
# action=refined_answer_subgraph,
# )
graph.add_node(
node="refined_sub_question_creation",
action=refined_sub_question_creation,
node="create_refined_sub_questions",
action=create_refined_sub_questions,
)
answer_refined_question = answer_refined_query_graph_builder().compile()
@@ -151,70 +114,25 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
action=generate_refined_answer,
)
# graph.add_node(
# node="check_refined_answer",
# action=check_refined_answer,
# )
# graph.add_node(
# node="ingest_initial_retrieval",
# action=ingest_initial_base_retrieval,
# )
# graph.add_node(
# node="retrieval_consolidation",
# action=retrieval_consolidation,
# )
# graph.add_node(
# node="ingest_initial_sub_question_answers",
# action=ingest_initial_sub_question_answers,
# )
# graph.add_node(
# node="generate_initial_answer",
# action=generate_initial_answer,
# )
# graph.add_node(
# node="initial_answer_quality_check",
# action=initial_answer_quality_check,
# )
graph.add_node(
node="entity_term_extraction_llm",
action=entity_term_extraction_llm,
node="extract_entity_term",
action=extract_entity_term,
)
graph.add_node(
node="refined_answer_decision",
action=refined_answer_decision,
node="validate_refined_answer",
action=validate_refined_answer,
)
graph.add_node(
node="answer_comparison",
action=answer_comparison,
node="compare_answers",
action=compare_answers,
)
graph.add_node(
node="logging_node",
action=agent_logging,
action=persist_agent_results,
)
# if test_mode:
# graph.add_node(
# node="generate_initial_base_answer",
# action=generate_initial_base_answer,
# )
### Add edges ###
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
# graph.add_edge(
# start_key=START,
# end_key="agent_path_decision",
# )
# graph.add_edge(
# start_key="agent_path_decision",
# end_key="agent_path_routing",
# )
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(
@@ -225,7 +143,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph.add_conditional_edges(
"initial_tool_choice",
route_initial_tool_choice,
["tool_call", "agent_search_start", "logging_node"],
["tool_call", "start_agent_search", "logging_node"],
)
graph.add_edge(
@@ -238,96 +156,38 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
)
graph.add_edge(
<<<<<<< HEAD
start_key="agent_search_start",
end_key="generate_initial_answer_subgraph",
)
# graph.add_edge(
# start_key="agent_search_start",
# end_key="base_raw_search_subgraph",
# )
graph.add_edge(
start_key="agent_search_start",
end_key="entity_term_extraction_llm",
=======
start_key="start_agent_search",
end_key="initial_search_sq_subgraph",
>>>>>>> ab2510c4d (main nodes renaming)
)
# graph.add_edge(
# start_key="agent_search_start",
# end_key="initial_sub_question_creation",
# )
# graph.add_edge(
# start_key="base_raw_search_subgraph",
# end_key="ingest_initial_retrieval",
# )
# graph.add_edge(
# start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
# end_key="retrieval_consolidation",
# )
# graph.add_edge(
# start_key="retrieval_consolidation",
# end_key="generate_initial_answer",
# )
# graph.add_edge(
# start_key="LLM",
# end_key=END,
# )
# graph.add_edge(
# start_key=START,
# end_key="initial_sub_question_creation",
# )
# graph.add_conditional_edges(
# source="initial_sub_question_creation",
# path=parallelize_initial_sub_question_answering,
# path_map=["answer_query_subgraph"],
# )
# graph.add_edge(
# start_key="answer_query_subgraph",
# end_key="ingest_initial_sub_question_answers",
# )
# graph.add_edge(
# start_key="retrieval_consolidation",
# end_key="generate_initial_answer",
# )
# graph.add_edge(
# start_key="generate_initial_answer",
# end_key="entity_term_extraction_llm",
# )
# graph.add_edge(
# start_key="generate_initial_answer",
# end_key="initial_answer_quality_check",
# )
# graph.add_edge(
# start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
# end_key="refined_answer_decision",
# )
# graph.add_edge(
# start_key="initial_answer_quality_check",
# end_key="refined_answer_decision",
# )
graph.add_edge(
start_key="start_agent_search",
end_key="extract_entity_term",
)
graph.add_edge(
<<<<<<< HEAD
start_key=["generate_initial_answer_subgraph", "entity_term_extraction_llm"],
end_key="refined_answer_decision",
=======
start_key=["initial_search_sq_subgraph", "extract_entity_term"],
end_key="validate_refined_answer",
>>>>>>> ab2510c4d (main nodes renaming)
)
graph.add_conditional_edges(
source="refined_answer_decision",
source="validate_refined_answer",
path=continue_to_refined_answer_or_end,
path_map=["refined_sub_question_creation", "logging_node"],
path_map=["create_refined_sub_questions", "logging_node"],
)
graph.add_conditional_edges(
source="refined_sub_question_creation", # DONE
source="create_refined_sub_questions", # DONE
path=parallelize_refined_sub_question_answering,
path_map=["answer_refined_question"],
)
@@ -341,23 +201,12 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
end_key="generate_refined_answer",
)
# graph.add_conditional_edges(
# source="refined_answer_decision",
# path=continue_to_refined_answer_or_end,
# path_map=["refined_answer_subgraph", END],
# )
# graph.add_edge(
# start_key="refined_answer_subgraph",
# end_key="generate_refined_answer",
# )
graph.add_edge(
start_key="generate_refined_answer",
end_key="answer_comparison",
end_key="compare_answers",
)
graph.add_edge(
start_key="answer_comparison",
start_key="compare_answers",
end_key="logging_node",
)
@@ -366,16 +215,6 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
end_key=END,
)
# graph.add_edge(
# start_key="generate_refined_answer",
# end_key="check_refined_answer",
# )
# graph.add_edge(
# start_key="check_refined_answer",
# end_key=END,
# )
return graph

View File

@@ -1,36 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.deep_search_a.main.states import RoutingDecision
from onyx.agents.agent_search.models import AgentSearchConfig
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
now_start = datetime.now()
cast(AgentSearchConfig, config["metadata"]["config"])
# perform_initial_search_path_decision = (
# agent_a_config.perform_initial_search_path_decision
# )
logger.info(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
routing = "agent_search"
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
)
return RoutingDecision(
# Decide which route to take
routing_decision=routing,
log_messages=[
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
],
)

View File

@@ -1,31 +0,0 @@
from datetime import datetime
from typing import Literal
from langgraph.types import Command
from onyx.agents.agent_search.deep_search_a.main.states import MainState
def agent_path_routing(
state: MainState,
) -> Command[Literal["agent_search_start", "LLM"]]:
now_start = datetime.now()
routing = state.routing_decision if hasattr(state, "routing") else "agent_search"
if routing == "agent_search":
agent_path = "agent_search_start"
else:
agent_path = "LLM"
now_end = datetime.now()
return Command(
# state update
update={
"log_messages": [
f"{now_start} -- Main - Path routing: {agent_path}, Time taken: {now_end - now_start}"
]
},
# control flow
goto=agent_path,
)

View File

@@ -13,7 +13,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISO
from onyx.chat.models import RefinedAnswerImprovement
def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerComparison:
def compare_answers(state: MainState, config: RunnableConfig) -> AnswerComparison:
now_start = datetime.now()
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
@@ -23,11 +23,11 @@ def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerCompari
logger.info(f"--------{now_start}--------ANSWER COMPARISON STARTED--")
answer_comparison_prompt = ANSWER_COMPARISON_PROMPT.format(
compare_answers_prompt = ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer
)
msg = [HumanMessage(content=answer_comparison_prompt)]
msg = [HumanMessage(content=compare_answers_prompt)]
# Get the rewritten queries in a defined format
model = agent_a_config.fast_llm

View File

@@ -32,7 +32,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.tools.models import ToolCallKickoff
def refined_sub_question_creation(
def create_refined_sub_questions(
state: MainState, config: RunnableConfig
) -> FollowUpSubQuestionsUpdate:
""" """

View File

@@ -1,83 +0,0 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.chat.models import AgentAnswerPiece
def direct_llm_handling(
state: MainState, config: RunnableConfig
) -> InitialAnswerUpdate:
now_start = datetime.now()
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona_contextualialized_prompt = get_persona_agent_prompt_expressions(
agent_a_config.search_request.persona
).contextualized_prompt
logger.info(f"--------{now_start}--------LLM HANDLING START---")
model = agent_a_config.fast_llm
msg = [
HumanMessage(
content=DIRECT_LLM_PROMPT.format(
persona_specification=persona_contextualialized_prompt,
question=question,
)
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
dispatch_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_nr=0,
answer_type="agent_level_answer",
),
)
streamed_tokens.append(content)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
now_end = datetime.now()
logger.info(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---")
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=None,
generated_sub_questions=[],
agent_base_end_time=now_end,
agent_base_metrics=None,
log_messages=[
f"{now_start} -- Main - LLM handling: {answer}, Time taken: {now_end - now_start}"
],
)

View File

@@ -25,7 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROM
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
def entity_term_extraction_llm(
def extract_entity_term(
state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate:
now_start = datetime.now()

View File

@@ -1,58 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import (
InitialAnswerBASEUpdate,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
def generate_initial_base_search_only_answer(
state: MainState,
config: RunnableConfig,
) -> InitialAnswerBASEUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
original_question_docs = state.all_original_question_documents
model = agent_a_config.fast_llm
doc_context = format_docs(original_question_docs)
doc_context = trim_prompt_piece(
model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question
)
msg = [
HumanMessage(
content=INITIAL_RAG_BASE_PROMPT.format(
question=question,
context=doc_context,
)
)
]
# Grader
response = model.invoke(msg)
answer = response.pretty_repr()
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n"
)
return InitialAnswerBASEUpdate(initial_base_answer=answer)

View File

@@ -16,7 +16,7 @@ from onyx.db.chat import log_agent_metrics
from onyx.db.chat import log_agent_sub_question_results
def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutput:
now_start = datetime.now()
logger.info(f"--------{now_start}--------LOGGING NODE---")
@@ -94,16 +94,6 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
sub_question_answer_results=sub_question_answer_results,
)
# if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None:
# create_sub_answer(
# db_session=db_session,
# chat_session_id=chat_session_id,
# primary_message_id=primary_message_id,
# sub_question_id=sub_question_id,
# answer=answer_str,
# # )
# pass
now_end = datetime.now()
main_output = MainOutput(
log_messages=[

View File

@@ -17,7 +17,7 @@ from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
from onyx.context.search.models import InferenceSection
def agent_search_start(
def start_agent_search(
state: MainState, config: RunnableConfig
) -> ExploratorySearchUpdate:
now_start = datetime.now()

View File

@@ -11,7 +11,7 @@ from onyx.agents.agent_search.deep_search_a.main.states import (
from onyx.agents.agent_search.models import AgentSearchConfig
def refined_answer_decision(
def validate_refined_answer(
state: MainState, config: RunnableConfig
) -> RequireRefinedAnswerUpdate:
now_start = datetime.now()
@@ -19,12 +19,8 @@ def refined_answer_decision(
logger.info(f"--------{now_start}--------REFINED ANSWER DECISION---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
if "?" in agent_a_config.search_request.query:
decision = False
else:
decision = True
decision = True
decision = True # TODO: just for current testing purposes
now_end = datetime.now()

View File

@@ -14,6 +14,5 @@ class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
# expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats