diff --git a/backend/danswer/bots/slack/handlers/handle_message.py b/backend/danswer/bots/slack/handlers/handle_message.py index d81c5d5ba519..e06f17b57459 100644 --- a/backend/danswer/bots/slack/handlers/handle_message.py +++ b/backend/danswer/bots/slack/handlers/handle_message.py @@ -42,6 +42,7 @@ def handle_message( user=None, db_session=db_session, answer_generation_timeout=answer_generation_timeout, + real_time_flow=False, ) if not answer.error_msg: return answer diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 77c2801d066e..f31d64650a67 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -209,7 +209,7 @@ DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int( ) DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5")) DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int( - os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "60") + os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90") ) DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get( "DANSWER_BOT_DISPLAY_ERROR_MSGS", "" diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index 536be11933a7..69466195241d 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -33,6 +33,7 @@ def answer_qa_query( db_session: Session, disable_generative_answer: bool = DISABLE_GENERATIVE_AI, answer_generation_timeout: int = QA_TIMEOUT, + real_time_flow: bool = True, ) -> QAResponse: query = question.query filters = question.filters @@ -88,7 +89,9 @@ def answer_qa_query( ) try: - qa_model = get_default_qa_model(timeout=answer_generation_timeout) + qa_model = get_default_qa_model( + timeout=answer_generation_timeout, real_time_flow=real_time_flow + ) except (UnknownModelError, OpenAIKeyMissing) as e: return QAResponse( answer=None, diff --git a/backend/danswer/direct_qa/llm_utils.py b/backend/danswer/direct_qa/llm_utils.py index 5f779cca8b3f..0341005b2e53 100644 --- a/backend/danswer/direct_qa/llm_utils.py +++ b/backend/danswer/direct_qa/llm_utils.py @@ -22,6 +22,7 @@ from danswer.direct_qa.qa_block import QABlock from danswer.direct_qa.qa_block import QAHandler from danswer.direct_qa.qa_block import SimpleChatQAHandler from danswer.direct_qa.qa_block import SingleMessageQAHandler +from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor from danswer.direct_qa.qa_utils import get_gen_ai_api_key from danswer.direct_qa.request_model import RequestCompletionQA @@ -51,9 +52,13 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool: return False -def get_default_qa_handler(model: str) -> QAHandler: +def get_default_qa_handler(model: str, real_time_flow: bool = True) -> QAHandler: if model == DanswerGenAIModel.OPENAI_CHAT.value: - return SingleMessageQAHandler() + return ( + SingleMessageQAHandler() + if real_time_flow + else SingleMessageScratchpadHandler() + ) return SimpleChatQAHandler() @@ -64,6 +69,7 @@ def get_default_qa_model( model_host_type: str | None = GEN_AI_HOST_TYPE, api_key: str | None = GEN_AI_API_KEY, timeout: int = QA_TIMEOUT, + real_time_flow: bool = True, **kwargs: Any, ) -> QAModel: if not api_key: @@ -76,7 +82,9 @@ def get_default_qa_model( # un-used arguments will be ignored by the underlying `LLM` class # if any args are missing, a `TypeError` will be thrown llm = get_default_llm(timeout=timeout) - qa_handler = get_default_qa_handler(model=internal_model) + qa_handler = get_default_qa_handler( + model=internal_model, real_time_flow=real_time_flow + ) return QABlock( llm=llm, diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index 87f0aaaf32c3..9e474e596e85 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -13,21 +13,24 @@ from danswer.chunking.models import InferenceChunk from danswer.direct_qa.interfaces import AnswerQuestionReturn from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import DanswerAnswer -from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerQuotes from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT +from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT from danswer.direct_qa.qa_prompts import JsonChatProcessor from danswer.direct_qa.qa_prompts import QUESTION_PAT from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE +from danswer.direct_qa.qa_prompts import THOUGHT_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor +from danswer.direct_qa.qa_utils import process_answer from danswer.direct_qa.qa_utils import process_model_tokens from danswer.llm.llm import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import str_prompt_to_langchain_prompt from danswer.utils.logger import setup_logger +from danswer.utils.text_processing import escape_newlines logger = setup_logger() @@ -43,11 +46,26 @@ class QAHandler(abc.ABC): ) -> list[BaseMessage]: raise NotImplementedError - @abc.abstractmethod - def process_response( + @property + def is_json_output(self) -> bool: + """Does the model expected to output a valid json""" + return True + + def process_llm_output( + self, model_output: str, context_chunks: list[InferenceChunk] + ) -> tuple[DanswerAnswer, DanswerQuotes]: + return process_answer( + model_output, context_chunks, is_json_prompt=self.is_json_output + ) + + def process_llm_token_stream( self, tokens: Iterator[str], context_chunks: list[InferenceChunk] ) -> AnswerQuestionStreamReturn: - raise NotImplementedError + yield from process_model_tokens( + tokens=tokens, + context_docs=context_chunks, + is_json_prompt=self.is_json_output, + ) class JsonChatQAHandler(QAHandler): @@ -60,19 +78,12 @@ class JsonChatQAHandler(QAHandler): ) ) - def process_response( - self, - tokens: Iterator[str], - context_chunks: list[InferenceChunk], - ) -> AnswerQuestionStreamReturn: - yield from process_model_tokens( - tokens=tokens, - context_docs=context_chunks, - is_json_prompt=True, - ) - class SimpleChatQAHandler(QAHandler): + @property + def is_json_output(self) -> bool: + return False + def build_prompt( self, query: str, context_chunks: list[InferenceChunk] ) -> list[BaseMessage]: @@ -84,26 +95,11 @@ class SimpleChatQAHandler(QAHandler): ) ) - def process_response( - self, - tokens: Iterator[str], - context_chunks: list[InferenceChunk], - ) -> AnswerQuestionStreamReturn: - yield from process_model_tokens( - tokens=tokens, - context_docs=context_chunks, - is_json_prompt=False, - ) - class SingleMessageQAHandler(QAHandler): def build_prompt( self, query: str, context_chunks: list[InferenceChunk] ) -> list[BaseMessage]: - complete_answer_not_found_response = ( - '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}' - ) - context_docs_str = "\n".join( f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks ) @@ -115,27 +111,64 @@ class SingleMessageQAHandler(QAHandler): "to provide accurate and detailed answers to diverse queries.\n" "You ALWAYS responds in a json containing an answer and quotes that support the answer.\n" "Your responses are as INFORMATIVE and DETAILED as possible.\n" - "If you don't know the answer, respond with " - f"{CODE_BLOCK_PAT.format(complete_answer_not_found_response)}" - "\nSample response:" - f"{CODE_BLOCK_PAT.format(json.dumps(SAMPLE_JSON_RESPONSE))}" f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}" - f"{GENERAL_SEP_PAT}{QUESTION_PAT} {query}" + f"{GENERAL_SEP_PAT}Sample response:" + f"{CODE_BLOCK_PAT.format(json.dumps(EMPTY_SAMPLE_JSON))}\n" + f"{QUESTION_PAT} {query}" "\nHint: Make the answer as detailed as possible and use a JSON! " "Quotes MUST be EXACT substrings from provided documents!" ) ] return prompt - def process_response( - self, - tokens: Iterator[str], - context_chunks: list[InferenceChunk], + +class SingleMessageScratchpadHandler(QAHandler): + def build_prompt( + self, query: str, context_chunks: list[InferenceChunk] + ) -> list[BaseMessage]: + cot_block = ( + f"{THOUGHT_PAT} Let's think step by step. Use this section as a scratchpad.\n" + f"{json.dumps(EMPTY_SAMPLE_JSON)}" + ) + + context_docs_str = "\n".join( + f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks + ) + + prompt: list[BaseMessage] = [ + HumanMessage( + content="You are a question answering system that is constantly learning and improving. " + "You can process and comprehend vast amounts of text and utilize this knowledge " + "to provide accurate and detailed answers to diverse queries.\n" + f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}{GENERAL_SEP_PAT}" + f"You MUST use the following format:\n" + f"{CODE_BLOCK_PAT.format(cot_block)}\n" + f"Begin!\n{QUESTION_PAT} {query}" + ) + ] + return prompt + + def process_llm_output( + self, model_output: str, context_chunks: list[InferenceChunk] + ) -> tuple[DanswerAnswer, DanswerQuotes]: + logger.debug(model_output) + + answer_start = model_output.find('{"answer":') + # Only found thoughts, no final answer + if answer_start == -1: + return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) + + final_json = escape_newlines(model_output[answer_start:]) + + return process_answer( + final_json, context_chunks, is_json_prompt=self.is_json_output + ) + + def process_llm_token_stream( + self, tokens: Iterator[str], context_chunks: list[InferenceChunk] ) -> AnswerQuestionStreamReturn: - yield from process_model_tokens( - tokens=tokens, - context_docs=context_chunks, - is_json_prompt=True, + raise ValueError( + "This Scratchpad approach is not suitable for real time uses like streaming" ) @@ -172,17 +205,6 @@ class JsonChatQAUnshackledHandler(QAHandler): return prompt - def process_response( - self, - tokens: Iterator[str], - context_chunks: list[InferenceChunk], - ) -> AnswerQuestionStreamReturn: - yield from process_model_tokens( - tokens=tokens, - context_docs=context_chunks, - is_json_prompt=True, - ) - def _tiktoken_trim_chunks( chunks: list[InferenceChunk], max_chunk_toks: int = 512 @@ -212,7 +234,7 @@ class QABlock(QAModel): def warm_up_model(self) -> None: """This is called during server start up to load the models into memory in case the chosen LLM is not accessed via API""" - self._llm.stream("Ignore this!") + self._llm.invoke("Ignore this!") def answer_question( self, @@ -221,21 +243,9 @@ class QABlock(QAModel): ) -> AnswerQuestionReturn: trimmed_context_docs = _tiktoken_trim_chunks(context_docs) prompt = self._qa_handler.build_prompt(query, trimmed_context_docs) - tokens = self._llm.stream(prompt) + model_out = self._llm.invoke(prompt) - final_answer = "" - quotes = DanswerQuotes([]) - for output in self._qa_handler.process_response(tokens, trimmed_context_docs): - if output is None: - continue - - if isinstance(output, DanswerAnswerPiece): - if output.answer_piece: - final_answer += output.answer_piece - elif isinstance(output, DanswerQuotes): - quotes = output - - return DanswerAnswer(final_answer), quotes + return self._qa_handler.process_llm_output(model_out, trimmed_context_docs) def answer_question_stream( self, @@ -245,4 +255,6 @@ class QABlock(QAModel): trimmed_context_docs = _tiktoken_trim_chunks(context_docs) prompt = self._qa_handler.build_prompt(query, trimmed_context_docs) tokens = self._llm.stream(prompt) - yield from self._qa_handler.process_response(tokens, trimmed_context_docs) + yield from self._qa_handler.process_llm_token_stream( + tokens, trimmed_context_docs + ) diff --git a/backend/danswer/direct_qa/qa_prompts.py b/backend/danswer/direct_qa/qa_prompts.py index 9878aaba5f37..0e998dbc01b8 100644 --- a/backend/danswer/direct_qa/qa_prompts.py +++ b/backend/danswer/direct_qa/qa_prompts.py @@ -11,9 +11,12 @@ CODE_BLOCK_PAT = "\n```\n{}\n```\n" DOC_SEP_PAT = "---NEW DOCUMENT---" DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n" QUESTION_PAT = "Query:" +THOUGHT_PAT = "Thought:" ANSWER_PAT = "Answer:" +FINAL_ANSWER_PAT = "Final Answer:" UNCERTAINTY_PAT = "?" QUOTE_PAT = "Quote:" +QUOTES_PAT_PLURAL = "Quotes:" BASE_PROMPT = ( "Answer the query based on provided documents and quote relevant sections. " @@ -31,6 +34,17 @@ SAMPLE_JSON_RESPONSE = { "located on the Champ de Mars in France.", ], } + +EMPTY_SAMPLE_JSON = { + "answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.", + "quotes": [ + "each quote must be UNEDITED and EXACTLY as shown in the provided documents!", + "HINT the quotes are not shown to the user!", + ], +} + +ANSWER_NOT_FOUND_JSON = '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}' + SAMPLE_RESPONSE_COT = ( "Let's think step by step. The user is asking for the " "location of the Eiffel Tower. The first document describes the Eiffel Tower " diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index 598d8de44fc6..d7a9fdef32e5 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -80,12 +80,17 @@ def extract_answer_quotes_json( def separate_answer_quotes( - answer_raw: str, + answer_raw: str, is_json_prompt: bool = False ) -> Tuple[Optional[str], Optional[list[str]]]: try: model_raw_json = json.loads(answer_raw) return extract_answer_quotes_json(model_raw_json) except ValueError: + if is_json_prompt: + logger.error( + "Model did not output in json format as expected, " + "trying to parse it regardless" + ) return extract_answer_quotes_freeform(answer_raw) @@ -149,9 +154,11 @@ def match_quotes_to_docs( def process_answer( - answer_raw: str, chunks: list[InferenceChunk] + answer_raw: str, + chunks: list[InferenceChunk], + is_json_prompt: bool = True, ) -> tuple[DanswerAnswer, DanswerQuotes]: - answer, quote_strings = separate_answer_quotes(answer_raw) + answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt) if answer == UNCERTAINTY_PAT or not answer: if answer == UNCERTAINTY_PAT: logger.debug("Answer matched UNCERTAINTY_PAT") diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 1054a453e793..20e930e5bd2c 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -5,34 +5,35 @@ from dataclasses import asdict from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT +from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT from danswer.llm.build import get_default_llm from danswer.server.models import QueryValidationResponse from danswer.server.utils import get_json_line QUERY_PAT = "QUERY: " -REASONING_PAT = "REASONING: " +REASONING_PAT = "THOUGHT: " ANSWERABLE_PAT = "ANSWERABLE: " def get_query_validation_messages(user_query: str) -> list[dict[str, str]]: - ambiguous_example = ( - f"{QUERY_PAT}What is this Slack channel about?\n" + ambiguous_example_question = f"{QUERY_PAT}What is this Slack channel about?" + ambiguous_example_answer = ( f"{REASONING_PAT}First the system must determine which Slack channel is " f"being referred to. By fetching 5 documents related to Slack channel contents, " f"it is not possible to determine which Slack channel the user is referring to.\n" f"{ANSWERABLE_PAT}False" ) - debug_example = ( - f"{QUERY_PAT}Danswer is unreachable.\n" + debug_example_question = f"{QUERY_PAT}Danswer is unreachable." + debug_example_answer = ( f"{REASONING_PAT}The system searches documents related to Danswer being " f"unreachable. Assuming the documents from search contains situations where " - f"Danswer is not reachable and contains a fix, the query is answerable.\n" + f"Danswer is not reachable and contains a fix, the query may be answerable.\n" f"{ANSWERABLE_PAT}True" ) - up_to_date_example = ( - f"{QUERY_PAT}How many customers do we have\n" + up_to_date_example_question = f"{QUERY_PAT}How many customers do we have" + up_to_date_example_answer = ( f"{REASONING_PAT}Assuming the retrieved documents contain up to date customer " f"acquisition information including a list of customers, the query can be answered. " f"It is important to note that if the information only exists in a database, " @@ -44,18 +45,18 @@ def get_query_validation_messages(user_query: str) -> list[dict[str, str]]: { "role": "user", "content": "You are a helper tool to determine if a query is answerable using retrieval augmented " - f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant " - f"documents found from search. Sources contain both up to date and proprietary information for " - f"the specific team. For named or unknown entities, assume the search will always find " - f"consistent knowledge about the entity.\n" - f"The system is not tuned for writing code nor for interfacing with structured data " - f"via query languages like SQL.\n" - f"Determine if that system should attempt to answer. " - f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"\n' - f"{CODE_BLOCK_PAT.format(ambiguous_example)}\n" - f"{CODE_BLOCK_PAT.format(debug_example)}\n" - f"{CODE_BLOCK_PAT.format(up_to_date_example)}\n" - f"{CODE_BLOCK_PAT.format(QUERY_PAT + user_query)}\n", + f"generation.\nThe main system will try to answer the user query based on ONLY the top 5 most relevant " + f"documents found from search.\nSources contain both up to date and proprietary information for " + f"the specific team.\nFor named or unknown entities, assume the search will find " + f"relevant and consistent knowledge about the entity.\n" + f"The system is not tuned for writing code.\n" + f"The system is not tuned for interfacing with structured data via query languages like SQL.\n" + f"Determine if that system should attempt to answer.\n" + f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"\n{GENERAL_SEP_PAT}\n' + f"{ambiguous_example_question}{CODE_BLOCK_PAT.format(ambiguous_example_answer)}\n" + f"{debug_example_question}{CODE_BLOCK_PAT.format(debug_example_answer)}\n" + f"{up_to_date_example_question}{CODE_BLOCK_PAT.format(up_to_date_example_answer)}\n" + f"{QUERY_PAT + user_query}", }, ] @@ -103,7 +104,8 @@ def stream_query_answerability(user_query: str) -> Iterator[str]: if not reasoning_pat_found and REASONING_PAT in model_output: reasoning_pat_found = True - remaining = model_output[len(REASONING_PAT) :] + reason_ind = model_output.find(REASONING_PAT) + remaining = model_output[reason_ind + len(REASONING_PAT) :] if remaining: yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining))) continue diff --git a/backend/danswer/utils/text_processing.py b/backend/danswer/utils/text_processing.py index c27c323f0bb2..fbd4885121b3 100644 --- a/backend/danswer/utils/text_processing.py +++ b/backend/danswer/utils/text_processing.py @@ -4,6 +4,10 @@ import bs4 from bs4 import BeautifulSoup +def escape_newlines(s: str) -> str: + return re.sub(r"(? str: quote_clean = quote.strip() if quote_clean[0] == '"':