implemented top-level tool calling + force search

This commit is contained in:
Evan Lohn 2025-01-23 19:06:22 -08:00
parent 982040c792
commit ddbfc65ad0
11 changed files with 125 additions and 49 deletions

@ -2,14 +2,14 @@ from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.basic.nodes.prepare_tool_input import prepare_tool_input
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.prepare_tool_input import prepare_tool_input
from onyx.agents.agent_search.orchestration.tool_call import tool_call
from onyx.utils.logger import setup_logger

@ -1,7 +1,9 @@
from collections.abc import Hashable
from datetime import datetime
from typing import cast
from typing import Literal
from langchain_core.runnables import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
@ -14,12 +16,29 @@ from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.deep_search_a.main.states import (
RequireRefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["tool_call", "agent_search_start", "logging_node"]:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
if state.tool_choice is not None:
if (
agent_config.use_agentic_search
and state.tool_choice.tool.name == agent_config.search_tool.name
):
return "agent_search_start"
else:
return "tool_call"
else:
return "logging_node"
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:

@ -20,24 +20,18 @@ from onyx.agents.agent_search.deep_search_a.main.edges import (
from onyx.agents.agent_search.deep_search_a.main.edges import (
parallelize_refined_sub_question_answering,
)
from onyx.agents.agent_search.deep_search_a.main.edges import (
route_initial_tool_choice,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import (
agent_logging,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_decision import (
agent_path_decision,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_routing import (
agent_path_routing,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import (
agent_search_start,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.answer_comparison import (
answer_comparison,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.direct_llm_handling import (
direct_llm_handling,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
entity_term_extraction_llm,
)
@ -73,6 +67,12 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation i
)
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.orchestration.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.prepare_tool_input import prepare_tool_input
from onyx.agents.agent_search.orchestration.tool_call import tool_call
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
@ -87,21 +87,37 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
input=MainInput,
)
# graph.add_node(
# node="agent_path_decision",
# action=agent_path_decision,
# )
# graph.add_node(
# node="agent_path_routing",
# action=agent_path_routing,
# )
# graph.add_node(
# node="LLM",
# action=direct_llm_handling,
# )
graph.add_node(
node="agent_path_decision",
action=agent_path_decision,
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node(
node="initial_tool_choice",
action=llm_tool_choice,
)
graph.add_node(
node="tool_call",
action=tool_call,
)
graph.add_node(
node="agent_path_routing",
action=agent_path_routing,
node="basic_use_tool_response",
action=basic_use_tool_response,
)
graph.add_node(
node="LLM",
action=direct_llm_handling,
)
graph.add_node(
node="agent_search_start",
action=agent_search_start,
@ -205,14 +221,35 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
# graph.add_edge(
# start_key=START,
# end_key="agent_path_decision",
# )
# graph.add_edge(
# start_key="agent_path_decision",
# end_key="agent_path_routing",
# )
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(
start_key=START,
end_key="agent_path_decision",
start_key="prepare_tool_input",
end_key="initial_tool_choice",
)
graph.add_conditional_edges(
"initial_tool_choice",
route_initial_tool_choice,
["tool_call", "agent_search_start", "logging_node"],
)
graph.add_edge(
start_key="agent_path_decision",
end_key="agent_path_routing",
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key="logging_node",
)
graph.add_edge(
@ -245,10 +282,10 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="LLM",
end_key=END,
)
# graph.add_edge(
# start_key="LLM",
# end_key=END,
# )
# graph.add_edge(
# start_key=START,

@ -21,8 +21,8 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
agent_start_time = state.agent_start_time
agent_base_end_time = state.agent_base_end_time
agent_refined_start_time = state.agent_refined_start_time or None
agent_refined_end_time = state.agent_refined_end_time or None
agent_refined_start_time = state.agent_refined_start_time
agent_refined_end_time = state.agent_refined_end_time
agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None
@ -67,14 +67,15 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
# log the agent metrics
if agent_a_config.db_session is not None:
log_agent_metrics(
db_session=agent_a_config.db_session,
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=agent_start_time,
agent_metrics=combined_agent_metrics,
)
if agent_base_duration is not None:
log_agent_metrics(
db_session=agent_a_config.db_session,
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=agent_start_time,
agent_metrics=combined_agent_metrics,
)
if agent_a_config.use_persistence:
# Persist the sub-answer in the database

@ -12,6 +12,9 @@ from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
@ -133,6 +136,9 @@ class MainInput(CoreState):
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
BaseDecompUpdate,
InitialAnswerUpdate,
InitialAnswerBASEUpdate,

@ -39,7 +39,7 @@ def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoic
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
structured_response_format = agent_config.structured_response_format
tools = agent_config.tools or []
tools = [tool for tool in (agent_config.tools or []) if tool.name in state.tools]
force_use_tool = agent_config.force_use_tool
tool, tool_args = None, None

@ -1,15 +1,16 @@
from typing import Any
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
def prepare_tool_input(state: BasicState, config: RunnableConfig) -> ToolChoiceInput:
cast(AgentSearchConfig, config["metadata"]["config"])
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
return ToolChoiceInput(
should_stream_answer=True,
prompt_snapshot=None, # uses default prompt builder
tools=[tool.name for tool in (agent_config.tools or [])],
)

@ -8,12 +8,17 @@ from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
# creating a subgraph that can be invoked in parallel via Send/Command APIs
class ToolChoiceInput(BaseModel):
should_stream_answer: bool = True
# default to the prompt builder from the config, but
# allow overrides for arbitrary tool calls
prompt_snapshot: PromptSnapshot | None = None
# names of tools to use for tool calling. Filters the tools available in the config
tools: list[str] = []
class ToolCallOutput(BaseModel):
tool_call_summary: ToolCallSummary

@ -153,7 +153,11 @@ def generate_log_message(
def get_test_config(
db_session: Session, primary_llm: LLM, fast_llm: LLM, search_request: SearchRequest
db_session: Session,
primary_llm: LLM,
fast_llm: LLM,
search_request: SearchRequest,
use_agentic_search: bool = True,
) -> tuple[AgentSearchConfig, SearchTool]:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
document_pruning_config = DocumentPruningConfig(
@ -229,6 +233,7 @@ def get_test_config(
use_persistence=True,
db_session=db_session,
tools=[search_tool],
use_agentic_search=use_agentic_search,
)
return config, search_tool

@ -964,9 +964,11 @@ def log_agent_metrics(
start_time=start_time,
base_duration__s=agent_timings.base_duration__s,
full_duration__s=agent_timings.full_duration__s,
base_metrics=vars(agent_base_metrics),
refined_metrics=vars(agent_refined_metrics),
all_metrics=vars(agent_additional_metrics),
base_metrics=vars(agent_base_metrics) if agent_base_metrics else None,
refined_metrics=vars(agent_refined_metrics) if agent_refined_metrics else None,
all_metrics=vars(agent_additional_metrics)
if agent_additional_metrics
else None,
)
db_session.add(agent_metric_tracking)