mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-02 17:08:22 +02:00
implemented top-level tool calling + force search
This commit is contained in:
parent
982040c792
commit
ddbfc65ad0
backend/onyx
agents/agent_search
basic
deep_search_a/main
orchestration
shared_graph_utils
db
@ -2,14 +2,14 @@ from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.basic.nodes.prepare_tool_input import prepare_tool_input
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.prepare_tool_input import prepare_tool_input
|
||||
from onyx.agents.agent_search.orchestration.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
@ -14,12 +16,29 @@ from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["tool_call", "agent_search_start", "logging_node"]:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if state.tool_choice is not None:
|
||||
if (
|
||||
agent_config.use_agentic_search
|
||||
and state.tool_choice.tool.name == agent_config.search_tool.name
|
||||
):
|
||||
return "agent_search_start"
|
||||
else:
|
||||
return "tool_call"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
|
@ -20,24 +20,18 @@ from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
parallelize_refined_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
route_initial_tool_choice,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import (
|
||||
agent_logging,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_decision import (
|
||||
agent_path_decision,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_routing import (
|
||||
agent_path_routing,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import (
|
||||
agent_search_start,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.answer_comparison import (
|
||||
answer_comparison,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.direct_llm_handling import (
|
||||
direct_llm_handling,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
|
||||
entity_term_extraction_llm,
|
||||
)
|
||||
@ -73,6 +67,12 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation i
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.orchestration.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.prepare_tool_input import prepare_tool_input
|
||||
from onyx.agents.agent_search.orchestration.tool_call import tool_call
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -87,21 +87,37 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="agent_path_decision",
|
||||
# action=agent_path_decision,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="agent_path_routing",
|
||||
# action=agent_path_routing,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="LLM",
|
||||
# action=direct_llm_handling,
|
||||
# )
|
||||
graph.add_node(
|
||||
node="agent_path_decision",
|
||||
action=agent_path_decision,
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="agent_path_routing",
|
||||
action=agent_path_routing,
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="LLM",
|
||||
action=direct_llm_handling,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="agent_search_start",
|
||||
action=agent_search_start,
|
||||
@ -205,14 +221,35 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
|
||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="agent_path_decision",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_path_decision",
|
||||
# end_key="agent_path_routing",
|
||||
# )
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="agent_path_decision",
|
||||
start_key="prepare_tool_input",
|
||||
end_key="initial_tool_choice",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["tool_call", "agent_search_start", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_path_decision",
|
||||
end_key="agent_path_routing",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
@ -245,10 +282,10 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="LLM",
|
||||
end_key=END,
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="LLM",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
|
@ -21,8 +21,8 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
|
||||
agent_start_time = state.agent_start_time
|
||||
agent_base_end_time = state.agent_base_end_time
|
||||
agent_refined_start_time = state.agent_refined_start_time or None
|
||||
agent_refined_end_time = state.agent_refined_end_time or None
|
||||
agent_refined_start_time = state.agent_refined_start_time
|
||||
agent_refined_end_time = state.agent_refined_end_time
|
||||
agent_end_time = agent_refined_end_time or agent_base_end_time
|
||||
|
||||
agent_base_duration = None
|
||||
@ -67,14 +67,15 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
|
||||
# log the agent metrics
|
||||
if agent_a_config.db_session is not None:
|
||||
log_agent_metrics(
|
||||
db_session=agent_a_config.db_session,
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
agent_type=agent_type,
|
||||
start_time=agent_start_time,
|
||||
agent_metrics=combined_agent_metrics,
|
||||
)
|
||||
if agent_base_duration is not None:
|
||||
log_agent_metrics(
|
||||
db_session=agent_a_config.db_session,
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
agent_type=agent_type,
|
||||
start_time=agent_start_time,
|
||||
agent_metrics=combined_agent_metrics,
|
||||
)
|
||||
|
||||
if agent_a_config.use_persistence:
|
||||
# Persist the sub-answer in the database
|
||||
|
@ -12,6 +12,9 @@ from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
@ -133,6 +136,9 @@ class MainInput(CoreState):
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
BaseDecompUpdate,
|
||||
InitialAnswerUpdate,
|
||||
InitialAnswerBASEUpdate,
|
||||
|
@ -39,7 +39,7 @@ def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoic
|
||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
||||
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
tools = agent_config.tools or []
|
||||
tools = [tool for tool in (agent_config.tools or []) if tool.name in state.tools]
|
||||
force_use_tool = agent_config.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
|
@ -1,15 +1,16 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
|
||||
|
||||
def prepare_tool_input(state: BasicState, config: RunnableConfig) -> ToolChoiceInput:
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
return ToolChoiceInput(
|
||||
should_stream_answer=True,
|
||||
prompt_snapshot=None, # uses default prompt builder
|
||||
tools=[tool.name for tool in (agent_config.tools or [])],
|
||||
)
|
@ -8,12 +8,17 @@ from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
|
||||
# creating a subgraph that can be invoked in parallel via Send/Command APIs
|
||||
class ToolChoiceInput(BaseModel):
|
||||
should_stream_answer: bool = True
|
||||
# default to the prompt builder from the config, but
|
||||
# allow overrides for arbitrary tool calls
|
||||
prompt_snapshot: PromptSnapshot | None = None
|
||||
|
||||
# names of tools to use for tool calling. Filters the tools available in the config
|
||||
tools: list[str] = []
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
tool_call_summary: ToolCallSummary
|
||||
|
@ -153,7 +153,11 @@ def generate_log_message(
|
||||
|
||||
|
||||
def get_test_config(
|
||||
db_session: Session, primary_llm: LLM, fast_llm: LLM, search_request: SearchRequest
|
||||
db_session: Session,
|
||||
primary_llm: LLM,
|
||||
fast_llm: LLM,
|
||||
search_request: SearchRequest,
|
||||
use_agentic_search: bool = True,
|
||||
) -> tuple[AgentSearchConfig, SearchTool]:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
@ -229,6 +233,7 @@ def get_test_config(
|
||||
use_persistence=True,
|
||||
db_session=db_session,
|
||||
tools=[search_tool],
|
||||
use_agentic_search=use_agentic_search,
|
||||
)
|
||||
|
||||
return config, search_tool
|
||||
|
@ -964,9 +964,11 @@ def log_agent_metrics(
|
||||
start_time=start_time,
|
||||
base_duration__s=agent_timings.base_duration__s,
|
||||
full_duration__s=agent_timings.full_duration__s,
|
||||
base_metrics=vars(agent_base_metrics),
|
||||
refined_metrics=vars(agent_refined_metrics),
|
||||
all_metrics=vars(agent_additional_metrics),
|
||||
base_metrics=vars(agent_base_metrics) if agent_base_metrics else None,
|
||||
refined_metrics=vars(agent_refined_metrics) if agent_refined_metrics else None,
|
||||
all_metrics=vars(agent_additional_metrics)
|
||||
if agent_additional_metrics
|
||||
else None,
|
||||
)
|
||||
|
||||
db_session.add(agent_metric_tracking)
|
||||
|
Loading…
x
Reference in New Issue
Block a user