basic search restructure: WIP on fixing tests

This commit is contained in:
Evan Lohn 2025-01-22 16:15:22 -08:00
parent 8aa82be12a
commit dd260140b2
17 changed files with 403 additions and 305 deletions

View File

@ -197,6 +197,7 @@ def get_answer_stream(
rerank_settings=query_request.rerank_settings,
db_session=db_session,
use_agentic_search=query_request.use_agentic_search,
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
)
packets = stream_chat_message_objects(

View File

@ -1,20 +1,18 @@
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.runnables.config import RunnableConfig
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.states import BasicStateUpdate
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.stream_processing.utils import (
map_document_id_order,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_graph_builder() -> StateGraph:
@ -27,17 +25,33 @@ def basic_graph_builder() -> StateGraph:
### Add nodes ###
graph.add_node(
node="get_response",
action=get_response,
node="llm_tool_choice",
action=llm_tool_choice,
)
graph.add_node(
node="tool_call",
action=tool_call,
)
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="get_response")
graph.add_edge(start_key=START, end_key="llm_tool_choice")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
graph.add_conditional_edges("get_response", should_continue, ["get_response", END])
graph.add_edge(
start_key="get_response",
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key=END,
)
@ -46,56 +60,10 @@ def basic_graph_builder() -> StateGraph:
def should_continue(state: BasicState) -> str:
return (
END if state["last_llm_call"] is None or state["calls"] > 1 else "get_response"
)
def get_response(state: BasicState, config: RunnableConfig) -> BasicStateUpdate:
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
llm = agent_a_config.primary_llm
current_llm_call = state["last_llm_call"]
if current_llm_call is None:
raise ValueError("last_llm_call is None")
structured_response_format = agent_a_config.structured_response_format
response_handler_manager = state["response_handler_manager"]
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
else None
),
structured_response_format=structured_response_format,
)
for response in response_handler_manager.handle_llm_response(stream):
dispatch_custom_event(
"basic_response",
response,
)
next_call = response_handler_manager.next_llm_call(current_llm_call)
if next_call is not None:
final_search_results, displayed_search_results = SearchTool.get_search_result(
next_call
) or ([], [])
else:
final_search_results, displayed_search_results = [], []
response_handler_manager.answer_handler.update(
(
final_search_results,
map_document_id_order(final_search_results),
map_document_id_order(displayed_search_results),
)
)
return BasicStateUpdate(
last_llm_call=next_call,
calls=state["calls"] + 1,
# If there are no tool calls, basic graph already streamed the answer
END
if state["tool_choice"] is None
else "tool_call"
)

View File

@ -1,20 +1,21 @@
from typing import TypedDict
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
## Update States
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
# If you are using a value from the config and realize it needs to change,
# you should add it to the state and use/update the version in the state.
## Graph Input State
class BasicInput(TypedDict):
base_question: str
last_llm_call: LLMCall | None
response_handler_manager: LLMResponseHandlerManager
calls: int
should_stream_answer: bool
## Graph Output State
@ -24,9 +25,22 @@ class BasicOutput(TypedDict):
pass
class BasicStateUpdate(TypedDict):
last_llm_call: LLMCall | None
calls: int
## Update States
class ToolCallUpdate(TypedDict):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
class ToolChoice(TypedDict):
tool: Tool
tool_args: dict
id: str | None
class ToolChoiceUpdate(TypedDict):
tool_choice: ToolChoice | None
## Graph State
@ -34,6 +48,8 @@ class BasicStateUpdate(TypedDict):
class BasicState(
BasicInput,
ToolCallUpdate,
ToolChoiceUpdate,
BasicOutput,
):
pass

View File

@ -2,11 +2,15 @@ from dataclasses import dataclass
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@ -22,7 +26,18 @@ class AgentSearchConfig:
primary_llm: LLM
fast_llm: LLM
search_tool: SearchTool
use_agentic_search: bool = True
# Whether to force use of a tool, or to
# force tool args IF the tool is used
force_use_tool: ForceUseTool
# contains message history for the current chat session
# has the following (at most one is non-None)
# message_history: list[PreviousMessage] | None = None
# single_message_history: str | None = None
prompt_builder: AnswerPromptBuilder
use_agentic_search: bool = False
# For persisting agent search data
chat_session_id: UUID | None = None
@ -45,11 +60,25 @@ class AgentSearchConfig:
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
# Message history for the current chat session
message_history: list[PreviousMessage] | None = None
# Tools available for use
tools: list[Tool] | None = None
using_tool_calling_llm: bool = False
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def validate_db_session(self) -> "AgentSearchConfig":
if self.use_persistence and self.db_session is None:
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
return self
class AgentDocumentCitations(BaseModel):
document_id: str

