From e0b87d9d4e3f4b730de7a294b21f23969556fc3c Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Mon, 4 Dec 2023 15:02:08 -0800 Subject: [PATCH] Fix Weak Model Prompt (#810) --- backend/danswer/direct_qa/qa_block.py | 4 +++- backend/danswer/direct_qa/qa_utils.py | 10 +++++++--- backend/danswer/prompts/direct_qa_prompts.py | 15 ++++++++------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index b7276eac3f53..dac919178595 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -113,7 +113,9 @@ class WeakLLMQAHandler(QAHandler): def build_prompt( self, query: str, context_chunks: list[InferenceChunk] ) -> list[BaseMessage]: - message = WEAK_LLM_PROMPT.format(single_reference_doc=context_chunks[0].content) + message = WEAK_LLM_PROMPT.format( + user_query=query, single_reference_doc=context_chunks[0].content + ) return [HumanMessage(content=message)] diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index 640bd4e0b788..a40c19731d44 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -196,9 +196,9 @@ def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: def _extract_quotes_from_completed_token_stream( - model_output: str, context_chunks: list[InferenceChunk] + model_output: str, context_chunks: list[InferenceChunk], is_json_prompt: bool = True ) -> DanswerQuotes: - answer, quotes = process_answer(model_output, context_chunks) + answer, quotes = process_answer(model_output, context_chunks, is_json_prompt) if answer: logger.info(answer) elif model_output: @@ -262,7 +262,11 @@ def process_model_tokens( logger.debug(f"Raw Model QnA Output: {model_output}") - yield _extract_quotes_from_completed_token_stream(model_output, context_docs) + yield _extract_quotes_from_completed_token_stream( + model_output=model_output, + context_chunks=context_docs, + is_json_prompt=is_json_prompt, + ) def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: diff --git a/backend/danswer/prompts/direct_qa_prompts.py b/backend/danswer/prompts/direct_qa_prompts.py index 474cdfbf752f..a6ab5908e47d 100644 --- a/backend/danswer/prompts/direct_qa_prompts.py +++ b/backend/danswer/prompts/direct_qa_prompts.py @@ -92,18 +92,19 @@ You MUST respond in the following format: # For weak LLM which only takes one chunk and cannot output json +# Also not requiring quotes as it tends to not work WEAK_LLM_PROMPT = f""" -Respond to the user query using a reference document. -{GENERAL_SEP_PAT} +Respond to the user query using the following reference document. + Reference Document: +{GENERAL_SEP_PAT} {{single_reference_doc}} {GENERAL_SEP_PAT} -Answer the user query below based on the reference document above. -Respond with an "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as needed to support \ -the answer.' -{QUESTION_PAT.upper()} {{user_query}} -{ANSWER_PAT.upper()} +Answer the user query below based on the reference document above. + +{QUESTION_PAT.upper()} +{{user_query}} """.strip()