mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-30 09:40:50 +02:00
disabled llm when skip_gen_ai_answer_question set (#2687)
* disabled llm when skip_gen_ai_answer_question set * added unit test * typing
This commit is contained in:
@ -311,7 +311,7 @@ class Answer:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield tool_runner.tool_final_result()
|
yield tool_runner.tool_final_result()
|
||||||
|
if not self.skip_gen_ai_answer_generation:
|
||||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||||
|
|
||||||
yield from self._process_llm_stream(
|
yield from self._process_llm_stream(
|
||||||
@ -413,6 +413,10 @@ class Answer:
|
|||||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||||
|
|
||||||
if not chosen_tool_and_args:
|
if not chosen_tool_and_args:
|
||||||
|
if self.skip_gen_ai_answer_generation:
|
||||||
|
raise ValueError(
|
||||||
|
"skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated"
|
||||||
|
)
|
||||||
prompt_builder.update_system_prompt(
|
prompt_builder.update_system_prompt(
|
||||||
default_build_system_message(self.prompt_config)
|
default_build_system_message(self.prompt_config)
|
||||||
)
|
)
|
||||||
@ -477,7 +481,7 @@ class Answer:
|
|||||||
final = tool_runner.tool_final_result()
|
final = tool_runner.tool_final_result()
|
||||||
|
|
||||||
yield final
|
yield final
|
||||||
|
if not self.skip_gen_ai_answer_generation:
|
||||||
prompt = prompt_builder.build()
|
prompt = prompt_builder.build()
|
||||||
|
|
||||||
yield from self._process_llm_stream(prompt=prompt, tools=None)
|
yield from self._process_llm_stream(prompt=prompt, tools=None)
|
||||||
|
@ -77,6 +77,7 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
|
|||||||
"number_of_questions_in_dataset": len(questions),
|
"number_of_questions_in_dataset": len(questions),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config["env_name"]:
|
||||||
env_vars = get_docker_container_env_vars(config["env_name"])
|
env_vars = get_docker_container_env_vars(config["env_name"])
|
||||||
if env_vars["ENV_SEED_CONFIGURATION"]:
|
if env_vars["ENV_SEED_CONFIGURATION"]:
|
||||||
del env_vars["ENV_SEED_CONFIGURATION"]
|
del env_vars["ENV_SEED_CONFIGURATION"]
|
||||||
@ -95,6 +96,7 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
|
|||||||
)
|
)
|
||||||
shutil.copy2(questions_file_path, copied_questions_file_path)
|
shutil.copy2(questions_file_path, copied_questions_file_path)
|
||||||
|
|
||||||
|
if config["zipped_documents_file"]:
|
||||||
zipped_files_path = config["zipped_documents_file"]
|
zipped_files_path = config["zipped_documents_file"]
|
||||||
copied_zipped_documents_path = os.path.join(
|
copied_zipped_documents_path = os.path.join(
|
||||||
test_output_folder, os.path.basename(zipped_files_path)
|
test_output_folder, os.path.basename(zipped_files_path)
|
||||||
|
132
backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py
Normal file
132
backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from typing import Any
|
||||||
|
from typing import cast
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from danswer.llm.answering.answer import Answer
|
||||||
|
from danswer.one_shot_answer.answer_question import AnswerObjectIterator
|
||||||
|
from danswer.tools.force import ForceUseTool
|
||||||
|
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"skip_gen_ai_answer_generation": True,
|
||||||
|
"question": "What is the capital of the moon?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"skip_gen_ai_answer_generation": False,
|
||||||
|
"question": "What is the capital of the moon but twice?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_skip_gen_ai_answer_generation_flag(config: dict[str, Any]) -> None:
|
||||||
|
search_tool = Mock()
|
||||||
|
search_tool.name = "search"
|
||||||
|
search_tool.run = Mock()
|
||||||
|
search_tool.run.return_value = [Mock()]
|
||||||
|
mock_llm = Mock()
|
||||||
|
mock_llm.config = Mock()
|
||||||
|
mock_llm.config.model_name = "gpt-4o-mini"
|
||||||
|
mock_llm.stream = Mock()
|
||||||
|
mock_llm.stream.return_value = [Mock()]
|
||||||
|
answer = Answer(
|
||||||
|
question=config["question"],
|
||||||
|
answer_style_config=Mock(),
|
||||||
|
prompt_config=Mock(),
|
||||||
|
llm=mock_llm,
|
||||||
|
single_message_history="history",
|
||||||
|
tools=[search_tool],
|
||||||
|
force_use_tool=(
|
||||||
|
ForceUseTool(
|
||||||
|
tool_name=search_tool.name,
|
||||||
|
args={"query": config["question"]},
|
||||||
|
force_use=True,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
skip_explicit_tool_calling=True,
|
||||||
|
return_contexts=True,
|
||||||
|
skip_gen_ai_answer_generation=config["skip_gen_ai_answer_generation"],
|
||||||
|
)
|
||||||
|
count = 0
|
||||||
|
for _ in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||||
|
count += 1
|
||||||
|
assert count == 2
|
||||||
|
if not config["skip_gen_ai_answer_generation"]:
|
||||||
|
mock_llm.stream.assert_called_once()
|
||||||
|
else:
|
||||||
|
mock_llm.stream.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
##### From here down is the client side test that was not working #####
|
||||||
|
|
||||||
|
|
||||||
|
class FinishedTestException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# could not get this to work, it seems like the mock is not being used
|
||||||
|
# tests that the main run_qa function passes the skip_gen_ai_answer_generation flag to the Answer object
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config, questions",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"skip_gen_ai_answer_generation": True,
|
||||||
|
"output_folder": "./test_output_folder",
|
||||||
|
"zipped_documents_file": "./test_docs.jsonl",
|
||||||
|
"questions_file": "./test_questions.jsonl",
|
||||||
|
"commit_sha": None,
|
||||||
|
"launch_web_ui": False,
|
||||||
|
"only_retrieve_docs": True,
|
||||||
|
"use_cloud_gpu": False,
|
||||||
|
"model_server_ip": "PUT_PUBLIC_CLOUD_IP_HERE",
|
||||||
|
"model_server_port": "PUT_PUBLIC_CLOUD_PORT_HERE",
|
||||||
|
"environment_name": "",
|
||||||
|
"env_name": "",
|
||||||
|
"limit": None,
|
||||||
|
},
|
||||||
|
[{"uid": "1", "question": "What is the capital of the moon?"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"skip_gen_ai_answer_generation": False,
|
||||||
|
"output_folder": "./test_output_folder",
|
||||||
|
"zipped_documents_file": "./test_docs.jsonl",
|
||||||
|
"questions_file": "./test_questions.jsonl",
|
||||||
|
"commit_sha": None,
|
||||||
|
"launch_web_ui": False,
|
||||||
|
"only_retrieve_docs": True,
|
||||||
|
"use_cloud_gpu": False,
|
||||||
|
"model_server_ip": "PUT_PUBLIC_CLOUD_IP_HERE",
|
||||||
|
"model_server_port": "PUT_PUBLIC_CLOUD_PORT_HERE",
|
||||||
|
"environment_name": "",
|
||||||
|
"env_name": "",
|
||||||
|
"limit": None,
|
||||||
|
},
|
||||||
|
[{"uid": "1", "question": "What is the capital of the moon but twice?"}],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.skip(reason="not working")
|
||||||
|
def test_run_qa_skip_gen_ai(
|
||||||
|
config: dict[str, Any], questions: list[dict[str, Any]], mocker: MockerFixture
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"tests.regression.answer_quality.run_qa._initialize_files",
|
||||||
|
return_value=("test", questions),
|
||||||
|
)
|
||||||
|
|
||||||
|
def arg_checker(question_data: dict, config: dict, question_number: int) -> None:
|
||||||
|
assert question_data == questions[0]
|
||||||
|
raise FinishedTestException()
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"tests.regression.answer_quality.run_qa._process_question", arg_checker
|
||||||
|
)
|
||||||
|
with pytest.raises(FinishedTestException):
|
||||||
|
_process_and_write_query_results(config)
|
Reference in New Issue
Block a user