mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 13:15:18 +02:00
WIP, but working basic search using initial tool choice node
This commit is contained in:
@@ -5,11 +5,12 @@ from langgraph.graph import StateGraph
|
|||||||
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
|
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
|
||||||
basic_use_tool_response,
|
basic_use_tool_response,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice
|
from onyx.agents.agent_search.basic.nodes.prepare_tool_input import prepare_tool_input
|
||||||
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
|
|
||||||
from onyx.agents.agent_search.basic.states import BasicInput
|
from onyx.agents.agent_search.basic.states import BasicInput
|
||||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||||
from onyx.agents.agent_search.basic.states import BasicState
|
from onyx.agents.agent_search.basic.states import BasicState
|
||||||
|
from onyx.agents.agent_search.orchestration.llm_tool_choice import llm_tool_choice
|
||||||
|
from onyx.agents.agent_search.orchestration.tool_call import tool_call
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -24,6 +25,11 @@ def basic_graph_builder() -> StateGraph:
|
|||||||
|
|
||||||
### Add nodes ###
|
### Add nodes ###
|
||||||
|
|
||||||
|
graph.add_node(
|
||||||
|
node="prepare_tool_input",
|
||||||
|
action=prepare_tool_input,
|
||||||
|
)
|
||||||
|
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
node="llm_tool_choice",
|
node="llm_tool_choice",
|
||||||
action=llm_tool_choice,
|
action=llm_tool_choice,
|
||||||
@@ -41,7 +47,9 @@ def basic_graph_builder() -> StateGraph:
|
|||||||
|
|
||||||
### Add edges ###
|
### Add edges ###
|
||||||
|
|
||||||
graph.add_edge(start_key=START, end_key="llm_tool_choice")
|
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||||
|
|
||||||
|
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||||
|
|
||||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||||
|
|
||||||
@@ -62,10 +70,27 @@ def should_continue(state: BasicState) -> str:
|
|||||||
return (
|
return (
|
||||||
# If there are no tool calls, basic graph already streamed the answer
|
# If there are no tool calls, basic graph already streamed the answer
|
||||||
END
|
END
|
||||||
if state["tool_choice"] is None
|
if state.tool_choice is None
|
||||||
else "tool_call"
|
else "tool_call"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
from onyx.db.engine import get_session_context_manager
|
||||||
|
from onyx.context.search.models import SearchRequest
|
||||||
|
from onyx.llm.factory import get_default_llms
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||||
|
|
||||||
|
graph = basic_graph_builder()
|
||||||
|
compiled_graph = graph.compile()
|
||||||
|
# TODO: unify basic input
|
||||||
|
input = BasicInput(logs="")
|
||||||
|
primary_llm, fast_llm = get_default_llms()
|
||||||
|
with get_session_context_manager() as db_session:
|
||||||
|
config, _ = get_test_config(
|
||||||
|
db_session=db_session,
|
||||||
|
primary_llm=primary_llm,
|
||||||
|
fast_llm=fast_llm,
|
||||||
|
search_request=SearchRequest(query="How does onyx use FastAPI?"),
|
||||||
|
)
|
||||||
|
compiled_graph.invoke(input, config={"metadata": {"config": config}})
|
||||||
|
@@ -13,20 +13,26 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
|||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
)
|
)
|
||||||
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
|
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
|
||||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||||
structured_response_format = agent_config.structured_response_format
|
structured_response_format = agent_config.structured_response_format
|
||||||
llm = agent_config.primary_llm
|
llm = agent_config.primary_llm
|
||||||
tool_choice = state["tool_choice"]
|
tool_choice = state.tool_choice
|
||||||
if tool_choice is None:
|
if tool_choice is None:
|
||||||
raise ValueError("Tool choice is None")
|
raise ValueError("Tool choice is None")
|
||||||
tool = tool_choice["tool"]
|
tool = tool_choice.tool
|
||||||
prompt_builder = agent_config.prompt_builder
|
prompt_builder = agent_config.prompt_builder
|
||||||
tool_call_summary = state["tool_call_summary"]
|
if state.tool_call_output is None:
|
||||||
tool_call_responses = state["tool_call_responses"]
|
raise ValueError("Tool call output is None")
|
||||||
state["tool_call_final_result"]
|
tool_call_output = state.tool_call_output
|
||||||
|
tool_call_summary = tool_call_output.tool_call_summary
|
||||||
|
tool_call_responses = tool_call_output.tool_call_responses
|
||||||
|
|
||||||
new_prompt_builder = tool.build_next_prompt(
|
new_prompt_builder = tool.build_next_prompt(
|
||||||
prompt_builder=prompt_builder,
|
prompt_builder=prompt_builder,
|
||||||
tool_call_summary=tool_call_summary,
|
tool_call_summary=tool_call_summary,
|
||||||
@@ -53,11 +59,11 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
|
|||||||
)
|
)
|
||||||
|
|
||||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||||
process_llm_stream(
|
new_tool_call_chunk = process_llm_stream(
|
||||||
stream,
|
stream,
|
||||||
True,
|
True,
|
||||||
final_search_results=final_search_results,
|
final_search_results=final_search_results,
|
||||||
displayed_search_results=initial_search_results,
|
displayed_search_results=initial_search_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
return BasicOutput()
|
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
|
||||||
|
@@ -0,0 +1,15 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
|
from onyx.agents.agent_search.basic.states import BasicState
|
||||||
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||||
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tool_input(state: BasicState, config: RunnableConfig) -> ToolChoiceInput:
|
||||||
|
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||||
|
return ToolChoiceInput(
|
||||||
|
should_stream_answer=True,
|
||||||
|
prompt_snapshot=None, # uses default prompt builder
|
||||||
|
)
|
@@ -1,10 +1,11 @@
|
|||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from onyx.tools.message import ToolCallSummary
|
from langchain_core.messages import AIMessageChunk
|
||||||
from onyx.tools.models import ToolCallFinalResult
|
from pydantic import BaseModel
|
||||||
from onyx.tools.models import ToolCallKickoff
|
|
||||||
from onyx.tools.models import ToolResponse
|
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||||
from onyx.tools.tool import Tool
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||||
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||||
|
|
||||||
# States contain values that change over the course of graph execution,
|
# States contain values that change over the course of graph execution,
|
||||||
# Config is for values that are set at the start and never change.
|
# Config is for values that are set at the start and never change.
|
||||||
@@ -14,33 +15,19 @@ from onyx.tools.tool import Tool
|
|||||||
## Graph Input State
|
## Graph Input State
|
||||||
|
|
||||||
|
|
||||||
class BasicInput(TypedDict):
|
class BasicInput(BaseModel):
|
||||||
should_stream_answer: bool
|
# TODO: subclass global log update state
|
||||||
|
logs: str = ""
|
||||||
|
|
||||||
|
|
||||||
## Graph Output State
|
## Graph Output State
|
||||||
|
|
||||||
|
|
||||||
class BasicOutput(TypedDict):
|
class BasicOutput(TypedDict):
|
||||||
pass
|
tool_call_chunk: AIMessageChunk
|
||||||
|
|
||||||
|
|
||||||
## Update States
|
## Update States
|
||||||
class ToolCallUpdate(TypedDict):
|
|
||||||
tool_call_summary: ToolCallSummary
|
|
||||||
tool_call_kickoff: ToolCallKickoff
|
|
||||||
tool_call_responses: list[ToolResponse]
|
|
||||||
tool_call_final_result: ToolCallFinalResult
|
|
||||||
|
|
||||||
|
|
||||||
class ToolChoice(TypedDict):
|
|
||||||
tool: Tool
|
|
||||||
tool_args: dict
|
|
||||||
id: str | None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolChoiceUpdate(TypedDict):
|
|
||||||
tool_choice: ToolChoice | None
|
|
||||||
|
|
||||||
|
|
||||||
## Graph State
|
## Graph State
|
||||||
@@ -48,8 +35,8 @@ class ToolChoiceUpdate(TypedDict):
|
|||||||
|
|
||||||
class BasicState(
|
class BasicState(
|
||||||
BasicInput,
|
BasicInput,
|
||||||
|
ToolChoiceInput,
|
||||||
ToolCallUpdate,
|
ToolCallUpdate,
|
||||||
ToolChoiceUpdate,
|
ToolChoiceUpdate,
|
||||||
BasicOutput,
|
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
@@ -43,6 +43,7 @@ def process_llm_stream(
|
|||||||
else:
|
else:
|
||||||
answer_handler = PassThroughAnswerResponseHandler()
|
answer_handler = PassThroughAnswerResponseHandler()
|
||||||
|
|
||||||
|
full_answer = ""
|
||||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||||
# the stream will contain AIMessageChunks with tool call information.
|
# the stream will contain AIMessageChunks with tool call information.
|
||||||
for response in stream:
|
for response in stream:
|
||||||
@@ -51,6 +52,7 @@ def process_llm_stream(
|
|||||||
# TODO: handle non-string content
|
# TODO: handle non-string content
|
||||||
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
||||||
answer_piece = str(answer_piece)
|
answer_piece = str(answer_piece)
|
||||||
|
full_answer += answer_piece
|
||||||
|
|
||||||
if isinstance(response, AIMessageChunk) and (
|
if isinstance(response, AIMessageChunk) and (
|
||||||
response.tool_call_chunks or response.tool_calls
|
response.tool_call_chunks or response.tool_calls
|
||||||
@@ -64,4 +66,5 @@ def process_llm_stream(
|
|||||||
response_part,
|
response_part,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Full answer: {full_answer}")
|
||||||
return cast(AIMessageChunk, tool_call_chunk)
|
return cast(AIMessageChunk, tool_call_chunk)
|
||||||
|
@@ -4,11 +4,12 @@ from uuid import uuid4
|
|||||||
from langchain_core.messages import ToolCall
|
from langchain_core.messages import ToolCall
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
from onyx.agents.agent_search.basic.states import BasicState
|
|
||||||
from onyx.agents.agent_search.basic.states import ToolChoice
|
|
||||||
from onyx.agents.agent_search.basic.states import ToolChoiceUpdate
|
|
||||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||||
|
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||||
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||||
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||||
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||||
from onyx.chat.tool_handling.tool_response_handler import (
|
from onyx.chat.tool_handling.tool_response_handler import (
|
||||||
get_tool_call_for_non_tool_calling_llm_impl,
|
get_tool_call_for_non_tool_calling_llm_impl,
|
||||||
@@ -23,16 +24,17 @@ logger = setup_logger()
|
|||||||
# and a function that handles extracting the necessary fields
|
# and a function that handles extracting the necessary fields
|
||||||
# from the state and config
|
# from the state and config
|
||||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||||
def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate:
|
def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate:
|
||||||
"""
|
"""
|
||||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||||
"""
|
"""
|
||||||
should_stream_answer = state["should_stream_answer"]
|
should_stream_answer = state.should_stream_answer
|
||||||
|
|
||||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||||
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
||||||
prompt_builder = agent_config.prompt_builder
|
prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
|
||||||
|
|
||||||
llm = agent_config.primary_llm
|
llm = agent_config.primary_llm
|
||||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
||||||
|
|
||||||
@@ -78,12 +80,17 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda
|
|||||||
tool_choice=None,
|
tool_choice=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
built_prompt = (
|
||||||
|
prompt_builder.build()
|
||||||
|
if isinstance(prompt_builder, AnswerPromptBuilder)
|
||||||
|
else prompt_builder.built_prompt
|
||||||
|
)
|
||||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||||
# DEBUG: good breakpoint
|
# DEBUG: good breakpoint
|
||||||
stream = llm.stream(
|
stream = llm.stream(
|
||||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||||
prompt=prompt_builder.build(),
|
prompt=built_prompt,
|
||||||
tools=[tool.tool_definition() for tool in tools] or None,
|
tools=[tool.tool_definition() for tool in tools] or None,
|
||||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||||
structured_response_format=structured_response_format,
|
structured_response_format=structured_response_format,
|
||||||
@@ -93,6 +100,7 @@ def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpda
|
|||||||
|
|
||||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||||
if len(tool_message.tool_calls) == 0:
|
if len(tool_message.tool_calls) == 0:
|
||||||
|
logger.info("No tool calls emitted by LLM")
|
||||||
return ToolChoiceUpdate(
|
return ToolChoiceUpdate(
|
||||||
tool_choice=None,
|
tool_choice=None,
|
||||||
)
|
)
|
43
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
43
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||||
|
from onyx.tools.message import ToolCallSummary
|
||||||
|
from onyx.tools.models import ToolCallFinalResult
|
||||||
|
from onyx.tools.models import ToolCallKickoff
|
||||||
|
from onyx.tools.models import ToolResponse
|
||||||
|
from onyx.tools.tool import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoiceInput(BaseModel):
|
||||||
|
should_stream_answer: bool = True
|
||||||
|
# default to the prompt builder from the config, but
|
||||||
|
# allow overrides for arbitrary tool calls
|
||||||
|
prompt_snapshot: PromptSnapshot | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallOutput(BaseModel):
|
||||||
|
tool_call_summary: ToolCallSummary
|
||||||
|
tool_call_kickoff: ToolCallKickoff
|
||||||
|
tool_call_responses: list[ToolResponse]
|
||||||
|
tool_call_final_result: ToolCallFinalResult
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallUpdate(BaseModel):
|
||||||
|
tool_call_output: ToolCallOutput | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoice(BaseModel):
|
||||||
|
tool: Tool
|
||||||
|
tool_args: dict
|
||||||
|
id: str | None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoiceUpdate(BaseModel):
|
||||||
|
tool_choice: ToolChoice | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
|
||||||
|
pass
|
@@ -5,9 +5,10 @@ from langchain_core.messages import AIMessageChunk
|
|||||||
from langchain_core.messages.tool import ToolCall
|
from langchain_core.messages.tool import ToolCall
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
from onyx.agents.agent_search.basic.states import BasicState
|
|
||||||
from onyx.agents.agent_search.basic.states import ToolCallUpdate
|
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||||
|
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||||
|
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||||
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||||
from onyx.chat.models import AnswerPacket
|
from onyx.chat.models import AnswerPacket
|
||||||
from onyx.tools.message import build_tool_message
|
from onyx.tools.message import build_tool_message
|
||||||
from onyx.tools.message import ToolCallSummary
|
from onyx.tools.message import ToolCallSummary
|
||||||
@@ -22,23 +23,18 @@ def emit_packet(packet: AnswerPacket) -> None:
|
|||||||
dispatch_custom_event("basic_response", packet)
|
dispatch_custom_event("basic_response", packet)
|
||||||
|
|
||||||
|
|
||||||
# TODO: handle is_cancelled
|
def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate:
|
||||||
def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
|
|
||||||
"""Calls the tool specified in the state and updates the state with the result"""
|
"""Calls the tool specified in the state and updates the state with the result"""
|
||||||
# TODO: implement
|
|
||||||
|
|
||||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||||
# Unnecessary now, node should only be called if there is a tool call
|
|
||||||
# if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
|
||||||
# return
|
|
||||||
|
|
||||||
tool_choice = state["tool_choice"]
|
tool_choice = state.tool_choice
|
||||||
if tool_choice is None:
|
if tool_choice is None:
|
||||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||||
|
|
||||||
tool = tool_choice["tool"]
|
tool = tool_choice.tool
|
||||||
tool_args = tool_choice["tool_args"]
|
tool_args = tool_choice.tool_args
|
||||||
tool_id = tool_choice["id"]
|
tool_id = tool_choice.id
|
||||||
tool_runner = ToolRunner(tool, tool_args)
|
tool_runner = ToolRunner(tool, tool_args)
|
||||||
tool_kickoff = tool_runner.kickoff()
|
tool_kickoff = tool_runner.kickoff()
|
||||||
|
|
||||||
@@ -61,9 +57,10 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return ToolCallUpdate(
|
tool_call_output = ToolCallOutput(
|
||||||
tool_call_summary=tool_call_summary,
|
tool_call_summary=tool_call_summary,
|
||||||
tool_call_kickoff=tool_kickoff,
|
tool_call_kickoff=tool_kickoff,
|
||||||
tool_call_responses=tool_responses,
|
tool_call_responses=tool_responses,
|
||||||
tool_call_final_result=tool_final_result,
|
tool_call_final_result=tool_final_result,
|
||||||
)
|
)
|
||||||
|
return ToolCallUpdate(tool_call_output=tool_call_output)
|
@@ -122,8 +122,11 @@ def _manage_async_event_streaming(
|
|||||||
except (StopAsyncIteration, GeneratorExit):
|
except (StopAsyncIteration, GeneratorExit):
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
|
try:
|
||||||
for task in task_references.pop():
|
for task in task_references.pop():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
return _yield_async_to_sync()
|
return _yield_async_to_sync()
|
||||||
@@ -186,9 +189,7 @@ def run_basic_graph(
|
|||||||
graph = basic_graph_builder()
|
graph = basic_graph_builder()
|
||||||
compiled_graph = graph.compile()
|
compiled_graph = graph.compile()
|
||||||
# TODO: unify basic input
|
# TODO: unify basic input
|
||||||
input = BasicInput(
|
input = BasicInput()
|
||||||
should_stream_answer=True,
|
|
||||||
)
|
|
||||||
return run_graph(compiled_graph, config, input)
|
return run_graph(compiled_graph, config, input)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -228,6 +228,7 @@ def get_test_config(
|
|||||||
message_id=1,
|
message_id=1,
|
||||||
use_persistence=True,
|
use_persistence=True,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
tools=[search_tool],
|
||||||
)
|
)
|
||||||
|
|
||||||
return config, search_tool
|
return config, search_tool
|
||||||
|
@@ -4,6 +4,7 @@ from typing import cast
|
|||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.messages import SystemMessage
|
from langchain_core.messages import SystemMessage
|
||||||
|
from pydantic import BaseModel
|
||||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||||
|
|
||||||
from onyx.chat.models import PromptConfig
|
from onyx.chat.models import PromptConfig
|
||||||
@@ -182,6 +183,13 @@ class AnswerPromptBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Stores some parts of a prompt builder as needed for tool calls
|
||||||
|
class PromptSnapshot(BaseModel):
|
||||||
|
raw_message_history: list[PreviousMessage]
|
||||||
|
raw_user_query: str
|
||||||
|
built_prompt: list[BaseMessage]
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename this? AnswerConfig maybe?
|
# TODO: rename this? AnswerConfig maybe?
|
||||||
class LLMCall(BaseModel__v1):
|
class LLMCall(BaseModel__v1):
|
||||||
prompt_builder: AnswerPromptBuilder
|
prompt_builder: AnswerPromptBuilder
|
||||||
|
@@ -5,10 +5,10 @@ from typing import cast
|
|||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
from onyx.chat.llm_response_handler import ResponsePart
|
|
||||||
from onyx.chat.models import CitationInfo
|
from onyx.chat.models import CitationInfo
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.chat.models import OnyxAnswerPiece
|
from onyx.chat.models import OnyxAnswerPiece
|
||||||
|
from onyx.chat.models import ResponsePart
|
||||||
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
||||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
@@ -7,6 +7,7 @@ from langchain_core.messages import ToolCall
|
|||||||
from onyx.chat.models import ResponsePart
|
from onyx.chat.models import ResponsePart
|
||||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||||
|
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.tools.force import ForceUseTool
|
from onyx.tools.force import ForceUseTool
|
||||||
from onyx.tools.message import build_tool_message
|
from onyx.tools.message import build_tool_message
|
||||||
@@ -158,7 +159,7 @@ class ToolResponseHandler:
|
|||||||
def get_tool_call_for_non_tool_calling_llm_impl(
|
def get_tool_call_for_non_tool_calling_llm_impl(
|
||||||
force_use_tool: ForceUseTool,
|
force_use_tool: ForceUseTool,
|
||||||
tools: list[Tool],
|
tools: list[Tool],
|
||||||
prompt_builder: AnswerPromptBuilder,
|
prompt_builder: AnswerPromptBuilder | PromptSnapshot,
|
||||||
llm: LLM,
|
llm: LLM,
|
||||||
) -> tuple[Tool, dict] | None:
|
) -> tuple[Tool, dict] | None:
|
||||||
if force_use_tool.force_use:
|
if force_use_tool.force_use:
|
||||||
|
@@ -4,7 +4,7 @@ from typing import Any
|
|||||||
from langchain_core.messages.ai import AIMessage
|
from langchain_core.messages.ai import AIMessage
|
||||||
from langchain_core.messages.tool import ToolCall
|
from langchain_core.messages.tool import ToolCall
|
||||||
from langchain_core.messages.tool import ToolMessage
|
from langchain_core.messages.tool import ToolMessage
|
||||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||||
|
|
||||||
@@ -21,7 +21,8 @@ def build_tool_message(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolCallSummary(BaseModel__v1):
|
# TODO: does this NEED to be BaseModel__v1?
|
||||||
|
class ToolCallSummary(BaseModel):
|
||||||
tool_call_request: AIMessage
|
tool_call_request: AIMessage
|
||||||
tool_call_result: ToolMessage
|
tool_call_result: ToolMessage
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user