diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 416486dad1dc..1c5ec18c5de4 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -1,7 +1,7 @@ from collections import defaultdict from collections.abc import Callable +from uuid import UUID -from langchain.schema.messages import BaseMessage from sqlalchemy.orm import Session from onyx.agents.agent_search.models import AgentSearchConfig @@ -13,17 +13,18 @@ from onyx.chat.models import AnswerStream from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationInfo from onyx.chat.models import OnyxAnswerPiece -from onyx.chat.models import PromptConfig from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name from onyx.configs.constants import BASIC_KEY +from onyx.context.search.models import SearchRequest from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM -from onyx.llm.models import PreviousMessage from onyx.natural_language_processing.utils import get_tokenizer from onyx.tools.force import ForceUseTool from onyx.tools.tool import Tool +from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.utils import explicit_tool_calling_supported from onyx.utils.logger import setup_logger @@ -33,15 +34,14 @@ logger = setup_logger() class Answer: def __init__( self, - question: str, + prompt_builder: AnswerPromptBuilder, answer_style_config: AnswerStyleConfig, llm: LLM, - prompt_config: PromptConfig, + fast_llm: LLM, force_use_tool: ForceUseTool, - agent_search_config: AgentSearchConfig, - # must be the same length as `docs`. If None, all docs are considered "relevant" - message_history: list[PreviousMessage] | None = None, - single_message_history: str | None = None, + search_request: SearchRequest, + chat_session_id: UUID, + current_agent_message_id: int, # newly passed in files to include as part of this question # TODO THIS NEEDS TO BE HANDLED latest_query_files: list[InMemoryChatFile] | None = None, @@ -52,28 +52,18 @@ class Answer: skip_explicit_tool_calling: bool = False, skip_gen_ai_answer_generation: bool = False, is_connected: Callable[[], bool] | None = None, - fast_llm: LLM | None = None, db_session: Session | None = None, + use_agentic_search: bool = False, ) -> None: - if single_message_history and message_history: - raise ValueError( - "Cannot provide both `message_history` and `single_message_history`" - ) - - self.question = question self.is_connected: Callable[[], bool] | None = is_connected self.latest_query_files = latest_query_files or [] self.tools = tools or [] self.force_use_tool = force_use_tool - - self.message_history = message_history or [] # used for QA flow where we only want to send a single message - self.single_message_history = single_message_history self.answer_style_config = answer_style_config - self.prompt_config = prompt_config self.llm = llm self.fast_llm = fast_llm @@ -82,8 +72,6 @@ class Answer: model_name=llm.config.model_name, ) - self._final_prompt: list[BaseMessage] | None = None - self._streamed_output: list[str] | None = None self._processed_stream: (list[AnswerPacket] | None) = None @@ -97,6 +85,42 @@ class Answer: and not skip_explicit_tool_calling ) + search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)] + search_tool: SearchTool | None = None + + if len(search_tools) > 1: + # TODO: handle multiple search tools + logger.warning("Multiple search tools found, using first one") + search_tool = search_tools[0] + elif len(search_tools) == 1: + search_tool = search_tools[0] + else: + logger.warning("No search tool found") + if use_agentic_search: + raise ValueError("No search tool found, cannot use agentic search") + + using_tool_calling_llm = explicit_tool_calling_supported( + llm.config.model_provider, llm.config.model_name + ) + agent_search_config = AgentSearchConfig( + search_request=search_request, + primary_llm=llm, + fast_llm=fast_llm, + search_tool=search_tool, + force_use_tool=force_use_tool, + use_agentic_search=use_agentic_search, + chat_session_id=chat_session_id, + message_id=current_agent_message_id, + use_persistence=True, + allow_refinement=True, + db_session=db_session, + prompt_builder=prompt_builder, + tools=tools, + using_tool_calling_llm=using_tool_calling_llm, + files=latest_query_files, + structured_response_format=answer_style_config.structured_response_format, + skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, + ) self.agent_search_config = agent_search_config self.db_session = db_session diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index b71a5c5eefca..78a2b15dafda 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -284,7 +284,7 @@ class AnswerStyleConfig(BaseModel): class PromptConfig(BaseModel): """Final representation of the Prompt configuration passed - into the `Answer` object.""" + into the `PromptBuilder` object.""" system_prompt: str task_prompt: str diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 9037fb635a60..bf7f988365b9 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -8,7 +8,6 @@ from typing import cast from sqlalchemy.orm import Session -from onyx.agents.agent_search.models import AgentSearchConfig from onyx.chat.answer import Answer from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import create_temporary_persona @@ -134,7 +133,6 @@ from onyx.tools.tool_implementations.search.search_tool import ( SECTION_RELEVANCE_LIST_ID, ) from onyx.tools.tool_runner import ToolCallFinalResult -from onyx.tools.utils import explicit_tool_calling_supported from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger from onyx.utils.telemetry import mt_cloud_telemetry @@ -712,6 +710,7 @@ def stream_chat_message_objects( for tool_list in tool_dict.values(): tools.extend(tool_list) + # TODO: unify message history with single message history message_history = [ PreviousMessage.from_chat_message(msg, files) for msg in history_msgs ] @@ -739,26 +738,7 @@ def stream_chat_message_objects( else None ), ) - # TODO: Since we're deleting the current main path in Answer, - # we should construct this unconditionally inside Answer instead - # Leaving it here for the time being to avoid breaking changes - search_tools = [tool for tool in tools if isinstance(tool, SearchTool)] - search_tool: SearchTool | None = None - if len(search_tools) > 1: - # TODO: handle multiple search tools - logger.warning("Multiple search tools found, using first one") - search_tool = search_tools[0] - elif len(search_tools) == 1: - search_tool = search_tools[0] - else: - logger.warning("No search tool found") - if new_msg_req.use_agentic_search: - raise ValueError("No search tool found, cannot use agentic search") - - using_tool_calling_llm = explicit_tool_calling_supported( - llm.config.model_provider, llm.config.model_name - ) force_use_tool = _get_force_search_settings(new_msg_req, tools) prompt_builder = AnswerPromptBuilder( user_message=default_build_user_message( @@ -776,35 +756,12 @@ def stream_chat_message_objects( ) prompt_builder.update_system_prompt(default_build_system_message(prompt_config)) - agent_search_config = AgentSearchConfig( - search_request=search_request, - primary_llm=llm, - fast_llm=fast_llm, - search_tool=search_tool, - force_use_tool=force_use_tool, - use_agentic_search=new_msg_req.use_agentic_search, - chat_session_id=chat_session_id, - message_id=reserved_message_id, - use_persistence=True, - allow_refinement=True, - db_session=db_session, - prompt_builder=prompt_builder, - tools=tools, - using_tool_calling_llm=using_tool_calling_llm, - files=latest_query_files, - structured_response_format=new_msg_req.structured_response_format, - skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation, - ) - - # TODO: add previous messages, answer style config, tools, etc. - # LLM prompt building, response capturing, etc. answer = Answer( + prompt_builder=prompt_builder, is_connected=is_connected, - question=final_msg.message, latest_query_files=latest_query_files, answer_style_config=answer_style_config, - prompt_config=prompt_config, llm=( llm or get_main_llm_from_tuple( @@ -818,12 +775,13 @@ def stream_chat_message_objects( ) ), fast_llm=fast_llm, - message_history=message_history, - tools=tools, force_use_tool=force_use_tool, - single_message_history=single_message_history, - agent_search_config=agent_search_config, + search_request=search_request, + chat_session_id=chat_session_id, + current_agent_message_id=reserved_message_id, + tools=tools, db_session=db_session, + use_agentic_search=new_msg_req.use_agentic_search, ) # reference_db_search_docs = None diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 020e551fd154..cb80be9e1f2f 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -2,6 +2,7 @@ import json from typing import cast from unittest.mock import MagicMock from unittest.mock import Mock +from uuid import UUID import pytest from langchain_core.messages import AIMessageChunk @@ -11,7 +12,6 @@ from langchain_core.messages import SystemMessage from langchain_core.messages import ToolCall from langchain_core.messages import ToolCallChunk -from onyx.agents.agent_search.models import AgentSearchConfig from onyx.chat.answer import Answer from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationInfo @@ -21,6 +21,10 @@ from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message +from onyx.context.search.models import SearchRequest from onyx.llm.interfaces import LLM from onyx.tools.force import ForceUseTool from onyx.tools.models import ToolCallFinalResult @@ -39,15 +43,28 @@ def answer_instance( mock_llm: LLM, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig, - agent_search_config: AgentSearchConfig, ) -> Answer: return Answer( - question=QUERY, + prompt_builder=AnswerPromptBuilder( + user_message=default_build_user_message( + user_query=QUERY, + prompt_config=prompt_config, + files=[], + single_message_history=None, + ), + system_message=default_build_system_message(prompt_config), + message_history=[], + llm_config=mock_llm.config, + raw_user_query=QUERY, + raw_user_uploaded_files=[], + ), answer_style_config=answer_style_config, llm=mock_llm, - prompt_config=prompt_config, + fast_llm=mock_llm, force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None), - agent_search_config=agent_search_config, + search_request=SearchRequest(query=QUERY), + chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), + current_agent_message_id=0, ) @@ -57,6 +74,8 @@ def test_basic_answer(answer_instance: Answer) -> None: AIMessageChunk(content="This is a "), AIMessageChunk(content="mock answer."), ] + answer_instance.agent_search_config.fast_llm = mock_llm + answer_instance.agent_search_config.primary_llm = mock_llm output = list(answer_instance.processed_streamed_output) assert len(output) == 2 diff --git a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py index 206e05846b44..3aa937565fd0 100644 --- a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py +++ b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py @@ -1,13 +1,16 @@ from typing import Any from unittest.mock import Mock +from uuid import UUID import pytest +from langchain_core.messages import HumanMessage from pytest_mock import MockerFixture -from onyx.agents.agent_search.models import AgentSearchConfig from onyx.chat.answer import Answer from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import PromptConfig +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder +from onyx.context.search.models import SearchRequest from onyx.tools.force import ForceUseTool from onyx.tools.tool_implementations.search.search_tool import SearchTool from tests.regression.answer_quality.run_qa import _process_and_write_query_results @@ -30,7 +33,6 @@ def test_skip_gen_ai_answer_generation_flag( config: dict[str, Any], mock_search_tool: SearchTool, answer_style_config: AnswerStyleConfig, - agent_search_config: AgentSearchConfig, prompt_config: PromptConfig, ) -> None: question = config["question"] @@ -42,30 +44,28 @@ def test_skip_gen_ai_answer_generation_flag( mock_llm.stream = Mock() mock_llm.stream.return_value = [Mock()] - agent_search_config.primary_llm = mock_llm - agent_search_config.fast_llm = mock_llm - agent_search_config.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation - agent_search_config.search_tool = mock_search_tool - agent_search_config.using_tool_calling_llm = False - agent_search_config.tools = [mock_search_tool] - answer = Answer( - question=question, answer_style_config=answer_style_config, - prompt_config=prompt_config, llm=mock_llm, - single_message_history="history", + fast_llm=mock_llm, tools=[mock_search_tool], - force_use_tool=( - ForceUseTool( - tool_name=mock_search_tool.name, - args={"query": question}, - force_use=True, - ) + force_use_tool=ForceUseTool( + tool_name=mock_search_tool.name, + args={"query": question}, + force_use=True, ), skip_explicit_tool_calling=True, skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, - agent_search_config=agent_search_config, + search_request=SearchRequest(query=question), + prompt_builder=AnswerPromptBuilder( + user_message=HumanMessage(content=question), + message_history=[], + llm_config=mock_llm.config, + raw_user_query=question, + raw_user_uploaded_files=[], + ), + chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), + current_agent_message_id=0, ) results = list(answer.processed_streamed_output) for res in results: