diff --git a/backend/onyx/agents/agent_search/basic/graph_builder.py b/backend/onyx/agents/agent_search/basic/graph_builder.py index 411744ed7..6ed7c8963 100644 --- a/backend/onyx/agents/agent_search/basic/graph_builder.py +++ b/backend/onyx/agents/agent_search/basic/graph_builder.py @@ -5,11 +5,12 @@ from langgraph.graph import StateGraph from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import ( basic_use_tool_response, ) -from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice -from onyx.agents.agent_search.basic.nodes.tool_call import tool_call +from onyx.agents.agent_search.basic.nodes.prepare_tool_input import prepare_tool_input from onyx.agents.agent_search.basic.states import BasicInput from onyx.agents.agent_search.basic.states import BasicOutput from onyx.agents.agent_search.basic.states import BasicState +from onyx.agents.agent_search.orchestration.llm_tool_choice import llm_tool_choice +from onyx.agents.agent_search.orchestration.tool_call import tool_call from onyx.utils.logger import setup_logger logger = setup_logger() @@ -24,6 +25,11 @@ def basic_graph_builder() -> StateGraph: ### Add nodes ### + graph.add_node( + node="prepare_tool_input", + action=prepare_tool_input, + ) + graph.add_node( node="llm_tool_choice", action=llm_tool_choice, @@ -41,7 +47,9 @@ def basic_graph_builder() -> StateGraph: ### Add edges ### - graph.add_edge(start_key=START, end_key="llm_tool_choice") + graph.add_edge(start_key=START, end_key="prepare_tool_input") + + graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice") graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END]) @@ -62,10 +70,27 @@ def should_continue(state: BasicState) -> str: return ( # If there are no tool calls, basic graph already streamed the answer END - if state["tool_choice"] is None + if state.tool_choice is None else "tool_call" ) if __name__ == "__main__": - pass + from onyx.db.engine import get_session_context_manager + from onyx.context.search.models import SearchRequest + from onyx.llm.factory import get_default_llms + from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config + + graph = basic_graph_builder() + compiled_graph = graph.compile() + # TODO: unify basic input + input = BasicInput(logs="") + primary_llm, fast_llm = get_default_llms() + with get_session_context_manager() as db_session: + config, _ = get_test_config( + db_session=db_session, + primary_llm=primary_llm, + fast_llm=fast_llm, + search_request=SearchRequest(query="How does onyx use FastAPI?"), + ) + compiled_graph.invoke(input, config={"metadata": {"config": config}}) diff --git a/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py b/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py index c251db434..732e3e43a 100644 --- a/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py +++ b/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py @@ -13,20 +13,26 @@ from onyx.tools.tool_implementations.search.search_tool import ( from onyx.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, ) +from onyx.utils.logger import setup_logger + +logger = setup_logger() def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput: agent_config = cast(AgentSearchConfig, config["metadata"]["config"]) structured_response_format = agent_config.structured_response_format llm = agent_config.primary_llm - tool_choice = state["tool_choice"] + tool_choice = state.tool_choice if tool_choice is None: raise ValueError("Tool choice is None") - tool = tool_choice["tool"] + tool = tool_choice.tool prompt_builder = agent_config.prompt_builder - tool_call_summary = state["tool_call_summary"] - tool_call_responses = state["tool_call_responses"] - state["tool_call_final_result"] + if state.tool_call_output is None: + raise ValueError("Tool call output is None") + tool_call_output = state.tool_call_output + tool_call_summary = tool_call_output.tool_call_summary + tool_call_responses = tool_call_output.tool_call_responses + new_prompt_builder = tool.build_next_prompt( prompt_builder=prompt_builder, tool_call_summary=tool_call_summary, @@ -53,11 +59,11 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO ) # For now, we don't do multiple tool calls, so we ignore the tool_message - process_llm_stream( + new_tool_call_chunk = process_llm_stream( stream, True, final_search_results=final_search_results, displayed_search_results=initial_search_results, ) - return BasicOutput() + return BasicOutput(tool_call_chunk=new_tool_call_chunk) diff --git a/backend/onyx/agents/agent_search/basic/nodes/prepare_tool_input.py b/backend/onyx/agents/agent_search/basic/nodes/prepare_tool_input.py new file mode 100644 index 000000000..847e7c895 --- /dev/null +++ b/backend/onyx/agents/agent_search/basic/nodes/prepare_tool_input.py @@ -0,0 +1,15 @@ +from typing import cast + +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.basic.states import BasicState +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.orchestration.states import ToolChoiceInput + + +def prepare_tool_input(state: BasicState, config: RunnableConfig) -> ToolChoiceInput: + cast(AgentSearchConfig, config["metadata"]["config"]) + return ToolChoiceInput( + should_stream_answer=True, + prompt_snapshot=None, # uses default prompt builder + ) diff --git a/backend/onyx/agents/agent_search/basic/states.py b/backend/onyx/agents/agent_search/basic/states.py index b26564517..ee549b712 100644 --- a/backend/onyx/agents/agent_search/basic/states.py +++ b/backend/onyx/agents/agent_search/basic/states.py @@ -1,10 +1,11 @@ from typing import TypedDict -from onyx.tools.message import ToolCallSummary -from onyx.tools.models import ToolCallFinalResult -from onyx.tools.models import ToolCallKickoff -from onyx.tools.models import ToolResponse -from onyx.tools.tool import Tool +from langchain_core.messages import AIMessageChunk +from pydantic import BaseModel + +from onyx.agents.agent_search.orchestration.states import ToolCallUpdate +from onyx.agents.agent_search.orchestration.states import ToolChoiceInput +from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate # States contain values that change over the course of graph execution, # Config is for values that are set at the start and never change. @@ -14,33 +15,19 @@ from onyx.tools.tool import Tool ## Graph Input State -class BasicInput(TypedDict): - should_stream_answer: bool +class BasicInput(BaseModel): + # TODO: subclass global log update state + logs: str = "" ## Graph Output State class BasicOutput(TypedDict): - pass + tool_call_chunk: AIMessageChunk ## Update States -class ToolCallUpdate(TypedDict): - tool_call_summary: ToolCallSummary - tool_call_kickoff: ToolCallKickoff - tool_call_responses: list[ToolResponse] - tool_call_final_result: ToolCallFinalResult - - -class ToolChoice(TypedDict): - tool: Tool - tool_args: dict - id: str | None - - -class ToolChoiceUpdate(TypedDict): - tool_choice: ToolChoice | None ## Graph State @@ -48,8 +35,8 @@ class ToolChoiceUpdate(TypedDict): class BasicState( BasicInput, + ToolChoiceInput, ToolCallUpdate, ToolChoiceUpdate, - BasicOutput, ): pass diff --git a/backend/onyx/agents/agent_search/basic/utils.py b/backend/onyx/agents/agent_search/basic/utils.py index a62bf0369..28b1b26f1 100644 --- a/backend/onyx/agents/agent_search/basic/utils.py +++ b/backend/onyx/agents/agent_search/basic/utils.py @@ -43,6 +43,7 @@ def process_llm_stream( else: answer_handler = PassThroughAnswerResponseHandler() + full_answer = "" # This stream will be the llm answer if no tool is chosen. When a tool is chosen, # the stream will contain AIMessageChunks with tool call information. for response in stream: @@ -51,6 +52,7 @@ def process_llm_stream( # TODO: handle non-string content logger.warning(f"Received non-string content: {type(answer_piece)}") answer_piece = str(answer_piece) + full_answer += answer_piece if isinstance(response, AIMessageChunk) and ( response.tool_call_chunks or response.tool_calls @@ -64,4 +66,5 @@ def process_llm_stream( response_part, ) + logger.info(f"Full answer: {full_answer}") return cast(AIMessageChunk, tool_call_chunk) diff --git a/backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py b/backend/onyx/agents/agent_search/orchestration/llm_tool_choice.py similarity index 86% rename from backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py rename to backend/onyx/agents/agent_search/orchestration/llm_tool_choice.py index ffb210a81..43b8ce5ae 100644 --- a/backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py +++ b/backend/onyx/agents/agent_search/orchestration/llm_tool_choice.py @@ -4,11 +4,12 @@ from uuid import uuid4 from langchain_core.messages import ToolCall from langchain_core.runnables.config import RunnableConfig -from onyx.agents.agent_search.basic.states import BasicState -from onyx.agents.agent_search.basic.states import ToolChoice -from onyx.agents.agent_search.basic.states import ToolChoiceUpdate from onyx.agents.agent_search.basic.utils import process_llm_stream from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.orchestration.states import ToolChoice +from onyx.agents.agent_search.orchestration.states import ToolChoiceState +from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name from onyx.chat.tool_handling.tool_response_handler import ( get_tool_call_for_non_tool_calling_llm_impl, @@ -23,16 +24,17 @@ logger = setup_logger() # and a function that handles extracting the necessary fields # from the state and config # TODO: fan-out to multiple tool call nodes? Make this configurable? -def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate: +def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate: """ This node is responsible for calling the LLM to choose a tool. If no tool is chosen, The node MAY emit an answer, depending on whether state["should_stream_answer"] is set. """ - should_stream_answer = state["should_stream_answer"] + should_stream_answer = state.should_stream_answer agent_config = cast(AgentSearchConfig, config["metadata"]["config"]) using_tool_calling_llm = agent_config.using_tool_calling_llm - prompt_builder = agent_config.prompt_builder + prompt_builder = state.prompt_snapshot or agent_config.prompt_builder + llm = agent_config.primary_llm skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation @@ -78,12 +80,17 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda tool_choice=None, ) + built_prompt = ( + prompt_builder.build() + if isinstance(prompt_builder, AnswerPromptBuilder) + else prompt_builder.built_prompt + ) # At this point, we are either using a tool calling LLM or we are skipping the tool call. # DEBUG: good breakpoint stream = llm.stream( # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM # may choose to not call any tools and just generate the answer, in which case the task prompt is needed. - prompt=prompt_builder.build(), + prompt=built_prompt, tools=[tool.tool_definition() for tool in tools] or None, tool_choice=("required" if tools and force_use_tool.force_use else None), structured_response_format=structured_response_format, @@ -93,6 +100,7 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda # If no tool calls are emitted by the LLM, we should not choose a tool if len(tool_message.tool_calls) == 0: + logger.info("No tool calls emitted by LLM") return ToolChoiceUpdate( tool_choice=None, ) diff --git a/backend/onyx/agents/agent_search/orchestration/states.py b/backend/onyx/agents/agent_search/orchestration/states.py new file mode 100644 index 000000000..4ce8d6ecb --- /dev/null +++ b/backend/onyx/agents/agent_search/orchestration/states.py @@ -0,0 +1,43 @@ +from pydantic import BaseModel + +from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot +from onyx.tools.message import ToolCallSummary +from onyx.tools.models import ToolCallFinalResult +from onyx.tools.models import ToolCallKickoff +from onyx.tools.models import ToolResponse +from onyx.tools.tool import Tool + + +class ToolChoiceInput(BaseModel): + should_stream_answer: bool = True + # default to the prompt builder from the config, but + # allow overrides for arbitrary tool calls + prompt_snapshot: PromptSnapshot | None = None + + +class ToolCallOutput(BaseModel): + tool_call_summary: ToolCallSummary + tool_call_kickoff: ToolCallKickoff + tool_call_responses: list[ToolResponse] + tool_call_final_result: ToolCallFinalResult + + +class ToolCallUpdate(BaseModel): + tool_call_output: ToolCallOutput | None = None + + +class ToolChoice(BaseModel): + tool: Tool + tool_args: dict + id: str | None + + class Config: + arbitrary_types_allowed = True + + +class ToolChoiceUpdate(BaseModel): + tool_choice: ToolChoice | None = None + + +class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput): + pass diff --git a/backend/onyx/agents/agent_search/basic/nodes/tool_call.py b/backend/onyx/agents/agent_search/orchestration/tool_call.py similarity index 75% rename from backend/onyx/agents/agent_search/basic/nodes/tool_call.py rename to backend/onyx/agents/agent_search/orchestration/tool_call.py index 00f2c629b..bf238ca13 100644 --- a/backend/onyx/agents/agent_search/basic/nodes/tool_call.py +++ b/backend/onyx/agents/agent_search/orchestration/tool_call.py @@ -5,9 +5,10 @@ from langchain_core.messages import AIMessageChunk from langchain_core.messages.tool import ToolCall from langchain_core.runnables.config import RunnableConfig -from onyx.agents.agent_search.basic.states import BasicState -from onyx.agents.agent_search.basic.states import ToolCallUpdate from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.orchestration.states import ToolCallOutput +from onyx.agents.agent_search.orchestration.states import ToolCallUpdate +from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate from onyx.chat.models import AnswerPacket from onyx.tools.message import build_tool_message from onyx.tools.message import ToolCallSummary @@ -22,23 +23,18 @@ def emit_packet(packet: AnswerPacket) -> None: dispatch_custom_event("basic_response", packet) -# TODO: handle is_cancelled -def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate: +def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate: """Calls the tool specified in the state and updates the state with the result""" - # TODO: implement cast(AgentSearchConfig, config["metadata"]["config"]) - # Unnecessary now, node should only be called if there is a tool call - # if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls: - # return - tool_choice = state["tool_choice"] + tool_choice = state.tool_choice if tool_choice is None: raise ValueError("Cannot invoke tool call node without a tool choice") - tool = tool_choice["tool"] - tool_args = tool_choice["tool_args"] - tool_id = tool_choice["id"] + tool = tool_choice.tool + tool_args = tool_choice.tool_args + tool_id = tool_choice.id tool_runner = ToolRunner(tool, tool_args) tool_kickoff = tool_runner.kickoff() @@ -61,9 +57,10 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate: ), ) - return ToolCallUpdate( + tool_call_output = ToolCallOutput( tool_call_summary=tool_call_summary, tool_call_kickoff=tool_kickoff, tool_call_responses=tool_responses, tool_call_final_result=tool_final_result, ) + return ToolCallUpdate(tool_call_output=tool_call_output) diff --git a/backend/onyx/agents/agent_search/run_graph.py b/backend/onyx/agents/agent_search/run_graph.py index 6f460a30a..c94f8e984 100644 --- a/backend/onyx/agents/agent_search/run_graph.py +++ b/backend/onyx/agents/agent_search/run_graph.py @@ -122,8 +122,11 @@ def _manage_async_event_streaming( except (StopAsyncIteration, GeneratorExit): break finally: - for task in task_references.pop(): - task.cancel() + try: + for task in task_references.pop(): + task.cancel() + except StopAsyncIteration: + pass loop.close() return _yield_async_to_sync() @@ -186,9 +189,7 @@ def run_basic_graph( graph = basic_graph_builder() compiled_graph = graph.compile() # TODO: unify basic input - input = BasicInput( - should_stream_answer=True, - ) + input = BasicInput() return run_graph(compiled_graph, config, input) diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index c48c33a5b..ae62e0eec 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -228,6 +228,7 @@ def get_test_config( message_id=1, use_persistence=True, db_session=db_session, + tools=[search_tool], ) return config, search_tool diff --git a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py index 0ce876e20..c7cdec8f9 100644 --- a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py +++ b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py @@ -4,6 +4,7 @@ from typing import cast from langchain_core.messages import BaseMessage from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage +from pydantic import BaseModel from pydantic.v1 import BaseModel as BaseModel__v1 from onyx.chat.models import PromptConfig @@ -182,6 +183,13 @@ class AnswerPromptBuilder: ) +# Stores some parts of a prompt builder as needed for tool calls +class PromptSnapshot(BaseModel): + raw_message_history: list[PreviousMessage] + raw_user_query: str + built_prompt: list[BaseMessage] + + # TODO: rename this? AnswerConfig maybe? class LLMCall(BaseModel__v1): prompt_builder: AnswerPromptBuilder diff --git a/backend/onyx/chat/stream_processing/answer_response_handler.py b/backend/onyx/chat/stream_processing/answer_response_handler.py index 9c90c2e22..055011ae3 100644 --- a/backend/onyx/chat/stream_processing/answer_response_handler.py +++ b/backend/onyx/chat/stream_processing/answer_response_handler.py @@ -5,10 +5,10 @@ from typing import cast from langchain_core.messages import BaseMessage -from onyx.chat.llm_response_handler import ResponsePart from onyx.chat.models import CitationInfo from onyx.chat.models import LlmDoc from onyx.chat.models import OnyxAnswerPiece +from onyx.chat.models import ResponsePart from onyx.chat.stream_processing.citation_processing import CitationProcessor from onyx.chat.stream_processing.utils import DocumentIdOrderMapping from onyx.utils.logger import setup_logger diff --git a/backend/onyx/chat/tool_handling/tool_response_handler.py b/backend/onyx/chat/tool_handling/tool_response_handler.py index 8a6cce953..21f4830aa 100644 --- a/backend/onyx/chat/tool_handling/tool_response_handler.py +++ b/backend/onyx/chat/tool_handling/tool_response_handler.py @@ -7,6 +7,7 @@ from langchain_core.messages import ToolCall from onyx.chat.models import ResponsePart from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall +from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot from onyx.llm.interfaces import LLM from onyx.tools.force import ForceUseTool from onyx.tools.message import build_tool_message @@ -158,7 +159,7 @@ class ToolResponseHandler: def get_tool_call_for_non_tool_calling_llm_impl( force_use_tool: ForceUseTool, tools: list[Tool], - prompt_builder: AnswerPromptBuilder, + prompt_builder: AnswerPromptBuilder | PromptSnapshot, llm: LLM, ) -> tuple[Tool, dict] | None: if force_use_tool.force_use: diff --git a/backend/onyx/tools/message.py b/backend/onyx/tools/message.py index 659f38731..bb71d56be 100644 --- a/backend/onyx/tools/message.py +++ b/backend/onyx/tools/message.py @@ -4,7 +4,7 @@ from typing import Any from langchain_core.messages.ai import AIMessage from langchain_core.messages.tool import ToolCall from langchain_core.messages.tool import ToolMessage -from pydantic.v1 import BaseModel as BaseModel__v1 +from pydantic import BaseModel from onyx.natural_language_processing.utils import BaseTokenizer @@ -21,7 +21,8 @@ def build_tool_message( ) -class ToolCallSummary(BaseModel__v1): +# TODO: does this NEED to be BaseModel__v1? +class ToolCallSummary(BaseModel): tool_call_request: AIMessage tool_call_result: ToolMessage