fixed chat tests

This commit is contained in:
Evan Lohn
2025-01-24 11:42:40 -08:00
parent ddbfc65ad0
commit db2004542e
8 changed files with 77 additions and 48 deletions

View File

@@ -59,7 +59,6 @@ def process_llm_stream(
): ):
tool_call_chunk += response # type: ignore tool_call_chunk += response # type: ignore
elif should_stream_answer: elif should_stream_answer:
# TODO: handle emitting of CitationInfo
for response_part in answer_handler.handle_response_part(response, []): for response_part in answer_handler.handle_response_part(response, []):
dispatch_custom_event( dispatch_custom_event(
"basic_response", "basic_response",

View File

@@ -1,5 +1,6 @@
from typing import cast from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicOutput from onyx.agents.agent_search.basic.states import BasicOutput
@@ -53,17 +54,19 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
initial_search_results = cast(list[LlmDoc], initial_search_results) initial_search_results = cast(list[LlmDoc], initial_search_results)
stream = llm.stream( new_tool_call_chunk = AIMessageChunk(content="")
prompt=new_prompt_builder.build(), if not agent_config.skip_gen_ai_answer_generation:
structured_response_format=structured_response_format, stream = llm.stream(
) prompt=new_prompt_builder.build(),
structured_response_format=structured_response_format,
)
# For now, we don't do multiple tool calls, so we ignore the tool_message # For now, we don't do multiple tool calls, so we ignore the tool_message
new_tool_call_chunk = process_llm_stream( new_tool_call_chunk = process_llm_stream(
stream, stream,
True, True,
final_search_results=final_search_results, final_search_results=final_search_results,
displayed_search_results=initial_search_results, displayed_search_results=initial_search_results,
) )
return BasicOutput(tool_call_chunk=new_tool_call_chunk) return BasicOutput(tool_call_chunk=new_tool_call_chunk)

View File

@@ -96,7 +96,9 @@ def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoic
structured_response_format=structured_response_format, structured_response_format=structured_response_format,
) )
tool_message = process_llm_stream(stream, should_stream_answer) tool_message = process_llm_stream(
stream, should_stream_answer and not agent_config.skip_gen_ai_answer_generation
)
# If no tool calls are emitted by the LLM, we should not choose a tool # If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0: if len(tool_message.tool_calls) == 0:

View File

@@ -50,8 +50,6 @@ class Answer:
# but we only support them anyways # but we only support them anyways
# if set to True, then never use the LLMs provided tool-calling functonality # if set to True, then never use the LLMs provided tool-calling functonality
skip_explicit_tool_calling: bool = False, skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool
return_contexts: bool = False,
skip_gen_ai_answer_generation: bool = False, skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None, is_connected: Callable[[], bool] | None = None,
fast_llm: LLM | None = None, fast_llm: LLM | None = None,
@@ -89,7 +87,6 @@ class Answer:
self._streamed_output: list[str] | None = None self._streamed_output: list[str] | None = None
self._processed_stream: (list[AnswerPacket] | None) = None self._processed_stream: (list[AnswerPacket] | None) = None
self._return_contexts = return_contexts
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False self._is_cancelled = False

View File

@@ -766,6 +766,8 @@ def stream_chat_message_objects(
raw_user_uploaded_files=latest_query_files or [], raw_user_uploaded_files=latest_query_files or [],
single_message_history=single_message_history, single_message_history=single_message_history,
) )
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
agent_search_config = AgentSearchConfig( agent_search_config = AgentSearchConfig(
search_request=search_request, search_request=search_request,
primary_llm=llm, primary_llm=llm,

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage from langchain_core.messages import SystemMessage
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
@@ -15,6 +14,8 @@ from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.constants import DocumentSource from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection from onyx.context.search.models import InferenceSection
@@ -28,6 +29,7 @@ from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search_like_tool_utils import ( from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID, FINAL_CONTEXT_DOCUMENTS_ID,
) )
from onyx.tools.utils import explicit_tool_calling_supported
QUERY = "Test question" QUERY = "Test question"
DEFAULT_SEARCH_ARGS = {"query": "search"} DEFAULT_SEARCH_ARGS = {"query": "search"}
@@ -40,26 +42,38 @@ def answer_style_config() -> AnswerStyleConfig:
@pytest.fixture @pytest.fixture
def agent_search_config( def agent_search_config(
mock_llm: LLM, mock_search_tool: SearchTool mock_llm: LLM, mock_search_tool: SearchTool, prompt_config: PromptConfig
) -> AgentSearchConfig: ) -> AgentSearchConfig:
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=QUERY,
prompt_config=prompt_config,
files=[],
single_message_history=None,
),
message_history=[],
llm_config=mock_llm.config,
raw_user_query=QUERY,
raw_user_uploaded_files=[],
single_message_history=None,
)
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
using_tool_calling_llm = explicit_tool_calling_supported(
mock_llm.config.model_provider, mock_llm.config.model_name
)
return AgentSearchConfig( return AgentSearchConfig(
search_request=SearchRequest(query=QUERY), search_request=SearchRequest(query=QUERY),
primary_llm=mock_llm, primary_llm=mock_llm,
fast_llm=mock_llm, fast_llm=mock_llm,
search_tool=mock_search_tool, search_tool=mock_search_tool,
force_use_tool=ForceUseTool(force_use=False, tool_name=""), force_use_tool=ForceUseTool(force_use=False, tool_name=""),
prompt_builder=AnswerPromptBuilder( prompt_builder=prompt_builder,
user_message=HumanMessage(content=QUERY),
message_history=[],
llm_config=mock_llm.config,
raw_user_query=QUERY,
raw_user_uploaded_files=[],
),
chat_session_id=None, chat_session_id=None,
message_id=1, message_id=1,
use_persistence=True, use_persistence=True,
db_session=None, db_session=None,
use_agentic_search=False, use_agentic_search=False,
using_tool_calling_llm=using_tool_calling_llm,
) )

