mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-03 18:08:58 +02:00
missed files from prev commit
This commit is contained in:
@@ -0,0 +1,57 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
|
||||
|
||||
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
llm = agent_config.primary_llm
|
||||
tool_choice = state["tool_choice"]
|
||||
if tool_choice is None:
|
||||
raise ValueError("Tool choice is None")
|
||||
tool = tool_choice["tool"]
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
tool_call_summary = state["tool_call_summary"]
|
||||
tool_call_responses = state["tool_call_responses"]
|
||||
state["tool_call_final_result"]
|
||||
new_prompt_builder = tool.build_next_prompt(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_call_responses,
|
||||
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
initial_search_results = []
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID:
|
||||
search_contexts = yield_item.response.contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
|
||||
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
||||
|
||||
stream = llm.stream(
|
||||
prompt=new_prompt_builder.build(),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||
process_llm_stream(stream, True)
|
||||
|
||||
return BasicOutput()
|
134
backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py
Normal file
134
backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.states import ToolChoice
|
||||
from onyx.agents.agent_search.basic.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state["should_stream_answer"]
|
||||
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
llm = agent_config.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
||||
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
tools = agent_config.tools or []
|
||||
force_use_tool = agent_config.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
tool_name, tool_args = (
|
||||
force_use_tool.tool_name,
|
||||
force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not using_tool_calling_llm and tools:
|
||||
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=force_use_tool,
|
||||
tools=tools,
|
||||
prompt_builder=prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# If we have a tool and tool args, we are redy to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
),
|
||||
)
|
||||
|
||||
# if we're skipping gen ai answer generation, we should only
|
||||
# continue if we're forcing a tool call (which will be emitted by
|
||||
# the tool calling llm in the stream() below)
|
||||
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in tools] or None,
|
||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(stream, should_stream_answer)
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# TODO: here we could handle parallel tool calls. Right now
|
||||
# we just pick the first one that matches.
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in tool_message.tool_calls:
|
||||
known_tools_by_name = [
|
||||
tool for tool in tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"tools: {tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
raise ValueError(
|
||||
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
|
||||
)
|
||||
|
||||
logger.info(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
),
|
||||
)
|
69
backend/onyx/agents/agent_search/basic/nodes/tool_call.py
Normal file
69
backend/onyx/agents/agent_search/basic/nodes/tool_call.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.tool_runner import ToolRunner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def emit_packet(packet: AnswerPacket) -> None:
|
||||
dispatch_custom_event("basic_response", packet)
|
||||
|
||||
|
||||
# TODO: handle is_cancelled
|
||||
def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
# TODO: implement
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# Unnecessary now, node should only be called if there is a tool call
|
||||
# if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
# return
|
||||
|
||||
tool_choice = state["tool_choice"]
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice["tool"]
|
||||
tool_args = tool_choice["tool_args"]
|
||||
tool_id = tool_choice["id"]
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
# TODO: custom events for yields
|
||||
emit_packet(tool_kickoff)
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
emit_packet(response)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
emit_packet(tool_final_result)
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
return ToolCallUpdate(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
52
backend/onyx/agents/agent_search/basic/utils.py
Normal file
52
backend/onyx/agents/agent_search/basic/utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# TODO: handle citations here; below is what was previously passed in
|
||||
# see basic_use_tool_response.py for where these variables come from
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
stream: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
# for response in response_handler_manager.handle_llm_response(stream):
|
||||
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for response in stream:
|
||||
answer_piece = response.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# TODO: handle non-string content
|
||||
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
||||
answer_piece = str(answer_piece)
|
||||
|
||||
if isinstance(response, AIMessageChunk) and (
|
||||
response.tool_call_chunks or response.tool_calls
|
||||
):
|
||||
tool_call_chunk += response # type: ignore
|
||||
elif should_stream_answer:
|
||||
# TODO: handle emitting of CitationInfo
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
OnyxAnswerPiece(answer_piece=answer_piece),
|
||||
)
|
||||
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
Reference in New Issue
Block a user