From 089c734f63143b65543e9c26b31e4adf6741aff1 Mon Sep 17 00:00:00 2001 From: evan-danswer Date: Sun, 6 Oct 2024 14:10:02 -0400 Subject: [PATCH] disabled llm when skip_gen_ai_answer_question set (#2687) * disabled llm when skip_gen_ai_answer_question set * added unit test * typing --- backend/danswer/llm/answering/answer.py | 22 +-- .../tests/regression/answer_quality/run_qa.py | 38 ++--- .../danswer/llm/answering/test_skip_gen_ai.py | 132 ++++++++++++++++++ 3 files changed, 165 insertions(+), 27 deletions(-) create mode 100644 backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 922d757d3..12c1bc25f 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -311,13 +311,13 @@ class Answer: ) ) yield tool_runner.tool_final_result() + if not self.skip_gen_ai_answer_generation: + prompt = prompt_builder.build(tool_call_summary=tool_call_summary) - prompt = prompt_builder.build(tool_call_summary=tool_call_summary) - - yield from self._process_llm_stream( - prompt=prompt, - tools=[tool.tool_definition() for tool in self.tools], - ) + yield from self._process_llm_stream( + prompt=prompt, + tools=[tool.tool_definition() for tool in self.tools], + ) return @@ -413,6 +413,10 @@ class Answer: logger.notice(f"Chosen tool: {chosen_tool_and_args}") if not chosen_tool_and_args: + if self.skip_gen_ai_answer_generation: + raise ValueError( + "skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated" + ) prompt_builder.update_system_prompt( default_build_system_message(self.prompt_config) ) @@ -477,10 +481,10 @@ class Answer: final = tool_runner.tool_final_result() yield final + if not self.skip_gen_ai_answer_generation: + prompt = prompt_builder.build() - prompt = prompt_builder.build() - - yield from self._process_llm_stream(prompt=prompt, tools=None) + yield from self._process_llm_stream(prompt=prompt, tools=None) @property def processed_streamed_output(self) -> AnswerStream: diff --git a/backend/tests/regression/answer_quality/run_qa.py b/backend/tests/regression/answer_quality/run_qa.py index 5de034b37..f6dd0e0b5 100644 --- a/backend/tests/regression/answer_quality/run_qa.py +++ b/backend/tests/regression/answer_quality/run_qa.py @@ -77,14 +77,15 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]: "number_of_questions_in_dataset": len(questions), } - env_vars = get_docker_container_env_vars(config["env_name"]) - if env_vars["ENV_SEED_CONFIGURATION"]: - del env_vars["ENV_SEED_CONFIGURATION"] - if env_vars["GPG_KEY"]: - del env_vars["GPG_KEY"] - if metadata["test_config"]["llm"]["api_key"]: - del metadata["test_config"]["llm"]["api_key"] - metadata.update(env_vars) + if config["env_name"]: + env_vars = get_docker_container_env_vars(config["env_name"]) + if env_vars["ENV_SEED_CONFIGURATION"]: + del env_vars["ENV_SEED_CONFIGURATION"] + if env_vars["GPG_KEY"]: + del env_vars["GPG_KEY"] + if metadata["test_config"]["llm"]["api_key"]: + del metadata["test_config"]["llm"]["api_key"] + metadata.update(env_vars) metadata_path = os.path.join(test_output_folder, METADATA_FILENAME) print("saving metadata to:", metadata_path) with open(metadata_path, "w", encoding="utf-8") as yaml_file: @@ -95,17 +96,18 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]: ) shutil.copy2(questions_file_path, copied_questions_file_path) - zipped_files_path = config["zipped_documents_file"] - copied_zipped_documents_path = os.path.join( - test_output_folder, os.path.basename(zipped_files_path) - ) - shutil.copy2(zipped_files_path, copied_zipped_documents_path) + if config["zipped_documents_file"]: + zipped_files_path = config["zipped_documents_file"] + copied_zipped_documents_path = os.path.join( + test_output_folder, os.path.basename(zipped_files_path) + ) + shutil.copy2(zipped_files_path, copied_zipped_documents_path) - zipped_files_folder = os.path.dirname(zipped_files_path) - jsonl_file_path = os.path.join(zipped_files_folder, "target_docs.jsonl") - if os.path.exists(jsonl_file_path): - copied_jsonl_path = os.path.join(test_output_folder, "target_docs.jsonl") - shutil.copy2(jsonl_file_path, copied_jsonl_path) + zipped_files_folder = os.path.dirname(zipped_files_path) + jsonl_file_path = os.path.join(zipped_files_folder, "target_docs.jsonl") + if os.path.exists(jsonl_file_path): + copied_jsonl_path = os.path.join(test_output_folder, "target_docs.jsonl") + shutil.copy2(jsonl_file_path, copied_jsonl_path) return test_output_folder, questions diff --git a/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py b/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py new file mode 100644 index 000000000..998b2932c --- /dev/null +++ b/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py @@ -0,0 +1,132 @@ +from typing import Any +from typing import cast +from unittest.mock import Mock + +import pytest +from pytest_mock import MockerFixture + +from danswer.llm.answering.answer import Answer +from danswer.one_shot_answer.answer_question import AnswerObjectIterator +from danswer.tools.force import ForceUseTool +from tests.regression.answer_quality.run_qa import _process_and_write_query_results + + +@pytest.mark.parametrize( + "config", + [ + { + "skip_gen_ai_answer_generation": True, + "question": "What is the capital of the moon?", + }, + { + "skip_gen_ai_answer_generation": False, + "question": "What is the capital of the moon but twice?", + }, + ], +) +def test_skip_gen_ai_answer_generation_flag(config: dict[str, Any]) -> None: + search_tool = Mock() + search_tool.name = "search" + search_tool.run = Mock() + search_tool.run.return_value = [Mock()] + mock_llm = Mock() + mock_llm.config = Mock() + mock_llm.config.model_name = "gpt-4o-mini" + mock_llm.stream = Mock() + mock_llm.stream.return_value = [Mock()] + answer = Answer( + question=config["question"], + answer_style_config=Mock(), + prompt_config=Mock(), + llm=mock_llm, + single_message_history="history", + tools=[search_tool], + force_use_tool=( + ForceUseTool( + tool_name=search_tool.name, + args={"query": config["question"]}, + force_use=True, + ) + ), + skip_explicit_tool_calling=True, + return_contexts=True, + skip_gen_ai_answer_generation=config["skip_gen_ai_answer_generation"], + ) + count = 0 + for _ in cast(AnswerObjectIterator, answer.processed_streamed_output): + count += 1 + assert count == 2 + if not config["skip_gen_ai_answer_generation"]: + mock_llm.stream.assert_called_once() + else: + mock_llm.stream.assert_not_called() + + +##### From here down is the client side test that was not working ##### + + +class FinishedTestException(Exception): + pass + + +# could not get this to work, it seems like the mock is not being used +# tests that the main run_qa function passes the skip_gen_ai_answer_generation flag to the Answer object +@pytest.mark.parametrize( + "config, questions", + [ + ( + { + "skip_gen_ai_answer_generation": True, + "output_folder": "./test_output_folder", + "zipped_documents_file": "./test_docs.jsonl", + "questions_file": "./test_questions.jsonl", + "commit_sha": None, + "launch_web_ui": False, + "only_retrieve_docs": True, + "use_cloud_gpu": False, + "model_server_ip": "PUT_PUBLIC_CLOUD_IP_HERE", + "model_server_port": "PUT_PUBLIC_CLOUD_PORT_HERE", + "environment_name": "", + "env_name": "", + "limit": None, + }, + [{"uid": "1", "question": "What is the capital of the moon?"}], + ), + ( + { + "skip_gen_ai_answer_generation": False, + "output_folder": "./test_output_folder", + "zipped_documents_file": "./test_docs.jsonl", + "questions_file": "./test_questions.jsonl", + "commit_sha": None, + "launch_web_ui": False, + "only_retrieve_docs": True, + "use_cloud_gpu": False, + "model_server_ip": "PUT_PUBLIC_CLOUD_IP_HERE", + "model_server_port": "PUT_PUBLIC_CLOUD_PORT_HERE", + "environment_name": "", + "env_name": "", + "limit": None, + }, + [{"uid": "1", "question": "What is the capital of the moon but twice?"}], + ), + ], +) +@pytest.mark.skip(reason="not working") +def test_run_qa_skip_gen_ai( + config: dict[str, Any], questions: list[dict[str, Any]], mocker: MockerFixture +) -> None: + mocker.patch( + "tests.regression.answer_quality.run_qa._initialize_files", + return_value=("test", questions), + ) + + def arg_checker(question_data: dict, config: dict, question_number: int) -> None: + assert question_data == questions[0] + raise FinishedTestException() + + mocker.patch( + "tests.regression.answer_quality.run_qa._process_question", arg_checker + ) + with pytest.raises(FinishedTestException): + _process_and_write_query_results(config)