mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-09 12:47:13 +02:00
main nodes renaming
This commit is contained in:
@@ -53,8 +53,8 @@ def consolidate_sub_answers_graph_builder() -> StateGraph:
|
||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_search_start",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# start_key="start_agent_search",
|
||||
# end_key="extract_entity_term",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
|
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["tool_call", "agent_search_start", "logging_node"]:
|
||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if state.tool_choice is not None:
|
||||
if (
|
||||
@@ -33,7 +33,7 @@ def route_initial_tool_choice(
|
||||
and agent_config.search_tool is not None
|
||||
and state.tool_choice.tool.name == agent_config.search_tool.name
|
||||
):
|
||||
return "agent_search_start"
|
||||
return "start_agent_search"
|
||||
else:
|
||||
return "tool_call"
|
||||
else:
|
||||
@@ -83,9 +83,9 @@ def parallelize_initial_sub_question_answering(
|
||||
# Define the function that determines whether to continue or not
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinedAnswerUpdate,
|
||||
) -> Literal["refined_sub_question_creation", "logging_node"]:
|
||||
) -> Literal["create_refined_sub_questions", "logging_node"]:
|
||||
if state.require_refined_answer_eval:
|
||||
return "refined_sub_question_creation"
|
||||
return "create_refined_sub_questions"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
@@ -14,17 +14,14 @@ from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
route_initial_tool_choice,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import (
|
||||
agent_logging,
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.compare_answers import (
|
||||
compare_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import (
|
||||
agent_search_start,
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.create_refined_sub_questions import (
|
||||
create_refined_sub_questions,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.answer_comparison import (
|
||||
answer_comparison,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
|
||||
entity_term_extraction_llm,
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entity_term import (
|
||||
extract_entity_term,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
|
||||
generate_refined_answer,
|
||||
@@ -32,11 +29,14 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer i
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import (
|
||||
ingest_refined_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision import (
|
||||
refined_answer_decision,
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.persist_agent_results import (
|
||||
persist_agent_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
|
||||
refined_sub_question_creation,
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.start_agent_search import (
|
||||
start_agent_search,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.validate_refined_answer import (
|
||||
validate_refined_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
@@ -65,20 +65,6 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="agent_path_decision",
|
||||
# action=agent_path_decision,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="agent_path_routing",
|
||||
# action=agent_path_routing,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="LLM",
|
||||
# action=direct_llm_handling,
|
||||
# )
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
@@ -97,42 +83,19 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
graph.add_node(
|
||||
node="agent_search_start",
|
||||
action=agent_search_start,
|
||||
node="start_agent_search",
|
||||
action=start_agent_search,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="initial_sub_question_creation",
|
||||
# action=initial_sub_question_creation,
|
||||
# )
|
||||
|
||||
generate_initial_answer_subgraph = generate_initial_answer_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="generate_initial_answer_subgraph",
|
||||
action=generate_initial_answer_subgraph,
|
||||
)
|
||||
|
||||
# answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="answer_query_subgraph",
|
||||
# action=answer_query_subgraph,
|
||||
# )
|
||||
|
||||
# base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="base_raw_search_subgraph",
|
||||
# action=base_raw_search_subgraph,
|
||||
# )
|
||||
|
||||
# refined_answer_subgraph = refined_answers_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="refined_answer_subgraph",
|
||||
# action=refined_answer_subgraph,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="refined_sub_question_creation",
|
||||
action=refined_sub_question_creation,
|
||||
node="create_refined_sub_questions",
|
||||
action=create_refined_sub_questions,
|
||||
)
|
||||
|
||||
answer_refined_question = answer_refined_query_graph_builder().compile()
|
||||
@@ -151,70 +114,25 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
action=generate_refined_answer,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="check_refined_answer",
|
||||
# action=check_refined_answer,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="ingest_initial_retrieval",
|
||||
# action=ingest_initial_base_retrieval,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="retrieval_consolidation",
|
||||
# action=retrieval_consolidation,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="ingest_initial_sub_question_answers",
|
||||
# action=ingest_initial_sub_question_answers,
|
||||
# )
|
||||
# graph.add_node(
|
||||
# node="generate_initial_answer",
|
||||
# action=generate_initial_answer,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="initial_answer_quality_check",
|
||||
# action=initial_answer_quality_check,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="entity_term_extraction_llm",
|
||||
action=entity_term_extraction_llm,
|
||||
node="extract_entity_term",
|
||||
action=extract_entity_term,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_answer_decision",
|
||||
action=refined_answer_decision,
|
||||
node="validate_refined_answer",
|
||||
action=validate_refined_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_comparison",
|
||||
action=answer_comparison,
|
||||
node="compare_answers",
|
||||
action=compare_answers,
|
||||
)
|
||||
graph.add_node(
|
||||
node="logging_node",
|
||||
action=agent_logging,
|
||||
action=persist_agent_results,
|
||||
)
|
||||
# if test_mode:
|
||||
# graph.add_node(
|
||||
# node="generate_initial_base_answer",
|
||||
# action=generate_initial_base_answer,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="agent_path_decision",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_path_decision",
|
||||
# end_key="agent_path_routing",
|
||||
# )
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(
|
||||
@@ -225,7 +143,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["tool_call", "agent_search_start", "logging_node"],
|
||||
["tool_call", "start_agent_search", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
@@ -238,96 +156,38 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
<<<<<<< HEAD
|
||||
start_key="agent_search_start",
|
||||
end_key="generate_initial_answer_subgraph",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="agent_search_start",
|
||||
# end_key="base_raw_search_subgraph",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_search_start",
|
||||
end_key="entity_term_extraction_llm",
|
||||
=======
|
||||
start_key="start_agent_search",
|
||||
end_key="initial_search_sq_subgraph",
|
||||
>>>>>>> ab2510c4d (main nodes renaming)
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_search_start",
|
||||
# end_key="initial_sub_question_creation",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="base_raw_search_subgraph",
|
||||
# end_key="ingest_initial_retrieval",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
|
||||
# end_key="retrieval_consolidation",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="retrieval_consolidation",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="LLM",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="initial_sub_question_creation",
|
||||
# )
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source="initial_sub_question_creation",
|
||||
# path=parallelize_initial_sub_question_answering,
|
||||
# path_map=["answer_query_subgraph"],
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="answer_query_subgraph",
|
||||
# end_key="ingest_initial_sub_question_answers",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="retrieval_consolidation",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key="initial_answer_quality_check",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
|
||||
# end_key="refined_answer_decision",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="initial_answer_quality_check",
|
||||
# end_key="refined_answer_decision",
|
||||
# )
|
||||
graph.add_edge(
|
||||
start_key="start_agent_search",
|
||||
end_key="extract_entity_term",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
<<<<<<< HEAD
|
||||
start_key=["generate_initial_answer_subgraph", "entity_term_extraction_llm"],
|
||||
end_key="refined_answer_decision",
|
||||
=======
|
||||
start_key=["initial_search_sq_subgraph", "extract_entity_term"],
|
||||
end_key="validate_refined_answer",
|
||||
>>>>>>> ab2510c4d (main nodes renaming)
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_answer_decision",
|
||||
source="validate_refined_answer",
|
||||
path=continue_to_refined_answer_or_end,
|
||||
path_map=["refined_sub_question_creation", "logging_node"],
|
||||
path_map=["create_refined_sub_questions", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_sub_question_creation", # DONE
|
||||
source="create_refined_sub_questions", # DONE
|
||||
path=parallelize_refined_sub_question_answering,
|
||||
path_map=["answer_refined_question"],
|
||||
)
|
||||
@@ -341,23 +201,12 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
end_key="generate_refined_answer",
|
||||
)
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source="refined_answer_decision",
|
||||
# path=continue_to_refined_answer_or_end,
|
||||
# path_map=["refined_answer_subgraph", END],
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="refined_answer_subgraph",
|
||||
# end_key="generate_refined_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_refined_answer",
|
||||
end_key="answer_comparison",
|
||||
end_key="compare_answers",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_comparison",
|
||||
start_key="compare_answers",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
@@ -366,16 +215,6 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_refined_answer",
|
||||
# end_key="check_refined_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="check_refined_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
|
@@ -1,36 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import RoutingDecision
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
|
||||
now_start = datetime.now()
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
|
||||
# perform_initial_search_path_decision = (
|
||||
# agent_a_config.perform_initial_search_path_decision
|
||||
# )
|
||||
|
||||
logger.info(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
|
||||
|
||||
routing = "agent_search"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
|
||||
)
|
||||
return RoutingDecision(
|
||||
# Decide which route to take
|
||||
routing_decision=routing,
|
||||
log_messages=[
|
||||
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
@@ -1,31 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def agent_path_routing(
|
||||
state: MainState,
|
||||
) -> Command[Literal["agent_search_start", "LLM"]]:
|
||||
now_start = datetime.now()
|
||||
routing = state.routing_decision if hasattr(state, "routing") else "agent_search"
|
||||
|
||||
if routing == "agent_search":
|
||||
agent_path = "agent_search_start"
|
||||
else:
|
||||
agent_path = "LLM"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
return Command(
|
||||
# state update
|
||||
update={
|
||||
"log_messages": [
|
||||
f"{now_start} -- Main - Path routing: {agent_path}, Time taken: {now_end - now_start}"
|
||||
]
|
||||
},
|
||||
# control flow
|
||||
goto=agent_path,
|
||||
)
|
@@ -13,7 +13,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISO
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
|
||||
|
||||
def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerComparison:
|
||||
def compare_answers(state: MainState, config: RunnableConfig) -> AnswerComparison:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
@@ -23,11 +23,11 @@ def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerCompari
|
||||
|
||||
logger.info(f"--------{now_start}--------ANSWER COMPARISON STARTED--")
|
||||
|
||||
answer_comparison_prompt = ANSWER_COMPARISON_PROMPT.format(
|
||||
compare_answers_prompt = ANSWER_COMPARISON_PROMPT.format(
|
||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=answer_comparison_prompt)]
|
||||
msg = [HumanMessage(content=compare_answers_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
@@ -32,7 +32,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
def refined_sub_question_creation(
|
||||
def create_refined_sub_questions(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> FollowUpSubQuestionsUpdate:
|
||||
""" """
|
@@ -1,83 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
|
||||
|
||||
def direct_llm_handling(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_contextualialized_prompt = get_persona_agent_prompt_expressions(
|
||||
agent_a_config.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
logger.info(f"--------{now_start}--------LLM HANDLING START---")
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DIRECT_LLM_PROMPT.format(
|
||||
persona_specification=persona_contextualialized_prompt,
|
||||
question=question,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---")
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=[],
|
||||
agent_base_end_time=now_end,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - LLM handling: {answer}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
@@ -25,7 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROM
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def entity_term_extraction_llm(
|
||||
def extract_entity_term(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> EntityTermExtractionUpdate:
|
||||
now_start = datetime.now()
|
@@ -1,58 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
InitialAnswerBASEUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def generate_initial_base_search_only_answer(
|
||||
state: MainState,
|
||||
config: RunnableConfig,
|
||||
) -> InitialAnswerBASEUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
original_question_docs = state.all_original_question_documents
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
doc_context = format_docs(original_question_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_BASE_PROMPT.format(
|
||||
question=question,
|
||||
context=doc_context,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
response = model.invoke(msg)
|
||||
answer = response.pretty_repr()
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
return InitialAnswerBASEUpdate(initial_base_answer=answer)
|
@@ -16,7 +16,7 @@ from onyx.db.chat import log_agent_metrics
|
||||
from onyx.db.chat import log_agent_sub_question_results
|
||||
|
||||
|
||||
def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------LOGGING NODE---")
|
||||
@@ -94,16 +94,6 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
|
||||
# if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None:
|
||||
# create_sub_answer(
|
||||
# db_session=db_session,
|
||||
# chat_session_id=chat_session_id,
|
||||
# primary_message_id=primary_message_id,
|
||||
# sub_question_id=sub_question_id,
|
||||
# answer=answer_str,
|
||||
# # )
|
||||
# pass
|
||||
|
||||
now_end = datetime.now()
|
||||
main_output = MainOutput(
|
||||
log_messages=[
|
@@ -17,7 +17,7 @@ from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def agent_search_start(
|
||||
def start_agent_search(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> ExploratorySearchUpdate:
|
||||
now_start = datetime.now()
|
@@ -11,7 +11,7 @@ from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def refined_answer_decision(
|
||||
def validate_refined_answer(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RequireRefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
@@ -19,12 +19,8 @@ def refined_answer_decision(
|
||||
logger.info(f"--------{now_start}--------REFINED ANSWER DECISION---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if "?" in agent_a_config.search_request.query:
|
||||
decision = False
|
||||
else:
|
||||
decision = True
|
||||
|
||||
decision = True
|
||||
decision = True # TODO: just for current testing purposes
|
||||
|
||||
now_end = datetime.now()
|
||||
|
@@ -14,6 +14,5 @@ class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
# expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
|
Reference in New Issue
Block a user