Evaluate LLM Answers via Reflexion (#430)

This commit is contained in:
Yuhong Sun
2023-09-11 14:45:13 -07:00
committed by GitHub
parent ddfa8cf8a6
commit 9316b78f47
5 changed files with 105 additions and 10 deletions

View File

@@ -220,3 +220,8 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
"DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
).lower() not in ["false", ""]
# Add a second LLM call post Answer to verify if the Answer is valid
# Throws out answers that don't directly or fully answer the user query
ENABLE_DANSWERBOT_REFLEXION = (
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
)

View File

@@ -2,6 +2,7 @@ from sqlalchemy.orm import Session
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import IGNORE_FOR_QA
@@ -18,6 +19,7 @@ from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.utils.logger import setup_logger
@@ -34,6 +36,7 @@ def answer_qa_query(
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
answer_generation_timeout: int = QA_TIMEOUT,
real_time_flow: bool = True,
enable_reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
) -> QAResponse:
query = question.query
filters = question.filters
@@ -123,14 +126,31 @@ def answer_qa_query(
error_msg = None
try:
answer, quotes = qa_model.answer_question(query, usable_chunks)
d_answer, quotes = qa_model.answer_question(query, usable_chunks)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
answer, quotes = None, None
d_answer, quotes = None, None
error_msg = f"Error occurred in call to LLM - {e}"
if not real_time_flow and enable_reflexion and d_answer is not None:
valid = False
if d_answer.answer is not None:
valid = get_answer_validity(query, d_answer.answer)
if not valid:
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=error_msg,
query_event_id=query_event_id,
)
return QAResponse(
answer=answer.answer if answer else None,
answer=d_answer.answer if d_answer else None,
quotes=quotes.quotes if quotes else None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),

View File

@@ -114,8 +114,8 @@ class SingleMessageQAHandler(QAHandler):
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}"
f"{GENERAL_SEP_PAT}Sample response:"
f"{CODE_BLOCK_PAT.format(json.dumps(EMPTY_SAMPLE_JSON))}\n"
f"{QUESTION_PAT} {query}"
"\nHint: Make the answer as detailed as possible and use a JSON! "
f"{QUESTION_PAT} {query}\n"
"Hint: Make the answer as detailed as possible and use a JSON! "
"Quotes MUST be EXACT substrings from provided documents!"
)
]
@@ -127,7 +127,7 @@ class SingleMessageScratchpadHandler(QAHandler):
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
cot_block = (
f"{THOUGHT_PAT} Let's think step by step. Use this section as a scratchpad.\n"
f"{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.\n\n"
f"{json.dumps(EMPTY_SAMPLE_JSON)}"
)
@@ -141,9 +141,11 @@ class SingleMessageScratchpadHandler(QAHandler):
"You can process and comprehend vast amounts of text and utilize this knowledge "
"to provide accurate and detailed answers to diverse queries.\n"
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}{GENERAL_SEP_PAT}"
f"You MUST use the following format:\n"
f"You MUST respond in the following format:"
f"{CODE_BLOCK_PAT.format(cot_block)}\n"
f"Begin!\n{QUESTION_PAT} {query}"
f"{QUESTION_PAT} {query}\n"
"Hint: Make the answer as detailed as possible and use a JSON! "
"Quotes can ONLY be EXACT substrings from provided documents!"
)
]
return prompt

View File

@@ -17,6 +17,7 @@ FINAL_ANSWER_PAT = "Final Answer:"
UNCERTAINTY_PAT = "?"
QUOTE_PAT = "Quote:"
QUOTES_PAT_PLURAL = "Quotes:"
INVALID_PAT = "Invalid:"
BASE_PROMPT = (
"Answer the query based on provided documents and quote relevant sections. "
@@ -38,8 +39,8 @@ SAMPLE_JSON_RESPONSE = {
EMPTY_SAMPLE_JSON = {
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
"quotes": [
"each quote must be UNEDITED and EXACTLY as shown in the provided documents!",
"HINT the quotes are not shown to the user!",
"each quote must be UNEDITED and EXACTLY as shown in the context documents!",
"HINT, quotes are not shown to the user!",
],
}

View File

@@ -0,0 +1,67 @@
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
from danswer.direct_qa.qa_prompts import INVALID_PAT
from danswer.direct_qa.qa_prompts import QUESTION_PAT
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
from danswer.llm.build import get_default_llm
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
def get_answer_validation_messages(query: str, answer: str) -> list[dict[str, str]]:
cot_block = (
f"{THOUGHT_PAT} Use this as a scratchpad to write out in a step by step manner your reasoning "
f"about EACH criterion to ensure that your conclusion is correct.\n"
f"{INVALID_PAT} True or False"
)
q_a_block = f"{QUESTION_PAT} {query}\n\n" f"{ANSWER_PAT} {answer}"
messages = [
{
"role": "user",
"content": (
f"{CODE_BLOCK_PAT.format(q_a_block).lstrip()}{GENERAL_SEP_PAT}\n"
"Determine if the answer is valid for the query.\n"
f"The answer is invalid if ANY of the following is true:\n"
"- Does not directly answer the user query.\n"
"- Answers a related but different question.\n"
'- Contains anything meaning "I don\'t know" or "information not found".\n\n'
f"You must use the following format:"
f"{CODE_BLOCK_PAT.format(cot_block)}"
f'Hint: Invalid must be exactly "True" or "False" (without the quotes)'
),
},
]
return messages
def extract_validity(model_output: str) -> bool:
if INVALID_PAT in model_output:
result = model_output.split(INVALID_PAT)[-1].strip()
if "true" in result.lower():
return False
return True # If something is wrong, let's not toss away the answer
@log_function_time()
def get_answer_validity(
query: str,
answer: str,
) -> bool:
messages = get_answer_validation_messages(query, answer)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
logger.debug(model_output)
validity = extract_validity(model_output)
logger.info(
f'LLM Answer of "{answer}" was determined to be {"valid" if validity else "invalid"}.'
)
return validity