From 5ca898bde2b5c8f5424ad1142bb1ae3e225f4a62 Mon Sep 17 00:00:00 2001 From: evan-danswer Date: Mon, 17 Feb 2025 13:01:24 -0800 Subject: [PATCH] Force use tool overrides (#4024) * initial rename + timeout bump * querry override --- .../agent_search/basic/graph_builder.py | 26 +++++++++---------- .../agent_search/deep_search/main/edges.py | 4 +-- .../deep_search/main/graph_builder.py | 20 +++++++------- .../nodes/{tool_call.py => call_tool.py} | 2 +- .../{llm_tool_choice.py => choose_tool.py} | 2 +- ..._tool_response.py => use_tool_response.py} | 0 backend/onyx/chat/answer.py | 13 ++++++++++ backend/onyx/chat/process_message.py | 2 +- backend/onyx/configs/agent_configs.py | 2 +- .../search/search_tool.py | 9 ++++--- 10 files changed, 47 insertions(+), 33 deletions(-) rename backend/onyx/agents/agent_search/orchestration/nodes/{tool_call.py => call_tool.py} (99%) rename backend/onyx/agents/agent_search/orchestration/nodes/{llm_tool_choice.py => choose_tool.py} (99%) rename backend/onyx/agents/agent_search/orchestration/nodes/{basic_use_tool_response.py => use_tool_response.py} (100%) diff --git a/backend/onyx/agents/agent_search/basic/graph_builder.py b/backend/onyx/agents/agent_search/basic/graph_builder.py index 8e09b62e16..07689568a9 100644 --- a/backend/onyx/agents/agent_search/basic/graph_builder.py +++ b/backend/onyx/agents/agent_search/basic/graph_builder.py @@ -5,14 +5,14 @@ from langgraph.graph import StateGraph from onyx.agents.agent_search.basic.states import BasicInput from onyx.agents.agent_search.basic.states import BasicOutput from onyx.agents.agent_search.basic.states import BasicState -from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import ( - basic_use_tool_response, -) -from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice +from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool +from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import ( prepare_tool_input, ) -from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call +from onyx.agents.agent_search.orchestration.nodes.use_tool_response import ( + basic_use_tool_response, +) from onyx.utils.logger import setup_logger logger = setup_logger() @@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph: ) graph.add_node( - node="llm_tool_choice", - action=llm_tool_choice, + node="choose_tool", + action=choose_tool, ) graph.add_node( - node="tool_call", - action=tool_call, + node="call_tool", + action=call_tool, ) graph.add_node( @@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph: graph.add_edge(start_key=START, end_key="prepare_tool_input") - graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice") + graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool") - graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END]) + graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END]) graph.add_edge( - start_key="tool_call", + start_key="call_tool", end_key="basic_use_tool_response", ) @@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str: # If there are no tool calls, basic graph already streamed the answer END if state.tool_choice is None - else "tool_call" + else "call_tool" ) diff --git a/backend/onyx/agents/agent_search/deep_search/main/edges.py b/backend/onyx/agents/agent_search/deep_search/main/edges.py index 3f9d5a873f..79989e8c9d 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/edges.py +++ b/backend/onyx/agents/agent_search/deep_search/main/edges.py @@ -25,7 +25,7 @@ logger = setup_logger() def route_initial_tool_choice( state: MainState, config: RunnableConfig -) -> Literal["tool_call", "start_agent_search", "logging_node"]: +) -> Literal["call_tool", "start_agent_search", "logging_node"]: """ LangGraph edge to route to agent search. """ @@ -38,7 +38,7 @@ def route_initial_tool_choice( ): return "start_agent_search" else: - return "tool_call" + return "call_tool" else: return "logging_node" diff --git a/backend/onyx/agents/agent_search/deep_search/main/graph_builder.py b/backend/onyx/agents/agent_search/deep_search/main/graph_builder.py index 1af167ae9d..75e23a7920 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/graph_builder.py +++ b/backend/onyx/agents/agent_search/deep_search/main/graph_builder.py @@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import ( answer_refined_query_graph_builder, ) -from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import ( - basic_use_tool_response, -) -from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice +from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool +from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import ( prepare_tool_input, ) -from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call +from onyx.agents.agent_search.orchestration.nodes.use_tool_response import ( + basic_use_tool_response, +) from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config from onyx.utils.logger import setup_logger @@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: # Choose the initial tool graph.add_node( node="initial_tool_choice", - action=llm_tool_choice, + action=choose_tool, ) # Call the tool, if required graph.add_node( - node="tool_call", - action=tool_call, + node="call_tool", + action=call_tool, ) # Use the tool response @@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_conditional_edges( "initial_tool_choice", route_initial_tool_choice, - ["tool_call", "start_agent_search", "logging_node"], + ["call_tool", "start_agent_search", "logging_node"], ) graph.add_edge( - start_key="tool_call", + start_key="call_tool", end_key="basic_use_tool_response", ) graph.add_edge( diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py b/backend/onyx/agents/agent_search/orchestration/nodes/call_tool.py similarity index 99% rename from backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py rename to backend/onyx/agents/agent_search/orchestration/nodes/call_tool.py index 17f8411e13..5265d5a61d 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/call_tool.py @@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None: write_custom_event("basic_response", packet, writer) -def tool_call( +def call_tool( state: ToolChoiceUpdate, config: RunnableConfig, writer: StreamWriter = lambda _: None, diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/llm_tool_choice.py b/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py similarity index 99% rename from backend/onyx/agents/agent_search/orchestration/nodes/llm_tool_choice.py rename to backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py index 22f646a719..f7fdd71e50 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/llm_tool_choice.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py @@ -25,7 +25,7 @@ logger = setup_logger() # and a function that handles extracting the necessary fields # from the state and config # TODO: fan-out to multiple tool call nodes? Make this configurable? -def llm_tool_choice( +def choose_tool( state: ToolChoiceState, config: RunnableConfig, writer: StreamWriter = lambda _: None, diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/basic_use_tool_response.py b/backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py similarity index 100% rename from backend/onyx/agents/agent_search/orchestration/nodes/basic_use_tool_response.py rename to backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 118c7aaf54..eb9b2130dd 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -27,6 +27,7 @@ from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM from onyx.tools.force import ForceUseTool from onyx.tools.tool import Tool +from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.utils import explicit_tool_calling_supported from onyx.utils.gpu_utils import gpu_status_request @@ -89,6 +90,18 @@ class Answer: ) allow_agent_reranking = gpu_status_request() or using_cloud_reranking + # TODO: this is a hack to force the query to be used for the search tool + # this should be removed once we fully unify graph inputs (i.e. + # remove SearchQuery entirely) + if ( + force_use_tool.force_use + and search_tool + and force_use_tool.args + and force_use_tool.tool_name == search_tool.name + and QUERY_FIELD in force_use_tool.args + ): + search_request.query = force_use_tool.args[QUERY_FIELD] + self.graph_inputs = GraphInputs( search_request=search_request, prompt_builder=prompt_builder, diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index d1e5fbea60..f04be63d93 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -7,7 +7,7 @@ from typing import cast from sqlalchemy.orm import Session -from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException +from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException from onyx.chat.answer import Answer from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import create_temporary_persona diff --git a/backend/onyx/configs/agent_configs.py b/backend/onyx/configs/agent_configs.py index cf028fee9f..0f5d8e60c9 100644 --- a/backend/onyx/configs/agent_configs.py +++ b/backend/onyx/configs/agent_configs.py @@ -223,7 +223,7 @@ AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int( or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION ) -AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 15 # in seconds +AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 11d147526a..4b556e4711 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -58,6 +58,7 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary" SEARCH_DOC_CONTENT_ID = "search_doc_content" SECTION_RELEVANCE_LIST_ID = "section_relevance_list" SEARCH_EVALUATION_ID = "llm_doc_eval" +QUERY_FIELD = "query" class SearchResponseSummary(SearchQueryInfo): @@ -179,12 +180,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]): "parameters": { "type": "object", "properties": { - "query": { + QUERY_FIELD: { "type": "string", "description": "What to search for", }, }, - "required": ["query"], + "required": [QUERY_FIELD], }, }, } @@ -223,7 +224,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]): rephrased_query = history_based_query_rephrase( query=query, history=history, llm=llm ) - return {"query": rephrased_query} + return {QUERY_FIELD: rephrased_query} """Actual tool execution""" @@ -279,7 +280,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]): def run( self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any ) -> Generator[ToolResponse, None, None]: - query = cast(str, llm_kwargs["query"]) + query = cast(str, llm_kwargs[QUERY_FIELD]) force_no_rerank = False alternate_db_session = None retrieved_sections_callback = None