mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Evaluate LLM Answers via Reflexion (#430)
This commit is contained in:
@@ -220,3 +220,8 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
|
||||
"DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
|
||||
).lower() not in ["false", ""]
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
)
|
||||
|
@@ -2,6 +2,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
@@ -18,6 +19,7 @@ from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.semantic_search import chunks_to_search_docs
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -34,6 +36,7 @@ def answer_qa_query(
|
||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||
answer_generation_timeout: int = QA_TIMEOUT,
|
||||
real_time_flow: bool = True,
|
||||
enable_reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
) -> QAResponse:
|
||||
query = question.query
|
||||
filters = question.filters
|
||||
@@ -123,14 +126,31 @@ def answer_qa_query(
|
||||
|
||||
error_msg = None
|
||||
try:
|
||||
answer, quotes = qa_model.answer_question(query, usable_chunks)
|
||||
d_answer, quotes = qa_model.answer_question(query, usable_chunks)
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
answer, quotes = None, None
|
||||
d_answer, quotes = None, None
|
||||
error_msg = f"Error occurred in call to LLM - {e}"
|
||||
|
||||
if not real_time_flow and enable_reflexion and d_answer is not None:
|
||||
valid = False
|
||||
if d_answer.answer is not None:
|
||||
valid = get_answer_validity(query, d_answer.answer)
|
||||
|
||||
if not valid:
|
||||
return QAResponse(
|
||||
answer=None,
|
||||
quotes=None,
|
||||
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
error_msg=error_msg,
|
||||
query_event_id=query_event_id,
|
||||
)
|
||||
|
||||
return QAResponse(
|
||||
answer=answer.answer if answer else None,
|
||||
answer=d_answer.answer if d_answer else None,
|
||||
quotes=quotes.quotes if quotes else None,
|
||||
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||
|
@@ -114,8 +114,8 @@ class SingleMessageQAHandler(QAHandler):
|
||||
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}"
|
||||
f"{GENERAL_SEP_PAT}Sample response:"
|
||||
f"{CODE_BLOCK_PAT.format(json.dumps(EMPTY_SAMPLE_JSON))}\n"
|
||||
f"{QUESTION_PAT} {query}"
|
||||
"\nHint: Make the answer as detailed as possible and use a JSON! "
|
||||
f"{QUESTION_PAT} {query}\n"
|
||||
"Hint: Make the answer as detailed as possible and use a JSON! "
|
||||
"Quotes MUST be EXACT substrings from provided documents!"
|
||||
)
|
||||
]
|
||||
@@ -127,7 +127,7 @@ class SingleMessageScratchpadHandler(QAHandler):
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[BaseMessage]:
|
||||
cot_block = (
|
||||
f"{THOUGHT_PAT} Let's think step by step. Use this section as a scratchpad.\n"
|
||||
f"{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.\n\n"
|
||||
f"{json.dumps(EMPTY_SAMPLE_JSON)}"
|
||||
)
|
||||
|
||||
@@ -141,9 +141,11 @@ class SingleMessageScratchpadHandler(QAHandler):
|
||||
"You can process and comprehend vast amounts of text and utilize this knowledge "
|
||||
"to provide accurate and detailed answers to diverse queries.\n"
|
||||
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}{GENERAL_SEP_PAT}"
|
||||
f"You MUST use the following format:\n"
|
||||
f"You MUST respond in the following format:"
|
||||
f"{CODE_BLOCK_PAT.format(cot_block)}\n"
|
||||
f"Begin!\n{QUESTION_PAT} {query}"
|
||||
f"{QUESTION_PAT} {query}\n"
|
||||
"Hint: Make the answer as detailed as possible and use a JSON! "
|
||||
"Quotes can ONLY be EXACT substrings from provided documents!"
|
||||
)
|
||||
]
|
||||
return prompt
|
||||
|
@@ -17,6 +17,7 @@ FINAL_ANSWER_PAT = "Final Answer:"
|
||||
UNCERTAINTY_PAT = "?"
|
||||
QUOTE_PAT = "Quote:"
|
||||
QUOTES_PAT_PLURAL = "Quotes:"
|
||||
INVALID_PAT = "Invalid:"
|
||||
|
||||
BASE_PROMPT = (
|
||||
"Answer the query based on provided documents and quote relevant sections. "
|
||||
@@ -38,8 +39,8 @@ SAMPLE_JSON_RESPONSE = {
|
||||
EMPTY_SAMPLE_JSON = {
|
||||
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
|
||||
"quotes": [
|
||||
"each quote must be UNEDITED and EXACTLY as shown in the provided documents!",
|
||||
"HINT the quotes are not shown to the user!",
|
||||
"each quote must be UNEDITED and EXACTLY as shown in the context documents!",
|
||||
"HINT, quotes are not shown to the user!",
|
||||
],
|
||||
}
|
||||
|
||||
|
67
backend/danswer/secondary_llm_flows/answer_validation.py
Normal file
67
backend/danswer/secondary_llm_flows/answer_validation.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
||||
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
||||
from danswer.direct_qa.qa_prompts import INVALID_PAT
|
||||
from danswer.direct_qa.qa_prompts import QUESTION_PAT
|
||||
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_answer_validation_messages(query: str, answer: str) -> list[dict[str, str]]:
|
||||
cot_block = (
|
||||
f"{THOUGHT_PAT} Use this as a scratchpad to write out in a step by step manner your reasoning "
|
||||
f"about EACH criterion to ensure that your conclusion is correct.\n"
|
||||
f"{INVALID_PAT} True or False"
|
||||
)
|
||||
|
||||
q_a_block = f"{QUESTION_PAT} {query}\n\n" f"{ANSWER_PAT} {answer}"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"{CODE_BLOCK_PAT.format(q_a_block).lstrip()}{GENERAL_SEP_PAT}\n"
|
||||
"Determine if the answer is valid for the query.\n"
|
||||
f"The answer is invalid if ANY of the following is true:\n"
|
||||
"- Does not directly answer the user query.\n"
|
||||
"- Answers a related but different question.\n"
|
||||
'- Contains anything meaning "I don\'t know" or "information not found".\n\n'
|
||||
f"You must use the following format:"
|
||||
f"{CODE_BLOCK_PAT.format(cot_block)}"
|
||||
f'Hint: Invalid must be exactly "True" or "False" (without the quotes)'
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def extract_validity(model_output: str) -> bool:
|
||||
if INVALID_PAT in model_output:
|
||||
result = model_output.split(INVALID_PAT)[-1].strip()
|
||||
if "true" in result.lower():
|
||||
return False
|
||||
return True # If something is wrong, let's not toss away the answer
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def get_answer_validity(
|
||||
query: str,
|
||||
answer: str,
|
||||
) -> bool:
|
||||
messages = get_answer_validation_messages(query, answer)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
validity = extract_validity(model_output)
|
||||
logger.info(
|
||||
f'LLM Answer of "{answer}" was determined to be {"valid" if validity else "invalid"}.'
|
||||
)
|
||||
|
||||
return validity
|
Reference in New Issue
Block a user