mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-22 17:16:20 +02:00
Evaluate LLM Answers via Reflexion (#430)
This commit is contained in:
@@ -220,3 +220,8 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
|||||||
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
|
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
|
||||||
"DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
|
"DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
|
||||||
).lower() not in ["false", ""]
|
).lower() not in ["false", ""]
|
||||||
|
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||||
|
# Throws out answers that don't directly or fully answer the user query
|
||||||
|
ENABLE_DANSWERBOT_REFLEXION = (
|
||||||
|
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||||
|
)
|
||||||
|
@@ -2,6 +2,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
|
from danswer.configs.app_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import IGNORE_FOR_QA
|
from danswer.configs.constants import IGNORE_FOR_QA
|
||||||
@@ -18,6 +19,7 @@ from danswer.search.models import QueryFlow
|
|||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
from danswer.search.semantic_search import chunks_to_search_docs
|
from danswer.search.semantic_search import chunks_to_search_docs
|
||||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||||
|
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -34,6 +36,7 @@ def answer_qa_query(
|
|||||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||||
answer_generation_timeout: int = QA_TIMEOUT,
|
answer_generation_timeout: int = QA_TIMEOUT,
|
||||||
real_time_flow: bool = True,
|
real_time_flow: bool = True,
|
||||||
|
enable_reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||||
) -> QAResponse:
|
) -> QAResponse:
|
||||||
query = question.query
|
query = question.query
|
||||||
filters = question.filters
|
filters = question.filters
|
||||||
@@ -123,14 +126,31 @@ def answer_qa_query(
|
|||||||
|
|
||||||
error_msg = None
|
error_msg = None
|
||||||
try:
|
try:
|
||||||
answer, quotes = qa_model.answer_question(query, usable_chunks)
|
d_answer, quotes = qa_model.answer_question(query, usable_chunks)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# exception is logged in the answer_question method, no need to re-log
|
# exception is logged in the answer_question method, no need to re-log
|
||||||
answer, quotes = None, None
|
d_answer, quotes = None, None
|
||||||
error_msg = f"Error occurred in call to LLM - {e}"
|
error_msg = f"Error occurred in call to LLM - {e}"
|
||||||
|
|
||||||
|
if not real_time_flow and enable_reflexion and d_answer is not None:
|
||||||
|
valid = False
|
||||||
|
if d_answer.answer is not None:
|
||||||
|
valid = get_answer_validity(query, d_answer.answer)
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
return QAResponse(
|
||||||
|
answer=None,
|
||||||
|
quotes=None,
|
||||||
|
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||||
|
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||||
|
predicted_flow=predicted_flow,
|
||||||
|
predicted_search=predicted_search,
|
||||||
|
error_msg=error_msg,
|
||||||
|
query_event_id=query_event_id,
|
||||||
|
)
|
||||||
|
|
||||||
return QAResponse(
|
return QAResponse(
|
||||||
answer=answer.answer if answer else None,
|
answer=d_answer.answer if d_answer else None,
|
||||||
quotes=quotes.quotes if quotes else None,
|
quotes=quotes.quotes if quotes else None,
|
||||||
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||||
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||||
|
@@ -114,8 +114,8 @@ class SingleMessageQAHandler(QAHandler):
|
|||||||
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}"
|
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}"
|
||||||
f"{GENERAL_SEP_PAT}Sample response:"
|
f"{GENERAL_SEP_PAT}Sample response:"
|
||||||
f"{CODE_BLOCK_PAT.format(json.dumps(EMPTY_SAMPLE_JSON))}\n"
|
f"{CODE_BLOCK_PAT.format(json.dumps(EMPTY_SAMPLE_JSON))}\n"
|
||||||
f"{QUESTION_PAT} {query}"
|
f"{QUESTION_PAT} {query}\n"
|
||||||
"\nHint: Make the answer as detailed as possible and use a JSON! "
|
"Hint: Make the answer as detailed as possible and use a JSON! "
|
||||||
"Quotes MUST be EXACT substrings from provided documents!"
|
"Quotes MUST be EXACT substrings from provided documents!"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -127,7 +127,7 @@ class SingleMessageScratchpadHandler(QAHandler):
|
|||||||
self, query: str, context_chunks: list[InferenceChunk]
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
cot_block = (
|
cot_block = (
|
||||||
f"{THOUGHT_PAT} Let's think step by step. Use this section as a scratchpad.\n"
|
f"{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.\n\n"
|
||||||
f"{json.dumps(EMPTY_SAMPLE_JSON)}"
|
f"{json.dumps(EMPTY_SAMPLE_JSON)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,9 +141,11 @@ class SingleMessageScratchpadHandler(QAHandler):
|
|||||||
"You can process and comprehend vast amounts of text and utilize this knowledge "
|
"You can process and comprehend vast amounts of text and utilize this knowledge "
|
||||||
"to provide accurate and detailed answers to diverse queries.\n"
|
"to provide accurate and detailed answers to diverse queries.\n"
|
||||||
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}{GENERAL_SEP_PAT}"
|
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}{GENERAL_SEP_PAT}"
|
||||||
f"You MUST use the following format:\n"
|
f"You MUST respond in the following format:"
|
||||||
f"{CODE_BLOCK_PAT.format(cot_block)}\n"
|
f"{CODE_BLOCK_PAT.format(cot_block)}\n"
|
||||||
f"Begin!\n{QUESTION_PAT} {query}"
|
f"{QUESTION_PAT} {query}\n"
|
||||||
|
"Hint: Make the answer as detailed as possible and use a JSON! "
|
||||||
|
"Quotes can ONLY be EXACT substrings from provided documents!"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return prompt
|
return prompt
|
||||||
|
@@ -17,6 +17,7 @@ FINAL_ANSWER_PAT = "Final Answer:"
|
|||||||
UNCERTAINTY_PAT = "?"
|
UNCERTAINTY_PAT = "?"
|
||||||
QUOTE_PAT = "Quote:"
|
QUOTE_PAT = "Quote:"
|
||||||
QUOTES_PAT_PLURAL = "Quotes:"
|
QUOTES_PAT_PLURAL = "Quotes:"
|
||||||
|
INVALID_PAT = "Invalid:"
|
||||||
|
|
||||||
BASE_PROMPT = (
|
BASE_PROMPT = (
|
||||||
"Answer the query based on provided documents and quote relevant sections. "
|
"Answer the query based on provided documents and quote relevant sections. "
|
||||||
@@ -38,8 +39,8 @@ SAMPLE_JSON_RESPONSE = {
|
|||||||
EMPTY_SAMPLE_JSON = {
|
EMPTY_SAMPLE_JSON = {
|
||||||
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
|
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
|
||||||
"quotes": [
|
"quotes": [
|
||||||
"each quote must be UNEDITED and EXACTLY as shown in the provided documents!",
|
"each quote must be UNEDITED and EXACTLY as shown in the context documents!",
|
||||||
"HINT the quotes are not shown to the user!",
|
"HINT, quotes are not shown to the user!",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
67
backend/danswer/secondary_llm_flows/answer_validation.py
Normal file
67
backend/danswer/secondary_llm_flows/answer_validation.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||||
|
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import INVALID_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import QUESTION_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
|
||||||
|
from danswer.llm.build import get_default_llm
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_answer_validation_messages(query: str, answer: str) -> list[dict[str, str]]:
|
||||||
|
cot_block = (
|
||||||
|
f"{THOUGHT_PAT} Use this as a scratchpad to write out in a step by step manner your reasoning "
|
||||||
|
f"about EACH criterion to ensure that your conclusion is correct.\n"
|
||||||
|
f"{INVALID_PAT} True or False"
|
||||||
|
)
|
||||||
|
|
||||||
|
q_a_block = f"{QUESTION_PAT} {query}\n\n" f"{ANSWER_PAT} {answer}"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
f"{CODE_BLOCK_PAT.format(q_a_block).lstrip()}{GENERAL_SEP_PAT}\n"
|
||||||
|
"Determine if the answer is valid for the query.\n"
|
||||||
|
f"The answer is invalid if ANY of the following is true:\n"
|
||||||
|
"- Does not directly answer the user query.\n"
|
||||||
|
"- Answers a related but different question.\n"
|
||||||
|
'- Contains anything meaning "I don\'t know" or "information not found".\n\n'
|
||||||
|
f"You must use the following format:"
|
||||||
|
f"{CODE_BLOCK_PAT.format(cot_block)}"
|
||||||
|
f'Hint: Invalid must be exactly "True" or "False" (without the quotes)'
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def extract_validity(model_output: str) -> bool:
|
||||||
|
if INVALID_PAT in model_output:
|
||||||
|
result = model_output.split(INVALID_PAT)[-1].strip()
|
||||||
|
if "true" in result.lower():
|
||||||
|
return False
|
||||||
|
return True # If something is wrong, let's not toss away the answer
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def get_answer_validity(
|
||||||
|
query: str,
|
||||||
|
answer: str,
|
||||||
|
) -> bool:
|
||||||
|
messages = get_answer_validation_messages(query, answer)
|
||||||
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
|
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
validity = extract_validity(model_output)
|
||||||
|
logger.info(
|
||||||
|
f'LLM Answer of "{answer}" was determined to be {"valid" if validity else "invalid"}.'
|
||||||
|
)
|
||||||
|
|
||||||
|
return validity
|
Reference in New Issue
Block a user