mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-01 18:49:27 +02:00
Force use tool overrides (#4024)
* initial rename + timeout bump * querry override
This commit is contained in:
parent
58b252727f
commit
5ca898bde2
@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
|
|||||||
from onyx.agents.agent_search.basic.states import BasicInput
|
from onyx.agents.agent_search.basic.states import BasicInput
|
||||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||||
from onyx.agents.agent_search.basic.states import BasicState
|
from onyx.agents.agent_search.basic.states import BasicState
|
||||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||||
basic_use_tool_response,
|
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||||
)
|
|
||||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
|
||||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||||
prepare_tool_input,
|
prepare_tool_input,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||||
|
basic_use_tool_response,
|
||||||
|
)
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="llm_tool_choice",
|
node="choose_tool",
|
||||||
action=llm_tool_choice,
|
action=choose_tool,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="tool_call",
|
node="call_tool",
|
||||||
action=tool_call,
|
action=call_tool,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
|
|||||||
|
|
||||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||||
|
|
||||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
|
||||||
|
|
||||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
|
||||||
|
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
start_key="tool_call",
|
start_key="call_tool",
|
||||||
end_key="basic_use_tool_response",
|
end_key="basic_use_tool_response",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
|
|||||||
# If there are no tool calls, basic graph already streamed the answer
|
# If there are no tool calls, basic graph already streamed the answer
|
||||||
END
|
END
|
||||||
if state.tool_choice is None
|
if state.tool_choice is None
|
||||||
else "tool_call"
|
else "call_tool"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ logger = setup_logger()
|
|||||||
|
|
||||||
def route_initial_tool_choice(
|
def route_initial_tool_choice(
|
||||||
state: MainState, config: RunnableConfig
|
state: MainState, config: RunnableConfig
|
||||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
|
||||||
"""
|
"""
|
||||||
LangGraph edge to route to agent search.
|
LangGraph edge to route to agent search.
|
||||||
"""
|
"""
|
||||||
@ -38,7 +38,7 @@ def route_initial_tool_choice(
|
|||||||
):
|
):
|
||||||
return "start_agent_search"
|
return "start_agent_search"
|
||||||
else:
|
else:
|
||||||
return "tool_call"
|
return "call_tool"
|
||||||
else:
|
else:
|
||||||
return "logging_node"
|
return "logging_node"
|
||||||
|
|
||||||
|
@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
|||||||
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
|
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
|
||||||
answer_refined_query_graph_builder,
|
answer_refined_query_graph_builder,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||||
basic_use_tool_response,
|
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||||
)
|
|
||||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
|
||||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||||
prepare_tool_input,
|
prepare_tool_input,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||||
|
basic_use_tool_response,
|
||||||
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
|||||||
# Choose the initial tool
|
# Choose the initial tool
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="initial_tool_choice",
|
node="initial_tool_choice",
|
||||||
action=llm_tool_choice,
|
action=choose_tool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the tool, if required
|
# Call the tool, if required
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="tool_call",
|
node="call_tool",
|
||||||
action=tool_call,
|
action=call_tool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the tool response
|
# Use the tool response
|
||||||
@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
|||||||
graph.add_conditional_edges(
|
graph.add_conditional_edges(
|
||||||
"initial_tool_choice",
|
"initial_tool_choice",
|
||||||
route_initial_tool_choice,
|
route_initial_tool_choice,
|
||||||
["tool_call", "start_agent_search", "logging_node"],
|
["call_tool", "start_agent_search", "logging_node"],
|
||||||
)
|
)
|
||||||
|
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
start_key="tool_call",
|
start_key="call_tool",
|
||||||
end_key="basic_use_tool_response",
|
end_key="basic_use_tool_response",
|
||||||
)
|
)
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
|
@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
|
|||||||
write_custom_event("basic_response", packet, writer)
|
write_custom_event("basic_response", packet, writer)
|
||||||
|
|
||||||
|
|
||||||
def tool_call(
|
def call_tool(
|
||||||
state: ToolChoiceUpdate,
|
state: ToolChoiceUpdate,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
writer: StreamWriter = lambda _: None,
|
writer: StreamWriter = lambda _: None,
|
@ -25,7 +25,7 @@ logger = setup_logger()
|
|||||||
# and a function that handles extracting the necessary fields
|
# and a function that handles extracting the necessary fields
|
||||||
# from the state and config
|
# from the state and config
|
||||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||||
def llm_tool_choice(
|
def choose_tool(
|
||||||
state: ToolChoiceState,
|
state: ToolChoiceState,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
writer: StreamWriter = lambda _: None,
|
writer: StreamWriter = lambda _: None,
|
@ -27,6 +27,7 @@ from onyx.file_store.utils import InMemoryChatFile
|
|||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.tools.force import ForceUseTool
|
from onyx.tools.force import ForceUseTool
|
||||||
from onyx.tools.tool import Tool
|
from onyx.tools.tool import Tool
|
||||||
|
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
from onyx.tools.utils import explicit_tool_calling_supported
|
from onyx.tools.utils import explicit_tool_calling_supported
|
||||||
from onyx.utils.gpu_utils import gpu_status_request
|
from onyx.utils.gpu_utils import gpu_status_request
|
||||||
@ -89,6 +90,18 @@ class Answer:
|
|||||||
)
|
)
|
||||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||||
|
|
||||||
|
# TODO: this is a hack to force the query to be used for the search tool
|
||||||
|
# this should be removed once we fully unify graph inputs (i.e.
|
||||||
|
# remove SearchQuery entirely)
|
||||||
|
if (
|
||||||
|
force_use_tool.force_use
|
||||||
|
and search_tool
|
||||||
|
and force_use_tool.args
|
||||||
|
and force_use_tool.tool_name == search_tool.name
|
||||||
|
and QUERY_FIELD in force_use_tool.args
|
||||||
|
):
|
||||||
|
search_request.query = force_use_tool.args[QUERY_FIELD]
|
||||||
|
|
||||||
self.graph_inputs = GraphInputs(
|
self.graph_inputs = GraphInputs(
|
||||||
search_request=search_request,
|
search_request=search_request,
|
||||||
prompt_builder=prompt_builder,
|
prompt_builder=prompt_builder,
|
||||||
|
@ -7,7 +7,7 @@ from typing import cast
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException
|
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||||
from onyx.chat.answer import Answer
|
from onyx.chat.answer import Answer
|
||||||
from onyx.chat.chat_utils import create_chat_chain
|
from onyx.chat.chat_utils import create_chat_chain
|
||||||
from onyx.chat.chat_utils import create_temporary_persona
|
from onyx.chat.chat_utils import create_temporary_persona
|
||||||
|
@ -223,7 +223,7 @@ AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
|
|||||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||||
)
|
)
|
||||||
|
|
||||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 15 # in seconds
|
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
|
||||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
|
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
|
||||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
|
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
|
||||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||||
|
@ -58,6 +58,7 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
|||||||
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
||||||
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||||
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||||
|
QUERY_FIELD = "query"
|
||||||
|
|
||||||
|
|
||||||
class SearchResponseSummary(SearchQueryInfo):
|
class SearchResponseSummary(SearchQueryInfo):
|
||||||
@ -179,12 +180,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
QUERY_FIELD: {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "What to search for",
|
"description": "What to search for",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["query"],
|
"required": [QUERY_FIELD],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -223,7 +224,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
|||||||
rephrased_query = history_based_query_rephrase(
|
rephrased_query = history_based_query_rephrase(
|
||||||
query=query, history=history, llm=llm
|
query=query, history=history, llm=llm
|
||||||
)
|
)
|
||||||
return {"query": rephrased_query}
|
return {QUERY_FIELD: rephrased_query}
|
||||||
|
|
||||||
"""Actual tool execution"""
|
"""Actual tool execution"""
|
||||||
|
|
||||||
@ -279,7 +280,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
|||||||
def run(
|
def run(
|
||||||
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||||
) -> Generator[ToolResponse, None, None]:
|
) -> Generator[ToolResponse, None, None]:
|
||||||
query = cast(str, llm_kwargs["query"])
|
query = cast(str, llm_kwargs[QUERY_FIELD])
|
||||||
force_no_rerank = False
|
force_no_rerank = False
|
||||||
alternate_db_session = None
|
alternate_db_session = None
|
||||||
retrieved_sections_callback = None
|
retrieved_sections_callback = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user