diff --git a/backend/ee/onyx/server/query_and_chat/query_backend.py b/backend/ee/onyx/server/query_and_chat/query_backend.py index cb9003b7e..34fc9dbaf 100644 --- a/backend/ee/onyx/server/query_and_chat/query_backend.py +++ b/backend/ee/onyx/server/query_and_chat/query_backend.py @@ -197,6 +197,7 @@ def get_answer_stream( rerank_settings=query_request.rerank_settings, db_session=db_session, use_agentic_search=query_request.use_agentic_search, + skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation, ) packets = stream_chat_message_objects( diff --git a/backend/onyx/agents/agent_search/basic/graph_builder.py b/backend/onyx/agents/agent_search/basic/graph_builder.py index 4114ef036..411744ed7 100644 --- a/backend/onyx/agents/agent_search/basic/graph_builder.py +++ b/backend/onyx/agents/agent_search/basic/graph_builder.py @@ -1,20 +1,18 @@ -from typing import cast - -from langchain_core.callbacks.manager import dispatch_custom_event -from langchain_core.runnables.config import RunnableConfig from langgraph.graph import END from langgraph.graph import START from langgraph.graph import StateGraph +from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import ( + basic_use_tool_response, +) +from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice +from onyx.agents.agent_search.basic.nodes.tool_call import tool_call from onyx.agents.agent_search.basic.states import BasicInput from onyx.agents.agent_search.basic.states import BasicOutput from onyx.agents.agent_search.basic.states import BasicState -from onyx.agents.agent_search.basic.states import BasicStateUpdate -from onyx.agents.agent_search.models import AgentSearchConfig -from onyx.chat.stream_processing.utils import ( - map_document_id_order, -) -from onyx.tools.tool_implementations.search.search_tool import SearchTool +from onyx.utils.logger import setup_logger + +logger = setup_logger() def basic_graph_builder() -> StateGraph: @@ -27,17 +25,33 @@ def basic_graph_builder() -> StateGraph: ### Add nodes ### graph.add_node( - node="get_response", - action=get_response, + node="llm_tool_choice", + action=llm_tool_choice, + ) + + graph.add_node( + node="tool_call", + action=tool_call, + ) + + graph.add_node( + node="basic_use_tool_response", + action=basic_use_tool_response, ) ### Add edges ### - graph.add_edge(start_key=START, end_key="get_response") + graph.add_edge(start_key=START, end_key="llm_tool_choice") + + graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END]) - graph.add_conditional_edges("get_response", should_continue, ["get_response", END]) graph.add_edge( - start_key="get_response", + start_key="tool_call", + end_key="basic_use_tool_response", + ) + + graph.add_edge( + start_key="basic_use_tool_response", end_key=END, ) @@ -46,56 +60,10 @@ def basic_graph_builder() -> StateGraph: def should_continue(state: BasicState) -> str: return ( - END if state["last_llm_call"] is None or state["calls"] > 1 else "get_response" - ) - - -def get_response(state: BasicState, config: RunnableConfig) -> BasicStateUpdate: - agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) - llm = agent_a_config.primary_llm - current_llm_call = state["last_llm_call"] - if current_llm_call is None: - raise ValueError("last_llm_call is None") - structured_response_format = agent_a_config.structured_response_format - response_handler_manager = state["response_handler_manager"] - # DEBUG: good breakpoint - stream = llm.stream( - # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM - # may choose to not call any tools and just generate the answer, in which case the task prompt is needed. - prompt=current_llm_call.prompt_builder.build(), - tools=[tool.tool_definition() for tool in current_llm_call.tools] or None, - tool_choice=( - "required" - if current_llm_call.tools and current_llm_call.force_use_tool.force_use - else None - ), - structured_response_format=structured_response_format, - ) - - for response in response_handler_manager.handle_llm_response(stream): - dispatch_custom_event( - "basic_response", - response, - ) - - next_call = response_handler_manager.next_llm_call(current_llm_call) - if next_call is not None: - final_search_results, displayed_search_results = SearchTool.get_search_result( - next_call - ) or ([], []) - else: - final_search_results, displayed_search_results = [], [] - - response_handler_manager.answer_handler.update( - ( - final_search_results, - map_document_id_order(final_search_results), - map_document_id_order(displayed_search_results), - ) - ) - return BasicStateUpdate( - last_llm_call=next_call, - calls=state["calls"] + 1, + # If there are no tool calls, basic graph already streamed the answer + END + if state["tool_choice"] is None + else "tool_call" ) diff --git a/backend/onyx/agents/agent_search/basic/states.py b/backend/onyx/agents/agent_search/basic/states.py index 52a694e6c..b26564517 100644 --- a/backend/onyx/agents/agent_search/basic/states.py +++ b/backend/onyx/agents/agent_search/basic/states.py @@ -1,20 +1,21 @@ from typing import TypedDict -from onyx.chat.llm_response_handler import LLMResponseHandlerManager -from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall - - -## Update States +from onyx.tools.message import ToolCallSummary +from onyx.tools.models import ToolCallFinalResult +from onyx.tools.models import ToolCallKickoff +from onyx.tools.models import ToolResponse +from onyx.tools.tool import Tool +# States contain values that change over the course of graph execution, +# Config is for values that are set at the start and never change. +# If you are using a value from the config and realize it needs to change, +# you should add it to the state and use/update the version in the state. ## Graph Input State class BasicInput(TypedDict): - base_question: str - last_llm_call: LLMCall | None - response_handler_manager: LLMResponseHandlerManager - calls: int + should_stream_answer: bool ## Graph Output State @@ -24,9 +25,22 @@ class BasicOutput(TypedDict): pass -class BasicStateUpdate(TypedDict): - last_llm_call: LLMCall | None - calls: int +## Update States +class ToolCallUpdate(TypedDict): + tool_call_summary: ToolCallSummary + tool_call_kickoff: ToolCallKickoff + tool_call_responses: list[ToolResponse] + tool_call_final_result: ToolCallFinalResult + + +class ToolChoice(TypedDict): + tool: Tool + tool_args: dict + id: str | None + + +class ToolChoiceUpdate(TypedDict): + tool_choice: ToolChoice | None ## Graph State @@ -34,6 +48,8 @@ class BasicStateUpdate(TypedDict): class BasicState( BasicInput, + ToolCallUpdate, + ToolChoiceUpdate, BasicOutput, ): pass diff --git a/backend/onyx/agents/agent_search/models.py b/backend/onyx/agents/agent_search/models.py index 97e68fb26..41913ae4c 100644 --- a/backend/onyx/agents/agent_search/models.py +++ b/backend/onyx/agents/agent_search/models.py @@ -2,11 +2,15 @@ from dataclasses import dataclass from uuid import UUID from pydantic import BaseModel +from pydantic import model_validator from sqlalchemy.orm import Session +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.context.search.models import SearchRequest +from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM -from onyx.llm.models import PreviousMessage +from onyx.tools.force import ForceUseTool +from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_tool import SearchTool @@ -22,7 +26,18 @@ class AgentSearchConfig: primary_llm: LLM fast_llm: LLM search_tool: SearchTool - use_agentic_search: bool = True + + # Whether to force use of a tool, or to + # force tool args IF the tool is used + force_use_tool: ForceUseTool + + # contains message history for the current chat session + # has the following (at most one is non-None) + # message_history: list[PreviousMessage] | None = None + # single_message_history: str | None = None + prompt_builder: AnswerPromptBuilder + + use_agentic_search: bool = False # For persisting agent search data chat_session_id: UUID | None = None @@ -45,11 +60,25 @@ class AgentSearchConfig: # Whether to allow creation of refinement questions (and entity extraction, etc.) allow_refinement: bool = True - # Message history for the current chat session - message_history: list[PreviousMessage] | None = None + # Tools available for use + tools: list[Tool] | None = None + + using_tool_calling_llm: bool = False + + files: list[InMemoryChatFile] | None = None structured_response_format: dict | None = None + skip_gen_ai_answer_generation: bool = False + + @model_validator(mode="after") + def validate_db_session(self) -> "AgentSearchConfig": + if self.use_persistence and self.db_session is None: + raise ValueError( + "db_session must be provided for pro search when using persistence" + ) + return self + class AgentDocumentCitations(BaseModel): document_id: str diff --git a/backend/onyx/agents/agent_search/run_graph.py b/backend/onyx/agents/agent_search/run_graph.py index 5761f112c..bc8c1cd1f 100644 --- a/backend/onyx/agents/agent_search/run_graph.py +++ b/backend/onyx/agents/agent_search/run_graph.py @@ -16,7 +16,6 @@ from onyx.agents.agent_search.deep_search_a.main.graph_builder import ( from onyx.agents.agent_search.deep_search_a.main.states import MainInput as MainInput_a from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config -from onyx.chat.llm_response_handler import LLMResponseHandlerManager from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import AnswerPacket from onyx.chat.models import AnswerStream @@ -25,7 +24,6 @@ from onyx.chat.models import StreamStopInfo from onyx.chat.models import SubQueryPiece from onyx.chat.models import SubQuestionPiece from onyx.chat.models import ToolResponse -from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall from onyx.configs.dev_configs import GRAPH_NAME from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager @@ -143,7 +141,6 @@ def run_graph( config: AgentSearchConfig, input: BasicInput | MainInput_a, ) -> AnswerStream: - input["base_question"] = config.search_request.query if config else "" # TODO: add these to the environment config.perform_initial_search_path_decision = True config.perform_initial_search_decomposition = True @@ -192,17 +189,12 @@ def run_main_graph( # TODO: unify input types, especially prosearchconfig def run_basic_graph( config: AgentSearchConfig, - last_llm_call: LLMCall | None, - response_handler_manager: LLMResponseHandlerManager, ) -> AnswerStream: graph = basic_graph_builder() compiled_graph = graph.compile() # TODO: unify basic input input = BasicInput( - base_question="", - last_llm_call=last_llm_call, - response_handler_manager=response_handler_manager, - calls=0, + should_stream_answer=True, ) return run_graph(compiled_graph, config, input) diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index 3fe2acc48..8f01f8cb1 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -11,6 +11,7 @@ from typing import cast from uuid import UUID from langchain_core.messages import BaseMessage +from langchain_core.messages import HumanMessage from sqlalchemy.orm import Session from onyx.agents.agent_search.models import AgentSearchConfig @@ -21,6 +22,7 @@ from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationConfig from onyx.chat.models import DocumentPruningConfig from onyx.chat.models import PromptConfig +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from onyx.configs.constants import DEFAULT_PERSONA_ID @@ -31,6 +33,7 @@ from onyx.context.search.models import SearchRequest from onyx.db.persona import get_persona_by_id from onyx.db.persona import Persona from onyx.llm.interfaces import LLM +from onyx.tools.force import ForceUseTool from onyx.tools.tool_constructor import SearchToolConfig from onyx.tools.tool_implementations.search.search_tool import SearchTool @@ -207,14 +210,23 @@ def get_test_config( config = AgentSearchConfig( search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + search_tool=search_tool, + force_use_tool=ForceUseTool(force_use=False, tool_name=""), + prompt_builder=AnswerPromptBuilder( + user_message=HumanMessage(content=search_request.query), + message_history=[], + llm_config=primary_llm.config, + raw_user_query=search_request.query, + raw_user_uploaded_files=[], + ), # chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim # chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan message_id=1, use_persistence=True, - primary_llm=primary_llm, - fast_llm=fast_llm, - search_tool=search_tool, + db_session=db_session, ) return config, search_tool diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 3d6e2f64d..6b755ec8a 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -1,16 +1,12 @@ from collections import defaultdict from collections.abc import Callable -from uuid import uuid4 from langchain.schema.messages import BaseMessage -from langchain_core.messages import AIMessageChunk -from langchain_core.messages import ToolCall from sqlalchemy.orm import Session from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.run_graph import run_basic_graph from onyx.agents.agent_search.run_graph import run_main_graph -from onyx.chat.llm_response_handler import LLMResponseHandlerManager from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import AnswerPacket from onyx.chat.models import AnswerStream @@ -18,18 +14,9 @@ from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationInfo from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import PromptConfig -from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder -from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message -from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message -from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall -from onyx.chat.stream_processing.answer_response_handler import ( - CitationResponseHandler, -) -from onyx.chat.stream_processing.utils import ( - map_document_id_order, -) +from onyx.chat.models import StreamStopInfo +from onyx.chat.models import StreamStopReason from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name -from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler from onyx.configs.constants import BASIC_KEY from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM @@ -37,7 +24,6 @@ from onyx.llm.models import PreviousMessage from onyx.natural_language_processing.utils import get_tokenizer from onyx.tools.force import ForceUseTool from onyx.tools.tool import Tool -from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.utils import explicit_tool_calling_supported from onyx.utils.logger import setup_logger @@ -52,7 +38,7 @@ class Answer: llm: LLM, prompt_config: PromptConfig, force_use_tool: ForceUseTool, - pro_search_config: AgentSearchConfig, + agent_search_config: AgentSearchConfig, # must be the same length as `docs`. If None, all docs are considered "relevant" message_history: list[PreviousMessage] | None = None, single_message_history: str | None = None, @@ -114,7 +100,7 @@ class Answer: and not skip_explicit_tool_calling ) - self.pro_search_config = pro_search_config + self.agent_search_config = agent_search_config self.db_session = db_session def _get_tools_list(self) -> list[Tool]: @@ -132,130 +118,132 @@ class Answer: return [tool] # TODO: delete the function and move the full body to processed_streamed_output - def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: - current_llm_call = llm_calls[-1] + def _get_response(self) -> AnswerStream: + # current_llm_call = llm_calls[-1] - tool, tool_args = None, None - # handle the case where no decision has to be made; we simply run the tool - if ( - current_llm_call.force_use_tool.force_use - and current_llm_call.force_use_tool.args is not None - ): - tool_name, tool_args = ( - current_llm_call.force_use_tool.tool_name, - current_llm_call.force_use_tool.args, - ) - tool = get_tool_by_name(current_llm_call.tools, tool_name) + # tool, tool_args = None, None + # # handle the case where no decision has to be made; we simply run the tool + # if ( + # current_llm_call.force_use_tool.force_use + # and current_llm_call.force_use_tool.args is not None + # ): + # tool_name, tool_args = ( + # current_llm_call.force_use_tool.tool_name, + # current_llm_call.force_use_tool.args, + # ) + # tool = get_tool_by_name(current_llm_call.tools, tool_name) - # special pre-logic for non-tool calling LLM case - elif not self.using_tool_calling_llm and current_llm_call.tools: - chosen_tool_and_args = ( - ToolResponseHandler.get_tool_call_for_non_tool_calling_llm( - current_llm_call, self.llm - ) - ) - if chosen_tool_and_args: - tool, tool_args = chosen_tool_and_args + # # special pre-logic for non-tool calling LLM case + # elif not self.using_tool_calling_llm and current_llm_call.tools: + # chosen_tool_and_args = ( + # ToolResponseHandler.get_tool_call_for_non_tool_calling_llm( + # current_llm_call, self.llm + # ) + # ) + # if chosen_tool_and_args: + # tool, tool_args = chosen_tool_and_args - if tool and tool_args: - dummy_tool_call_chunk = AIMessageChunk(content="") - dummy_tool_call_chunk.tool_calls = [ - ToolCall(name=tool.name, args=tool_args, id=str(uuid4())) - ] + # if tool and tool_args: + # dummy_tool_call_chunk = AIMessageChunk(content="") + # dummy_tool_call_chunk.tool_calls = [ + # ToolCall(name=tool.name, args=tool_args, id=str(uuid4())) + # ] - response_handler_manager = LLMResponseHandlerManager( - ToolResponseHandler([tool]), None, self.is_cancelled - ) - yield from response_handler_manager.handle_llm_response( - iter([dummy_tool_call_chunk]) - ) + # response_handler_manager = LLMResponseHandlerManager( + # ToolResponseHandler([tool]), None, self.is_cancelled + # ) + # yield from response_handler_manager.handle_llm_response( + # iter([dummy_tool_call_chunk]) + # ) - tmp_call = response_handler_manager.next_llm_call(current_llm_call) - if tmp_call is None: - return # no more LLM calls to process - current_llm_call = tmp_call + # tmp_call = response_handler_manager.next_llm_call(current_llm_call) + # if tmp_call is None: + # return # no more LLM calls to process + # current_llm_call = tmp_call - # if we're skipping gen ai answer generation, we should break - # out unless we're forcing a tool call. If we don't, we might generate an - # answer, which is a no-no! - if ( - self.skip_gen_ai_answer_generation - and not current_llm_call.force_use_tool.force_use - ): - return + # # if we're skipping gen ai answer generation, we should break + # # out unless we're forcing a tool call. If we don't, we might generate an + # # answer, which is a no-no! + # if ( + # self.skip_gen_ai_answer_generation + # and not current_llm_call.force_use_tool.force_use + # ): + # return - # set up "handlers" to listen to the LLM response stream and - # feed back the processed results + handle tool call requests - # + figure out what the next LLM call should be - tool_call_handler = ToolResponseHandler(current_llm_call.tools) + # # set up "handlers" to listen to the LLM response stream and + # # feed back the processed results + handle tool call requests + # # + figure out what the next LLM call should be + # tool_call_handler = ToolResponseHandler(current_llm_call.tools) - final_search_results, displayed_search_results = SearchTool.get_search_result( - current_llm_call - ) or ([], []) + # final_search_results, displayed_search_results = SearchTool.get_search_result( + # current_llm_call + # ) or ([], []) - # NEXT: we still want to handle the LLM response stream, but it is now: - # 1. handle the tool call requests - # 2. feed back the processed results - # 3. handle the citations + # # NEXT: we still want to handle the LLM response stream, but it is now: + # # 1. handle the tool call requests + # # 2. feed back the processed results + # # 3. handle the citations - answer_handler = CitationResponseHandler( - context_docs=final_search_results, - final_doc_id_to_rank_map=map_document_id_order(final_search_results), - display_doc_id_to_rank_map=map_document_id_order(displayed_search_results), - ) + # answer_handler = CitationResponseHandler( + # context_docs=final_search_results, + # final_doc_id_to_rank_map=map_document_id_order(final_search_results), + # display_doc_id_to_rank_map=map_document_id_order(displayed_search_results), + # ) - # At the moment, this wrapper class passes streamed stuff through citation and tool handlers. - # In the future, we'll want to handle citations and tool calls in the langgraph graph. - response_handler_manager = LLMResponseHandlerManager( - tool_call_handler, answer_handler, self.is_cancelled - ) + # # At the moment, this wrapper class passes streamed stuff through citation and tool handlers. + # # In the future, we'll want to handle citations and tool calls in the langgraph graph. + # response_handler_manager = LLMResponseHandlerManager( + # tool_call_handler, answer_handler, self.is_cancelled + # ) # In langgraph, whether we do the basic thing (call llm stream) or pro search # is based on a flag in the pro search config - if self.pro_search_config.use_agentic_search: - if self.pro_search_config.search_request is None: - raise ValueError("Search request must be provided for pro search") - - if self.db_session is None: - raise ValueError("db_session must be provided for pro search") - if self.fast_llm is None: - raise ValueError("fast_llm must be provided for pro search") + if self.agent_search_config.use_agentic_search: + if ( + self.agent_search_config.db_session is None + and self.agent_search_config.use_persistence + ): + raise ValueError( + "db_session must be provided for pro search when using persistence" + ) stream = run_main_graph( - config=self.pro_search_config, + config=self.agent_search_config, ) else: stream = run_basic_graph( - config=self.pro_search_config, - last_llm_call=current_llm_call, - response_handler_manager=response_handler_manager, + config=self.agent_search_config, ) processed_stream = [] for packet in stream: + if self.is_cancelled(): + packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + yield packet + break processed_stream.append(packet) yield packet self._processed_stream = processed_stream return # DEBUG: good breakpoint - stream = self.llm.stream( - # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM - # may choose to not call any tools and just generate the answer, in which case the task prompt is needed. - prompt=current_llm_call.prompt_builder.build(), - tools=[tool.tool_definition() for tool in current_llm_call.tools] or None, - tool_choice=( - "required" - if current_llm_call.tools and current_llm_call.force_use_tool.force_use - else None - ), - structured_response_format=self.answer_style_config.structured_response_format, - ) - yield from response_handler_manager.handle_llm_response(stream) + # stream = self.llm.stream( + # # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM + # # may choose to not call any tools and just generate the answer, in which case the task prompt is needed. + # prompt=current_llm_call.prompt_builder.build(), + # tools=[tool.tool_definition() for tool in current_llm_call.tools] or None, + # tool_choice=( + # "required" + # if current_llm_call.tools and current_llm_call.force_use_tool.force_use + # else None + # ), + # structured_response_format=self.answer_style_config.structured_response_format, + # ) + # yield from response_handler_manager.handle_llm_response(stream) - new_llm_call = response_handler_manager.next_llm_call(current_llm_call) - if new_llm_call: - yield from self._get_response(llm_calls + [new_llm_call]) + # new_llm_call = response_handler_manager.next_llm_call(current_llm_call) + # if new_llm_call: + # yield from self._get_response(llm_calls + [new_llm_call]) @property def processed_streamed_output(self) -> AnswerStream: @@ -263,33 +251,33 @@ class Answer: yield from self._processed_stream return - prompt_builder = AnswerPromptBuilder( - user_message=default_build_user_message( - user_query=self.question, - prompt_config=self.prompt_config, - files=self.latest_query_files, - single_message_history=self.single_message_history, - ), - message_history=self.message_history, - llm_config=self.llm.config, - raw_user_query=self.question, - raw_user_uploaded_files=self.latest_query_files or [], - single_message_history=self.single_message_history, - ) - prompt_builder.update_system_prompt( - default_build_system_message(self.prompt_config) - ) - llm_call = LLMCall( - prompt_builder=prompt_builder, - tools=self._get_tools_list(), - force_use_tool=self.force_use_tool, - files=self.latest_query_files, - tool_call_info=[], - using_tool_calling_llm=self.using_tool_calling_llm, - ) + # prompt_builder = AnswerPromptBuilder( + # user_message=default_build_user_message( + # user_query=self.question, + # prompt_config=self.prompt_config, + # files=self.latest_query_files, + # single_message_history=self.single_message_history, + # ), + # message_history=self.message_history, + # llm_config=self.llm.config, + # raw_user_query=self.question, + # raw_user_uploaded_files=self.latest_query_files or [], + # single_message_history=self.single_message_history, + # ) + # prompt_builder.update_system_prompt( + # default_build_system_message(self.prompt_config) + # ) + # llm_call = LLMCall( + # prompt_builder=prompt_builder, + # tools=self._get_tools_list(), + # force_use_tool=self.force_use_tool, + # files=self.latest_query_files, + # tool_call_info=[], + # using_tool_calling_llm=self.using_tool_calling_llm, + # ) processed_stream = [] - for processed_packet in self._get_response([llm_call]): + for processed_packet in self._get_response(): processed_stream.append(processed_packet) yield processed_packet diff --git a/backend/onyx/chat/chat_utils.py b/backend/onyx/chat/chat_utils.py index 526241187..96e53a54a 100644 --- a/backend/onyx/chat/chat_utils.py +++ b/backend/onyx/chat/chat_utils.py @@ -49,6 +49,7 @@ def prepare_chat_message_request( rerank_settings: RerankingDetails | None, db_session: Session, use_agentic_search: bool = False, + skip_gen_ai_answer_generation: bool = False, ) -> CreateChatMessageRequest: # Typically used for one shot flows like SlackBot or non-chat API endpoint use cases new_chat_session = create_chat_session( @@ -74,6 +75,7 @@ def prepare_chat_message_request( retrieval_options=retrieval_details, rerank_settings=rerank_settings, use_agentic_search=use_agentic_search, + skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, ) diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index da19e19b2..32eb395ae 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -33,6 +33,9 @@ from onyx.chat.models import QADocsResponse from onyx.chat.models import StreamingError from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT @@ -130,6 +133,7 @@ from onyx.tools.tool_implementations.search.search_tool import ( SECTION_RELEVANCE_LIST_ID, ) from onyx.tools.tool_runner import ToolCallFinalResult +from onyx.tools.utils import explicit_tool_calling_supported from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger from onyx.utils.telemetry import mt_cloud_telemetry @@ -137,7 +141,6 @@ from onyx.utils.timing import log_function_time from onyx.utils.timing import log_generator_function_time from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR - logger = setup_logger() @@ -534,11 +537,8 @@ def stream_chat_message_objects( files = load_all_chat_files( history_msgs, new_msg_req.file_descriptors, db_session ) - latest_query_files = [ - file - for file in files - if file.file_id in [f["id"] for f in new_msg_req.file_descriptors] - ] + req_file_ids = [f["id"] for f in new_msg_req.file_descriptors] + latest_query_files = [file for file in files if file.file_id in req_file_ids] if user_message: attach_files_to_chat_message( @@ -748,17 +748,42 @@ def stream_chat_message_objects( # TODO: handle multiple search tools raise ValueError("Multiple search tools found") search_tool = search_tools[0] - pro_search_config = AgentSearchConfig( - use_agentic_search=new_msg_req.use_agentic_search, - search_request=search_request, - chat_session_id=chat_session_id, - message_id=reserved_message_id, + using_tool_calling_llm = explicit_tool_calling_supported( + llm.config.model_provider, llm.config.model_name + ) + force_use_tool = _get_force_search_settings(new_msg_req, tools) + prompt_builder = AnswerPromptBuilder( + user_message=default_build_user_message( + user_query=final_msg.message, + prompt_config=prompt_config, + files=latest_query_files, + single_message_history=single_message_history, + ), + system_message=default_build_system_message(prompt_config), message_history=message_history, + llm_config=llm.config, + raw_user_query=final_msg.message, + raw_user_uploaded_files=latest_query_files or [], + single_message_history=single_message_history, + ) + agent_search_config = AgentSearchConfig( + search_request=search_request, primary_llm=llm, fast_llm=fast_llm, search_tool=search_tool, - structured_response_format=new_msg_req.structured_response_format, + force_use_tool=force_use_tool, + use_agentic_search=new_msg_req.use_agentic_search, + chat_session_id=chat_session_id, + message_id=reserved_message_id, + use_persistence=True, + allow_refinement=True, db_session=db_session, + prompt_builder=prompt_builder, + tools=tools, + using_tool_calling_llm=using_tool_calling_llm, + files=latest_query_files, + structured_response_format=new_msg_req.structured_response_format, + skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation, ) # TODO: add previous messages, answer style config, tools, etc. @@ -785,9 +810,9 @@ def stream_chat_message_objects( fast_llm=fast_llm, message_history=message_history, tools=tools, - force_use_tool=_get_force_search_settings(new_msg_req, tools), + force_use_tool=force_use_tool, single_message_history=single_message_history, - pro_search_config=pro_search_config, + agent_search_config=agent_search_config, db_session=db_session, ) diff --git a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py index 9395582ee..0ce876e20 100644 --- a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py +++ b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py @@ -84,6 +84,7 @@ class AnswerPromptBuilder: raw_user_query: str, raw_user_uploaded_files: list[InMemoryChatFile], single_message_history: str | None = None, + system_message: SystemMessage | None = None, ) -> None: self.max_tokens = compute_max_llm_input_tokens(llm_config) @@ -108,7 +109,14 @@ class AnswerPromptBuilder: ), ) - self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None + self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = ( + ( + system_message, + check_message_tokens(system_message, self.llm_tokenizer_encode_func), + ) + if system_message + else None + ) self.user_message_and_token_cnt = ( user_message, check_message_tokens( diff --git a/backend/onyx/chat/tool_handling/tool_response_handler.py b/backend/onyx/chat/tool_handling/tool_response_handler.py index 21359f827..8a6cce953 100644 --- a/backend/onyx/chat/tool_handling/tool_response_handler.py +++ b/backend/onyx/chat/tool_handling/tool_response_handler.py @@ -5,6 +5,7 @@ from langchain_core.messages import BaseMessage from langchain_core.messages import ToolCall from onyx.chat.models import ResponsePart +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall from onyx.llm.interfaces import LLM from onyx.tools.force import ForceUseTool @@ -50,56 +51,12 @@ class ToolResponseHandler: def get_tool_call_for_non_tool_calling_llm( cls, llm_call: LLMCall, llm: LLM ) -> tuple[Tool, dict] | None: - if llm_call.force_use_tool.force_use: - # if we are forcing a tool, we don't need to check which tools to run - tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name) - - tool_args = ( - llm_call.force_use_tool.args - if llm_call.force_use_tool.args is not None - else tool.get_args_for_non_tool_calling_llm( - query=llm_call.prompt_builder.raw_user_query, - history=llm_call.prompt_builder.raw_message_history, - llm=llm, - force_run=True, - ) - ) - - if tool_args is None: - raise RuntimeError(f"Tool '{tool.name}' did not return args") - - return (tool, tool_args) - else: - tool_options = check_which_tools_should_run_for_non_tool_calling_llm( - tools=llm_call.tools, - query=llm_call.prompt_builder.raw_user_query, - history=llm_call.prompt_builder.raw_message_history, - llm=llm, - ) - - available_tools_and_args = [ - (llm_call.tools[ind], args) - for ind, args in enumerate(tool_options) - if args is not None - ] - - logger.info( - f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}" - ) - - chosen_tool_and_args = ( - select_single_tool_for_non_tool_calling_llm( - tools_and_args=available_tools_and_args, - history=llm_call.prompt_builder.raw_message_history, - query=llm_call.prompt_builder.raw_user_query, - llm=llm, - ) - if available_tools_and_args - else None - ) - - logger.notice(f"Chosen tool: {chosen_tool_and_args}") - return chosen_tool_and_args + return get_tool_call_for_non_tool_calling_llm_impl( + force_use_tool=llm_call.force_use_tool, + tools=llm_call.tools, + prompt_builder=llm_call.prompt_builder, + llm=llm, + ) def _handle_tool_call(self) -> Generator[ResponsePart, None, None]: if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls: @@ -196,3 +153,61 @@ class ToolResponseHandler: self.tool_final_result, ], ) + + +def get_tool_call_for_non_tool_calling_llm_impl( + force_use_tool: ForceUseTool, + tools: list[Tool], + prompt_builder: AnswerPromptBuilder, + llm: LLM, +) -> tuple[Tool, dict] | None: + if force_use_tool.force_use: + # if we are forcing a tool, we don't need to check which tools to run + tool = get_tool_by_name(tools, force_use_tool.tool_name) + + tool_args = ( + force_use_tool.args + if force_use_tool.args is not None + else tool.get_args_for_non_tool_calling_llm( + query=prompt_builder.raw_user_query, + history=prompt_builder.raw_message_history, + llm=llm, + force_run=True, + ) + ) + + if tool_args is None: + raise RuntimeError(f"Tool '{tool.name}' did not return args") + + return (tool, tool_args) + else: + tool_options = check_which_tools_should_run_for_non_tool_calling_llm( + tools=tools, + query=prompt_builder.raw_user_query, + history=prompt_builder.raw_message_history, + llm=llm, + ) + + available_tools_and_args = [ + (tools[ind], args) + for ind, args in enumerate(tool_options) + if args is not None + ] + + logger.info( + f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}" + ) + + chosen_tool_and_args = ( + select_single_tool_for_non_tool_calling_llm( + tools_and_args=available_tools_and_args, + history=prompt_builder.raw_message_history, + query=prompt_builder.raw_user_query, + llm=llm, + ) + if available_tools_and_args + else None + ) + + logger.notice(f"Chosen tool: {chosen_tool_and_args}") + return chosen_tool_and_args diff --git a/backend/onyx/secondary_llm_flows/choose_search.py b/backend/onyx/secondary_llm_flows/choose_search.py index 7f48501af..328cc1b3c 100644 --- a/backend/onyx/secondary_llm_flows/choose_search.py +++ b/backend/onyx/secondary_llm_flows/choose_search.py @@ -81,7 +81,4 @@ def check_if_need_search( logger.debug(f"Run search prediction: {require_search_output}") - if (SKIP_SEARCH.split()[0]).lower() in require_search_output.lower(): - return False - - return True + return (SKIP_SEARCH.split()[0]).lower() not in require_search_output.lower() diff --git a/backend/onyx/server/query_and_chat/models.py b/backend/onyx/server/query_and_chat/models.py index ef3449845..bcb117747 100644 --- a/backend/onyx/server/query_and_chat/models.py +++ b/backend/onyx/server/query_and_chat/models.py @@ -138,6 +138,8 @@ class CreateChatMessageRequest(ChunkContext): # TODO: decide how many of the above options we want to pass through to pro search use_agentic_search: bool = False + skip_gen_ai_answer_generation: bool = False + @model_validator(mode="after") def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest": if self.search_doc_ids is None and self.retrieval_options is None: diff --git a/backend/onyx/tools/utils.py b/backend/onyx/tools/utils.py index f70e6591b..4c22ecda8 100644 --- a/backend/onyx/tools/utils.py +++ b/backend/onyx/tools/utils.py @@ -20,10 +20,7 @@ OPEN_AI_TOOL_CALLING_MODELS = { def explicit_tool_calling_supported(model_provider: str, model_name: str) -> bool: - if model_provider == "openai" and model_name in OPEN_AI_TOOL_CALLING_MODELS: - return True - - return False + return model_provider == "openai" and model_name in OPEN_AI_TOOL_CALLING_MODELS def compute_tool_tokens(tool: Tool, llm_tokenizer: BaseTokenizer) -> int: diff --git a/backend/tests/unit/onyx/chat/conftest.py b/backend/tests/unit/onyx/chat/conftest.py index 138668b19..684cfb620 100644 --- a/backend/tests/unit/onyx/chat/conftest.py +++ b/backend/tests/unit/onyx/chat/conftest.py @@ -3,15 +3,20 @@ from datetime import datetime from unittest.mock import MagicMock import pytest +from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage +from onyx.agents.agent_search.models import AgentSearchConfig from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationConfig from onyx.chat.models import LlmDoc from onyx.chat.models import PromptConfig from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.constants import DocumentSource +from onyx.context.search.models import SearchRequest +from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMConfig +from onyx.tools.force import ForceUseTool from onyx.tools.models import ToolResponse from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search_like_tool_utils import ( @@ -27,6 +32,31 @@ def answer_style_config() -> AnswerStyleConfig: return AnswerStyleConfig(citation_config=CitationConfig()) +@pytest.fixture +def agent_search_config( + mock_llm: LLM, mock_search_tool: SearchTool +) -> AgentSearchConfig: + return AgentSearchConfig( + search_request=SearchRequest(query=QUERY), + primary_llm=mock_llm, + fast_llm=mock_llm, + search_tool=mock_search_tool, + force_use_tool=ForceUseTool(force_use=False, tool_name=""), + prompt_builder=AnswerPromptBuilder( + user_message=HumanMessage(content=QUERY), + message_history=[], + llm_config=mock_llm.config, + raw_user_query=QUERY, + raw_user_uploaded_files=[], + ), + chat_session_id=None, + message_id=1, + use_persistence=True, + db_session=None, + use_agentic_search=False, + ) + + @pytest.fixture def prompt_config() -> PromptConfig: return PromptConfig( diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 2d483951e..50b27c17d 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -11,6 +11,7 @@ from langchain_core.messages import SystemMessage from langchain_core.messages import ToolCall from langchain_core.messages import ToolCallChunk +from onyx.agents.agent_search.models import AgentSearchConfig from onyx.chat.answer import Answer from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationInfo @@ -30,7 +31,10 @@ from tests.unit.onyx.chat.conftest import QUERY @pytest.fixture def answer_instance( - mock_llm: LLM, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig + mock_llm: LLM, + answer_style_config: AnswerStyleConfig, + prompt_config: PromptConfig, + agent_search_config: AgentSearchConfig, ) -> Answer: return Answer( question=QUERY, @@ -38,6 +42,7 @@ def answer_instance( llm=mock_llm, prompt_config=prompt_config, force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None), + agent_search_config=agent_search_config, ) @@ -284,7 +289,8 @@ def test_answer_with_search_no_tool_calling( def test_is_cancelled(answer_instance: Answer) -> None: # Set up the LLM mock to return multiple chunks mock_llm = Mock() - answer_instance.llm = mock_llm + answer_instance.agent_search_config.primary_llm = mock_llm + answer_instance.agent_search_config.fast_llm = mock_llm mock_llm.stream.return_value = [ AIMessageChunk(content="This is the "), AIMessageChunk(content="first part."), @@ -303,6 +309,7 @@ def test_is_cancelled(answer_instance: Answer) -> None: if i == 1: connection_status["connected"] = False + print(output) assert len(output) == 3 assert output[0] == OnyxAnswerPiece(answer_piece="This is the ") assert output[1] == OnyxAnswerPiece(answer_piece="first part.") diff --git a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py index e380b4aaa..5a061d97c 100644 --- a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py +++ b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py @@ -5,10 +5,12 @@ from unittest.mock import Mock import pytest from pytest_mock import MockerFixture +from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config from onyx.chat.answer import Answer from onyx.chat.answer import AnswerStream from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import PromptConfig +from onyx.context.search.models import SearchRequest from onyx.tools.force import ForceUseTool from onyx.tools.tool_implementations.search.search_tool import SearchTool from tests.regression.answer_quality.run_qa import _process_and_write_query_results @@ -41,6 +43,12 @@ def test_skip_gen_ai_answer_generation_flag( mock_llm.config.model_name = "gpt-4o-mini" mock_llm.stream = Mock() mock_llm.stream.return_value = [Mock()] + + session = Mock() + agent_search_config, _ = get_test_config( + session, mock_llm, mock_llm, SearchRequest(query=question) + ) + answer = Answer( question=question, answer_style_config=answer_style_config, @@ -58,6 +66,7 @@ def test_skip_gen_ai_answer_generation_flag( skip_explicit_tool_calling=True, return_contexts=True, skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, + agent_search_config=agent_search_config, ) count = 0 for _ in cast(AnswerStream, answer.processed_streamed_output):