Add basic chain of thought PromptProcessor(#316)

This commit is contained in:
Chris Weaver 2023-08-20 18:48:24 -07:00 committed by GitHub
parent 3ec602b47f
commit 9f1898c384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 7 deletions

View File

@ -24,10 +24,6 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg

View File

@ -30,6 +30,13 @@ SAMPLE_JSON_RESPONSE = {
"located on the Champ de Mars in France.",
],
}
SAMPLE_RESPONSE_COT = (
"Let's think step by step. The user is asking for the "
"location of the Eiffel Tower. The first document describes the Eiffel Tower "
"as being an iconic symbol of Paris and that it is located on the Champ de Mars. "
"Since the Champ de Mars is in Paris, we know that the Eiffel Tower is in Paris."
f"\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
def _append_acknowledge_doc_messages(
@ -154,7 +161,9 @@ class JsonChatProcessor(ChatPromptProcessor):
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
question: str,
chunks: list[InferenceChunk],
include_metadata: bool = False,
) -> list[dict[str, str]]:
metadata_prompt_section = (
"with metadata and contents " if include_metadata else ""
@ -181,7 +190,6 @@ class JsonChatProcessor(ChatPromptProcessor):
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages = [{"role": "system", "content": intro_msg}]
for chunk in chunks:
full_context = ""
if include_metadata:
@ -197,6 +205,65 @@ class JsonChatProcessor(ChatPromptProcessor):
return messages
class JsonCoTChatProcessor(ChatPromptProcessor):
"""Pros: improves performance slightly over the regular JsonChatProcessor.
Cons: Much slower.
"""
@property
def specifies_json_output(self) -> bool:
return True
@staticmethod
def fill_prompt(
question: str,
chunks: list[InferenceChunk],
include_metadata: bool = True,
) -> list[dict[str, str]]:
metadata_prompt_section = (
"with metadata and contents " if include_metadata else ""
)
intro_msg = (
f"You are a Question Answering assistant that answers queries "
f"based on the provided documents.\n"
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
)
complete_answer_not_found_response = (
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
)
task_msg = (
"Now answer the user query based on documents above and quote relevant sections.\n"
"When answering, you should think step by step, and verbalize your thought process.\n"
"Then respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
"All quotes MUST be EXACT substrings from provided documents.\n"
"Your responses should be informative, detailed, and consider all possibilities and edge cases.\n"
"You MUST prioritize information from provided documents over internal knowledge.\n"
"If the query cannot be answered based on the documents, respond with "
f"{complete_answer_not_found_response}\n"
"If the query requires aggregating the number of documents, respond with "
'{"answer": "Aggregations not supported", "quotes": []}\n'
f"Sample response:\n\n{SAMPLE_RESPONSE_COT}"
)
messages = [{"role": "system", "content": intro_msg}]
for chunk in chunks:
full_context = ""
if include_metadata:
full_context = _add_metadata_section(
full_context, chunk, prepend_tab=False, include_sep=False
)
full_context += chunk.content
messages = _append_acknowledge_doc_messages(messages, full_context)
messages.append({"role": "user", "content": task_msg})
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n\n"})
messages.append({"role": "user", "content": "Let's think step by step."})
return messages
class WeakModelFreeformProcessor(NonChatPromptProcessor):
"""Avoid using this one if the model is capable of using another prompt
Intended for models that can't follow complex instructions or have short context windows

View File

@ -181,7 +181,6 @@ def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
def extract_quotes_from_completed_token_stream(
model_output: str, context_chunks: list[InferenceChunk]
) -> DanswerQuotes:
logger.debug(model_output)
answer, quotes = process_answer(model_output, context_chunks)
if answer:
logger.info(answer)
@ -236,6 +235,18 @@ def process_model_tokens(
yield DanswerAnswerPiece(answer_piece=token)
hold_quote = ""
logger.debug(f"Raw model output: {model_output}")
# for a JSON prompt, make sure that we're only passing through the "JSON part"
# since that is what `extract_quotes_from_completed_token_stream` expects
if is_json_prompt:
try:
json_answer_ind = model_output.index('{"answer":')
if json_answer_ind != 0:
model_output = model_output[json_answer_ind:]
except ValueError:
logger.exception("Did not find answer pattern in response for JSON prompt")
yield extract_quotes_from_completed_token_stream(model_output, context_docs)