diff --git a/backend/onyx/agents/agent_search/basic/utils.py b/backend/onyx/agents/agent_search/basic/utils.py index 28b1b26f1..8898fa214 100644 --- a/backend/onyx/agents/agent_search/basic/utils.py +++ b/backend/onyx/agents/agent_search/basic/utils.py @@ -59,7 +59,6 @@ def process_llm_stream( ): tool_call_chunk += response # type: ignore elif should_stream_answer: - # TODO: handle emitting of CitationInfo for response_part in answer_handler.handle_response_part(response, []): dispatch_custom_event( "basic_response", diff --git a/backend/onyx/agents/agent_search/orchestration/basic_use_tool_response.py b/backend/onyx/agents/agent_search/orchestration/basic_use_tool_response.py index 732e3e43a..22f8ceb12 100644 --- a/backend/onyx/agents/agent_search/orchestration/basic_use_tool_response.py +++ b/backend/onyx/agents/agent_search/orchestration/basic_use_tool_response.py @@ -1,5 +1,6 @@ from typing import cast +from langchain_core.messages import AIMessageChunk from langchain_core.runnables.config import RunnableConfig from onyx.agents.agent_search.basic.states import BasicOutput @@ -53,17 +54,19 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO initial_search_results = cast(list[LlmDoc], initial_search_results) - stream = llm.stream( - prompt=new_prompt_builder.build(), - structured_response_format=structured_response_format, - ) + new_tool_call_chunk = AIMessageChunk(content="") + if not agent_config.skip_gen_ai_answer_generation: + stream = llm.stream( + prompt=new_prompt_builder.build(), + structured_response_format=structured_response_format, + ) - # For now, we don't do multiple tool calls, so we ignore the tool_message - new_tool_call_chunk = process_llm_stream( - stream, - True, - final_search_results=final_search_results, - displayed_search_results=initial_search_results, - ) + # For now, we don't do multiple tool calls, so we ignore the tool_message + new_tool_call_chunk = process_llm_stream( + stream, + True, + final_search_results=final_search_results, + displayed_search_results=initial_search_results, + ) return BasicOutput(tool_call_chunk=new_tool_call_chunk) diff --git a/backend/onyx/agents/agent_search/orchestration/llm_tool_choice.py b/backend/onyx/agents/agent_search/orchestration/llm_tool_choice.py index 1ad805e68..221cfa0fa 100644 --- a/backend/onyx/agents/agent_search/orchestration/llm_tool_choice.py +++ b/backend/onyx/agents/agent_search/orchestration/llm_tool_choice.py @@ -96,7 +96,9 @@ def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoic structured_response_format=structured_response_format, ) - tool_message = process_llm_stream(stream, should_stream_answer) + tool_message = process_llm_stream( + stream, should_stream_answer and not agent_config.skip_gen_ai_answer_generation + ) # If no tool calls are emitted by the LLM, we should not choose a tool if len(tool_message.tool_calls) == 0: diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 6b755ec8a..416486dad 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -50,8 +50,6 @@ class Answer: # but we only support them anyways # if set to True, then never use the LLMs provided tool-calling functonality skip_explicit_tool_calling: bool = False, - # Returns the full document sections text from the search tool - return_contexts: bool = False, skip_gen_ai_answer_generation: bool = False, is_connected: Callable[[], bool] | None = None, fast_llm: LLM | None = None, @@ -89,7 +87,6 @@ class Answer: self._streamed_output: list[str] | None = None self._processed_stream: (list[AnswerPacket] | None) = None - self._return_contexts = return_contexts self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation self._is_cancelled = False diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 60e2ff4e3..2d6423697 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -766,6 +766,8 @@ def stream_chat_message_objects( raw_user_uploaded_files=latest_query_files or [], single_message_history=single_message_history, ) + prompt_builder.update_system_prompt(default_build_system_message(prompt_config)) + agent_search_config = AgentSearchConfig( search_request=search_request, primary_llm=llm, diff --git a/backend/tests/unit/onyx/chat/conftest.py b/backend/tests/unit/onyx/chat/conftest.py index e4170efb9..d68627074 100644 --- a/backend/tests/unit/onyx/chat/conftest.py +++ b/backend/tests/unit/onyx/chat/conftest.py @@ -3,7 +3,6 @@ from datetime import datetime from unittest.mock import MagicMock import pytest -from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage from onyx.agents.agent_search.models import AgentSearchConfig @@ -15,6 +14,8 @@ from onyx.chat.models import OnyxContext from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message from onyx.configs.constants import DocumentSource from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceSection @@ -28,6 +29,7 @@ from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, ) +from onyx.tools.utils import explicit_tool_calling_supported QUERY = "Test question" DEFAULT_SEARCH_ARGS = {"query": "search"} @@ -40,26 +42,38 @@ def answer_style_config() -> AnswerStyleConfig: @pytest.fixture def agent_search_config( - mock_llm: LLM, mock_search_tool: SearchTool + mock_llm: LLM, mock_search_tool: SearchTool, prompt_config: PromptConfig ) -> AgentSearchConfig: + prompt_builder = AnswerPromptBuilder( + user_message=default_build_user_message( + user_query=QUERY, + prompt_config=prompt_config, + files=[], + single_message_history=None, + ), + message_history=[], + llm_config=mock_llm.config, + raw_user_query=QUERY, + raw_user_uploaded_files=[], + single_message_history=None, + ) + prompt_builder.update_system_prompt(default_build_system_message(prompt_config)) + using_tool_calling_llm = explicit_tool_calling_supported( + mock_llm.config.model_provider, mock_llm.config.model_name + ) return AgentSearchConfig( search_request=SearchRequest(query=QUERY), primary_llm=mock_llm, fast_llm=mock_llm, search_tool=mock_search_tool, force_use_tool=ForceUseTool(force_use=False, tool_name=""), - prompt_builder=AnswerPromptBuilder( - user_message=HumanMessage(content=QUERY), - message_history=[], - llm_config=mock_llm.config, - raw_user_query=QUERY, - raw_user_uploaded_files=[], - ), + prompt_builder=prompt_builder, chat_session_id=None, message_id=1, use_persistence=True, db_session=None, use_agentic_search=False, + using_tool_calling_llm=using_tool_calling_llm, ) diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 52c4311cc..020e551fd 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -52,7 +52,7 @@ def answer_instance( def test_basic_answer(answer_instance: Answer) -> None: - mock_llm = cast(Mock, answer_instance.llm) + mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm) mock_llm.stream.return_value = [ AIMessageChunk(content="This is a "), AIMessageChunk(content="mock answer."), @@ -103,15 +103,16 @@ def test_basic_answer(answer_instance: Answer) -> None: def test_answer_with_search_call( answer_instance: Answer, mock_search_results: list[LlmDoc], + mock_contexts: OnyxContexts, mock_search_tool: MagicMock, force_use_tool: ForceUseTool, expected_tool_args: dict, ) -> None: - answer_instance.tools = [mock_search_tool] - answer_instance.force_use_tool = force_use_tool + answer_instance.agent_search_config.tools = [mock_search_tool] + answer_instance.agent_search_config.force_use_tool = force_use_tool # Set up the LLM mock to return search results and then an answer - mock_llm = cast(Mock, answer_instance.llm) + mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm) stream_side_effect: list[list[BaseMessage]] = [] @@ -143,30 +144,40 @@ def test_answer_with_search_call( ) mock_llm.stream.side_effect = stream_side_effect + print("side effect") + for v in stream_side_effect: + print(v) + print("-" * 300) + print(len(stream_side_effect)) + print("-" * 300) # Process the output output = list(answer_instance.processed_streamed_output) # Updated assertions - assert len(output) == 7 + # assert len(output) == 7 assert output[0] == ToolCallKickoff( tool_name="search", tool_args=expected_tool_args ) assert output[1] == ToolResponse( + id=SEARCH_DOC_CONTENT_ID, + response=mock_contexts, + ) + assert output[2] == ToolResponse( id="final_context_documents", response=mock_search_results, ) - assert output[2] == ToolCallFinalResult( + assert output[3] == ToolCallFinalResult( tool_name="search", tool_args=expected_tool_args, tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], ) - assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ") + assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ") expected_citation = CitationInfo(citation_num=1, document_id="doc1") - assert output[4] == expected_citation - assert output[5] == OnyxAnswerPiece( + assert output[5] == expected_citation + assert output[6] == OnyxAnswerPiece( answer_piece="the answer is abc[[1]](https://example.com/doc1). " ) - assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.") + assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.") expected_answer = ( "Based on the search results, " diff --git a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py index 5a061d97c..206e05846 100644 --- a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py +++ b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py @@ -1,16 +1,13 @@ from typing import Any -from typing import cast from unittest.mock import Mock import pytest from pytest_mock import MockerFixture -from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config +from onyx.agents.agent_search.models import AgentSearchConfig from onyx.chat.answer import Answer -from onyx.chat.answer import AnswerStream from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import PromptConfig -from onyx.context.search.models import SearchRequest from onyx.tools.force import ForceUseTool from onyx.tools.tool_implementations.search.search_tool import SearchTool from tests.regression.answer_quality.run_qa import _process_and_write_query_results @@ -33,6 +30,7 @@ def test_skip_gen_ai_answer_generation_flag( config: dict[str, Any], mock_search_tool: SearchTool, answer_style_config: AnswerStyleConfig, + agent_search_config: AgentSearchConfig, prompt_config: PromptConfig, ) -> None: question = config["question"] @@ -44,10 +42,12 @@ def test_skip_gen_ai_answer_generation_flag( mock_llm.stream = Mock() mock_llm.stream.return_value = [Mock()] - session = Mock() - agent_search_config, _ = get_test_config( - session, mock_llm, mock_llm, SearchRequest(query=question) - ) + agent_search_config.primary_llm = mock_llm + agent_search_config.fast_llm = mock_llm + agent_search_config.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation + agent_search_config.search_tool = mock_search_tool + agent_search_config.using_tool_calling_llm = False + agent_search_config.tools = [mock_search_tool] answer = Answer( question=question, @@ -64,14 +64,15 @@ def test_skip_gen_ai_answer_generation_flag( ) ), skip_explicit_tool_calling=True, - return_contexts=True, skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, agent_search_config=agent_search_config, ) - count = 0 - for _ in cast(AnswerStream, answer.processed_streamed_output): - count += 1 - assert count == 3 if skip_gen_ai_answer_generation else 4 + results = list(answer.processed_streamed_output) + for res in results: + print(res) + + expected_count = 4 if skip_gen_ai_answer_generation else 5 + assert len(results) == expected_count if not skip_gen_ai_answer_generation: mock_llm.stream.assert_called_once() else: