mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-02 08:58:11 +02:00
WIP, but working basic search using initial tool choice node
This commit is contained in:
parent
4b0a4a2741
commit
982040c792
@ -5,11 +5,12 @@ from langgraph.graph import StateGraph
|
||||
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.basic.nodes.prepare_tool_input import prepare_tool_input
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -24,6 +25,11 @@ def basic_graph_builder() -> StateGraph:
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
@ -41,7 +47,9 @@ def basic_graph_builder() -> StateGraph:
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="llm_tool_choice")
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
@ -62,10 +70,27 @@ def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state["tool_choice"] is None
|
||||
if state.tool_choice is None
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput(logs="")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
config, _ = get_test_config(
|
||||
db_session=db_session,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_request=SearchRequest(query="How does onyx use FastAPI?"),
|
||||
)
|
||||
compiled_graph.invoke(input, config={"metadata": {"config": config}})
|
||||
|
@ -13,20 +13,26 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
llm = agent_config.primary_llm
|
||||
tool_choice = state["tool_choice"]
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Tool choice is None")
|
||||
tool = tool_choice["tool"]
|
||||
tool = tool_choice.tool
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
tool_call_summary = state["tool_call_summary"]
|
||||
tool_call_responses = state["tool_call_responses"]
|
||||
state["tool_call_final_result"]
|
||||
if state.tool_call_output is None:
|
||||
raise ValueError("Tool call output is None")
|
||||
tool_call_output = state.tool_call_output
|
||||
tool_call_summary = tool_call_output.tool_call_summary
|
||||
tool_call_responses = tool_call_output.tool_call_responses
|
||||
|
||||
new_prompt_builder = tool.build_next_prompt(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
@ -53,11 +59,11 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
|
||||
)
|
||||
|
||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||
process_llm_stream(
|
||||
new_tool_call_chunk = process_llm_stream(
|
||||
stream,
|
||||
True,
|
||||
final_search_results=final_search_results,
|
||||
displayed_search_results=initial_search_results,
|
||||
)
|
||||
|
||||
return BasicOutput()
|
||||
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
|
||||
|
@ -0,0 +1,15 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
|
||||
|
||||
def prepare_tool_input(state: BasicState, config: RunnableConfig) -> ToolChoiceInput:
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
return ToolChoiceInput(
|
||||
should_stream_answer=True,
|
||||
prompt_snapshot=None, # uses default prompt builder
|
||||
)
|
@ -1,10 +1,11 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
@ -14,33 +15,19 @@ from onyx.tools.tool import Tool
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BasicInput(TypedDict):
|
||||
should_stream_answer: bool
|
||||
class BasicInput(BaseModel):
|
||||
# TODO: subclass global log update state
|
||||
logs: str = ""
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BasicOutput(TypedDict):
|
||||
pass
|
||||
tool_call_chunk: AIMessageChunk
|
||||
|
||||
|
||||
## Update States
|
||||
class ToolCallUpdate(TypedDict):
|
||||
tool_call_summary: ToolCallSummary
|
||||
tool_call_kickoff: ToolCallKickoff
|
||||
tool_call_responses: list[ToolResponse]
|
||||
tool_call_final_result: ToolCallFinalResult
|
||||
|
||||
|
||||
class ToolChoice(TypedDict):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
|
||||
|
||||
class ToolChoiceUpdate(TypedDict):
|
||||
tool_choice: ToolChoice | None
|
||||
|
||||
|
||||
## Graph State
|
||||
@ -48,8 +35,8 @@ class ToolChoiceUpdate(TypedDict):
|
||||
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
BasicOutput,
|
||||
):
|
||||
pass
|
||||
|
@ -43,6 +43,7 @@ def process_llm_stream(
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for response in stream:
|
||||
@ -51,6 +52,7 @@ def process_llm_stream(
|
||||
# TODO: handle non-string content
|
||||
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
|
||||
if isinstance(response, AIMessageChunk) and (
|
||||
response.tool_call_chunks or response.tool_calls
|
||||
@ -64,4 +66,5 @@ def process_llm_stream(
|
||||
response_part,
|
||||
)
|
||||
|
||||
logger.info(f"Full answer: {full_answer}")
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
|
@ -4,11 +4,12 @@ from uuid import uuid4
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.states import ToolChoice
|
||||
from onyx.agents.agent_search.basic.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
@ -23,16 +24,17 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate:
|
||||
def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state["should_stream_answer"]
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
|
||||
|
||||
llm = agent_config.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
||||
|
||||
@ -78,12 +80,17 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
built_prompt = (
|
||||
prompt_builder.build()
|
||||
if isinstance(prompt_builder, AnswerPromptBuilder)
|
||||
else prompt_builder.built_prompt
|
||||
)
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=prompt_builder.build(),
|
||||
prompt=built_prompt,
|
||||
tools=[tool.tool_definition() for tool in tools] or None,
|
||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||
structured_response_format=structured_response_format,
|
||||
@ -93,6 +100,7 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
logger.info("No tool calls emitted by LLM")
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
43
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
43
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
@ -0,0 +1,43 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
class ToolChoiceInput(BaseModel):
|
||||
should_stream_answer: bool = True
|
||||
# default to the prompt builder from the config, but
|
||||
# allow overrides for arbitrary tool calls
|
||||
prompt_snapshot: PromptSnapshot | None = None
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
tool_call_summary: ToolCallSummary
|
||||
tool_call_kickoff: ToolCallKickoff
|
||||
tool_call_responses: list[ToolResponse]
|
||||
tool_call_final_result: ToolCallFinalResult
|
||||
|
||||
|
||||
class ToolCallUpdate(BaseModel):
|
||||
tool_call_output: ToolCallOutput | None = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ToolChoiceUpdate(BaseModel):
|
||||
tool_choice: ToolChoice | None = None
|
||||
|
||||
|
||||
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
|
||||
pass
|
@ -5,9 +5,10 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
@ -22,23 +23,18 @@ def emit_packet(packet: AnswerPacket) -> None:
|
||||
dispatch_custom_event("basic_response", packet)
|
||||
|
||||
|
||||
# TODO: handle is_cancelled
|
||||
def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
|
||||
def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
# TODO: implement
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# Unnecessary now, node should only be called if there is a tool call
|
||||
# if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
# return
|
||||
|
||||
tool_choice = state["tool_choice"]
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice["tool"]
|
||||
tool_args = tool_choice["tool_args"]
|
||||
tool_id = tool_choice["id"]
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
@ -61,9 +57,10 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
|
||||
),
|
||||
)
|
||||
|
||||
return ToolCallUpdate(
|
||||
tool_call_output = ToolCallOutput(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
||||
return ToolCallUpdate(tool_call_output=tool_call_output)
|
@ -122,8 +122,11 @@ def _manage_async_event_streaming(
|
||||
except (StopAsyncIteration, GeneratorExit):
|
||||
break
|
||||
finally:
|
||||
for task in task_references.pop():
|
||||
task.cancel()
|
||||
try:
|
||||
for task in task_references.pop():
|
||||
task.cancel()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
loop.close()
|
||||
|
||||
return _yield_async_to_sync()
|
||||
@ -186,9 +189,7 @@ def run_basic_graph(
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput(
|
||||
should_stream_answer=True,
|
||||
)
|
||||
input = BasicInput()
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
|
@ -228,6 +228,7 @@ def get_test_config(
|
||||
message_id=1,
|
||||
use_persistence=True,
|
||||
db_session=db_session,
|
||||
tools=[search_tool],
|
||||
)
|
||||
|
||||
return config, search_tool
|
||||
|
@ -4,6 +4,7 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from onyx.chat.models import PromptConfig
|
||||
@ -182,6 +183,13 @@ class AnswerPromptBuilder:
|
||||
)
|
||||
|
||||
|
||||
# Stores some parts of a prompt builder as needed for tool calls
|
||||
class PromptSnapshot(BaseModel):
|
||||
raw_message_history: list[PreviousMessage]
|
||||
raw_user_query: str
|
||||
built_prompt: list[BaseMessage]
|
||||
|
||||
|
||||
# TODO: rename this? AnswerConfig maybe?
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
|
@ -5,10 +5,10 @@ from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.llm_response_handler import ResponsePart
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
@ -7,6 +7,7 @@ from langchain_core.messages import ToolCall
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.message import build_tool_message
|
||||
@ -158,7 +159,7 @@ class ToolResponseHandler:
|
||||
def get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool: ForceUseTool,
|
||||
tools: list[Tool],
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
prompt_builder: AnswerPromptBuilder | PromptSnapshot,
|
||||
llm: LLM,
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if force_use_tool.force_use:
|
||||
|
@ -4,7 +4,7 @@ from typing import Any
|
||||
from langchain_core.messages.ai import AIMessage
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
|
||||
@ -21,7 +21,8 @@ def build_tool_message(
|
||||
)
|
||||
|
||||
|
||||
class ToolCallSummary(BaseModel__v1):
|
||||
# TODO: does this NEED to be BaseModel__v1?
|
||||
class ToolCallSummary(BaseModel):
|
||||
tool_call_request: AIMessage
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user