View File

@ -16,7 +16,6 @@ from onyx.agents.agent_search.deep_search_a.main.graph_builder import (
from onyx.agents.agent_search.deep_search_a.main.states import MainInput as MainInput_a
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
@ -25,7 +24,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.configs.dev_configs import GRAPH_NAME
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
@ -143,7 +141,6 @@ def run_graph(
config: AgentSearchConfig,
input: BasicInput | MainInput_a,
) -> AnswerStream:
input["base_question"] = config.search_request.query if config else ""
# TODO: add these to the environment
config.perform_initial_search_path_decision = True
config.perform_initial_search_decomposition = True
@ -192,17 +189,12 @@ def run_main_graph(
# TODO: unify input types, especially prosearchconfig
def run_basic_graph(
config: AgentSearchConfig,
last_llm_call: LLMCall | None,
response_handler_manager: LLMResponseHandlerManager,
) -> AnswerStream:
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput(
base_question="",
last_llm_call=last_llm_call,
response_handler_manager=response_handler_manager,
calls=0,
should_stream_answer=True,
)
return run_graph(compiled_graph, config, input)

View File

@ -11,6 +11,7 @@ from typing import cast
from uuid import UUID
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
@ -21,6 +22,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
@ -31,6 +33,7 @@ from onyx.context.search.models import SearchRequest
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@ -207,14 +210,23 @@ def get_test_config(
config = AgentSearchConfig(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=search_request.query),
message_history=[],
llm_config=primary_llm.config,
raw_user_query=search_request.query,
raw_user_uploaded_files=[],
),
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
message_id=1,
use_persistence=True,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
db_session=db_session,
)
return config, search_tool

View File

@ -1,16 +1,12 @@
from collections import defaultdict
from collections.abc import Callable
from uuid import uuid4
from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
@ -18,18 +14,9 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from onyx.chat.stream_processing.utils import (
map_document_id_order,
)
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
from onyx.configs.constants import BASIC_KEY
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
@ -37,7 +24,6 @@ from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
@ -52,7 +38,7 @@ class Answer:
llm: LLM,
prompt_config: PromptConfig,
force_use_tool: ForceUseTool,
pro_search_config: AgentSearchConfig,
agent_search_config: AgentSearchConfig,
# must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
@ -114,7 +100,7 @@ class Answer:
and not skip_explicit_tool_calling
)
self.pro_search_config = pro_search_config
self.agent_search_config = agent_search_config
self.db_session = db_session
def _get_tools_list(self) -> list[Tool]:
@ -132,130 +118,132 @@ class Answer:
return [tool]
# TODO: delete the function and move the full body to processed_streamed_output
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
current_llm_call = llm_calls[-1]
def _get_response(self) -> AnswerStream:
# current_llm_call = llm_calls[-1]
tool, tool_args = None, None
# handle the case where no decision has to be made; we simply run the tool
if (
current_llm_call.force_use_tool.force_use
and current_llm_call.force_use_tool.args is not None
):
tool_name, tool_args = (
current_llm_call.force_use_tool.tool_name,
current_llm_call.force_use_tool.args,
)
tool = get_tool_by_name(current_llm_call.tools, tool_name)
# tool, tool_args = None, None
# # handle the case where no decision has to be made; we simply run the tool
# if (
# current_llm_call.force_use_tool.force_use
# and current_llm_call.force_use_tool.args is not None
# ):
# tool_name, tool_args = (
# current_llm_call.force_use_tool.tool_name,
# current_llm_call.force_use_tool.args,
# )
# tool = get_tool_by_name(current_llm_call.tools, tool_name)
# special pre-logic for non-tool calling LLM case
elif not self.using_tool_calling_llm and current_llm_call.tools:
chosen_tool_and_args = (
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
)
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
# # special pre-logic for non-tool calling LLM case
# elif not self.using_tool_calling_llm and current_llm_call.tools:
# chosen_tool_and_args = (
# ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
# current_llm_call, self.llm
# )
# )
# if chosen_tool_and_args:
# tool, tool_args = chosen_tool_and_args
if tool and tool_args:
dummy_tool_call_chunk = AIMessageChunk(content="")
dummy_tool_call_chunk.tool_calls = [
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]
# if tool and tool_args:
# dummy_tool_call_chunk = AIMessageChunk(content="")
# dummy_tool_call_chunk.tool_calls = [
# ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
# ]
response_handler_manager = LLMResponseHandlerManager(
ToolResponseHandler([tool]), None, self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
# response_handler_manager = LLMResponseHandlerManager(
# ToolResponseHandler([tool]), None, self.is_cancelled
# )
# yield from response_handler_manager.handle_llm_response(
# iter([dummy_tool_call_chunk])
# )
tmp_call = response_handler_manager.next_llm_call(current_llm_call)
if tmp_call is None:
return # no more LLM calls to process
current_llm_call = tmp_call
# tmp_call = response_handler_manager.next_llm_call(current_llm_call)
# if tmp_call is None:
# return # no more LLM calls to process
# current_llm_call = tmp_call
# if we're skipping gen ai answer generation, we should break
# out unless we're forcing a tool call. If we don't, we might generate an
# answer, which is a no-no!
if (
self.skip_gen_ai_answer_generation
and not current_llm_call.force_use_tool.force_use
):
return
# # if we're skipping gen ai answer generation, we should break
# # out unless we're forcing a tool call. If we don't, we might generate an
# # answer, which is a no-no!
# if (
# self.skip_gen_ai_answer_generation
# and not current_llm_call.force_use_tool.force_use
# ):
# return
# set up "handlers" to listen to the LLM response stream and
# feed back the processed results + handle tool call requests
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
# # set up "handlers" to listen to the LLM response stream and
# # feed back the processed results + handle tool call requests
# # + figure out what the next LLM call should be
# tool_call_handler = ToolResponseHandler(current_llm_call.tools)
final_search_results, displayed_search_results = SearchTool.get_search_result(
current_llm_call
) or ([], [])
# final_search_results, displayed_search_results = SearchTool.get_search_result(
# current_llm_call
# ) or ([], [])
# NEXT: we still want to handle the LLM response stream, but it is now:
# 1. handle the tool call requests
# 2. feed back the processed results
# 3. handle the citations
# # NEXT: we still want to handle the LLM response stream, but it is now:
# # 1. handle the tool call requests
# # 2. feed back the processed results
# # 3. handle the citations
answer_handler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
# At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
# In the future, we'll want to handle citations and tool calls in the langgraph graph.
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
)
# # At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
# # In the future, we'll want to handle citations and tool calls in the langgraph graph.
# response_handler_manager = LLMResponseHandlerManager(
# tool_call_handler, answer_handler, self.is_cancelled
# )
# In langgraph, whether we do the basic thing (call llm stream) or pro search
# is based on a flag in the pro search config
if self.pro_search_config.use_agentic_search:
if self.pro_search_config.search_request is None:
raise ValueError("Search request must be provided for pro search")
if self.db_session is None:
raise ValueError("db_session must be provided for pro search")
if self.fast_llm is None:
raise ValueError("fast_llm must be provided for pro search")
if self.agent_search_config.use_agentic_search:
if (
self.agent_search_config.db_session is None
and self.agent_search_config.use_persistence
):
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
stream = run_main_graph(
config=self.pro_search_config,
config=self.agent_search_config,
)
else:
stream = run_basic_graph(
config=self.pro_search_config,
last_llm_call=current_llm_call,
response_handler_manager=response_handler_manager,
config=self.agent_search_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
return
# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
else None
),
structured_response_format=self.answer_style_config.structured_response_format,
)
yield from response_handler_manager.handle_llm_response(stream)
# stream = self.llm.stream(
# # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
# prompt=current_llm_call.prompt_builder.build(),
# tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
# tool_choice=(
# "required"
# if current_llm_call.tools and current_llm_call.force_use_tool.force_use
# else None
# ),
# structured_response_format=self.answer_style_config.structured_response_format,
# )
# yield from response_handler_manager.handle_llm_response(stream)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
# new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
# if new_llm_call:
# yield from self._get_response(llm_calls + [new_llm_call])
@property
def processed_streamed_output(self) -> AnswerStream:
@ -263,33 +251,33 @@ class Answer:
yield from self._processed_stream
return
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=self.question,
prompt_config=self.prompt_config,
files=self.latest_query_files,
single_message_history=self.single_message_history,
),
message_history=self.message_history,
llm_config=self.llm.config,
raw_user_query=self.question,
raw_user_uploaded_files=self.latest_query_files or [],
single_message_history=self.single_message_history,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
llm_call = LLMCall(
prompt_builder=prompt_builder,
tools=self._get_tools_list(),
force_use_tool=self.force_use_tool,
files=self.latest_query_files,
tool_call_info=[],
using_tool_calling_llm=self.using_tool_calling_llm,
)
# prompt_builder = AnswerPromptBuilder(
# user_message=default_build_user_message(
# user_query=self.question,
# prompt_config=self.prompt_config,
# files=self.latest_query_files,
# single_message_history=self.single_message_history,
# ),
# message_history=self.message_history,
# llm_config=self.llm.config,
# raw_user_query=self.question,
# raw_user_uploaded_files=self.latest_query_files or [],
# single_message_history=self.single_message_history,
# )
# prompt_builder.update_system_prompt(
# default_build_system_message(self.prompt_config)
# )
# llm_call = LLMCall(
# prompt_builder=prompt_builder,
# tools=self._get_tools_list(),
# force_use_tool=self.force_use_tool,
# files=self.latest_query_files,
# tool_call_info=[],
# using_tool_calling_llm=self.using_tool_calling_llm,
# )
processed_stream = []
for processed_packet in self._get_response([llm_call]):
for processed_packet in self._get_response():
processed_stream.append(processed_packet)
yield processed_packet

