From 9f1898c384c04ab4c33c4504c6adea6b4f809486 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Sun, 20 Aug 2023 18:48:24 -0700 Subject: [PATCH] Add basic chain of thought PromptProcessor(#316) --- backend/danswer/direct_qa/open_ai.py | 4 -- backend/danswer/direct_qa/qa_prompts.py | 71 ++++++++++++++++++++++++- backend/danswer/direct_qa/qa_utils.py | 13 ++++- 3 files changed, 81 insertions(+), 7 deletions(-) diff --git a/backend/danswer/direct_qa/open_ai.py b/backend/danswer/direct_qa/open_ai.py index f8694d9ae..8e4dbdb14 100644 --- a/backend/danswer/direct_qa/open_ai.py +++ b/backend/danswer/direct_qa/open_ai.py @@ -24,10 +24,6 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.interfaces import AnswerQuestionReturn from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn -from danswer.direct_qa.interfaces import DanswerAnswer -from danswer.direct_qa.interfaces import DanswerAnswerPiece -from danswer.direct_qa.interfaces import DanswerQuote -from danswer.direct_qa.interfaces import DanswerQuotes from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.qa_prompts import ChatPromptProcessor from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg diff --git a/backend/danswer/direct_qa/qa_prompts.py b/backend/danswer/direct_qa/qa_prompts.py index 1e1ca9548..320409f93 100644 --- a/backend/danswer/direct_qa/qa_prompts.py +++ b/backend/danswer/direct_qa/qa_prompts.py @@ -30,6 +30,13 @@ SAMPLE_JSON_RESPONSE = { "located on the Champ de Mars in France.", ], } +SAMPLE_RESPONSE_COT = ( + "Let's think step by step. The user is asking for the " + "location of the Eiffel Tower. The first document describes the Eiffel Tower " + "as being an iconic symbol of Paris and that it is located on the Champ de Mars. " + "Since the Champ de Mars is in Paris, we know that the Eiffel Tower is in Paris." + f"\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}" +) def _append_acknowledge_doc_messages( @@ -154,7 +161,9 @@ class JsonChatProcessor(ChatPromptProcessor): @staticmethod def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False + question: str, + chunks: list[InferenceChunk], + include_metadata: bool = False, ) -> list[dict[str, str]]: metadata_prompt_section = ( "with metadata and contents " if include_metadata else "" @@ -181,7 +190,6 @@ class JsonChatProcessor(ChatPromptProcessor): f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}" ) messages = [{"role": "system", "content": intro_msg}] - for chunk in chunks: full_context = "" if include_metadata: @@ -197,6 +205,65 @@ class JsonChatProcessor(ChatPromptProcessor): return messages +class JsonCoTChatProcessor(ChatPromptProcessor): + """Pros: improves performance slightly over the regular JsonChatProcessor. + Cons: Much slower. + """ + + @property + def specifies_json_output(self) -> bool: + return True + + @staticmethod + def fill_prompt( + question: str, + chunks: list[InferenceChunk], + include_metadata: bool = True, + ) -> list[dict[str, str]]: + metadata_prompt_section = ( + "with metadata and contents " if include_metadata else "" + ) + intro_msg = ( + f"You are a Question Answering assistant that answers queries " + f"based on the provided documents.\n" + f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".' + ) + + complete_answer_not_found_response = ( + '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}' + ) + task_msg = ( + "Now answer the user query based on documents above and quote relevant sections.\n" + "When answering, you should think step by step, and verbalize your thought process.\n" + "Then respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n" + "All quotes MUST be EXACT substrings from provided documents.\n" + "Your responses should be informative, detailed, and consider all possibilities and edge cases.\n" + "You MUST prioritize information from provided documents over internal knowledge.\n" + "If the query cannot be answered based on the documents, respond with " + f"{complete_answer_not_found_response}\n" + "If the query requires aggregating the number of documents, respond with " + '{"answer": "Aggregations not supported", "quotes": []}\n' + f"Sample response:\n\n{SAMPLE_RESPONSE_COT}" + ) + messages = [{"role": "system", "content": intro_msg}] + + for chunk in chunks: + full_context = "" + if include_metadata: + full_context = _add_metadata_section( + full_context, chunk, prepend_tab=False, include_sep=False + ) + full_context += chunk.content + messages = _append_acknowledge_doc_messages(messages, full_context) + messages.append({"role": "user", "content": task_msg}) + + messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n\n"}) + + messages.append({"role": "user", "content": "Let's think step by step."}) + + return messages + + class WeakModelFreeformProcessor(NonChatPromptProcessor): """Avoid using this one if the model is capable of using another prompt Intended for models that can't follow complex instructions or have short context windows diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index eaa95ee88..b9cefe760 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -181,7 +181,6 @@ def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: def extract_quotes_from_completed_token_stream( model_output: str, context_chunks: list[InferenceChunk] ) -> DanswerQuotes: - logger.debug(model_output) answer, quotes = process_answer(model_output, context_chunks) if answer: logger.info(answer) @@ -236,6 +235,18 @@ def process_model_tokens( yield DanswerAnswerPiece(answer_piece=token) hold_quote = "" + logger.debug(f"Raw model output: {model_output}") + + # for a JSON prompt, make sure that we're only passing through the "JSON part" + # since that is what `extract_quotes_from_completed_token_stream` expects + if is_json_prompt: + try: + json_answer_ind = model_output.index('{"answer":') + if json_answer_ind != 0: + model_output = model_output[json_answer_ind:] + except ValueError: + logger.exception("Did not find answer pattern in response for JSON prompt") + yield extract_quotes_from_completed_token_stream(model_output, context_docs)