WIP, but working basic search using initial tool choice node

This commit is contained in:
Evan Lohn 2025-01-23 16:20:17 -08:00
parent 4b0a4a2741
commit 982040c792
14 changed files with 161 additions and 65 deletions

View File

@ -5,11 +5,12 @@ from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
from onyx.agents.agent_search.basic.nodes.prepare_tool_input import prepare_tool_input
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.tool_call import tool_call
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -24,6 +25,11 @@ def basic_graph_builder() -> StateGraph:
### Add nodes ###
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node(
node="llm_tool_choice",
action=llm_tool_choice,
@ -41,7 +47,9 @@ def basic_graph_builder() -> StateGraph:
### Add edges ###
graph.add_edge(start_key=START, end_key="llm_tool_choice")
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
@ -62,10 +70,27 @@ def should_continue(state: BasicState) -> str:
return (
# If there are no tool calls, basic graph already streamed the answer
END
if state["tool_choice"] is None
if state.tool_choice is None
else "tool_call"
)
if __name__ == "__main__":
pass
from onyx.db.engine import get_session_context_manager
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput(logs="")
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_request=SearchRequest(query="How does onyx use FastAPI?"),
)
compiled_graph.invoke(input, config={"metadata": {"config": config}})

View File

@ -13,20 +13,26 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
structured_response_format = agent_config.structured_response_format
llm = agent_config.primary_llm
tool_choice = state["tool_choice"]
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Tool choice is None")
tool = tool_choice["tool"]
tool = tool_choice.tool
prompt_builder = agent_config.prompt_builder
tool_call_summary = state["tool_call_summary"]
tool_call_responses = state["tool_call_responses"]
state["tool_call_final_result"]
if state.tool_call_output is None:
raise ValueError("Tool call output is None")
tool_call_output = state.tool_call_output
tool_call_summary = tool_call_output.tool_call_summary
tool_call_responses = tool_call_output.tool_call_responses
new_prompt_builder = tool.build_next_prompt(
prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary,
@ -53,11 +59,11 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
)
# For now, we don't do multiple tool calls, so we ignore the tool_message
process_llm_stream(
new_tool_call_chunk = process_llm_stream(
stream,
True,
final_search_results=final_search_results,
displayed_search_results=initial_search_results,
)
return BasicOutput()
return BasicOutput(tool_call_chunk=new_tool_call_chunk)

View File

@ -0,0 +1,15 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
def prepare_tool_input(state: BasicState, config: RunnableConfig) -> ToolChoiceInput:
cast(AgentSearchConfig, config["metadata"]["config"])
return ToolChoiceInput(
should_stream_answer=True,
prompt_snapshot=None, # uses default prompt builder
)

View File

@ -1,10 +1,11 @@
from typing import TypedDict
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
@ -14,33 +15,19 @@ from onyx.tools.tool import Tool
## Graph Input State
class BasicInput(TypedDict):
should_stream_answer: bool
class BasicInput(BaseModel):
# TODO: subclass global log update state
logs: str = ""
## Graph Output State
class BasicOutput(TypedDict):
pass
tool_call_chunk: AIMessageChunk
## Update States
class ToolCallUpdate(TypedDict):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
class ToolChoice(TypedDict):
tool: Tool
tool_args: dict
id: str | None
class ToolChoiceUpdate(TypedDict):
tool_choice: ToolChoice | None
## Graph State
@ -48,8 +35,8 @@ class ToolChoiceUpdate(TypedDict):
class BasicState(
BasicInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
BasicOutput,
):
pass

View File

@ -43,6 +43,7 @@ def process_llm_stream(
else:
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for response in stream:
@ -51,6 +52,7 @@ def process_llm_stream(
# TODO: handle non-string content
logger.warning(f"Received non-string content: {type(answer_piece)}")
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(response, AIMessageChunk) and (
response.tool_call_chunks or response.tool_calls
@ -64,4 +66,5 @@ def process_llm_stream(
response_part,
)
logger.info(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk)

View File

@ -4,11 +4,12 @@ from uuid import uuid4
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.states import ToolChoice
from onyx.agents.agent_search.basic.states import ToolChoiceUpdate
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
@ -23,16 +24,17 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate:
def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate:
"""
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
"""
should_stream_answer = state["should_stream_answer"]
should_stream_answer = state.should_stream_answer
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
using_tool_calling_llm = agent_config.using_tool_calling_llm
prompt_builder = agent_config.prompt_builder
prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
llm = agent_config.primary_llm
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
@ -78,12 +80,17 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda
tool_choice=None,
)
built_prompt = (
prompt_builder.build()
if isinstance(prompt_builder, AnswerPromptBuilder)
else prompt_builder.built_prompt
)
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=prompt_builder.build(),
prompt=built_prompt,
tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None),
structured_response_format=structured_response_format,
@ -93,6 +100,7 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
logger.info("No tool calls emitted by LLM")
return ToolChoiceUpdate(
tool_choice=None,
)

View File

@ -0,0 +1,43 @@
from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
class ToolChoiceInput(BaseModel):
should_stream_answer: bool = True
# default to the prompt builder from the config, but
# allow overrides for arbitrary tool calls
prompt_snapshot: PromptSnapshot | None = None
class ToolCallOutput(BaseModel):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
class ToolCallUpdate(BaseModel):
tool_call_output: ToolCallOutput | None = None
class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
class Config:
arbitrary_types_allowed = True
class ToolChoiceUpdate(BaseModel):
tool_choice: ToolChoice | None = None
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
pass

View File

@ -5,9 +5,10 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.states import ToolCallUpdate
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.models import AnswerPacket
from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary
@ -22,23 +23,18 @@ def emit_packet(packet: AnswerPacket) -> None:
dispatch_custom_event("basic_response", packet)
# TODO: handle is_cancelled
def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
# TODO: implement
cast(AgentSearchConfig, config["metadata"]["config"])
# Unnecessary now, node should only be called if there is a tool call
# if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
# return
tool_choice = state["tool_choice"]
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice")
tool = tool_choice["tool"]
tool_args = tool_choice["tool_args"]
tool_id = tool_choice["id"]
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
@ -61,9 +57,10 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
),
)
return ToolCallUpdate(
tool_call_output = ToolCallOutput(
tool_call_summary=tool_call_summary,
tool_call_kickoff=tool_kickoff,
tool_call_responses=tool_responses,
tool_call_final_result=tool_final_result,
)
return ToolCallUpdate(tool_call_output=tool_call_output)

View File

@ -122,8 +122,11 @@ def _manage_async_event_streaming(
except (StopAsyncIteration, GeneratorExit):
break
finally:
for task in task_references.pop():
task.cancel()
try:
for task in task_references.pop():
task.cancel()
except StopAsyncIteration:
pass
loop.close()
return _yield_async_to_sync()
@ -186,9 +189,7 @@ def run_basic_graph(
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput(
should_stream_answer=True,
)
input = BasicInput()
return run_graph(compiled_graph, config, input)

View File

@ -228,6 +228,7 @@ def get_test_config(
message_id=1,
use_persistence=True,
db_session=db_session,
tools=[search_tool],
)
return config, search_tool

View File

@ -4,6 +4,7 @@ from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModel__v1
from onyx.chat.models import PromptConfig
@ -182,6 +183,13 @@ class AnswerPromptBuilder:
)
# Stores some parts of a prompt builder as needed for tool calls
class PromptSnapshot(BaseModel):
raw_message_history: list[PreviousMessage]
raw_user_query: str
built_prompt: list[BaseMessage]
# TODO: rename this? AnswerConfig maybe?
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder

View File

@ -5,10 +5,10 @@ from typing import cast
from langchain_core.messages import BaseMessage
from onyx.chat.llm_response_handler import ResponsePart
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import ResponsePart
from onyx.chat.stream_processing.citation_processing import CitationProcessor
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
from onyx.utils.logger import setup_logger

View File

@ -7,6 +7,7 @@ from langchain_core.messages import ToolCall
from onyx.chat.models import ResponsePart
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.message import build_tool_message
@ -158,7 +159,7 @@ class ToolResponseHandler:
def get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool: ForceUseTool,
tools: list[Tool],
prompt_builder: AnswerPromptBuilder,
prompt_builder: AnswerPromptBuilder | PromptSnapshot,
llm: LLM,
) -> tuple[Tool, dict] | None:
if force_use_tool.force_use:

View File

@ -4,7 +4,7 @@ from typing import Any
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import ToolMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from pydantic import BaseModel
from onyx.natural_language_processing.utils import BaseTokenizer
@ -21,7 +21,8 @@ def build_tool_message(
)
class ToolCallSummary(BaseModel__v1):
# TODO: does this NEED to be BaseModel__v1?
class ToolCallSummary(BaseModel):
tool_call_request: AIMessage
tool_call_result: ToolMessage