mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-22 17:16:20 +02:00
Better QA Prompts (#409)
This commit is contained in:
@@ -77,7 +77,7 @@ GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
|
|||||||
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
|
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
|
||||||
|
|
||||||
# Set this to be enough for an answer + quotes
|
# Set this to be enough for an answer + quotes
|
||||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512"))
|
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "1024"))
|
||||||
|
|
||||||
# Danswer custom Deep Learning Models
|
# Danswer custom Deep Learning Models
|
||||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||||
|
@@ -18,10 +18,10 @@ from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
|||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.local_transformers import TransformerQA
|
from danswer.direct_qa.local_transformers import TransformerQA
|
||||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||||
from danswer.direct_qa.qa_block import JsonChatQAHandler
|
|
||||||
from danswer.direct_qa.qa_block import QABlock
|
from danswer.direct_qa.qa_block import QABlock
|
||||||
from danswer.direct_qa.qa_block import QAHandler
|
from danswer.direct_qa.qa_block import QAHandler
|
||||||
from danswer.direct_qa.qa_block import SimpleChatQAHandler
|
from danswer.direct_qa.qa_block import SimpleChatQAHandler
|
||||||
|
from danswer.direct_qa.qa_block import SingleMessageQAHandler
|
||||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||||
from danswer.direct_qa.request_model import RequestCompletionQA
|
from danswer.direct_qa.request_model import RequestCompletionQA
|
||||||
@@ -53,7 +53,7 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
|||||||
|
|
||||||
def get_default_qa_handler(model: str) -> QAHandler:
|
def get_default_qa_handler(model: str) -> QAHandler:
|
||||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||||
return JsonChatQAHandler()
|
return SingleMessageQAHandler()
|
||||||
|
|
||||||
return SimpleChatQAHandler()
|
return SimpleChatQAHandler()
|
||||||
|
|
||||||
|
@@ -16,7 +16,10 @@ from danswer.direct_qa.interfaces import DanswerAnswer
|
|||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
|
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
||||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||||
|
from danswer.direct_qa.qa_prompts import QUESTION_PAT
|
||||||
from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE
|
from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE
|
||||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||||
@@ -93,6 +96,49 @@ class SimpleChatQAHandler(QAHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleMessageQAHandler(QAHandler):
|
||||||
|
def build_prompt(
|
||||||
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
complete_answer_not_found_response = (
|
||||||
|
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
||||||
|
)
|
||||||
|
|
||||||
|
context_docs_str = "\n".join(
|
||||||
|
f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt: list[BaseMessage] = [
|
||||||
|
HumanMessage(
|
||||||
|
content="You are a question answering system that is constantly learning and improving. "
|
||||||
|
"You can process and comprehend vast amounts of text and utilize this knowledge "
|
||||||
|
"to provide accurate and detailed answers to diverse queries.\n"
|
||||||
|
"You ALWAYS responds in a json containing an answer and quotes that support the answer.\n"
|
||||||
|
"Your responses are as informative and detailed as possible.\n"
|
||||||
|
"If you don't know the answer, respond with "
|
||||||
|
f"{CODE_BLOCK_PAT.format(complete_answer_not_found_response)}"
|
||||||
|
"\nSample response:"
|
||||||
|
f"{CODE_BLOCK_PAT.format(json.dumps(SAMPLE_JSON_RESPONSE))}"
|
||||||
|
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}"
|
||||||
|
f"{GENERAL_SEP_PAT}{QUESTION_PAT} {query}"
|
||||||
|
"\nHint: Make the answer as informative as possible and use a JSON! "
|
||||||
|
"Quotes MUST be EXACT substrings from provided documents!"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
tokens: Iterator[str],
|
||||||
|
context_chunks: list[InferenceChunk],
|
||||||
|
) -> AnswerQuestionStreamReturn:
|
||||||
|
yield from process_model_tokens(
|
||||||
|
tokens=tokens,
|
||||||
|
context_docs=context_chunks,
|
||||||
|
is_json_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JsonChatQAUnshackledHandler(QAHandler):
|
class JsonChatQAUnshackledHandler(QAHandler):
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
self, query: str, context_chunks: list[InferenceChunk]
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
|
@@ -6,7 +6,8 @@ from danswer.configs.constants import DocumentSource
|
|||||||
from danswer.connectors.factory import identify_connector_class
|
from danswer.connectors.factory import identify_connector_class
|
||||||
|
|
||||||
|
|
||||||
GENERAL_SEP_PAT = "---\n"
|
GENERAL_SEP_PAT = "\n-----\n"
|
||||||
|
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
|
||||||
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
||||||
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
||||||
QUESTION_PAT = "Query:"
|
QUESTION_PAT = "Query:"
|
||||||
|
@@ -246,6 +246,9 @@ def process_model_tokens(
|
|||||||
json_answer_ind = model_output.index('{"answer":')
|
json_answer_ind = model_output.index('{"answer":')
|
||||||
if json_answer_ind != 0:
|
if json_answer_ind != 0:
|
||||||
model_output = model_output[json_answer_ind:]
|
model_output = model_output[json_answer_ind:]
|
||||||
|
end = model_output.rfind("}")
|
||||||
|
if end != -1:
|
||||||
|
model_output = model_output[: end + 1]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.exception("Did not find answer pattern in response for JSON prompt")
|
logger.exception("Did not find answer pattern in response for JSON prompt")
|
||||||
|
|
||||||
|
@@ -4,50 +4,59 @@ from dataclasses import asdict
|
|||||||
|
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||||
|
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
||||||
from danswer.llm.build import get_default_llm
|
from danswer.llm.build import get_default_llm
|
||||||
from danswer.server.models import QueryValidationResponse
|
from danswer.server.models import QueryValidationResponse
|
||||||
from danswer.server.utils import get_json_line
|
from danswer.server.utils import get_json_line
|
||||||
|
|
||||||
|
QUERY_PAT = "QUERY: "
|
||||||
REASONING_PAT = "REASONING: "
|
REASONING_PAT = "REASONING: "
|
||||||
ANSWERABLE_PAT = "ANSWERABLE: "
|
ANSWERABLE_PAT = "ANSWERABLE: "
|
||||||
COT_PAT = "\nLet's think step by step"
|
|
||||||
|
|
||||||
|
|
||||||
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
|
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
|
||||||
|
ambiguous_example = (
|
||||||
|
f"{QUERY_PAT}What is this Slack channel about?\n"
|
||||||
|
f"{REASONING_PAT}First the system must determine which Slack channel is "
|
||||||
|
f"being referred to. By fetching 5 documents related to Slack channel contents, "
|
||||||
|
f"it is not possible to determine which Slack channel the user is referring to.\n"
|
||||||
|
f"{ANSWERABLE_PAT}False"
|
||||||
|
)
|
||||||
|
|
||||||
|
debug_example = (
|
||||||
|
f"{QUERY_PAT}Danswer is unreachable.\n"
|
||||||
|
f"{REASONING_PAT}The system searches documents related to Danswer being "
|
||||||
|
f"unreachable. Assuming the documents from search contains situations where "
|
||||||
|
f"Danswer is not reachable and contains a fix, the query is answerable.\n"
|
||||||
|
f"{ANSWERABLE_PAT}True"
|
||||||
|
)
|
||||||
|
|
||||||
|
up_to_date_example = (
|
||||||
|
f"{QUERY_PAT}How many customers do we have\n"
|
||||||
|
f"{REASONING_PAT}Assuming the retrieved documents contain up to date customer "
|
||||||
|
f"acquisition information including a list of customers, the query can be answered. "
|
||||||
|
f"It is important to note that if the information only exists in a database, "
|
||||||
|
f"the system is unable to execute SQL and won't find an answer."
|
||||||
|
f"\n{ANSWERABLE_PAT}True"
|
||||||
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "user",
|
||||||
"content": f"You are a helper tool to determine if a query is answerable using retrieval augmented "
|
"content": "You are a helper tool to determine if a query is answerable using retrieval augmented "
|
||||||
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
|
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
|
||||||
f"documents found from search. Sources contain both up to date and proprietary information for "
|
f"documents found from search. Sources contain both up to date and proprietary information for "
|
||||||
f"the specific team. For named or unknown entities, assume the search will always find "
|
f"the specific team. For named or unknown entities, assume the search will always find "
|
||||||
f"consistent knowledge about the entity. Determine if that system should attempt to answer. "
|
f"consistent knowledge about the entity.\n"
|
||||||
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"',
|
f"The system is not tuned for writing code nor for interfacing with structured data "
|
||||||
|
f"via query languages like SQL.\n"
|
||||||
|
f"Determine if that system should attempt to answer. "
|
||||||
|
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"\n'
|
||||||
|
f"{CODE_BLOCK_PAT.format(ambiguous_example)}\n"
|
||||||
|
f"{CODE_BLOCK_PAT.format(debug_example)}\n"
|
||||||
|
f"{CODE_BLOCK_PAT.format(up_to_date_example)}\n"
|
||||||
|
f"{CODE_BLOCK_PAT.format(QUERY_PAT + user_query)}\n",
|
||||||
},
|
},
|
||||||
{"role": "user", "content": "What is this Slack channel about?"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": f"{REASONING_PAT}First the system must determine which Slack channel is being referred to."
|
|
||||||
f"By fetching 5 documents related to Slack channel contents, it is not possible to determine"
|
|
||||||
f"which Slack channel the user is referring to.\n{ANSWERABLE_PAT}False",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"Danswer is unreachable.{COT_PAT}",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": f"{REASONING_PAT}The system searches documents related to Danswer being "
|
|
||||||
f"unreachable. Assuming the documents from search contains situations where Danswer is not "
|
|
||||||
f"reachable and contains a fix, the query is answerable.\n{ANSWERABLE_PAT}True",
|
|
||||||
},
|
|
||||||
{"role": "user", "content": f"How many customers do we have?{COT_PAT}"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": f"{REASONING_PAT}Assuming the searched documents contains customer acquisition information"
|
|
||||||
f"including a list of customers, the query can be answered.\n{ANSWERABLE_PAT}True",
|
|
||||||
},
|
|
||||||
{"role": "user", "content": user_query + COT_PAT},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
Reference in New Issue
Block a user