View File

@@ -52,7 +52,7 @@ def answer_instance(
def test_basic_answer(answer_instance: Answer) -> None: def test_basic_answer(answer_instance: Answer) -> None:
mock_llm = cast(Mock, answer_instance.llm) mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
mock_llm.stream.return_value = [ mock_llm.stream.return_value = [
AIMessageChunk(content="This is a "), AIMessageChunk(content="This is a "),
AIMessageChunk(content="mock answer."), AIMessageChunk(content="mock answer."),
@@ -103,15 +103,16 @@ def test_basic_answer(answer_instance: Answer) -> None:
def test_answer_with_search_call( def test_answer_with_search_call(
answer_instance: Answer, answer_instance: Answer,
mock_search_results: list[LlmDoc], mock_search_results: list[LlmDoc],
mock_contexts: OnyxContexts,
mock_search_tool: MagicMock, mock_search_tool: MagicMock,
force_use_tool: ForceUseTool, force_use_tool: ForceUseTool,
expected_tool_args: dict, expected_tool_args: dict,
) -> None: ) -> None:
answer_instance.tools = [mock_search_tool] answer_instance.agent_search_config.tools = [mock_search_tool]
answer_instance.force_use_tool = force_use_tool answer_instance.agent_search_config.force_use_tool = force_use_tool
# Set up the LLM mock to return search results and then an answer # Set up the LLM mock to return search results and then an answer
mock_llm = cast(Mock, answer_instance.llm) mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
stream_side_effect: list[list[BaseMessage]] = [] stream_side_effect: list[list[BaseMessage]] = []
@@ -143,30 +144,40 @@ def test_answer_with_search_call(
) )
mock_llm.stream.side_effect = stream_side_effect mock_llm.stream.side_effect = stream_side_effect
print("side effect")
for v in stream_side_effect:
print(v)
print("-" * 300)
print(len(stream_side_effect))
print("-" * 300)
# Process the output # Process the output
output = list(answer_instance.processed_streamed_output) output = list(answer_instance.processed_streamed_output)
# Updated assertions # Updated assertions
assert len(output) == 7 # assert len(output) == 7
assert output[0] == ToolCallKickoff( assert output[0] == ToolCallKickoff(
tool_name="search", tool_args=expected_tool_args tool_name="search", tool_args=expected_tool_args
) )
assert output[1] == ToolResponse( assert output[1] == ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=mock_contexts,
)
assert output[2] == ToolResponse(
id="final_context_documents", id="final_context_documents",
response=mock_search_results, response=mock_search_results,
) )
assert output[2] == ToolCallFinalResult( assert output[3] == ToolCallFinalResult(
tool_name="search", tool_name="search",
tool_args=expected_tool_args, tool_args=expected_tool_args,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
) )
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ") assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1") expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[4] == expected_citation assert output[5] == expected_citation
assert output[5] == OnyxAnswerPiece( assert output[6] == OnyxAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). " answer_piece="the answer is abc[[1]](https://example.com/doc1). "
) )
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.") assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = ( expected_answer = (
"Based on the search results, " "Based on the search results, "

View File

@@ -1,16 +1,13 @@
from typing import Any from typing import Any
from typing import cast
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.answer import Answer from onyx.chat.answer import Answer
from onyx.chat.answer import AnswerStream
from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
from onyx.context.search.models import SearchRequest
from onyx.tools.force import ForceUseTool from onyx.tools.force import ForceUseTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search.search_tool import SearchTool
from tests.regression.answer_quality.run_qa import _process_and_write_query_results from tests.regression.answer_quality.run_qa import _process_and_write_query_results
@@ -33,6 +30,7 @@ def test_skip_gen_ai_answer_generation_flag(
config: dict[str, Any], config: dict[str, Any],
mock_search_tool: SearchTool, mock_search_tool: SearchTool,
answer_style_config: AnswerStyleConfig, answer_style_config: AnswerStyleConfig,
agent_search_config: AgentSearchConfig,
prompt_config: PromptConfig, prompt_config: PromptConfig,
) -> None: ) -> None:
question = config["question"] question = config["question"]
@@ -44,10 +42,12 @@ def test_skip_gen_ai_answer_generation_flag(
mock_llm.stream = Mock() mock_llm.stream = Mock()
mock_llm.stream.return_value = [Mock()] mock_llm.stream.return_value = [Mock()]
session = Mock() agent_search_config.primary_llm = mock_llm
agent_search_config, _ = get_test_config( agent_search_config.fast_llm = mock_llm
session, mock_llm, mock_llm, SearchRequest(query=question) agent_search_config.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
) agent_search_config.search_tool = mock_search_tool
agent_search_config.using_tool_calling_llm = False
agent_search_config.tools = [mock_search_tool]
answer = Answer( answer = Answer(
question=question, question=question,
@@ -64,14 +64,15 @@ def test_skip_gen_ai_answer_generation_flag(
) )
), ),
skip_explicit_tool_calling=True, skip_explicit_tool_calling=True,
return_contexts=True,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
agent_search_config=agent_search_config, agent_search_config=agent_search_config,
) )
count = 0 results = list(answer.processed_streamed_output)
for _ in cast(AnswerStream, answer.processed_streamed_output): for res in results:
count += 1 print(res)
assert count == 3 if skip_gen_ai_answer_generation else 4
expected_count = 4 if skip_gen_ai_answer_generation else 5
assert len(results) == expected_count
if not skip_gen_ai_answer_generation: if not skip_gen_ai_answer_generation:
mock_llm.stream.assert_called_once() mock_llm.stream.assert_called_once()
else: else: