Files
danswer/backend/onyx/agents/agent_search/deep_research/graph_builder.py
2025-06-06 17:02:05 -07:00

751 lines
26 KiB
Python

import random
from datetime import datetime
from json import JSONDecodeError
from pprint import pprint
from typing import cast
from langchain.globals import set_debug
from langchain.globals import set_verbose
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from langgraph.types import Send
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_research.configuration import (
DeepPlannerConfiguration,
)
from onyx.agents.agent_search.deep_research.configuration import (
DeepResearchConfiguration,
)
from onyx.agents.agent_search.deep_research.prompts import answer_instructions
from onyx.agents.agent_search.deep_research.prompts import COMPANY_CONTEXT
from onyx.agents.agent_search.deep_research.prompts import COMPANY_NAME
from onyx.agents.agent_search.deep_research.prompts import get_current_date
from onyx.agents.agent_search.deep_research.prompts import planner_prompt
from onyx.agents.agent_search.deep_research.prompts import query_writer_instructions
from onyx.agents.agent_search.deep_research.prompts import reflection_instructions
from onyx.agents.agent_search.deep_research.prompts import replanner_prompt
from onyx.agents.agent_search.deep_research.prompts import task_completion_prompt
from onyx.agents.agent_search.deep_research.prompts import task_to_query_prompt
from onyx.agents.agent_search.deep_research.states import OnyxSearchState
from onyx.agents.agent_search.deep_research.states import OverallState
from onyx.agents.agent_search.deep_research.states import PlanExecute
from onyx.agents.agent_search.deep_research.states import QueryGenerationState
from onyx.agents.agent_search.deep_research.states import ReflectionState
from onyx.agents.agent_search.deep_research.tools_and_schemas import Act
from onyx.agents.agent_search.deep_research.tools_and_schemas import json_to_pydantic
from onyx.agents.agent_search.deep_research.tools_and_schemas import Plan
from onyx.agents.agent_search.deep_research.tools_and_schemas import Reflection
from onyx.agents.agent_search.deep_research.tools_and_schemas import Response
from onyx.agents.agent_search.deep_research.tools_and_schemas import SearchQueryList
from onyx.agents.agent_search.deep_research.utils import collate_messages
from onyx.agents.agent_search.deep_research.utils import get_research_topic
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Persona
from onyx.llm.factory import get_default_llms
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
IS_DEBUG = False
IS_VERBOSE = False
MAX_RETRIEVED_DOCS = 10
def mock_do_onyx_search(query: str) -> str:
random_answers = [
"Onyx is a startup founded by Yuhong Sun and Chris Weaver.",
"Chris Weaver was born in the country of Wakanda",
"Yuhong Sun is the CEO of Onyx",
"Yuhong Sun was born in the country of Valhalla",
]
return {"text": random.choice(random_answers)}
def do_onyx_search(query: str) -> dict[str, str]:
"""
Perform a search using the SearchTool and return the results.
Args:
query: The search query string
Returns:
Dictionary containing the search results text
"""
retrieved_docs: list[InferenceSection] = []
primary_llm, fast_llm = get_default_llms()
try:
with get_session_with_current_tenant() as db_session:
# Create a default persona with basic settings
default_persona = Persona(
name="default",
chunks_above=2,
chunks_below=2,
description="Default persona for search",
)
search_tool = SearchTool(
db_session=db_session,
user=None,
persona=default_persona,
retrieval_options=None,
prompt_config=PromptConfig(
system_prompt="You are a helpful assistant.",
task_prompt="Answer the user's question based on the provided context.",
datetime_aware=True,
include_citations=True,
),
llm=primary_llm,
fast_llm=fast_llm,
pruning_config=DocumentPruningConfig(),
answer_style_config=AnswerStyleConfig(
citation_config=CitationConfig(
include_citations=True, citation_style="inline"
)
),
evaluation_type=LLMEvaluationType.SKIP,
)
for tool_response in search_tool.run(
query=query,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=None,
skip_query_analysis=False,
),
):
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
# Combine the retrieved documents into a single text
combined_text = "\n\n".join(
[doc.combined_content for doc in retrieved_docs[:MAX_RETRIEVED_DOCS]]
)
return {"text": combined_text}
except Exception as e:
logger.error(f"Error in do_onyx_search: {e}")
return {"text": "Error in search, no results returned"}
def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState:
"""
LangGraph node that generates a search queries based on the User's question.
Uses an LLM to create an optimized search query for onyx research based on
the User's question.
Args:
state: Current graph state containing the User's question
config: Configuration for the runnable
Returns:
Dictionary with state update, including search_query key containing the generated query
"""
configurable = DeepResearchConfiguration.from_runnable_config(config)
# check for custom initial search query count
if state.get("initial_search_query_count") is None:
state["initial_search_query_count"] = configurable.number_of_initial_queries
primary_llm, fast_llm = get_default_llms()
llm = primary_llm if configurable.query_generator_model == "primary" else fast_llm
# Format the prompt
current_date = get_current_date()
formatted_prompt = query_writer_instructions.format(
current_date=current_date,
research_topic=get_research_topic(state["messages"]),
number_queries=state["initial_search_query_count"],
company_name=COMPANY_NAME,
company_context=COMPANY_CONTEXT,
user_context=collate_messages(state["messages"]),
)
# Get the LLM response and extract its content
llm_response = llm.invoke(formatted_prompt)
try:
result = json_to_pydantic(llm_response.content, SearchQueryList)
return {"query_list": result.query}
except JSONDecodeError:
return {"query_list": [llm_response.content]}
def continue_to_onyx_research(state: QueryGenerationState) -> OverallState:
"""
LangGraph node that sends the search queries to the onyx research node.
This is used to spawn n number of onyx research nodes, one for each search query.
"""
return [
Send("onyx_research", {"search_query": search_query, "id": int(idx)})
for idx, search_query in enumerate(state["query_list"])
]
def onyx_research(state: OnyxSearchState, config: RunnableConfig) -> OverallState:
"""LangGraph node that performs onyx research using onyx search interface.
Executes an onyx search in combination with an llm.
Args:
state: Current graph state containing the search query and research loop count
config: Configuration for the runnable, including any search API settings or llm settings
Returns:
Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
"""
# TODO: think about whether we should use any filtered returned results in addition to the final text answer
response = do_onyx_search(state["search_query"])
text = response["text"]
sources_gathered = []
return {
"sources_gathered": sources_gathered,
"search_query": [state["search_query"]],
"onyx_research_result": [text],
}
def get_combined_summaries(state: OverallState, llm=None) -> str:
if llm is None:
_, llm = get_default_llms()
# Calculate tokens and trim if needed
tokenizer = get_tokenizer(
provider_type=llm.config.model_provider, model_name=llm.config.model_name
)
# Combine summaries and check token count
combined_summaries = "\n\n---\n\n".join(state["onyx_research_result"])
combined_summaries = tokenizer_trim_content(
content=combined_summaries, desired_length=10000, tokenizer=tokenizer
)
return combined_summaries
def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
"""LangGraph node that identifies knowledge gaps and generates potential follow-up queries.
Analyzes the current summary to identify areas for further research and generates
potential follow-up queries. Uses structured output to extract
the follow-up query in JSON format.
Args:
state: Current graph state containing the running summary and research topic
config: Configuration for the runnable, including LLM settings
Returns:
Dictionary with state update, including search_query key containing the generated follow-up query
"""
configurable = DeepResearchConfiguration.from_runnable_config(config)
# Increment the research loop count and get the reasoning model
state["research_loop_count"] = state.get("research_loop_count", 0) + 1
# Get the LLM to use for token counting
primary_llm, fast_llm = get_default_llms()
llm = primary_llm if configurable.reflection_model == "primary" else fast_llm
combined_summaries = get_combined_summaries(state, llm)
# Format the prompt
# First, collate the messages to give a historical context of the current conversation
# Then, produce a concatenation of the onyx research results
# Then, pass this to the reflection instructions
# Then, the LLM will produce a JSON response with the following fields:
# - is_sufficient: boolean indicating if the research is sufficient
# - knowledge_gap: string describing the knowledge gap
# - follow_up_queries: list of follow-up queries
current_date = get_current_date()
formatted_prompt = reflection_instructions.format(
current_date=current_date,
research_topic=get_research_topic(state["messages"]),
summaries=combined_summaries,
company_name=COMPANY_NAME,
company_context=COMPANY_CONTEXT,
)
# Get result from LLM
result = json_to_pydantic(llm.invoke(formatted_prompt).content, Reflection)
# TODO: convert to pydantic here
return {
"is_sufficient": result.is_sufficient,
"knowledge_gap": result.knowledge_gap,
"follow_up_queries": result.follow_up_queries,
"research_loop_count": state["research_loop_count"],
"number_of_ran_queries": len(state["search_query"]),
}
def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
"""
if isinstance(val, bool):
return val
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return 1
elif val in ("n", "no", "f", "false", "off", "0"):
return 0
else:
raise ValueError("invalid truth value %r" % (val,))
def evaluate_research(
state: ReflectionState,
config: RunnableConfig,
) -> OverallState:
"""LangGraph routing function that determines the next step in the research flow.
Controls the research loop by deciding whether to continue gathering information
or to finalize the summary based on the configured maximum number of research loops.
Args:
state: Current graph state containing the research loop count
config: Configuration for the runnable, including max_research_loops setting
Returns:
String literal indicating the next node to visit ("onyx_research" or "finalize_summary")
"""
configurable = DeepResearchConfiguration.from_runnable_config(config)
max_research_loops = (
state.get("max_research_loops")
if state.get("max_research_loops") is not None
else configurable.max_research_loops
)
if (
strtobool(state["is_sufficient"]) is True
or state["research_loop_count"] >= max_research_loops
):
return "finalize_answer"
else:
return [
Send(
"onyx_research",
{
"search_query": follow_up_query,
"id": state["number_of_ran_queries"] + int(idx),
},
)
for idx, follow_up_query in enumerate(state["follow_up_queries"])
]
def finalize_answer(state: OverallState, config: RunnableConfig):
"""LangGraph node that finalizes the research summary.
Prepares the final result based on the onyx research results.
Args:
state: Current graph state containing the running summary and sources gathered
Returns:
Dictionary with state update, including running_summary key containing the formatted final summary with sources
"""
configurable = DeepResearchConfiguration.from_runnable_config(config)
answer_model = state.get("answer_model") or configurable.answer_model
# get the LLM to generate the final answer
primary_llm, fast_llm = get_default_llms()
llm = primary_llm if answer_model == "primary" else fast_llm
combined_summaries = get_combined_summaries(state, llm)
# Format the prompt
current_date = get_current_date()
formatted_prompt = answer_instructions.format(
current_date=current_date,
research_topic=get_research_topic(state["messages"]),
summaries=combined_summaries,
company_name=COMPANY_NAME,
company_context=COMPANY_CONTEXT,
user_context=collate_messages(state["messages"]),
)
result = llm.invoke(formatted_prompt)
unique_sources = []
return {
"messages": [AIMessage(content=result.content)],
"sources_gathered": unique_sources,
}
def deep_research_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for deep research process.
"""
graph = StateGraph(
OverallState,
config_schema=DeepResearchConfiguration,
)
### Add nodes ###
graph.add_node("generate_query", generate_query)
graph.add_node("onyx_research", onyx_research)
graph.add_node("reflection", reflection)
graph.add_node("finalize_answer", finalize_answer)
# Set the entrypoint as `generate_query`
graph.add_edge(START, "generate_query")
# Add conditional edge to continue with search queries in a parallel branch
graph.add_conditional_edges(
"generate_query", continue_to_onyx_research, ["onyx_research"]
)
# Reflect on the onyx research
graph.add_edge("onyx_research", "reflection")
# Evaluate the research
graph.add_conditional_edges(
"reflection", evaluate_research, ["onyx_research", "finalize_answer"]
)
# Finalize the answer
graph.add_edge("finalize_answer", END)
return graph
def translate_task_to_query(
task: str,
context=None,
company_name=COMPANY_NAME,
company_context=COMPANY_CONTEXT,
initial_question=None,
) -> str:
"""
LangGraph node that translates a task to a query.
"""
_, fast_llm = get_default_llms()
formatted_prompt = task_to_query_prompt.format(
initial_question=initial_question,
task=task,
context=context,
company_name=company_name,
company_context=company_context,
)
return fast_llm.invoke(formatted_prompt).content
def is_search_query(query: str) -> bool:
terms = [
"search",
"query",
"find",
"look up",
"look for",
"find out",
"find information",
"find data",
"find facts",
"find statistics",
"find trends",
"find insights",
"find trends",
"find insights",
"find trends",
"find insights",
"gather",
"gather information",
"gather data",
"gather facts",
"gather statistics",
"gather trends",
"gather insights",
]
query = query.lower()
for term in terms:
if term in query:
return True
return False
def execute_step(
state: PlanExecute, config: RunnableConfig, writer: StreamWriter = lambda _: None
):
"""
LangGraph node that plans the deep research process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
plan = state["plan"]
task = plan[0]
step_count = state.get("step_count", 0) + 1
if is_search_query(task):
query = translate_task_to_query(
plan[0], context=state["past_steps"], initial_question=question
)
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=" executing a search query with Onyx...",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
graph = deep_research_graph_builder()
compiled_graph = graph.compile(debug=IS_DEBUG)
# TODO: use this input_state for the deep research graph
# input_state = DeepResearchInput(log_messages=[])
initial_state = {
"messages": [HumanMessage(content=query)],
"search_query": [],
"onyx_research_result": [],
"sources_gathered": [],
"initial_search_query_count": 3, # Default value from Configuration
"max_research_loops": 2, # State does not seem to pick up this value
"research_loop_count": 0,
"reasoning_model": "primary",
}
result = compiled_graph.invoke(initial_state)
return {
"past_steps": [(task, query, result["messages"][-1].content)],
"step_count": step_count,
}
else:
primary_llm, _ = get_default_llms()
formatted_prompt = task_completion_prompt.format(
task=task,
plan=state["plan"],
past_steps=state["past_steps"],
company_name=COMPANY_NAME,
company_context=COMPANY_CONTEXT,
)
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=" accomplishing the planned task...",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
response = primary_llm.invoke(formatted_prompt).content
return {
"past_steps": [(task, "task: " + task, response)],
"step_count": step_count,
}
def plan_step(
state: PlanExecute, config: RunnableConfig, writer: StreamWriter = lambda _: None
):
"""
LangGraph node that replans the deep research process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
formatted_prompt = planner_prompt.format(
input=question, company_name=COMPANY_NAME, company_context=COMPANY_CONTEXT
)
primary_llm, _ = get_default_llms()
response = primary_llm.invoke(formatted_prompt).content
plan = json_to_pydantic(response, Plan)
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece="Generating a plan to answer the user's question... ",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return {"plan": plan.steps}
def replan_step(
state: PlanExecute, config: RunnableConfig, writer: StreamWriter = lambda _: None
):
"""
LangGraph node that determines if the deep research process should end.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
formatted_prompt = replanner_prompt.format(
input=question,
plan=state["plan"],
past_steps=state["past_steps"],
company_name=COMPANY_NAME,
company_context=COMPANY_CONTEXT,
)
primary_llm, _ = get_default_llms()
response = primary_llm.invoke(formatted_prompt).content
output = json_to_pydantic(response, Act)
# TODO: add a check for time limit too
if isinstance(output.action, Response):
# Check for canned response, if so, return the answer from the last step
if output.action.response == "The final answer to the user's question":
return {
"response": state["past_steps"][-1][2],
}
else:
return {"response": output.action.response}
elif state["step_count"] >= state.get("max_steps", 5):
return {
"response": f"I've reached the maximum number of step, my best guess is {state['past_steps'][-1][2]}."
}
else:
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=" moving on to the next step...",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return {"plan": output.action.steps}
def should_end(
state: PlanExecute, config: RunnableConfig, writer: StreamWriter = lambda _: None
):
if "response" in state and state["response"]:
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=state["response"],
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return END
else:
return "agent"
def deep_planner_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for deep planner process.
"""
workflow = StateGraph(PlanExecute, config_schema=DeepPlannerConfiguration)
# Add the plan node
workflow.add_node("planner", plan_step)
# Add the execution step
workflow.add_node("agent", execute_step)
# Add a replan node
workflow.add_node("replan", replan_step)
workflow.add_edge(START, "planner")
# From plan we go to agent
workflow.add_edge("planner", "agent")
# From agent, we replan
workflow.add_edge("agent", "replan")
workflow.add_conditional_edges(
"replan",
should_end,
["agent", END],
)
return workflow
if __name__ == "__main__":
# Initialize the SQLAlchemy engine first
from onyx.db.engine import SqlEngine
SqlEngine.init_engine(
pool_size=5, # You can adjust these values based on your needs
max_overflow=10,
app_name="graph_builder",
)
# Set the debug and verbose flags for Langchain/Langgraph
set_debug(IS_DEBUG)
set_verbose(IS_VERBOSE)
# Compile the graph
query_start_time = datetime.now()
logger.debug(f"Start at {query_start_time}")
graph = deep_planner_graph_builder()
compiled_graph = graph.compile(debug=IS_DEBUG)
query_end_time = datetime.now()
logger.debug(f"Graph compiled in {query_end_time - query_start_time} seconds")
queries = [
"What is the capital of France?",
"What is Onyx?",
"Who are the founders of Onyx?",
"Who is the CEO of Onyx?",
"Where was the CEO of Onyx born?",
"What is the highest contract value for last month?",
"What is the most expensive component of our technical pipeline so far?",
"Who are top 5 competitors who are not US based?",
"What companies should we focus on to maximize our revenue?",
"What are some of the biggest problems for our customers and their potential solutions?",
]
hard_queries = [
"Where was the CEO of Onyx born?",
"Who are top 5 competitors who are not US based?",
"What companies should we focus on to maximize our revenue?",
"What are some of the biggest problems for our customers and their potential solutions?",
]
for query in hard_queries:
# Create the initial state with all required fields
initial_state = {
"input": [HumanMessage(content=query)],
"plan": [],
"past_steps": [],
"response": "",
"max_steps": 10,
"step_count": 0,
}
result = compiled_graph.invoke(initial_state)
print("Max planning loops: ", result["max_steps"])
print("Steps: ", result["step_count"])
print("Past steps: ")
pprint(result["past_steps"], indent=4)
print("Question: ", query)
print("Answer: ", result["response"])
print("--------------------------------")
# from pdb import set_trace
# set_trace()