WIP, but working basic search using initial tool choice node

This commit is contained in:
Evan Lohn
2025-01-23 16:20:17 -08:00
parent 4b0a4a2741
commit 982040c792
14 changed files with 161 additions and 65 deletions

View File

@@ -5,11 +5,12 @@ from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import ( from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
basic_use_tool_response, basic_use_tool_response,
) )
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice from onyx.agents.agent_search.basic.nodes.prepare_tool_input import prepare_tool_input
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
from onyx.agents.agent_search.basic.states import BasicInput from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.tool_call import tool_call
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@@ -24,6 +25,11 @@ def basic_graph_builder() -> StateGraph:
### Add nodes ### ### Add nodes ###
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node( graph.add_node(
node="llm_tool_choice", node="llm_tool_choice",
action=llm_tool_choice, action=llm_tool_choice,
@@ -41,7 +47,9 @@ def basic_graph_builder() -> StateGraph:
### Add edges ### ### Add edges ###
graph.add_edge(start_key=START, end_key="llm_tool_choice") graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END]) graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
@@ -62,10 +70,27 @@ def should_continue(state: BasicState) -> str:
return ( return (
# If there are no tool calls, basic graph already streamed the answer # If there are no tool calls, basic graph already streamed the answer
END END
if state["tool_choice"] is None if state.tool_choice is None
else "tool_call" else "tool_call"
) )
if __name__ == "__main__": if __name__ == "__main__":
pass from onyx.db.engine import get_session_context_manager
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput(logs="")
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_request=SearchRequest(query="How does onyx use FastAPI?"),
)
compiled_graph.invoke(input, config={"metadata": {"config": config}})

View File

@@ -13,20 +13,26 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search_like_tool_utils import ( from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID, FINAL_CONTEXT_DOCUMENTS_ID,
) )
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput: def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
structured_response_format = agent_config.structured_response_format structured_response_format = agent_config.structured_response_format
llm = agent_config.primary_llm llm = agent_config.primary_llm
tool_choice = state["tool_choice"] tool_choice = state.tool_choice
if tool_choice is None: if tool_choice is None:
raise ValueError("Tool choice is None") raise ValueError("Tool choice is None")
tool = tool_choice["tool"] tool = tool_choice.tool
prompt_builder = agent_config.prompt_builder prompt_builder = agent_config.prompt_builder
tool_call_summary = state["tool_call_summary"] if state.tool_call_output is None:
tool_call_responses = state["tool_call_responses"] raise ValueError("Tool call output is None")
state["tool_call_final_result"] tool_call_output = state.tool_call_output
tool_call_summary = tool_call_output.tool_call_summary
tool_call_responses = tool_call_output.tool_call_responses
new_prompt_builder = tool.build_next_prompt( new_prompt_builder = tool.build_next_prompt(
prompt_builder=prompt_builder, prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary, tool_call_summary=tool_call_summary,
@@ -53,11 +59,11 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
) )
# For now, we don't do multiple tool calls, so we ignore the tool_message # For now, we don't do multiple tool calls, so we ignore the tool_message
process_llm_stream( new_tool_call_chunk = process_llm_stream(
stream, stream,
True, True,
final_search_results=final_search_results, final_search_results=final_search_results,
displayed_search_results=initial_search_results, displayed_search_results=initial_search_results,
) )
return BasicOutput() return BasicOutput(tool_call_chunk=new_tool_call_chunk)

View File

@@ -0,0 +1,15 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
def prepare_tool_input(state: BasicState, config: RunnableConfig) -> ToolChoiceInput:
cast(AgentSearchConfig, config["metadata"]["config"])
return ToolChoiceInput(
should_stream_answer=True,
prompt_snapshot=None, # uses default prompt builder
)

View File

@@ -1,10 +1,11 @@
from typing import TypedDict from typing import TypedDict
from onyx.tools.message import ToolCallSummary from langchain_core.messages import AIMessageChunk
from onyx.tools.models import ToolCallFinalResult from pydantic import BaseModel
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.tools.tool import Tool from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# States contain values that change over the course of graph execution, # States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change. # Config is for values that are set at the start and never change.
@@ -14,33 +15,19 @@ from onyx.tools.tool import Tool
## Graph Input State ## Graph Input State
class BasicInput(TypedDict): class BasicInput(BaseModel):
should_stream_answer: bool # TODO: subclass global log update state
logs: str = ""
## Graph Output State ## Graph Output State
class BasicOutput(TypedDict): class BasicOutput(TypedDict):
pass tool_call_chunk: AIMessageChunk
## Update States ## Update States
class ToolCallUpdate(TypedDict):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
class ToolChoice(TypedDict):
tool: Tool
tool_args: dict
id: str | None
class ToolChoiceUpdate(TypedDict):
tool_choice: ToolChoice | None
## Graph State ## Graph State
@@ -48,8 +35,8 @@ class ToolChoiceUpdate(TypedDict):
class BasicState( class BasicState(
BasicInput, BasicInput,
ToolChoiceInput,
ToolCallUpdate, ToolCallUpdate,
ToolChoiceUpdate, ToolChoiceUpdate,
BasicOutput,
): ):
pass pass

