Force use tool overrides (#4024)

* initial rename + timeout bump

* querry override
This commit is contained in:
evan-danswer 2025-02-17 13:01:24 -08:00 committed by GitHub
parent 58b252727f
commit 5ca898bde2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 47 additions and 33 deletions

View File

@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import ( from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
basic_use_tool_response, from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import ( from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input, prepare_tool_input,
) )
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
) )
graph.add_node( graph.add_node(
node="llm_tool_choice", node="choose_tool",
action=llm_tool_choice, action=choose_tool,
) )
graph.add_node( graph.add_node(
node="tool_call", node="call_tool",
action=tool_call, action=call_tool,
) )
graph.add_node( graph.add_node(
@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
graph.add_edge(start_key=START, end_key="prepare_tool_input") graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice") graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END]) graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
graph.add_edge( graph.add_edge(
start_key="tool_call", start_key="call_tool",
end_key="basic_use_tool_response", end_key="basic_use_tool_response",
) )
@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
# If there are no tool calls, basic graph already streamed the answer # If there are no tool calls, basic graph already streamed the answer
END END
if state.tool_choice is None if state.tool_choice is None
else "tool_call" else "call_tool"
) )

View File

@ -25,7 +25,7 @@ logger = setup_logger()
def route_initial_tool_choice( def route_initial_tool_choice(
state: MainState, config: RunnableConfig state: MainState, config: RunnableConfig
) -> Literal["tool_call", "start_agent_search", "logging_node"]: ) -> Literal["call_tool", "start_agent_search", "logging_node"]:
""" """
LangGraph edge to route to agent search. LangGraph edge to route to agent search.
""" """
@ -38,7 +38,7 @@ def route_initial_tool_choice(
): ):
return "start_agent_search" return "start_agent_search"
else: else:
return "tool_call" return "call_tool"
else: else:
return "logging_node" return "logging_node"

View File

@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import ( from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
answer_refined_query_graph_builder, answer_refined_query_graph_builder,
) )
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import ( from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
basic_use_tool_response, from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import ( from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input, prepare_tool_input,
) )
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
# Choose the initial tool # Choose the initial tool
graph.add_node( graph.add_node(
node="initial_tool_choice", node="initial_tool_choice",
action=llm_tool_choice, action=choose_tool,
) )
# Call the tool, if required # Call the tool, if required
graph.add_node( graph.add_node(
node="tool_call", node="call_tool",
action=tool_call, action=call_tool,
) )
# Use the tool response # Use the tool response
@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph.add_conditional_edges( graph.add_conditional_edges(
"initial_tool_choice", "initial_tool_choice",
route_initial_tool_choice, route_initial_tool_choice,
["tool_call", "start_agent_search", "logging_node"], ["call_tool", "start_agent_search", "logging_node"],
) )
graph.add_edge( graph.add_edge(
start_key="tool_call", start_key="call_tool",
end_key="basic_use_tool_response", end_key="basic_use_tool_response",
) )
graph.add_edge( graph.add_edge(

View File

@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
write_custom_event("basic_response", packet, writer) write_custom_event("basic_response", packet, writer)
def tool_call( def call_tool(
state: ToolChoiceUpdate, state: ToolChoiceUpdate,
config: RunnableConfig, config: RunnableConfig,
writer: StreamWriter = lambda _: None, writer: StreamWriter = lambda _: None,

View File

@ -25,7 +25,7 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields # and a function that handles extracting the necessary fields
# from the state and config # from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable? # TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice( def choose_tool(
state: ToolChoiceState, state: ToolChoiceState,
config: RunnableConfig, config: RunnableConfig,
writer: StreamWriter = lambda _: None, writer: StreamWriter = lambda _: None,

View File

@ -27,6 +27,7 @@ from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import gpu_status_request from onyx.utils.gpu_utils import gpu_status_request
@ -89,6 +90,18 @@ class Answer:
) )
allow_agent_reranking = gpu_status_request() or using_cloud_reranking allow_agent_reranking = gpu_status_request() or using_cloud_reranking
# TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e.
# remove SearchQuery entirely)
if (
force_use_tool.force_use
and search_tool
and force_use_tool.args
and force_use_tool.tool_name == search_tool.name
and QUERY_FIELD in force_use_tool.args
):
search_request.query = force_use_tool.args[QUERY_FIELD]
self.graph_inputs = GraphInputs( self.graph_inputs = GraphInputs(
search_request=search_request, search_request=search_request,
prompt_builder=prompt_builder, prompt_builder=prompt_builder,

View File

@ -7,7 +7,7 @@ from typing import cast
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
from onyx.chat.answer import Answer from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona from onyx.chat.chat_utils import create_temporary_persona

View File

@ -223,7 +223,7 @@ AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
) )
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 15 # in seconds AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int( AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION") os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION

View File

@ -58,6 +58,7 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content" SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list" SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
SEARCH_EVALUATION_ID = "llm_doc_eval" SEARCH_EVALUATION_ID = "llm_doc_eval"
QUERY_FIELD = "query"
class SearchResponseSummary(SearchQueryInfo): class SearchResponseSummary(SearchQueryInfo):
@ -179,12 +180,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"query": { QUERY_FIELD: {
"type": "string", "type": "string",
"description": "What to search for", "description": "What to search for",
}, },
}, },
"required": ["query"], "required": [QUERY_FIELD],
}, },
}, },
} }
@ -223,7 +224,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
rephrased_query = history_based_query_rephrase( rephrased_query = history_based_query_rephrase(
query=query, history=history, llm=llm query=query, history=history, llm=llm
) )
return {"query": rephrased_query} return {QUERY_FIELD: rephrased_query}
"""Actual tool execution""" """Actual tool execution"""
@ -279,7 +280,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
def run( def run(
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]: ) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs["query"]) query = cast(str, llm_kwargs[QUERY_FIELD])
force_no_rerank = False force_no_rerank = False
alternate_db_session = None alternate_db_session = None
retrieved_sections_callback = None retrieved_sections_callback = None