View File

@ -49,6 +49,7 @@ def prepare_chat_message_request(
rerank_settings: RerankingDetails | None,
db_session: Session,
use_agentic_search: bool = False,
skip_gen_ai_answer_generation: bool = False,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
@ -74,6 +75,7 @@ def prepare_chat_message_request(
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
)

View File

@ -33,6 +33,9 @@ from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
@ -130,6 +133,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
@ -137,7 +141,6 @@ from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@ -534,11 +537,8 @@ def stream_chat_message_objects(
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
latest_query_files = [
file
for file in files
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
]
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
if user_message:
attach_files_to_chat_message(
@ -748,17 +748,42 @@ def stream_chat_message_objects(
# TODO: handle multiple search tools
raise ValueError("Multiple search tools found")
search_tool = search_tools[0]
pro_search_config = AgentSearchConfig(
use_agentic_search=new_msg_req.use_agentic_search,
search_request=search_request,
chat_session_id=chat_session_id,
message_id=reserved_message_id,
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
force_use_tool = _get_force_search_settings(new_msg_req, tools)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=final_msg.message,
prompt_config=prompt_config,
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config),
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,
raw_user_uploaded_files=latest_query_files or [],
single_message_history=single_message_history,
)
agent_search_config = AgentSearchConfig(
search_request=search_request,
primary_llm=llm,
fast_llm=fast_llm,
search_tool=search_tool,
structured_response_format=new_msg_req.structured_response_format,
force_use_tool=force_use_tool,
use_agentic_search=new_msg_req.use_agentic_search,
chat_session_id=chat_session_id,
message_id=reserved_message_id,
use_persistence=True,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
tools=tools,
using_tool_calling_llm=using_tool_calling_llm,
files=latest_query_files,
structured_response_format=new_msg_req.structured_response_format,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
# TODO: add previous messages, answer style config, tools, etc.
@ -785,9 +810,9 @@ def stream_chat_message_objects(
fast_llm=fast_llm,
message_history=message_history,
tools=tools,
force_use_tool=_get_force_search_settings(new_msg_req, tools),
force_use_tool=force_use_tool,
single_message_history=single_message_history,
pro_search_config=pro_search_config,
agent_search_config=agent_search_config,
db_session=db_session,
)

View File

@ -84,6 +84,7 @@ class AnswerPromptBuilder:
raw_user_query: str,
raw_user_uploaded_files: list[InMemoryChatFile],
single_message_history: str | None = None,
system_message: SystemMessage | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@ -108,7 +109,14 @@ class AnswerPromptBuilder:
),
)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
(
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
)
if system_message
else None
)
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(

View File

@ -5,6 +5,7 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from onyx.chat.models import ResponsePart
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
@ -50,56 +51,12 @@ class ToolResponseHandler:
def get_tool_call_for_non_tool_calling_llm(
cls, llm_call: LLMCall, llm: LLM
) -> tuple[Tool, dict] | None:
if llm_call.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name)
tool_args = (
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.raw_user_query,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.raw_user_query,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
available_tools_and_args = [
(llm_call.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.raw_user_query,
llm=llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args
return get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool=llm_call.force_use_tool,
tools=llm_call.tools,
prompt_builder=llm_call.prompt_builder,
llm=llm,
)
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
@ -196,3 +153,61 @@ class ToolResponseHandler:
self.tool_final_result,
],
)
def get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool: ForceUseTool,
tools: list[Tool],
prompt_builder: AnswerPromptBuilder,
llm: LLM,
) -> tuple[Tool, dict] | None:
if force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = get_tool_by_name(tools, force_use_tool.tool_name)
tool_args = (
force_use_tool.args
if force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=prompt_builder.raw_user_query,
history=prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=tools,
query=prompt_builder.raw_user_query,
history=prompt_builder.raw_message_history,
llm=llm,
)
available_tools_and_args = [
(tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=prompt_builder.raw_message_history,
query=prompt_builder.raw_user_query,
llm=llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args

View File

@ -81,7 +81,4 @@ def check_if_need_search(
logger.debug(f"Run search prediction: {require_search_output}")
if (SKIP_SEARCH.split()[0]).lower() in require_search_output.lower():
return False
return True
return (SKIP_SEARCH.split()[0]).lower() not in require_search_output.lower()

View File

@ -138,6 +138,8 @@ class CreateChatMessageRequest(ChunkContext):
# TODO: decide how many of the above options we want to pass through to pro search
use_agentic_search: bool = False
skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:

View File

@ -20,10 +20,7 @@ OPEN_AI_TOOL_CALLING_MODELS = {
def explicit_tool_calling_supported(model_provider: str, model_name: str) -> bool:
if model_provider == "openai" and model_name in OPEN_AI_TOOL_CALLING_MODELS:
return True
return False
return model_provider == "openai" and model_name in OPEN_AI_TOOL_CALLING_MODELS
def compute_tool_tokens(tool: Tool, llm_tokenizer: BaseTokenizer) -> int:

View File

@ -3,15 +3,20 @@ from datetime import datetime
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search_like_tool_utils import (
@ -27,6 +32,31 @@ def answer_style_config() -> AnswerStyleConfig:
return AnswerStyleConfig(citation_config=CitationConfig())
@pytest.fixture
def agent_search_config(
mock_llm: LLM, mock_search_tool: SearchTool
) -> AgentSearchConfig:
return AgentSearchConfig(
search_request=SearchRequest(query=QUERY),
primary_llm=mock_llm,
fast_llm=mock_llm,
search_tool=mock_search_tool,
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=QUERY),
message_history=[],
llm_config=mock_llm.config,
raw_user_query=QUERY,
raw_user_uploaded_files=[],
),
chat_session_id=None,
message_id=1,
use_persistence=True,
db_session=None,
use_agentic_search=False,
)
@pytest.fixture
def prompt_config() -> PromptConfig:
return PromptConfig(

View File

@ -11,6 +11,7 @@ from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolCall
from langchain_core.messages import ToolCallChunk
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.answer import Answer
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
@ -30,7 +31,10 @@ from tests.unit.onyx.chat.conftest import QUERY
@pytest.fixture
def answer_instance(
mock_llm: LLM, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig
mock_llm: LLM,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
agent_search_config: AgentSearchConfig,
) -> Answer:
return Answer(
question=QUERY,
@ -38,6 +42,7 @@ def answer_instance(
llm=mock_llm,
prompt_config=prompt_config,
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
agent_search_config=agent_search_config,
)
@ -284,7 +289,8 @@ def test_answer_with_search_no_tool_calling(
def test_is_cancelled(answer_instance: Answer) -> None:
# Set up the LLM mock to return multiple chunks
mock_llm = Mock()
answer_instance.llm = mock_llm
answer_instance.agent_search_config.primary_llm = mock_llm
answer_instance.agent_search_config.fast_llm = mock_llm
mock_llm.stream.return_value = [
AIMessageChunk(content="This is the "),
AIMessageChunk(content="first part."),
@ -303,6 +309,7 @@ def test_is_cancelled(answer_instance: Answer) -> None:
if i == 1:
connection_status["connected"] = False
print(output)
assert len(output) == 3
assert output[0] == OnyxAnswerPiece(answer_piece="This is the ")
assert output[1] == OnyxAnswerPiece(answer_piece="first part.")

View File

@ -5,10 +5,12 @@ from unittest.mock import Mock
import pytest
from pytest_mock import MockerFixture
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.answer import Answer
from onyx.chat.answer import AnswerStream
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import PromptConfig
from onyx.context.search.models import SearchRequest
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
@ -41,6 +43,12 @@ def test_skip_gen_ai_answer_generation_flag(
mock_llm.config.model_name = "gpt-4o-mini"
mock_llm.stream = Mock()
mock_llm.stream.return_value = [Mock()]
session = Mock()
agent_search_config, _ = get_test_config(
session, mock_llm, mock_llm, SearchRequest(query=question)
)
answer = Answer(
question=question,
answer_style_config=answer_style_config,
@ -58,6 +66,7 @@ def test_skip_gen_ai_answer_generation_flag(
skip_explicit_tool_calling=True,
return_contexts=True,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
agent_search_config=agent_search_config,
)
count = 0
for _ in cast(AnswerStream, answer.processed_streamed_output):