From 50bacc03b310e428ce86f301bbb0021ae9f58e80 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Wed, 22 Jan 2025 16:25:09 -0800 Subject: [PATCH] missed files from prev commit --- .../basic/nodes/basic_use_tool_response.py | 57 ++++++++ .../basic/nodes/llm_tool_choice.py | 134 ++++++++++++++++++ .../agent_search/basic/nodes/tool_call.py | 69 +++++++++ .../onyx/agents/agent_search/basic/utils.py | 52 +++++++ 4 files changed, 312 insertions(+) create mode 100644 backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py create mode 100644 backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py create mode 100644 backend/onyx/agents/agent_search/basic/nodes/tool_call.py create mode 100644 backend/onyx/agents/agent_search/basic/utils.py diff --git a/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py b/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py new file mode 100644 index 000000000000..81ca3cab7e5a --- /dev/null +++ b/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py @@ -0,0 +1,57 @@ +from typing import cast + +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.basic.states import BasicOutput +from onyx.agents.agent_search.basic.states import BasicState +from onyx.agents.agent_search.basic.utils import process_llm_stream +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.chat.models import LlmDoc +from onyx.tools.tool_implementations.search_like_tool_utils import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) +from onyx.tools.tool_implementations.search_like_tool_utils import ( + ORIGINAL_CONTEXT_DOCUMENTS_ID, +) + + +def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput: + agent_config = cast(AgentSearchConfig, config["metadata"]["config"]) + structured_response_format = agent_config.structured_response_format + llm = agent_config.primary_llm + tool_choice = state["tool_choice"] + if tool_choice is None: + raise ValueError("Tool choice is None") + tool = tool_choice["tool"] + prompt_builder = agent_config.prompt_builder + tool_call_summary = state["tool_call_summary"] + tool_call_responses = state["tool_call_responses"] + state["tool_call_final_result"] + new_prompt_builder = tool.build_next_prompt( + prompt_builder=prompt_builder, + tool_call_summary=tool_call_summary, + tool_responses=tool_call_responses, + using_tool_calling_llm=agent_config.using_tool_calling_llm, + ) + + initial_search_results = [] + for yield_item in tool_call_responses: + if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID: + cast(list[LlmDoc], yield_item.response) + elif yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID: + search_contexts = yield_item.response.contexts + for doc in search_contexts: + if doc.document_id not in initial_search_results: + initial_search_results.append(doc) + + initial_search_results = cast(list[LlmDoc], initial_search_results) + + stream = llm.stream( + prompt=new_prompt_builder.build(), + structured_response_format=structured_response_format, + ) + + # For now, we don't do multiple tool calls, so we ignore the tool_message + process_llm_stream(stream, True) + + return BasicOutput() diff --git a/backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py b/backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py new file mode 100644 index 000000000000..ffb210a819b6 --- /dev/null +++ b/backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py @@ -0,0 +1,134 @@ +from typing import cast +from uuid import uuid4 + +from langchain_core.messages import ToolCall +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.basic.states import BasicState +from onyx.agents.agent_search.basic.states import ToolChoice +from onyx.agents.agent_search.basic.states import ToolChoiceUpdate +from onyx.agents.agent_search.basic.utils import process_llm_stream +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name +from onyx.chat.tool_handling.tool_response_handler import ( + get_tool_call_for_non_tool_calling_llm_impl, +) +from onyx.tools.tool import Tool +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +# TODO: break this out into an implementation function +# and a function that handles extracting the necessary fields +# from the state and config +# TODO: fan-out to multiple tool call nodes? Make this configurable? +def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate: + """ + This node is responsible for calling the LLM to choose a tool. If no tool is chosen, + The node MAY emit an answer, depending on whether state["should_stream_answer"] is set. + """ + should_stream_answer = state["should_stream_answer"] + + agent_config = cast(AgentSearchConfig, config["metadata"]["config"]) + using_tool_calling_llm = agent_config.using_tool_calling_llm + prompt_builder = agent_config.prompt_builder + llm = agent_config.primary_llm + skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation + + structured_response_format = agent_config.structured_response_format + tools = agent_config.tools or [] + force_use_tool = agent_config.force_use_tool + + tool, tool_args = None, None + if force_use_tool.force_use and force_use_tool.args is not None: + tool_name, tool_args = ( + force_use_tool.tool_name, + force_use_tool.args, + ) + tool = get_tool_by_name(tools, tool_name) + + # special pre-logic for non-tool calling LLM case + elif not using_tool_calling_llm and tools: + chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl( + force_use_tool=force_use_tool, + tools=tools, + prompt_builder=prompt_builder, + llm=llm, + ) + if chosen_tool_and_args: + tool, tool_args = chosen_tool_and_args + + # If we have a tool and tool args, we are redy to request a tool call. + # This only happens if the tool call was forced or we are using a non-tool calling LLM. + if tool and tool_args: + return ToolChoiceUpdate( + tool_choice=ToolChoice( + tool=tool, + tool_args=tool_args, + id=str(uuid4()), + ), + ) + + # if we're skipping gen ai answer generation, we should only + # continue if we're forcing a tool call (which will be emitted by + # the tool calling llm in the stream() below) + if skip_gen_ai_answer_generation and not force_use_tool.force_use: + return ToolChoiceUpdate( + tool_choice=None, + ) + + # At this point, we are either using a tool calling LLM or we are skipping the tool call. + # DEBUG: good breakpoint + stream = llm.stream( + # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM + # may choose to not call any tools and just generate the answer, in which case the task prompt is needed. + prompt=prompt_builder.build(), + tools=[tool.tool_definition() for tool in tools] or None, + tool_choice=("required" if tools and force_use_tool.force_use else None), + structured_response_format=structured_response_format, + ) + + tool_message = process_llm_stream(stream, should_stream_answer) + + # If no tool calls are emitted by the LLM, we should not choose a tool + if len(tool_message.tool_calls) == 0: + return ToolChoiceUpdate( + tool_choice=None, + ) + + # TODO: here we could handle parallel tool calls. Right now + # we just pick the first one that matches. + selected_tool: Tool | None = None + selected_tool_call_request: ToolCall | None = None + for tool_call_request in tool_message.tool_calls: + known_tools_by_name = [ + tool for tool in tools if tool.name == tool_call_request["name"] + ] + + if known_tools_by_name: + selected_tool = known_tools_by_name[0] + selected_tool_call_request = tool_call_request + break + + logger.error( + "Tool call requested with unknown name field. \n" + f"tools: {tools}" + f"tool_call_request: {tool_call_request}" + ) + + if not selected_tool or not selected_tool_call_request: + raise ValueError( + f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}" + ) + + logger.info(f"Selected tool: {selected_tool.name}") + logger.debug(f"Selected tool call request: {selected_tool_call_request}") + + return ToolChoiceUpdate( + tool_choice=ToolChoice( + tool=selected_tool, + tool_args=selected_tool_call_request["args"], + id=selected_tool_call_request["id"], + ), + ) diff --git a/backend/onyx/agents/agent_search/basic/nodes/tool_call.py b/backend/onyx/agents/agent_search/basic/nodes/tool_call.py new file mode 100644 index 000000000000..00f2c629b343 --- /dev/null +++ b/backend/onyx/agents/agent_search/basic/nodes/tool_call.py @@ -0,0 +1,69 @@ +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.messages import AIMessageChunk +from langchain_core.messages.tool import ToolCall +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.basic.states import BasicState +from onyx.agents.agent_search.basic.states import ToolCallUpdate +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.chat.models import AnswerPacket +from onyx.tools.message import build_tool_message +from onyx.tools.message import ToolCallSummary +from onyx.tools.tool_runner import ToolRunner +from onyx.utils.logger import setup_logger + + +logger = setup_logger() + + +def emit_packet(packet: AnswerPacket) -> None: + dispatch_custom_event("basic_response", packet) + + +# TODO: handle is_cancelled +def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate: + """Calls the tool specified in the state and updates the state with the result""" + # TODO: implement + + cast(AgentSearchConfig, config["metadata"]["config"]) + # Unnecessary now, node should only be called if there is a tool call + # if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls: + # return + + tool_choice = state["tool_choice"] + if tool_choice is None: + raise ValueError("Cannot invoke tool call node without a tool choice") + + tool = tool_choice["tool"] + tool_args = tool_choice["tool_args"] + tool_id = tool_choice["id"] + tool_runner = ToolRunner(tool, tool_args) + tool_kickoff = tool_runner.kickoff() + + # TODO: custom events for yields + emit_packet(tool_kickoff) + + tool_responses = [] + for response in tool_runner.tool_responses(): + tool_responses.append(response) + emit_packet(response) + + tool_final_result = tool_runner.tool_final_result() + emit_packet(tool_final_result) + + tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id) + tool_call_summary = ToolCallSummary( + tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]), + tool_call_result=build_tool_message( + tool_call, tool_runner.tool_message_content() + ), + ) + + return ToolCallUpdate( + tool_call_summary=tool_call_summary, + tool_call_kickoff=tool_kickoff, + tool_call_responses=tool_responses, + tool_call_final_result=tool_final_result, + ) diff --git a/backend/onyx/agents/agent_search/basic/utils.py b/backend/onyx/agents/agent_search/basic/utils.py new file mode 100644 index 000000000000..7257770ca79e --- /dev/null +++ b/backend/onyx/agents/agent_search/basic/utils.py @@ -0,0 +1,52 @@ +from collections.abc import Iterator +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.messages import AIMessageChunk +from langchain_core.messages import BaseMessage + +from onyx.chat.models import LlmDoc +from onyx.chat.models import OnyxAnswerPiece +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +# TODO: handle citations here; below is what was previously passed in +# see basic_use_tool_response.py for where these variables come from +# answer_handler = CitationResponseHandler( +# context_docs=final_search_results, +# final_doc_id_to_rank_map=map_document_id_order(final_search_results), +# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results), +# ) + + +def process_llm_stream( + stream: Iterator[BaseMessage], + should_stream_answer: bool, + final_search_results: list[LlmDoc] | None = None, + displayed_search_results: list[LlmDoc] | None = None, +) -> AIMessageChunk: + tool_call_chunk = AIMessageChunk(content="") + # for response in response_handler_manager.handle_llm_response(stream): + + # This stream will be the llm answer if no tool is chosen. When a tool is chosen, + # the stream will contain AIMessageChunks with tool call information. + for response in stream: + answer_piece = response.content + if not isinstance(answer_piece, str): + # TODO: handle non-string content + logger.warning(f"Received non-string content: {type(answer_piece)}") + answer_piece = str(answer_piece) + + if isinstance(response, AIMessageChunk) and ( + response.tool_call_chunks or response.tool_calls + ): + tool_call_chunk += response # type: ignore + elif should_stream_answer: + # TODO: handle emitting of CitationInfo + dispatch_custom_event( + "basic_response", + OnyxAnswerPiece(answer_piece=answer_piece), + ) + + return cast(AIMessageChunk, tool_call_chunk)