mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-09 20:55:06 +02:00
751 lines
26 KiB
Python
751 lines
26 KiB
Python
import random
|
|
from datetime import datetime
|
|
from json import JSONDecodeError
|
|
from pprint import pprint
|
|
from typing import cast
|
|
|
|
from langchain.globals import set_debug
|
|
from langchain.globals import set_verbose
|
|
from langchain_core.messages import AIMessage
|
|
from langchain_core.messages import HumanMessage
|
|
from langchain_core.runnables import RunnableConfig
|
|
from langgraph.graph import END
|
|
from langgraph.graph import START
|
|
from langgraph.graph import StateGraph
|
|
from langgraph.types import Send
|
|
from langgraph.types import StreamWriter
|
|
|
|
from onyx.agents.agent_search.deep_research.configuration import (
|
|
DeepPlannerConfiguration,
|
|
)
|
|
from onyx.agents.agent_search.deep_research.configuration import (
|
|
DeepResearchConfiguration,
|
|
)
|
|
from onyx.agents.agent_search.deep_research.prompts import answer_instructions
|
|
from onyx.agents.agent_search.deep_research.prompts import COMPANY_CONTEXT
|
|
from onyx.agents.agent_search.deep_research.prompts import COMPANY_NAME
|
|
from onyx.agents.agent_search.deep_research.prompts import get_current_date
|
|
from onyx.agents.agent_search.deep_research.prompts import planner_prompt
|
|
from onyx.agents.agent_search.deep_research.prompts import query_writer_instructions
|
|
from onyx.agents.agent_search.deep_research.prompts import reflection_instructions
|
|
from onyx.agents.agent_search.deep_research.prompts import replanner_prompt
|
|
from onyx.agents.agent_search.deep_research.prompts import task_completion_prompt
|
|
from onyx.agents.agent_search.deep_research.prompts import task_to_query_prompt
|
|
from onyx.agents.agent_search.deep_research.states import OnyxSearchState
|
|
from onyx.agents.agent_search.deep_research.states import OverallState
|
|
from onyx.agents.agent_search.deep_research.states import PlanExecute
|
|
from onyx.agents.agent_search.deep_research.states import QueryGenerationState
|
|
from onyx.agents.agent_search.deep_research.states import ReflectionState
|
|
from onyx.agents.agent_search.deep_research.tools_and_schemas import Act
|
|
from onyx.agents.agent_search.deep_research.tools_and_schemas import json_to_pydantic
|
|
from onyx.agents.agent_search.deep_research.tools_and_schemas import Plan
|
|
from onyx.agents.agent_search.deep_research.tools_and_schemas import Reflection
|
|
from onyx.agents.agent_search.deep_research.tools_and_schemas import Response
|
|
from onyx.agents.agent_search.deep_research.tools_and_schemas import SearchQueryList
|
|
from onyx.agents.agent_search.deep_research.utils import collate_messages
|
|
from onyx.agents.agent_search.deep_research.utils import get_research_topic
|
|
from onyx.agents.agent_search.models import GraphConfig
|
|
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
|
from onyx.chat.models import AgentAnswerPiece
|
|
from onyx.chat.models import AnswerStyleConfig
|
|
from onyx.chat.models import CitationConfig
|
|
from onyx.chat.models import DocumentPruningConfig
|
|
from onyx.chat.models import PromptConfig
|
|
from onyx.context.search.enums import LLMEvaluationType
|
|
from onyx.context.search.models import InferenceSection
|
|
from onyx.db.engine import get_session_with_current_tenant
|
|
from onyx.db.models import Persona
|
|
from onyx.llm.factory import get_default_llms
|
|
from onyx.natural_language_processing.utils import get_tokenizer
|
|
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
|
from onyx.tools.models import SearchToolOverrideKwargs
|
|
from onyx.tools.tool_implementations.search.search_tool import (
|
|
SEARCH_RESPONSE_SUMMARY_ID,
|
|
)
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
test_mode = False
|
|
IS_DEBUG = False
|
|
IS_VERBOSE = False
|
|
MAX_RETRIEVED_DOCS = 10
|
|
|
|
|
|
def mock_do_onyx_search(query: str) -> str:
|
|
random_answers = [
|
|
"Onyx is a startup founded by Yuhong Sun and Chris Weaver.",
|
|
"Chris Weaver was born in the country of Wakanda",
|
|
"Yuhong Sun is the CEO of Onyx",
|
|
"Yuhong Sun was born in the country of Valhalla",
|
|
]
|
|
return {"text": random.choice(random_answers)}
|
|
|
|
|
|
def do_onyx_search(query: str) -> dict[str, str]:
|
|
"""
|
|
Perform a search using the SearchTool and return the results.
|
|
|
|
Args:
|
|
query: The search query string
|
|
|
|
Returns:
|
|
Dictionary containing the search results text
|
|
"""
|
|
retrieved_docs: list[InferenceSection] = []
|
|
primary_llm, fast_llm = get_default_llms()
|
|
try:
|
|
with get_session_with_current_tenant() as db_session:
|
|
# Create a default persona with basic settings
|
|
default_persona = Persona(
|
|
name="default",
|
|
chunks_above=2,
|
|
chunks_below=2,
|
|
description="Default persona for search",
|
|
)
|
|
|
|
search_tool = SearchTool(
|
|
db_session=db_session,
|
|
user=None,
|
|
persona=default_persona,
|
|
retrieval_options=None,
|
|
prompt_config=PromptConfig(
|
|
system_prompt="You are a helpful assistant.",
|
|
task_prompt="Answer the user's question based on the provided context.",
|
|
datetime_aware=True,
|
|
include_citations=True,
|
|
),
|
|
llm=primary_llm,
|
|
fast_llm=fast_llm,
|
|
pruning_config=DocumentPruningConfig(),
|
|
answer_style_config=AnswerStyleConfig(
|
|
citation_config=CitationConfig(
|
|
include_citations=True, citation_style="inline"
|
|
)
|
|
),
|
|
evaluation_type=LLMEvaluationType.SKIP,
|
|
)
|
|
|
|
for tool_response in search_tool.run(
|
|
query=query,
|
|
override_kwargs=SearchToolOverrideKwargs(
|
|
force_no_rerank=True,
|
|
alternate_db_session=db_session,
|
|
retrieved_sections_callback=None,
|
|
skip_query_analysis=False,
|
|
),
|
|
):
|
|
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
|
response = cast(SearchResponseSummary, tool_response.response)
|
|
retrieved_docs = response.top_sections
|
|
break
|
|
|
|
# Combine the retrieved documents into a single text
|
|
combined_text = "\n\n".join(
|
|
[doc.combined_content for doc in retrieved_docs[:MAX_RETRIEVED_DOCS]]
|
|
)
|
|
return {"text": combined_text}
|
|
except Exception as e:
|
|
logger.error(f"Error in do_onyx_search: {e}")
|
|
return {"text": "Error in search, no results returned"}
|
|
|
|
|
|
def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState:
|
|
"""
|
|
LangGraph node that generates a search queries based on the User's question.
|
|
|
|
Uses an LLM to create an optimized search query for onyx research based on
|
|
the User's question.
|
|
|
|
Args:
|
|
state: Current graph state containing the User's question
|
|
config: Configuration for the runnable
|
|
|
|
Returns:
|
|
Dictionary with state update, including search_query key containing the generated query
|
|
"""
|
|
configurable = DeepResearchConfiguration.from_runnable_config(config)
|
|
|
|
# check for custom initial search query count
|
|
if state.get("initial_search_query_count") is None:
|
|
state["initial_search_query_count"] = configurable.number_of_initial_queries
|
|
|
|
primary_llm, fast_llm = get_default_llms()
|
|
llm = primary_llm if configurable.query_generator_model == "primary" else fast_llm
|
|
|
|
# Format the prompt
|
|
current_date = get_current_date()
|
|
formatted_prompt = query_writer_instructions.format(
|
|
current_date=current_date,
|
|
research_topic=get_research_topic(state["messages"]),
|
|
number_queries=state["initial_search_query_count"],
|
|
company_name=COMPANY_NAME,
|
|
company_context=COMPANY_CONTEXT,
|
|
user_context=collate_messages(state["messages"]),
|
|
)
|
|
|
|
# Get the LLM response and extract its content
|
|
llm_response = llm.invoke(formatted_prompt)
|
|
try:
|
|
result = json_to_pydantic(llm_response.content, SearchQueryList)
|
|
return {"query_list": result.query}
|
|
except JSONDecodeError:
|
|
return {"query_list": [llm_response.content]}
|
|
|
|
|
|
def continue_to_onyx_research(state: QueryGenerationState) -> OverallState:
|
|
"""
|
|
LangGraph node that sends the search queries to the onyx research node.
|
|
|
|
This is used to spawn n number of onyx research nodes, one for each search query.
|
|
"""
|
|
return [
|
|
Send("onyx_research", {"search_query": search_query, "id": int(idx)})
|
|
for idx, search_query in enumerate(state["query_list"])
|
|
]
|
|
|
|
|
|
def onyx_research(state: OnyxSearchState, config: RunnableConfig) -> OverallState:
|
|
"""LangGraph node that performs onyx research using onyx search interface.
|
|
|
|
Executes an onyx search in combination with an llm.
|
|
|
|
Args:
|
|
state: Current graph state containing the search query and research loop count
|
|
config: Configuration for the runnable, including any search API settings or llm settings
|
|
|
|
Returns:
|
|
Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
|
|
"""
|
|
# TODO: think about whether we should use any filtered returned results in addition to the final text answer
|
|
response = do_onyx_search(state["search_query"])
|
|
|
|
text = response["text"]
|
|
sources_gathered = []
|
|
|
|
return {
|
|
"sources_gathered": sources_gathered,
|
|
"search_query": [state["search_query"]],
|
|
"onyx_research_result": [text],
|
|
}
|
|
|
|
|
|
def get_combined_summaries(state: OverallState, llm=None) -> str:
|
|
if llm is None:
|
|
_, llm = get_default_llms()
|
|
|
|
# Calculate tokens and trim if needed
|
|
tokenizer = get_tokenizer(
|
|
provider_type=llm.config.model_provider, model_name=llm.config.model_name
|
|
)
|
|
|
|
# Combine summaries and check token count
|
|
combined_summaries = "\n\n---\n\n".join(state["onyx_research_result"])
|
|
combined_summaries = tokenizer_trim_content(
|
|
content=combined_summaries, desired_length=10000, tokenizer=tokenizer
|
|
)
|
|
return combined_summaries
|
|
|
|
|
|
def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
|
|
"""LangGraph node that identifies knowledge gaps and generates potential follow-up queries.
|
|
|
|
Analyzes the current summary to identify areas for further research and generates
|
|
potential follow-up queries. Uses structured output to extract
|
|
the follow-up query in JSON format.
|
|
|
|
Args:
|
|
state: Current graph state containing the running summary and research topic
|
|
config: Configuration for the runnable, including LLM settings
|
|
|
|
Returns:
|
|
Dictionary with state update, including search_query key containing the generated follow-up query
|
|
"""
|
|
configurable = DeepResearchConfiguration.from_runnable_config(config)
|
|
# Increment the research loop count and get the reasoning model
|
|
state["research_loop_count"] = state.get("research_loop_count", 0) + 1
|
|
|
|
# Get the LLM to use for token counting
|
|
primary_llm, fast_llm = get_default_llms()
|
|
llm = primary_llm if configurable.reflection_model == "primary" else fast_llm
|
|
combined_summaries = get_combined_summaries(state, llm)
|
|
|
|
# Format the prompt
|
|
# First, collate the messages to give a historical context of the current conversation
|
|
# Then, produce a concatenation of the onyx research results
|
|
# Then, pass this to the reflection instructions
|
|
# Then, the LLM will produce a JSON response with the following fields:
|
|
# - is_sufficient: boolean indicating if the research is sufficient
|
|
# - knowledge_gap: string describing the knowledge gap
|
|
# - follow_up_queries: list of follow-up queries
|
|
current_date = get_current_date()
|
|
formatted_prompt = reflection_instructions.format(
|
|
current_date=current_date,
|
|
research_topic=get_research_topic(state["messages"]),
|
|
summaries=combined_summaries,
|
|
company_name=COMPANY_NAME,
|
|
company_context=COMPANY_CONTEXT,
|
|
)
|
|
|
|
# Get result from LLM
|
|
result = json_to_pydantic(llm.invoke(formatted_prompt).content, Reflection)
|
|
|
|
# TODO: convert to pydantic here
|
|
return {
|
|
"is_sufficient": result.is_sufficient,
|
|
"knowledge_gap": result.knowledge_gap,
|
|
"follow_up_queries": result.follow_up_queries,
|
|
"research_loop_count": state["research_loop_count"],
|
|
"number_of_ran_queries": len(state["search_query"]),
|
|
}
|
|
|
|
|
|
def strtobool(val):
|
|
"""Convert a string representation of truth to true (1) or false (0).
|
|
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
|
|
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
|
|
'val' is anything else.
|
|
"""
|
|
if isinstance(val, bool):
|
|
return val
|
|
val = val.lower()
|
|
if val in ("y", "yes", "t", "true", "on", "1"):
|
|
return 1
|
|
elif val in ("n", "no", "f", "false", "off", "0"):
|
|
return 0
|
|
else:
|
|
raise ValueError("invalid truth value %r" % (val,))
|
|
|
|
|
|
def evaluate_research(
|
|
state: ReflectionState,
|
|
config: RunnableConfig,
|
|
) -> OverallState:
|
|
"""LangGraph routing function that determines the next step in the research flow.
|
|
|
|
Controls the research loop by deciding whether to continue gathering information
|
|
or to finalize the summary based on the configured maximum number of research loops.
|
|
|
|
Args:
|
|
state: Current graph state containing the research loop count
|
|
config: Configuration for the runnable, including max_research_loops setting
|
|
|
|
Returns:
|
|
String literal indicating the next node to visit ("onyx_research" or "finalize_summary")
|
|
"""
|
|
configurable = DeepResearchConfiguration.from_runnable_config(config)
|
|
max_research_loops = (
|
|
state.get("max_research_loops")
|
|
if state.get("max_research_loops") is not None
|
|
else configurable.max_research_loops
|
|
)
|
|
if (
|
|
strtobool(state["is_sufficient"]) is True
|
|
or state["research_loop_count"] >= max_research_loops
|
|
):
|
|
return "finalize_answer"
|
|
else:
|
|
return [
|
|
Send(
|
|
"onyx_research",
|
|
{
|
|
"search_query": follow_up_query,
|
|
"id": state["number_of_ran_queries"] + int(idx),
|
|
},
|
|
)
|
|
for idx, follow_up_query in enumerate(state["follow_up_queries"])
|
|
]
|
|
|
|
|
|
def finalize_answer(state: OverallState, config: RunnableConfig):
|
|
"""LangGraph node that finalizes the research summary.
|
|
|
|
Prepares the final result based on the onyx research results.
|
|
|
|
Args:
|
|
state: Current graph state containing the running summary and sources gathered
|
|
|
|
Returns:
|
|
Dictionary with state update, including running_summary key containing the formatted final summary with sources
|
|
"""
|
|
configurable = DeepResearchConfiguration.from_runnable_config(config)
|
|
answer_model = state.get("answer_model") or configurable.answer_model
|
|
|
|
# get the LLM to generate the final answer
|
|
primary_llm, fast_llm = get_default_llms()
|
|
llm = primary_llm if answer_model == "primary" else fast_llm
|
|
combined_summaries = get_combined_summaries(state, llm)
|
|
|
|
# Format the prompt
|
|
current_date = get_current_date()
|
|
formatted_prompt = answer_instructions.format(
|
|
current_date=current_date,
|
|
research_topic=get_research_topic(state["messages"]),
|
|
summaries=combined_summaries,
|
|
company_name=COMPANY_NAME,
|
|
company_context=COMPANY_CONTEXT,
|
|
user_context=collate_messages(state["messages"]),
|
|
)
|
|
|
|
result = llm.invoke(formatted_prompt)
|
|
unique_sources = []
|
|
return {
|
|
"messages": [AIMessage(content=result.content)],
|
|
"sources_gathered": unique_sources,
|
|
}
|
|
|
|
|
|
def deep_research_graph_builder(test_mode: bool = False) -> StateGraph:
|
|
"""
|
|
LangGraph graph builder for deep research process.
|
|
"""
|
|
|
|
graph = StateGraph(
|
|
OverallState,
|
|
config_schema=DeepResearchConfiguration,
|
|
)
|
|
|
|
### Add nodes ###
|
|
|
|
graph.add_node("generate_query", generate_query)
|
|
graph.add_node("onyx_research", onyx_research)
|
|
graph.add_node("reflection", reflection)
|
|
graph.add_node("finalize_answer", finalize_answer)
|
|
|
|
# Set the entrypoint as `generate_query`
|
|
graph.add_edge(START, "generate_query")
|
|
# Add conditional edge to continue with search queries in a parallel branch
|
|
graph.add_conditional_edges(
|
|
"generate_query", continue_to_onyx_research, ["onyx_research"]
|
|
)
|
|
# Reflect on the onyx research
|
|
graph.add_edge("onyx_research", "reflection")
|
|
# Evaluate the research
|
|
graph.add_conditional_edges(
|
|
"reflection", evaluate_research, ["onyx_research", "finalize_answer"]
|
|
)
|
|
# Finalize the answer
|
|
graph.add_edge("finalize_answer", END)
|
|
|
|
return graph
|
|
|
|
|
|
def translate_task_to_query(
|
|
task: str,
|
|
context=None,
|
|
company_name=COMPANY_NAME,
|
|
company_context=COMPANY_CONTEXT,
|
|
initial_question=None,
|
|
) -> str:
|
|
"""
|
|
LangGraph node that translates a task to a query.
|
|
"""
|
|
_, fast_llm = get_default_llms()
|
|
|
|
formatted_prompt = task_to_query_prompt.format(
|
|
initial_question=initial_question,
|
|
task=task,
|
|
context=context,
|
|
company_name=company_name,
|
|
company_context=company_context,
|
|
)
|
|
return fast_llm.invoke(formatted_prompt).content
|
|
|
|
|
|
def is_search_query(query: str) -> bool:
|
|
terms = [
|
|
"search",
|
|
"query",
|
|
"find",
|
|
"look up",
|
|
"look for",
|
|
"find out",
|
|
"find information",
|
|
"find data",
|
|
"find facts",
|
|
"find statistics",
|
|
"find trends",
|
|
"find insights",
|
|
"find trends",
|
|
"find insights",
|
|
"find trends",
|
|
"find insights",
|
|
"gather",
|
|
"gather information",
|
|
"gather data",
|
|
"gather facts",
|
|
"gather statistics",
|
|
"gather trends",
|
|
"gather insights",
|
|
]
|
|
query = query.lower()
|
|
for term in terms:
|
|
if term in query:
|
|
return True
|
|
return False
|
|
|
|
|
|
def execute_step(
|
|
state: PlanExecute, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
|
):
|
|
"""
|
|
LangGraph node that plans the deep research process.
|
|
"""
|
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
|
question = graph_config.inputs.prompt_builder.raw_user_query
|
|
plan = state["plan"]
|
|
task = plan[0]
|
|
step_count = state.get("step_count", 0) + 1
|
|
if is_search_query(task):
|
|
query = translate_task_to_query(
|
|
plan[0], context=state["past_steps"], initial_question=question
|
|
)
|
|
write_custom_event(
|
|
"refined_agent_answer",
|
|
AgentAnswerPiece(
|
|
answer_piece=" executing a search query with Onyx...",
|
|
level=0,
|
|
level_question_num=0,
|
|
answer_type="agent_level_answer",
|
|
),
|
|
writer,
|
|
)
|
|
graph = deep_research_graph_builder()
|
|
compiled_graph = graph.compile(debug=IS_DEBUG)
|
|
# TODO: use this input_state for the deep research graph
|
|
# input_state = DeepResearchInput(log_messages=[])
|
|
|
|
initial_state = {
|
|
"messages": [HumanMessage(content=query)],
|
|
"search_query": [],
|
|
"onyx_research_result": [],
|
|
"sources_gathered": [],
|
|
"initial_search_query_count": 3, # Default value from Configuration
|
|
"max_research_loops": 2, # State does not seem to pick up this value
|
|
"research_loop_count": 0,
|
|
"reasoning_model": "primary",
|
|
}
|
|
|
|
result = compiled_graph.invoke(initial_state)
|
|
|
|
return {
|
|
"past_steps": [(task, query, result["messages"][-1].content)],
|
|
"step_count": step_count,
|
|
}
|
|
else:
|
|
primary_llm, _ = get_default_llms()
|
|
formatted_prompt = task_completion_prompt.format(
|
|
task=task,
|
|
plan=state["plan"],
|
|
past_steps=state["past_steps"],
|
|
company_name=COMPANY_NAME,
|
|
company_context=COMPANY_CONTEXT,
|
|
)
|
|
write_custom_event(
|
|
"refined_agent_answer",
|
|
AgentAnswerPiece(
|
|
answer_piece=" accomplishing the planned task...",
|
|
level=0,
|
|
level_question_num=0,
|
|
answer_type="agent_level_answer",
|
|
),
|
|
writer,
|
|
)
|
|
response = primary_llm.invoke(formatted_prompt).content
|
|
return {
|
|
"past_steps": [(task, "task: " + task, response)],
|
|
"step_count": step_count,
|
|
}
|
|
|
|
|
|
def plan_step(
|
|
state: PlanExecute, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
|
):
|
|
"""
|
|
LangGraph node that replans the deep research process.
|
|
"""
|
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
|
question = graph_config.inputs.prompt_builder.raw_user_query
|
|
|
|
formatted_prompt = planner_prompt.format(
|
|
input=question, company_name=COMPANY_NAME, company_context=COMPANY_CONTEXT
|
|
)
|
|
primary_llm, _ = get_default_llms()
|
|
response = primary_llm.invoke(formatted_prompt).content
|
|
plan = json_to_pydantic(response, Plan)
|
|
write_custom_event(
|
|
"refined_agent_answer",
|
|
AgentAnswerPiece(
|
|
answer_piece="Generating a plan to answer the user's question... ",
|
|
level=0,
|
|
level_question_num=0,
|
|
answer_type="agent_level_answer",
|
|
),
|
|
writer,
|
|
)
|
|
return {"plan": plan.steps}
|
|
|
|
|
|
def replan_step(
|
|
state: PlanExecute, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
|
):
|
|
"""
|
|
LangGraph node that determines if the deep research process should end.
|
|
"""
|
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
|
question = graph_config.inputs.prompt_builder.raw_user_query
|
|
|
|
formatted_prompt = replanner_prompt.format(
|
|
input=question,
|
|
plan=state["plan"],
|
|
past_steps=state["past_steps"],
|
|
company_name=COMPANY_NAME,
|
|
company_context=COMPANY_CONTEXT,
|
|
)
|
|
primary_llm, _ = get_default_llms()
|
|
response = primary_llm.invoke(formatted_prompt).content
|
|
output = json_to_pydantic(response, Act)
|
|
# TODO: add a check for time limit too
|
|
if isinstance(output.action, Response):
|
|
# Check for canned response, if so, return the answer from the last step
|
|
if output.action.response == "The final answer to the user's question":
|
|
return {
|
|
"response": state["past_steps"][-1][2],
|
|
}
|
|
else:
|
|
return {"response": output.action.response}
|
|
elif state["step_count"] >= state.get("max_steps", 5):
|
|
return {
|
|
"response": f"I've reached the maximum number of step, my best guess is {state['past_steps'][-1][2]}."
|
|
}
|
|
else:
|
|
write_custom_event(
|
|
"refined_agent_answer",
|
|
AgentAnswerPiece(
|
|
answer_piece=" moving on to the next step...",
|
|
level=0,
|
|
level_question_num=0,
|
|
answer_type="agent_level_answer",
|
|
),
|
|
writer,
|
|
)
|
|
|
|
return {"plan": output.action.steps}
|
|
|
|
|
|
def should_end(
|
|
state: PlanExecute, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
|
):
|
|
if "response" in state and state["response"]:
|
|
write_custom_event(
|
|
"refined_agent_answer",
|
|
AgentAnswerPiece(
|
|
answer_piece=state["response"],
|
|
level=0,
|
|
level_question_num=0,
|
|
answer_type="agent_level_answer",
|
|
),
|
|
writer,
|
|
)
|
|
return END
|
|
else:
|
|
return "agent"
|
|
|
|
|
|
def deep_planner_graph_builder(test_mode: bool = False) -> StateGraph:
|
|
"""
|
|
LangGraph graph builder for deep planner process.
|
|
"""
|
|
workflow = StateGraph(PlanExecute, config_schema=DeepPlannerConfiguration)
|
|
|
|
# Add the plan node
|
|
workflow.add_node("planner", plan_step)
|
|
|
|
# Add the execution step
|
|
workflow.add_node("agent", execute_step)
|
|
|
|
# Add a replan node
|
|
workflow.add_node("replan", replan_step)
|
|
|
|
workflow.add_edge(START, "planner")
|
|
|
|
# From plan we go to agent
|
|
workflow.add_edge("planner", "agent")
|
|
|
|
# From agent, we replan
|
|
workflow.add_edge("agent", "replan")
|
|
|
|
workflow.add_conditional_edges(
|
|
"replan",
|
|
should_end,
|
|
["agent", END],
|
|
)
|
|
return workflow
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Initialize the SQLAlchemy engine first
|
|
from onyx.db.engine import SqlEngine
|
|
|
|
SqlEngine.init_engine(
|
|
pool_size=5, # You can adjust these values based on your needs
|
|
max_overflow=10,
|
|
app_name="graph_builder",
|
|
)
|
|
|
|
# Set the debug and verbose flags for Langchain/Langgraph
|
|
set_debug(IS_DEBUG)
|
|
set_verbose(IS_VERBOSE)
|
|
|
|
# Compile the graph
|
|
query_start_time = datetime.now()
|
|
logger.debug(f"Start at {query_start_time}")
|
|
graph = deep_planner_graph_builder()
|
|
compiled_graph = graph.compile(debug=IS_DEBUG)
|
|
query_end_time = datetime.now()
|
|
logger.debug(f"Graph compiled in {query_end_time - query_start_time} seconds")
|
|
|
|
queries = [
|
|
"What is the capital of France?",
|
|
"What is Onyx?",
|
|
"Who are the founders of Onyx?",
|
|
"Who is the CEO of Onyx?",
|
|
"Where was the CEO of Onyx born?",
|
|
"What is the highest contract value for last month?",
|
|
"What is the most expensive component of our technical pipeline so far?",
|
|
"Who are top 5 competitors who are not US based?",
|
|
"What companies should we focus on to maximize our revenue?",
|
|
"What are some of the biggest problems for our customers and their potential solutions?",
|
|
]
|
|
|
|
hard_queries = [
|
|
"Where was the CEO of Onyx born?",
|
|
"Who are top 5 competitors who are not US based?",
|
|
"What companies should we focus on to maximize our revenue?",
|
|
"What are some of the biggest problems for our customers and their potential solutions?",
|
|
]
|
|
|
|
for query in hard_queries:
|
|
# Create the initial state with all required fields
|
|
initial_state = {
|
|
"input": [HumanMessage(content=query)],
|
|
"plan": [],
|
|
"past_steps": [],
|
|
"response": "",
|
|
"max_steps": 10,
|
|
"step_count": 0,
|
|
}
|
|
|
|
result = compiled_graph.invoke(initial_state)
|
|
print("Max planning loops: ", result["max_steps"])
|
|
print("Steps: ", result["step_count"])
|
|
print("Past steps: ")
|
|
pprint(result["past_steps"], indent=4)
|
|
print("Question: ", query)
|
|
print("Answer: ", result["response"])
|
|
print("--------------------------------")
|
|
# from pdb import set_trace
|
|
# set_trace()
|