disabled llm when skip_gen_ai_answer_question set (#2687)

* disabled llm when skip_gen_ai_answer_question set

* added unit test

* typing
This commit is contained in:
evan-danswer 2024-10-06 14:10:02 -04:00 committed by GitHub
parent 0da736bed9
commit 089c734f63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 165 additions and 27 deletions

View File

@ -311,13 +311,13 @@ class Answer:
)
)
yield tool_runner.tool_final_result()
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
return
@ -413,6 +413,10 @@ class Answer:
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
if self.skip_gen_ai_answer_generation:
raise ValueError(
"skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated"
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
@ -477,10 +481,10 @@ class Answer:
final = tool_runner.tool_final_result()
yield final
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build()
prompt = prompt_builder.build()
yield from self._process_llm_stream(prompt=prompt, tools=None)
yield from self._process_llm_stream(prompt=prompt, tools=None)
@property
def processed_streamed_output(self) -> AnswerStream:

View File

@ -77,14 +77,15 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
"number_of_questions_in_dataset": len(questions),
}
env_vars = get_docker_container_env_vars(config["env_name"])
if env_vars["ENV_SEED_CONFIGURATION"]:
del env_vars["ENV_SEED_CONFIGURATION"]
if env_vars["GPG_KEY"]:
del env_vars["GPG_KEY"]
if metadata["test_config"]["llm"]["api_key"]:
del metadata["test_config"]["llm"]["api_key"]
metadata.update(env_vars)
if config["env_name"]:
env_vars = get_docker_container_env_vars(config["env_name"])
if env_vars["ENV_SEED_CONFIGURATION"]:
del env_vars["ENV_SEED_CONFIGURATION"]
if env_vars["GPG_KEY"]:
del env_vars["GPG_KEY"]
if metadata["test_config"]["llm"]["api_key"]:
del metadata["test_config"]["llm"]["api_key"]
metadata.update(env_vars)
metadata_path = os.path.join(test_output_folder, METADATA_FILENAME)
print("saving metadata to:", metadata_path)
with open(metadata_path, "w", encoding="utf-8") as yaml_file:
@ -95,17 +96,18 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
)
shutil.copy2(questions_file_path, copied_questions_file_path)
zipped_files_path = config["zipped_documents_file"]
copied_zipped_documents_path = os.path.join(
test_output_folder, os.path.basename(zipped_files_path)
)
shutil.copy2(zipped_files_path, copied_zipped_documents_path)
if config["zipped_documents_file"]:
zipped_files_path = config["zipped_documents_file"]
copied_zipped_documents_path = os.path.join(
test_output_folder, os.path.basename(zipped_files_path)
)
shutil.copy2(zipped_files_path, copied_zipped_documents_path)
zipped_files_folder = os.path.dirname(zipped_files_path)
jsonl_file_path = os.path.join(zipped_files_folder, "target_docs.jsonl")
if os.path.exists(jsonl_file_path):
copied_jsonl_path = os.path.join(test_output_folder, "target_docs.jsonl")
shutil.copy2(jsonl_file_path, copied_jsonl_path)
zipped_files_folder = os.path.dirname(zipped_files_path)
jsonl_file_path = os.path.join(zipped_files_folder, "target_docs.jsonl")
if os.path.exists(jsonl_file_path):
copied_jsonl_path = os.path.join(test_output_folder, "target_docs.jsonl")
shutil.copy2(jsonl_file_path, copied_jsonl_path)
return test_output_folder, questions

View File

@ -0,0 +1,132 @@
from typing import Any
from typing import cast
from unittest.mock import Mock
import pytest
from pytest_mock import MockerFixture
from danswer.llm.answering.answer import Answer
from danswer.one_shot_answer.answer_question import AnswerObjectIterator
from danswer.tools.force import ForceUseTool
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
@pytest.mark.parametrize(
"config",
[
{
"skip_gen_ai_answer_generation": True,
"question": "What is the capital of the moon?",
},
{
"skip_gen_ai_answer_generation": False,
"question": "What is the capital of the moon but twice?",
},
],
)
def test_skip_gen_ai_answer_generation_flag(config: dict[str, Any]) -> None:
search_tool = Mock()
search_tool.name = "search"
search_tool.run = Mock()
search_tool.run.return_value = [Mock()]
mock_llm = Mock()
mock_llm.config = Mock()
mock_llm.config.model_name = "gpt-4o-mini"
mock_llm.stream = Mock()
mock_llm.stream.return_value = [Mock()]
answer = Answer(
question=config["question"],
answer_style_config=Mock(),
prompt_config=Mock(),
llm=mock_llm,
single_message_history="history",
tools=[search_tool],
force_use_tool=(
ForceUseTool(
tool_name=search_tool.name,
args={"query": config["question"]},
force_use=True,
)
),
skip_explicit_tool_calling=True,
return_contexts=True,
skip_gen_ai_answer_generation=config["skip_gen_ai_answer_generation"],
)
count = 0
for _ in cast(AnswerObjectIterator, answer.processed_streamed_output):
count += 1
assert count == 2
if not config["skip_gen_ai_answer_generation"]:
mock_llm.stream.assert_called_once()
else:
mock_llm.stream.assert_not_called()
##### From here down is the client side test that was not working #####
class FinishedTestException(Exception):
pass
# could not get this to work, it seems like the mock is not being used
# tests that the main run_qa function passes the skip_gen_ai_answer_generation flag to the Answer object
@pytest.mark.parametrize(
"config, questions",
[
(
{
"skip_gen_ai_answer_generation": True,
"output_folder": "./test_output_folder",
"zipped_documents_file": "./test_docs.jsonl",
"questions_file": "./test_questions.jsonl",
"commit_sha": None,
"launch_web_ui": False,
"only_retrieve_docs": True,
"use_cloud_gpu": False,
"model_server_ip": "PUT_PUBLIC_CLOUD_IP_HERE",
"model_server_port": "PUT_PUBLIC_CLOUD_PORT_HERE",
"environment_name": "",
"env_name": "",
"limit": None,
},
[{"uid": "1", "question": "What is the capital of the moon?"}],
),
(
{
"skip_gen_ai_answer_generation": False,
"output_folder": "./test_output_folder",
"zipped_documents_file": "./test_docs.jsonl",
"questions_file": "./test_questions.jsonl",
"commit_sha": None,
"launch_web_ui": False,
"only_retrieve_docs": True,
"use_cloud_gpu": False,
"model_server_ip": "PUT_PUBLIC_CLOUD_IP_HERE",
"model_server_port": "PUT_PUBLIC_CLOUD_PORT_HERE",
"environment_name": "",
"env_name": "",
"limit": None,
},
[{"uid": "1", "question": "What is the capital of the moon but twice?"}],
),
],
)
@pytest.mark.skip(reason="not working")
def test_run_qa_skip_gen_ai(
config: dict[str, Any], questions: list[dict[str, Any]], mocker: MockerFixture
) -> None:
mocker.patch(
"tests.regression.answer_quality.run_qa._initialize_files",
return_value=("test", questions),
)
def arg_checker(question_data: dict, config: dict, question_number: int) -> None:
assert question_data == questions[0]
raise FinishedTestException()
mocker.patch(
"tests.regression.answer_quality.run_qa._process_question", arg_checker
)
with pytest.raises(FinishedTestException):
_process_and_write_query_results(config)