From 9316b78f47a5be5774e0e520c128960f0de7fee4 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Mon, 11 Sep 2023 14:45:13 -0700 Subject: [PATCH] Evaluate LLM Answers via Reflexion (#430) --- backend/danswer/configs/app_configs.py | 5 ++ backend/danswer/direct_qa/answer_question.py | 26 ++++++- backend/danswer/direct_qa/qa_block.py | 12 ++-- backend/danswer/direct_qa/qa_prompts.py | 5 +- .../secondary_llm_flows/answer_validation.py | 67 +++++++++++++++++++ 5 files changed, 105 insertions(+), 10 deletions(-) create mode 100644 backend/danswer/secondary_llm_flows/answer_validation.py diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index f31d64650a67..e52d29d6881c 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -220,3 +220,8 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get( DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get( "DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER", "" ).lower() not in ["false", ""] +# Add a second LLM call post Answer to verify if the Answer is valid +# Throws out answers that don't directly or fully answer the user query +ENABLE_DANSWERBOT_REFLEXION = ( + os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true" +) diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index 69466195241d..63e450c8b6aa 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -2,6 +2,7 @@ from sqlalchemy.orm import Session from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DISABLE_GENERATIVE_AI +from danswer.configs.app_configs import ENABLE_DANSWERBOT_REFLEXION from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.constants import IGNORE_FOR_QA @@ -18,6 +19,7 @@ from danswer.search.models import QueryFlow from danswer.search.models import SearchType from danswer.search.semantic_search import chunks_to_search_docs from danswer.search.semantic_search import retrieve_ranked_documents +from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.server.models import QAResponse from danswer.server.models import QuestionRequest from danswer.utils.logger import setup_logger @@ -34,6 +36,7 @@ def answer_qa_query( disable_generative_answer: bool = DISABLE_GENERATIVE_AI, answer_generation_timeout: int = QA_TIMEOUT, real_time_flow: bool = True, + enable_reflexion: bool = ENABLE_DANSWERBOT_REFLEXION, ) -> QAResponse: query = question.query filters = question.filters @@ -123,14 +126,31 @@ def answer_qa_query( error_msg = None try: - answer, quotes = qa_model.answer_question(query, usable_chunks) + d_answer, quotes = qa_model.answer_question(query, usable_chunks) except Exception as e: # exception is logged in the answer_question method, no need to re-log - answer, quotes = None, None + d_answer, quotes = None, None error_msg = f"Error occurred in call to LLM - {e}" + if not real_time_flow and enable_reflexion and d_answer is not None: + valid = False + if d_answer.answer is not None: + valid = get_answer_validity(query, d_answer.answer) + + if not valid: + return QAResponse( + answer=None, + quotes=None, + top_ranked_docs=chunks_to_search_docs(ranked_chunks), + lower_ranked_docs=chunks_to_search_docs(unranked_chunks), + predicted_flow=predicted_flow, + predicted_search=predicted_search, + error_msg=error_msg, + query_event_id=query_event_id, + ) + return QAResponse( - answer=answer.answer if answer else None, + answer=d_answer.answer if d_answer else None, quotes=quotes.quotes if quotes else None, top_ranked_docs=chunks_to_search_docs(ranked_chunks), lower_ranked_docs=chunks_to_search_docs(unranked_chunks), diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index 9e474e596e85..46f07e176629 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -114,8 +114,8 @@ class SingleMessageQAHandler(QAHandler): f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}" f"{GENERAL_SEP_PAT}Sample response:" f"{CODE_BLOCK_PAT.format(json.dumps(EMPTY_SAMPLE_JSON))}\n" - f"{QUESTION_PAT} {query}" - "\nHint: Make the answer as detailed as possible and use a JSON! " + f"{QUESTION_PAT} {query}\n" + "Hint: Make the answer as detailed as possible and use a JSON! " "Quotes MUST be EXACT substrings from provided documents!" ) ] @@ -127,7 +127,7 @@ class SingleMessageScratchpadHandler(QAHandler): self, query: str, context_chunks: list[InferenceChunk] ) -> list[BaseMessage]: cot_block = ( - f"{THOUGHT_PAT} Let's think step by step. Use this section as a scratchpad.\n" + f"{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.\n\n" f"{json.dumps(EMPTY_SAMPLE_JSON)}" ) @@ -141,9 +141,11 @@ class SingleMessageScratchpadHandler(QAHandler): "You can process and comprehend vast amounts of text and utilize this knowledge " "to provide accurate and detailed answers to diverse queries.\n" f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}{GENERAL_SEP_PAT}" - f"You MUST use the following format:\n" + f"You MUST respond in the following format:" f"{CODE_BLOCK_PAT.format(cot_block)}\n" - f"Begin!\n{QUESTION_PAT} {query}" + f"{QUESTION_PAT} {query}\n" + "Hint: Make the answer as detailed as possible and use a JSON! " + "Quotes can ONLY be EXACT substrings from provided documents!" ) ] return prompt diff --git a/backend/danswer/direct_qa/qa_prompts.py b/backend/danswer/direct_qa/qa_prompts.py index 0e998dbc01b8..5eb6813fbdbf 100644 --- a/backend/danswer/direct_qa/qa_prompts.py +++ b/backend/danswer/direct_qa/qa_prompts.py @@ -17,6 +17,7 @@ FINAL_ANSWER_PAT = "Final Answer:" UNCERTAINTY_PAT = "?" QUOTE_PAT = "Quote:" QUOTES_PAT_PLURAL = "Quotes:" +INVALID_PAT = "Invalid:" BASE_PROMPT = ( "Answer the query based on provided documents and quote relevant sections. " @@ -38,8 +39,8 @@ SAMPLE_JSON_RESPONSE = { EMPTY_SAMPLE_JSON = { "answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.", "quotes": [ - "each quote must be UNEDITED and EXACTLY as shown in the provided documents!", - "HINT the quotes are not shown to the user!", + "each quote must be UNEDITED and EXACTLY as shown in the context documents!", + "HINT, quotes are not shown to the user!", ], } diff --git a/backend/danswer/secondary_llm_flows/answer_validation.py b/backend/danswer/secondary_llm_flows/answer_validation.py new file mode 100644 index 000000000000..6fe74aa4a26d --- /dev/null +++ b/backend/danswer/secondary_llm_flows/answer_validation.py @@ -0,0 +1,67 @@ +from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt +from danswer.direct_qa.qa_prompts import ANSWER_PAT +from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT +from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT +from danswer.direct_qa.qa_prompts import INVALID_PAT +from danswer.direct_qa.qa_prompts import QUESTION_PAT +from danswer.direct_qa.qa_prompts import THOUGHT_PAT +from danswer.llm.build import get_default_llm +from danswer.utils.logger import setup_logger +from danswer.utils.timing import log_function_time + +logger = setup_logger() + + +def get_answer_validation_messages(query: str, answer: str) -> list[dict[str, str]]: + cot_block = ( + f"{THOUGHT_PAT} Use this as a scratchpad to write out in a step by step manner your reasoning " + f"about EACH criterion to ensure that your conclusion is correct.\n" + f"{INVALID_PAT} True or False" + ) + + q_a_block = f"{QUESTION_PAT} {query}\n\n" f"{ANSWER_PAT} {answer}" + + messages = [ + { + "role": "user", + "content": ( + f"{CODE_BLOCK_PAT.format(q_a_block).lstrip()}{GENERAL_SEP_PAT}\n" + "Determine if the answer is valid for the query.\n" + f"The answer is invalid if ANY of the following is true:\n" + "- Does not directly answer the user query.\n" + "- Answers a related but different question.\n" + '- Contains anything meaning "I don\'t know" or "information not found".\n\n' + f"You must use the following format:" + f"{CODE_BLOCK_PAT.format(cot_block)}" + f'Hint: Invalid must be exactly "True" or "False" (without the quotes)' + ), + }, + ] + + return messages + + +def extract_validity(model_output: str) -> bool: + if INVALID_PAT in model_output: + result = model_output.split(INVALID_PAT)[-1].strip() + if "true" in result.lower(): + return False + return True # If something is wrong, let's not toss away the answer + + +@log_function_time() +def get_answer_validity( + query: str, + answer: str, +) -> bool: + messages = get_answer_validation_messages(query, answer) + filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) + model_output = get_default_llm().invoke(filled_llm_prompt) + logger.debug(model_output) + + validity = extract_validity(model_output) + logger.info( + f'LLM Answer of "{answer}" was determined to be {"valid" if validity else "invalid"}.' + ) + + return validity