View File

@@ -43,6 +43,7 @@ def process_llm_stream(
else: else:
answer_handler = PassThroughAnswerResponseHandler() answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen, # This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information. # the stream will contain AIMessageChunks with tool call information.
for response in stream: for response in stream:
@@ -51,6 +52,7 @@ def process_llm_stream(
# TODO: handle non-string content # TODO: handle non-string content
logger.warning(f"Received non-string content: {type(answer_piece)}") logger.warning(f"Received non-string content: {type(answer_piece)}")
answer_piece = str(answer_piece) answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(response, AIMessageChunk) and ( if isinstance(response, AIMessageChunk) and (
response.tool_call_chunks or response.tool_calls response.tool_call_chunks or response.tool_calls
@@ -64,4 +66,5 @@ def process_llm_stream(
response_part, response_part,
) )
logger.info(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk) return cast(AIMessageChunk, tool_call_chunk)

View File

@@ -4,11 +4,12 @@ from uuid import uuid4
from langchain_core.messages import ToolCall from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.states import ToolChoice
from onyx.agents.agent_search.basic.states import ToolChoiceUpdate
from onyx.agents.agent_search.basic.utils import process_llm_stream from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import ( from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl, get_tool_call_for_non_tool_calling_llm_impl,
@@ -23,16 +24,17 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields # and a function that handles extracting the necessary fields
# from the state and config # from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable? # TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate: def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate:
""" """
This node is responsible for calling the LLM to choose a tool. If no tool is chosen, This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set. The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
""" """
should_stream_answer = state["should_stream_answer"] should_stream_answer = state.should_stream_answer
agent_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
using_tool_calling_llm = agent_config.using_tool_calling_llm using_tool_calling_llm = agent_config.using_tool_calling_llm
prompt_builder = agent_config.prompt_builder prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
llm = agent_config.primary_llm llm = agent_config.primary_llm
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
@@ -78,12 +80,17 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda
tool_choice=None, tool_choice=None,
) )
built_prompt = (
prompt_builder.build()
if isinstance(prompt_builder, AnswerPromptBuilder)
else prompt_builder.built_prompt
)
# At this point, we are either using a tool calling LLM or we are skipping the tool call. # At this point, we are either using a tool calling LLM or we are skipping the tool call.
# DEBUG: good breakpoint # DEBUG: good breakpoint
stream = llm.stream( stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed. # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=prompt_builder.build(), prompt=built_prompt,
tools=[tool.tool_definition() for tool in tools] or None, tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None), tool_choice=("required" if tools and force_use_tool.force_use else None),
structured_response_format=structured_response_format, structured_response_format=structured_response_format,
@@ -93,6 +100,7 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda
# If no tool calls are emitted by the LLM, we should not choose a tool # If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0: if len(tool_message.tool_calls) == 0:
logger.info("No tool calls emitted by LLM")
return ToolChoiceUpdate( return ToolChoiceUpdate(
tool_choice=None, tool_choice=None,
) )

View File

@@ -0,0 +1,43 @@
from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
class ToolChoiceInput(BaseModel):
should_stream_answer: bool = True
# default to the prompt builder from the config, but
# allow overrides for arbitrary tool calls
prompt_snapshot: PromptSnapshot | None = None
class ToolCallOutput(BaseModel):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
class ToolCallUpdate(BaseModel):
tool_call_output: ToolCallOutput | None = None
class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
class Config:
arbitrary_types_allowed = True
class ToolChoiceUpdate(BaseModel):
tool_choice: ToolChoice | None = None
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
pass

View File

@@ -5,9 +5,10 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.states import ToolCallUpdate
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.models import AnswerPacket from onyx.chat.models import AnswerPacket
from onyx.tools.message import build_tool_message from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary from onyx.tools.message import ToolCallSummary
@@ -22,23 +23,18 @@ def emit_packet(packet: AnswerPacket) -> None:
dispatch_custom_event("basic_response", packet) dispatch_custom_event("basic_response", packet)
# TODO: handle is_cancelled def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate:
def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result""" """Calls the tool specified in the state and updates the state with the result"""
# TODO: implement
cast(AgentSearchConfig, config["metadata"]["config"]) cast(AgentSearchConfig, config["metadata"]["config"])
# Unnecessary now, node should only be called if there is a tool call
# if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
# return
tool_choice = state["tool_choice"] tool_choice = state.tool_choice
if tool_choice is None: if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice") raise ValueError("Cannot invoke tool call node without a tool choice")
tool = tool_choice["tool"] tool = tool_choice.tool
tool_args = tool_choice["tool_args"] tool_args = tool_choice.tool_args
tool_id = tool_choice["id"] tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args) tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff() tool_kickoff = tool_runner.kickoff()
@@ -61,9 +57,10 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
), ),
) )
return ToolCallUpdate( tool_call_output = ToolCallOutput(
tool_call_summary=tool_call_summary, tool_call_summary=tool_call_summary,
tool_call_kickoff=tool_kickoff, tool_call_kickoff=tool_kickoff,
tool_call_responses=tool_responses, tool_call_responses=tool_responses,
tool_call_final_result=tool_final_result, tool_call_final_result=tool_final_result,
) )
return ToolCallUpdate(tool_call_output=tool_call_output)

View File

@@ -122,8 +122,11 @@ def _manage_async_event_streaming(
except (StopAsyncIteration, GeneratorExit): except (StopAsyncIteration, GeneratorExit):
break break
finally: finally:
try:
for task in task_references.pop(): for task in task_references.pop():
task.cancel() task.cancel()
except StopAsyncIteration:
pass
loop.close() loop.close()
return _yield_async_to_sync() return _yield_async_to_sync()
@@ -186,9 +189,7 @@ def run_basic_graph(
graph = basic_graph_builder() graph = basic_graph_builder()
compiled_graph = graph.compile() compiled_graph = graph.compile()
# TODO: unify basic input # TODO: unify basic input
input = BasicInput( input = BasicInput()
should_stream_answer=True,
)
return run_graph(compiled_graph, config, input) return run_graph(compiled_graph, config, input)

View File

@@ -228,6 +228,7 @@ def get_test_config(
message_id=1, message_id=1,
use_persistence=True, use_persistence=True,
db_session=db_session, db_session=db_session,
tools=[search_tool],
) )
return config, search_tool return config, search_tool

View File

@@ -4,6 +4,7 @@ from typing import cast
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModel__v1 from pydantic.v1 import BaseModel as BaseModel__v1
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
@@ -182,6 +183,13 @@ class AnswerPromptBuilder:
) )
# Stores some parts of a prompt builder as needed for tool calls
class PromptSnapshot(BaseModel):
raw_message_history: list[PreviousMessage]
raw_user_query: str
built_prompt: list[BaseMessage]
# TODO: rename this? AnswerConfig maybe? # TODO: rename this? AnswerConfig maybe?
class LLMCall(BaseModel__v1): class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder prompt_builder: AnswerPromptBuilder

View File

@@ -5,10 +5,10 @@ from typing import cast
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from onyx.chat.llm_response_handler import ResponsePart
from onyx.chat.models import CitationInfo from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import ResponsePart
from onyx.chat.stream_processing.citation_processing import CitationProcessor from onyx.chat.stream_processing.citation_processing import CitationProcessor
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger

View File

@@ -7,6 +7,7 @@ from langchain_core.messages import ToolCall
from onyx.chat.models import ResponsePart from onyx.chat.models import ResponsePart
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool from onyx.tools.force import ForceUseTool
from onyx.tools.message import build_tool_message from onyx.tools.message import build_tool_message
@@ -158,7 +159,7 @@ class ToolResponseHandler:
def get_tool_call_for_non_tool_calling_llm_impl( def get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool: ForceUseTool, force_use_tool: ForceUseTool,
tools: list[Tool], tools: list[Tool],
prompt_builder: AnswerPromptBuilder, prompt_builder: AnswerPromptBuilder | PromptSnapshot,
llm: LLM, llm: LLM,
) -> tuple[Tool, dict] | None: ) -> tuple[Tool, dict] | None:
if force_use_tool.force_use: if force_use_tool.force_use:

View File

@@ -4,7 +4,7 @@ from typing import Any
from langchain_core.messages.ai import AIMessage from langchain_core.messages.ai import AIMessage
from langchain_core.messages.tool import ToolCall from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import ToolMessage from langchain_core.messages.tool import ToolMessage
from pydantic.v1 import BaseModel as BaseModel__v1 from pydantic import BaseModel
from onyx.natural_language_processing.utils import BaseTokenizer from onyx.natural_language_processing.utils import BaseTokenizer
@@ -21,7 +21,8 @@ def build_tool_message(
) )
class ToolCallSummary(BaseModel__v1): # TODO: does this NEED to be BaseModel__v1?
class ToolCallSummary(BaseModel):
tool_call_request: AIMessage tool_call_request: AIMessage
tool_call_result: ToolMessage tool_call_result: ToolMessage