mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-02 11:09:20 +02:00
* Fix follow ups in thread + fix user name * Add back single history str * Remove newline
208 lines
7.5 KiB
Python
208 lines
7.5 KiB
Python
from collections.abc import Generator
|
|
|
|
from langchain_core.messages import AIMessageChunk
|
|
from langchain_core.messages import BaseMessage
|
|
from langchain_core.messages import ToolCall
|
|
|
|
from onyx.chat.models import ResponsePart
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.tools.force import ForceUseTool
|
|
from onyx.tools.message import build_tool_message
|
|
from onyx.tools.message import ToolCallSummary
|
|
from onyx.tools.models import ToolCallFinalResult
|
|
from onyx.tools.models import ToolCallKickoff
|
|
from onyx.tools.models import ToolResponse
|
|
from onyx.tools.tool import Tool
|
|
from onyx.tools.tool_runner import (
|
|
check_which_tools_should_run_for_non_tool_calling_llm,
|
|
)
|
|
from onyx.tools.tool_runner import ToolRunner
|
|
from onyx.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
class ToolResponseHandler:
|
|
def __init__(self, tools: list[Tool]):
|
|
self.tools = tools
|
|
|
|
self.tool_call_chunk: AIMessageChunk | None = None
|
|
self.tool_call_requests: list[ToolCall] = []
|
|
|
|
self.tool_runner: ToolRunner | None = None
|
|
self.tool_call_summary: ToolCallSummary | None = None
|
|
|
|
self.tool_kickoff: ToolCallKickoff | None = None
|
|
self.tool_responses: list[ToolResponse] = []
|
|
self.tool_final_result: ToolCallFinalResult | None = None
|
|
|
|
@classmethod
|
|
def get_tool_call_for_non_tool_calling_llm(
|
|
cls, llm_call: LLMCall, llm: LLM
|
|
) -> tuple[Tool, dict] | None:
|
|
if llm_call.force_use_tool.force_use:
|
|
# if we are forcing a tool, we don't need to check which tools to run
|
|
tool = next(
|
|
(
|
|
t
|
|
for t in llm_call.tools
|
|
if t.name == llm_call.force_use_tool.tool_name
|
|
),
|
|
None,
|
|
)
|
|
if not tool:
|
|
raise RuntimeError(
|
|
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
|
|
)
|
|
|
|
tool_args = (
|
|
llm_call.force_use_tool.args
|
|
if llm_call.force_use_tool.args is not None
|
|
else tool.get_args_for_non_tool_calling_llm(
|
|
query=llm_call.prompt_builder.raw_user_query,
|
|
history=llm_call.prompt_builder.raw_message_history,
|
|
llm=llm,
|
|
force_run=True,
|
|
)
|
|
)
|
|
|
|
if tool_args is None:
|
|
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
|
|
|
return (tool, tool_args)
|
|
else:
|
|
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
|
tools=llm_call.tools,
|
|
query=llm_call.prompt_builder.raw_user_query,
|
|
history=llm_call.prompt_builder.raw_message_history,
|
|
llm=llm,
|
|
)
|
|
|
|
available_tools_and_args = [
|
|
(llm_call.tools[ind], args)
|
|
for ind, args in enumerate(tool_options)
|
|
if args is not None
|
|
]
|
|
|
|
logger.info(
|
|
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
|
)
|
|
|
|
chosen_tool_and_args = (
|
|
select_single_tool_for_non_tool_calling_llm(
|
|
tools_and_args=available_tools_and_args,
|
|
history=llm_call.prompt_builder.raw_message_history,
|
|
query=llm_call.prompt_builder.raw_user_query,
|
|
llm=llm,
|
|
)
|
|
if available_tools_and_args
|
|
else None
|
|
)
|
|
|
|
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
|
return chosen_tool_and_args
|
|
|
|
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
|
|
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
|
return
|
|
|
|
self.tool_call_requests = self.tool_call_chunk.tool_calls
|
|
|
|
selected_tool: Tool | None = None
|
|
selected_tool_call_request: ToolCall | None = None
|
|
for tool_call_request in self.tool_call_requests:
|
|
known_tools_by_name = [
|
|
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
|
]
|
|
|
|
if not known_tools_by_name:
|
|
logger.error(
|
|
"Tool call requested with unknown name field. \n"
|
|
f"self.tools: {self.tools}"
|
|
f"tool_call_request: {tool_call_request}"
|
|
)
|
|
continue
|
|
else:
|
|
selected_tool = known_tools_by_name[0]
|
|
selected_tool_call_request = tool_call_request
|
|
|
|
if selected_tool and selected_tool_call_request:
|
|
break
|
|
|
|
if not selected_tool or not selected_tool_call_request:
|
|
return
|
|
|
|
logger.info(f"Selected tool: {selected_tool.name}")
|
|
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
|
self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"])
|
|
self.tool_kickoff = self.tool_runner.kickoff()
|
|
yield self.tool_kickoff
|
|
|
|
for response in self.tool_runner.tool_responses():
|
|
self.tool_responses.append(response)
|
|
yield response
|
|
|
|
self.tool_final_result = self.tool_runner.tool_final_result()
|
|
yield self.tool_final_result
|
|
|
|
self.tool_call_summary = ToolCallSummary(
|
|
tool_call_request=self.tool_call_chunk,
|
|
tool_call_result=build_tool_message(
|
|
selected_tool_call_request, self.tool_runner.tool_message_content()
|
|
),
|
|
)
|
|
|
|
def handle_response_part(
|
|
self,
|
|
response_item: BaseMessage | None,
|
|
previous_response_items: list[BaseMessage],
|
|
) -> Generator[ResponsePart, None, None]:
|
|
if response_item is None:
|
|
yield from self._handle_tool_call()
|
|
|
|
if isinstance(response_item, AIMessageChunk) and (
|
|
response_item.tool_call_chunks or response_item.tool_calls
|
|
):
|
|
if self.tool_call_chunk is None:
|
|
self.tool_call_chunk = response_item
|
|
else:
|
|
self.tool_call_chunk += response_item # type: ignore
|
|
|
|
return
|
|
|
|
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
|
|
if (
|
|
self.tool_runner is None
|
|
or self.tool_call_summary is None
|
|
or self.tool_kickoff is None
|
|
or self.tool_final_result is None
|
|
):
|
|
return None
|
|
|
|
tool_runner = self.tool_runner
|
|
new_prompt_builder = tool_runner.tool.build_next_prompt(
|
|
prompt_builder=current_llm_call.prompt_builder,
|
|
tool_call_summary=self.tool_call_summary,
|
|
tool_responses=self.tool_responses,
|
|
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
|
)
|
|
return LLMCall(
|
|
prompt_builder=new_prompt_builder,
|
|
tools=[], # for now, only allow one tool call per response
|
|
force_use_tool=ForceUseTool(
|
|
force_use=False,
|
|
tool_name="",
|
|
args=None,
|
|
),
|
|
files=current_llm_call.files,
|
|
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
|
tool_call_info=[
|
|
self.tool_kickoff,
|
|
*self.tool_responses,
|
|
self.tool_final_result,
|
|
],
|
|
)
|