From 927dffecb560b125fd5769a89500896b640e77ee Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 2 Nov 2023 23:26:47 -0700 Subject: [PATCH] Prompt Layer Rework (#688) --- backend/danswer/chat/chat_llm.py | 4 +- backend/danswer/chat/chat_prompts.py | 4 +- backend/danswer/configs/app_configs.py | 3 + backend/danswer/configs/constants.py | 14 - backend/danswer/direct_qa/answer_question.py | 2 +- .../direct_qa/{llm_utils.py => factory.py} | 26 +- backend/danswer/direct_qa/interfaces.py | 1 - backend/danswer/direct_qa/qa_block.py | 190 ++++-------- backend/danswer/direct_qa/qa_prompts.py | 283 ------------------ backend/danswer/direct_qa/qa_utils.py | 6 +- backend/danswer/llm/custom_llm.py | 4 + backend/danswer/llm/gpt_4_all.py | 10 + backend/danswer/llm/interfaces.py | 4 + backend/danswer/llm/utils.py | 38 ++- backend/danswer/main.py | 3 +- backend/danswer/prompts/__init__.py | 0 backend/danswer/prompts/constants.py | 11 + backend/danswer/prompts/direct_qa_prompts.py | 111 +++++++ .../danswer/prompts/secondary_llm_flows.py | 96 ++++++ .../secondary_llm_flows/answer_validation.py | 30 +- .../secondary_llm_flows/extract_filters.py | 15 +- .../secondary_llm_flows/query_validation.py | 67 +---- backend/danswer/server/chat_backend.py | 6 +- backend/danswer/server/manage.py | 2 +- .../docker_compose/docker-compose.dev.yml | 3 + 25 files changed, 383 insertions(+), 550 deletions(-) rename backend/danswer/direct_qa/{llm_utils.py => factory.py} (52%) delete mode 100644 backend/danswer/direct_qa/qa_prompts.py create mode 100644 backend/danswer/prompts/__init__.py create mode 100644 backend/danswer/prompts/constants.py create mode 100644 backend/danswer/prompts/direct_qa_prompts.py create mode 100644 backend/danswer/prompts/secondary_llm_flows.py diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index 0ebceb7129d8..695e6ad6f6f3 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -35,7 +35,7 @@ from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM -from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import get_default_llm_token_encode from danswer.llm.utils import translate_danswer_msg_to_langchain from danswer.search.access_filters import build_access_filters_for_user from danswer.search.models import IndexFilters @@ -259,7 +259,7 @@ def llm_contextless_chat_answer( prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages] if system_text: - tokenizer = tokenizer or get_default_llm_tokenizer() + tokenizer = tokenizer or get_default_llm_token_encode() system_tokens = len(tokenizer(system_text)) system_msg = SystemMessage(content=system_text) diff --git a/backend/danswer/chat/chat_prompts.py b/backend/danswer/chat/chat_prompts.py index 2dfc18552f6b..97d361b93f59 100644 --- a/backend/danswer/chat/chat_prompts.py +++ b/backend/danswer/chat/chat_prompts.py @@ -2,12 +2,12 @@ from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage -from danswer.configs.constants import CODE_BLOCK_PAT from danswer.configs.constants import MessageType from danswer.db.models import ChatMessage from danswer.db.models import ToolInfo from danswer.indexing.models import InferenceChunk from danswer.llm.utils import translate_danswer_msg_to_langchain +from danswer.prompts.constants import CODE_BLOCK_PAT DANSWER_TOOL_NAME = "Current Search" DANSWER_TOOL_DESCRIPTION = ( @@ -176,7 +176,7 @@ def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str: return "No Results Found" return "\n".join( - f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}" + f"DOCUMENT {ind}:\n{CODE_BLOCK_PAT.format(chunk.content)}\n" for ind, chunk in enumerate(chunks, start=1) ) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 538a4766be1e..2c952a07e69f 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -212,6 +212,9 @@ DYNAMIC_CONFIG_STORE = os.environ.get( "DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore" ) DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage") +# For selecting a different LLM question-answering prompt format +# Valid values: default, cot, weak +QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None # notset, debug, info, warning, error, or critical LOG_LEVEL = os.environ.get("LOG_LEVEL", "info") # NOTE: Currently only supported in the Confluence and Google Drive connectors + diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 98d199e17d18..46752517a504 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -36,20 +36,6 @@ ID_SEPARATOR = ":;:" DEFAULT_BOOST = 0 SESSION_KEY = "session" -# Prompt building constants: -GENERAL_SEP_PAT = "\n-----\n" -CODE_BLOCK_PAT = "\n```\n{}\n```\n" -DOC_SEP_PAT = "---NEW DOCUMENT---" -DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n" -QUESTION_PAT = "Query:" -THOUGHT_PAT = "Thought:" -ANSWER_PAT = "Answer:" -FINAL_ANSWER_PAT = "Final Answer:" -UNCERTAINTY_PAT = "?" -QUOTE_PAT = "Quote:" -QUOTES_PAT_PLURAL = "Quotes:" -INVALID_PAT = "Invalid:" - class DocumentSource(str, Enum): SLACK = "slack" diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index d984c5058e8b..b4e64c713f5e 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -9,9 +9,9 @@ from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.constants import IGNORE_FOR_QA from danswer.db.feedback import create_query_event from danswer.db.models import User +from danswer.direct_qa.factory import get_default_qa_model from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import StreamingError -from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.qa_utils import get_usable_chunks from danswer.document_index.factory import get_default_document_index diff --git a/backend/danswer/direct_qa/llm_utils.py b/backend/danswer/direct_qa/factory.py similarity index 52% rename from backend/danswer/direct_qa/llm_utils.py rename to backend/danswer/direct_qa/factory.py index 80f813e7ef85..3225f903356d 100644 --- a/backend/danswer/direct_qa/llm_utils.py +++ b/backend/danswer/direct_qa/factory.py @@ -1,21 +1,35 @@ +from danswer.configs.app_configs import QA_PROMPT_OVERRIDE from danswer.configs.app_configs import QA_TIMEOUT from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.qa_block import QABlock from danswer.direct_qa.qa_block import QAHandler from danswer.direct_qa.qa_block import SingleMessageQAHandler from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler +from danswer.direct_qa.qa_block import WeakLLMQAHandler from danswer.llm.factory import get_default_llm from danswer.utils.logger import setup_logger logger = setup_logger() -# TODO introduce the prompt choice parameter -def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler: - return ( - SingleMessageQAHandler() if real_time_flow else SingleMessageScratchpadHandler() - ) - # return SimpleChatQAHandler() +def get_default_qa_handler( + real_time_flow: bool = True, + user_selection: str | None = QA_PROMPT_OVERRIDE, +) -> QAHandler: + if user_selection: + if user_selection.lower() == "default": + return SingleMessageQAHandler() + if user_selection.lower() == "cot": + return SingleMessageScratchpadHandler() + if user_selection.lower() == "weak": + return WeakLLMQAHandler() + + raise ValueError("Invalid Question-Answering prompt selected") + + if not real_time_flow: + return SingleMessageScratchpadHandler() + + return SingleMessageQAHandler() def get_default_qa_model( diff --git a/backend/danswer/direct_qa/interfaces.py b/backend/danswer/direct_qa/interfaces.py index 60897023fa1d..688d0f002bd1 100644 --- a/backend/danswer/direct_qa/interfaces.py +++ b/backend/danswer/direct_qa/interfaces.py @@ -52,7 +52,6 @@ class QAModel: def requires_api_key(self) -> bool: """Is this model protected by security features Does it need an api key to access the model for inference""" - # TODO, this should be false for custom request model and gpt4all return True def warm_up_model(self) -> None: diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index b96736fba5d5..0ea404c7daf2 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -1,38 +1,28 @@ import abc -import json import re from collections.abc import Callable from collections.abc import Iterator -from copy import copy -import tiktoken -from langchain.schema.messages import AIMessage from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage -from langchain.schema.messages import SystemMessage -from danswer.configs.constants import CODE_BLOCK_PAT -from danswer.configs.constants import GENERAL_SEP_PAT -from danswer.configs.constants import QUESTION_PAT -from danswer.configs.constants import THOUGHT_PAT -from danswer.configs.constants import UNCERTAINTY_PAT from danswer.direct_qa.interfaces import AnswerQuestionReturn from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerQuotes from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.models import LLMMetricsContainer -from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON -from danswer.direct_qa.qa_prompts import JsonChatProcessor -from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor from danswer.direct_qa.qa_utils import process_answer from danswer.direct_qa.qa_utils import process_model_tokens from danswer.indexing.models import InferenceChunk from danswer.llm.interfaces import LLM from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import dict_based_prompt_to_langchain_prompt -from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import str_prompt_to_langchain_prompt +from danswer.llm.utils import get_default_llm_token_encode +from danswer.llm.utils import tokenizer_trim_chunks +from danswer.prompts.constants import CODE_BLOCK_PAT +from danswer.prompts.direct_qa_prompts import COT_PROMPT +from danswer.prompts.direct_qa_prompts import JSON_PROMPT +from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT from danswer.utils.logger import setup_logger from danswer.utils.text_processing import clean_up_code_blocks from danswer.utils.text_processing import escape_newlines @@ -41,10 +31,6 @@ logger = setup_logger() class QAHandler(abc.ABC): - """Evolution of the `PromptProcessor` - handles both building the prompt and - processing the response. These are necessarily coupled, since the prompt determines - the response format (and thus how it should be parsed into an answer + quotes).""" - @abc.abstractmethod def build_prompt( self, query: str, context_chunks: list[InferenceChunk] @@ -52,9 +38,13 @@ class QAHandler(abc.ABC): raise NotImplementedError @property + @abc.abstractmethod def is_json_output(self) -> bool: - """Does the model expected to output a valid json""" - return True + """Does the model output a valid json with answer and quotes keys? Most flows with a + capable model should output a json. This hints to the model that the output is used + with a downstream system rather than freeform creative output. Most models should be + finetuned to recognize this.""" + raise NotImplementedError def process_llm_output( self, model_output: str, context_chunks: list[InferenceChunk] @@ -73,18 +63,13 @@ class QAHandler(abc.ABC): ) -class JsonChatQAHandler(QAHandler): - def build_prompt( - self, query: str, context_chunks: list[InferenceChunk] - ) -> list[BaseMessage]: - return dict_based_prompt_to_langchain_prompt( - JsonChatProcessor.fill_prompt( - question=query, chunks=context_chunks, include_metadata=False - ) - ) +class WeakLLMQAHandler(QAHandler): + """Since Danswer supports a variety of LLMs, this less demanding prompt is provided + as an option to use with weaker LLMs such as small version, low float precision, quantized, + or distilled models. It only uses one context document and has very weak requirements of + output format. + """ - -class SimpleChatQAHandler(QAHandler): @property def is_json_output(self) -> bool: return False @@ -92,67 +77,51 @@ class SimpleChatQAHandler(QAHandler): def build_prompt( self, query: str, context_chunks: list[InferenceChunk] ) -> list[BaseMessage]: - return str_prompt_to_langchain_prompt( - WeakModelFreeformProcessor.fill_prompt( - question=query, - chunks=context_chunks, - include_metadata=False, - ) - ) + message = WEAK_LLM_PROMPT.format(single_reference_doc=context_chunks[0].content) + + return [HumanMessage(content=message)] class SingleMessageQAHandler(QAHandler): + @property + def is_json_output(self) -> bool: + return True + def build_prompt( self, query: str, context_chunks: list[InferenceChunk] ) -> list[BaseMessage]: context_docs_str = "\n".join( - f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks + f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks ) - prompt: list[BaseMessage] = [ - HumanMessage( - content="You are a question answering system that is constantly learning and improving. " - "You can process and comprehend vast amounts of text and utilize this knowledge " - "to provide accurate and detailed answers to diverse queries.\n" - "You ALWAYS responds with only a json containing an answer and quotes that support the answer.\n" - "Your responses are as INFORMATIVE and DETAILED as possible.\n" - f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}" - f"{GENERAL_SEP_PAT}Sample response:" - f"{CODE_BLOCK_PAT.format(json.dumps(EMPTY_SAMPLE_JSON))}\n" - f"{QUESTION_PAT} {query}\n" - "Hint: Make the answer as DETAILED as possible and respond in JSON format!\n" - "Quotes MUST be EXACT substrings from provided documents!" - ) - ] + single_message = JSON_PROMPT.format( + context_docs_str=context_docs_str, user_query=query + ) + + prompt: list[BaseMessage] = [HumanMessage(content=single_message)] return prompt class SingleMessageScratchpadHandler(QAHandler): + @property + def is_json_output(self) -> bool: + # Even though the full LLM output isn't a valid json + # only the valid json portion is kept and passed along + # therefore it is treated as a json output + return True + def build_prompt( self, query: str, context_chunks: list[InferenceChunk] ) -> list[BaseMessage]: - cot_block = ( - f"{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.\n\n" - f"{json.dumps(EMPTY_SAMPLE_JSON)}" - ) - context_docs_str = "\n".join( - f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks + f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks ) - prompt: list[BaseMessage] = [ - HumanMessage( - content="You are a question answering system that is constantly learning and improving. " - "You can process and comprehend vast amounts of text and utilize this knowledge " - "to provide accurate and detailed answers to diverse queries.\n" - f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}{GENERAL_SEP_PAT}" - f"You MUST respond in the following format:" - f"{CODE_BLOCK_PAT.format(cot_block)}\n" - f"{QUESTION_PAT} {query}\n" - "Hint: Make the answer as detailed as possible and use a JSON! " - "Quotes can ONLY be EXACT substrings from provided documents!" - ) - ] + single_message = COT_PROMPT.format( + context_docs_str=context_docs_str, user_query=query + ) + + prompt: list[BaseMessage] = [HumanMessage(content=single_message)] return prompt def process_llm_output( @@ -175,77 +144,26 @@ class SingleMessageScratchpadHandler(QAHandler): def process_llm_token_stream( self, tokens: Iterator[str], context_chunks: list[InferenceChunk] ) -> AnswerQuestionStreamReturn: + # Can be supported but the parsing is more involved, not handling until needed raise ValueError( "This Scratchpad approach is not suitable for real time uses like streaming" ) -class JsonChatQAUnshackledHandler(QAHandler): - def build_prompt( - self, query: str, context_chunks: list[InferenceChunk] - ) -> list[BaseMessage]: - prompt: list[BaseMessage] = [] - - complete_answer_not_found_response = ( - '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}' - ) - prompt.append( - SystemMessage( - content=( - "Use the following pieces of context to answer the users question. Your response " - "should be in JSON format and contain an answer and (optionally) quotes that help support the answer. " - "Your responses should be informative, detailed, and consider all possibilities and edge cases. " - f"If you don't know the answer, respond with '{complete_answer_not_found_response}'\n" - f"Sample response:\n\n{json.dumps(EMPTY_SAMPLE_JSON)}" - ) - ) - ) - prompt.append( - SystemMessage( - content='Start by reading the following documents and responding with "Acknowledged".' - ) - ) - for chunk in context_chunks: - prompt.append(SystemMessage(content=chunk.content)) - prompt.append(AIMessage(content="Acknowledged")) - - prompt.append(HumanMessage(content=f"Question: {query}\n")) - - return prompt - - -def _tiktoken_trim_chunks( - chunks: list[InferenceChunk], max_chunk_toks: int = 512 -) -> list[InferenceChunk]: - """Edit chunks that have too high token count. Generally due to parsing issues or - characters from another language that are 1 char = 1 token - Trimming by tokens leads to information loss but currently no better way of handling - NOTE: currently gpt-3.5 / gpt-4 tokenizer across all LLMs currently - TODO: make "chunk modification" its own step in the pipeline - """ - encoder = tiktoken.get_encoding("cl100k_base") - new_chunks = copy(chunks) - for ind, chunk in enumerate(new_chunks): - tokens = encoder.encode(chunk.content) - if len(tokens) > max_chunk_toks: - new_chunk = copy(chunk) - new_chunk.content = encoder.decode(tokens[:max_chunk_toks]) - new_chunks[ind] = new_chunk - return new_chunks - - class QABlock(QAModel): def __init__(self, llm: LLM, qa_handler: QAHandler) -> None: self._llm = llm self._qa_handler = qa_handler + @property + def requires_api_key(self) -> bool: + return self._llm.requires_api_key + def warm_up_model(self) -> None: """This is called during server start up to load the models into memory in case the chosen LLM is not accessed via API""" if self._llm.requires_warm_up: - logger.info( - "Warming up LLM, this should only run for in memory LLMs like GPT4All" - ) + logger.info("Warming up LLM with a first inference") self._llm.invoke("Ignore this!") def answer_question( @@ -254,7 +172,7 @@ class QABlock(QAModel): context_docs: list[InferenceChunk], metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> AnswerQuestionReturn: - trimmed_context_docs = _tiktoken_trim_chunks(context_docs) + trimmed_context_docs = tokenizer_trim_chunks(context_docs) prompt = self._qa_handler.build_prompt(query, trimmed_context_docs) model_out = self._llm.invoke(prompt) @@ -262,14 +180,14 @@ class QABlock(QAModel): prompt_tokens = sum( [ check_number_of_tokens( - text=p.content, encode_fn=get_default_llm_tokenizer() + text=p.content, encode_fn=get_default_llm_token_encode() ) for p in prompt ] ) response_tokens = check_number_of_tokens( - text=model_out, encode_fn=get_default_llm_tokenizer() + text=model_out, encode_fn=get_default_llm_token_encode() ) metrics_callback( @@ -285,7 +203,7 @@ class QABlock(QAModel): query: str, context_docs: list[InferenceChunk], ) -> AnswerQuestionStreamReturn: - trimmed_context_docs = _tiktoken_trim_chunks(context_docs) + trimmed_context_docs = tokenizer_trim_chunks(context_docs) prompt = self._qa_handler.build_prompt(query, trimmed_context_docs) tokens = self._llm.stream(prompt) yield from self._qa_handler.process_llm_token_stream( diff --git a/backend/danswer/direct_qa/qa_prompts.py b/backend/danswer/direct_qa/qa_prompts.py deleted file mode 100644 index 4439afea9b1d..000000000000 --- a/backend/danswer/direct_qa/qa_prompts.py +++ /dev/null @@ -1,283 +0,0 @@ -import abc -import json - -from danswer.configs.constants import ANSWER_PAT -from danswer.configs.constants import DOC_CONTENT_START_PAT -from danswer.configs.constants import DOC_SEP_PAT -from danswer.configs.constants import DocumentSource -from danswer.configs.constants import GENERAL_SEP_PAT -from danswer.configs.constants import QUESTION_PAT -from danswer.configs.constants import QUOTE_PAT -from danswer.configs.constants import UNCERTAINTY_PAT -from danswer.connectors.factory import identify_connector_class -from danswer.indexing.models import InferenceChunk - - -BASE_PROMPT = ( - "Answer the query based on provided documents and quote relevant sections. " - "Respond with a json containing a concise answer and up to three most relevant quotes from the documents. " - 'Respond with "?" for the answer if the query cannot be answered based on the documents. ' - "The quotes must be EXACT substrings from the documents." -) - -EMPTY_SAMPLE_JSON = { - "answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.", - "quotes": [ - "each quote must be UNEDITED and EXACTLY as shown in the context documents!", - "HINT, quotes are not shown to the user!", - ], -} - - -def _append_acknowledge_doc_messages( - current_messages: list[dict[str, str]], new_chunk_content: str -) -> list[dict[str, str]]: - updated_messages = current_messages.copy() - updated_messages.extend( - [ - { - "role": "user", - "content": new_chunk_content, - }, - {"role": "assistant", "content": "Acknowledged"}, - ] - ) - return updated_messages - - -def _add_metadata_section( - prompt_current: str, - chunk: InferenceChunk, - prepend_tab: bool = False, - include_sep: bool = False, -) -> str: - """ - Inserts a metadata section at the start of a document, providing additional context to the upcoming document. - - Parameters: - prompt_current (str): The existing content of the prompt so far with. - chunk (InferenceChunk): An object that contains the document's source type and metadata information to be added. - prepend_tab (bool, optional): If set to True, a tab character is added at the start of each line in the metadata - section for consistent spacing for LLM. - include_sep (bool, optional): If set to True, includes default section separator pattern at the end of the metadata - section. - - Returns: - str: The prompt with the newly added metadata section. - """ - - def _prepend(s: str, ppt: bool) -> str: - return "\t" + s if ppt else s - - prompt_current += _prepend(f"DOCUMENT SOURCE: {chunk.source_type}\n", prepend_tab) - if chunk.metadata: - prompt_current += _prepend("METADATA:\n", prepend_tab) - connector_class = identify_connector_class(DocumentSource(chunk.source_type)) - for metadata_line in connector_class.parse_metadata(chunk.metadata): - prompt_current += _prepend(f"\t{metadata_line}\n", prepend_tab) - prompt_current += _prepend(DOC_CONTENT_START_PAT, prepend_tab) - if include_sep: - prompt_current += GENERAL_SEP_PAT - return prompt_current - - -class PromptProcessor(abc.ABC): - """Take the most relevant chunks and fills out a LLM prompt using the chunk contents - and optionally metadata about the chunk""" - - @property - @abc.abstractmethod - def specifies_json_output(self) -> bool: - raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> str | list[dict[str, str]]: - raise NotImplementedError - - -class NonChatPromptProcessor(PromptProcessor): - @staticmethod - @abc.abstractmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> str: - raise NotImplementedError - - -class ChatPromptProcessor(PromptProcessor): - @staticmethod - @abc.abstractmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> list[dict[str, str]]: - raise NotImplementedError - - -class JsonProcessor(NonChatPromptProcessor): - @property - def specifies_json_output(self) -> bool: - return True - - @staticmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> str: - prompt = ( - BASE_PROMPT + f" Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}\n\n" - f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' - ) - - for chunk in chunks: - prompt += f"\n\n{DOC_SEP_PAT}\n" - if include_metadata: - prompt = _add_metadata_section( - prompt, chunk, prepend_tab=False, include_sep=True - ) - - prompt += chunk.content - - prompt += "\n\n---\n\n" - prompt += f"{QUESTION_PAT}\n{question}\n" - return prompt - - -class JsonChatProcessor(ChatPromptProcessor): - @property - def specifies_json_output(self) -> bool: - return True - - @staticmethod - def fill_prompt( - question: str, - chunks: list[InferenceChunk], - include_metadata: bool = False, - ) -> list[dict[str, str]]: - metadata_prompt_section = ( - "with metadata and contents " if include_metadata else "" - ) - intro_msg = ( - f"You are a Question Answering assistant that answers queries " - f"based on the provided most relevant documents.\n" - f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".' - ) - - complete_answer_not_found_response = ( - '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}' - ) - task_msg = ( - "Now answer the next user query based on documents above and quote relevant sections.\n" - "Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n" - "All quotes MUST be EXACT substrings from provided documents.\n" - "Your responses should be informative and concise.\n" - "You MUST prioritize information from provided documents over internal knowledge.\n" - "If the query cannot be answered based on the documents, respond with " - f"{complete_answer_not_found_response}\n" - "If the query requires aggregating the number of documents, respond with " - '{"answer": "Aggregations not supported", "quotes": []}\n' - f"Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}" - ) - messages = [{"role": "system", "content": intro_msg}] - for chunk in chunks: - full_context = "" - if include_metadata: - full_context = _add_metadata_section( - full_context, chunk, prepend_tab=False, include_sep=False - ) - full_context += chunk.content - messages = _append_acknowledge_doc_messages(messages, full_context) - messages.append({"role": "system", "content": task_msg}) - - messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"}) - - return messages - - -class WeakModelFreeformProcessor(NonChatPromptProcessor): - """Avoid using this one if the model is capable of using another prompt - Intended for models that can't follow complex instructions or have short context windows - This prompt only uses 1 reference document chunk - """ - - @property - def specifies_json_output(self) -> bool: - return False - - @staticmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> str: - first_chunk_content = chunks[0].content if chunks else "No Document Provided" - - prompt = ( - f"Reference Document:\n{first_chunk_content}\n{GENERAL_SEP_PAT}" - f"Answer the user query below based on the reference document above. " - f'Respond with an "{ANSWER_PAT}" section and ' - f'as many "{QUOTE_PAT}" sections as needed to support the answer.' - f"\n{GENERAL_SEP_PAT}" - f"{QUESTION_PAT} {question}\n" - f"{ANSWER_PAT}" - ) - - return prompt - - -class WeakChatModelFreeformProcessor(ChatPromptProcessor): - """Avoid using this one if the model is capable of using another prompt - Intended for models that can't follow complex instructions or have short context windows - This prompt only uses 1 reference document chunk - """ - - @property - def specifies_json_output(self) -> bool: - return False - - @staticmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> list[dict[str, str]]: - first_chunk_content = chunks[0].content if chunks else "No Document Provided" - intro_msg = ( - f"You are a question answering assistant. " - f'Respond to the query with an "{ANSWER_PAT}" section and ' - f'as many "{QUOTE_PAT}" sections as needed to support the answer. ' - f"Answer the user query based on the following document:\n\n{first_chunk_content}" - ) - - messages = [{"role": "system", "content": intro_msg}] - - user_query = f"{QUESTION_PAT} {question}" - messages.append({"role": "user", "content": user_query}) - - return messages - - -# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may revisit in future - - -class FreeformProcessor(NonChatPromptProcessor): - @property - def specifies_json_output(self) -> bool: - return False - - @staticmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> str: - prompt = ( - f"Answer the query based on the documents below and quote the documents segments containing the answer. " - f'Respond with one "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as is relevant. ' - f'Start each quote with "{QUOTE_PAT}". Each quote should be a single continuous segment from a document. ' - f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". ' - f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n' - ) - - for chunk in chunks: - prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}" - - prompt += "\n\n---\n\n" - prompt += f"{QUESTION_PAT}\n{question}\n" - prompt += f"{ANSWER_PAT}\n" - return prompt diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index 94d996cc5bf2..7f45fceee031 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -15,11 +15,11 @@ from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import DanswerQuotes -from danswer.direct_qa.qa_prompts import ANSWER_PAT -from danswer.direct_qa.qa_prompts import QUOTE_PAT -from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.indexing.models import InferenceChunk from danswer.llm.utils import check_number_of_tokens +from danswer.prompts.constants import ANSWER_PAT +from danswer.prompts.constants import QUOTE_PAT +from danswer.prompts.constants import UNCERTAINTY_PAT from danswer.utils.logger import setup_logger from danswer.utils.text_processing import clean_model_quote from danswer.utils.text_processing import clean_up_code_blocks diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index e39b10049848..5de9b0d97cd1 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -21,6 +21,10 @@ class CustomModelServer(LLM): https://medium.com/@yuhongsun96/how-to-augment-llms-with-private-data-29349bd8ae9f """ + @property + def requires_api_key(self) -> bool: + return False + def __init__( self, # Not used here but you probably want a model server that isn't completely open diff --git a/backend/danswer/llm/gpt_4_all.py b/backend/danswer/llm/gpt_4_all.py index 57aeecc32609..316c4e7aacfc 100644 --- a/backend/danswer/llm/gpt_4_all.py +++ b/backend/danswer/llm/gpt_4_all.py @@ -39,6 +39,16 @@ class DanswerGPT4All(LLM): """Option to run an LLM locally, however this is significantly slower and answers tend to be much worse""" + @property + def requires_warm_up(self) -> bool: + """GPT4All models are lazy loaded, load them on server start so that the + first inference isn't extremely delayed""" + return True + + @property + def requires_api_key(self) -> bool: + return False + def __init__( self, timeout: int, diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index 081762ea5bc2..9e56b0934741 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -18,6 +18,10 @@ class LLM(abc.ABC): """Is this model running in memory and needs an initial call to warm it up?""" return False + @property + def requires_api_key(self) -> bool: + return True + @abc.abstractmethod def invoke(self, prompt: LanguageModelInput) -> str: raise NotImplementedError diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 5534a2f09ec0..1b5541c75202 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -1,5 +1,6 @@ from collections.abc import Callable from collections.abc import Iterator +from copy import copy from typing import Any from typing import cast @@ -13,30 +14,61 @@ from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessageChunk from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage +from tiktoken.core import Encoding from danswer.configs.app_configs import LOG_LEVEL from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import MessageType +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.db.models import ChatMessage from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.indexing.models import InferenceChunk from danswer.llm.interfaces import LLM from danswer.utils.logger import setup_logger logger = setup_logger() -_LLM_TOKENIZER: Callable[[str], Any] | None = None +_LLM_TOKENIZER: Any = None +_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None -def get_default_llm_tokenizer() -> Callable: +def get_default_llm_tokenizer() -> Any: """Currently only supports the OpenAI default tokenizer: tiktoken""" global _LLM_TOKENIZER if _LLM_TOKENIZER is None: - _LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base").encode + _LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base") return _LLM_TOKENIZER +def get_default_llm_token_encode() -> Callable[[str], Any]: + global _LLM_TOKENIZER_ENCODE + if _LLM_TOKENIZER_ENCODE is None: + tokenizer = get_default_llm_tokenizer() + if isinstance(tokenizer, Encoding): + return tokenizer.encode # type: ignore + + # Currently only supports OpenAI encoder + raise ValueError("Invalid Encoder selected") + + return _LLM_TOKENIZER_ENCODE + + +def tokenizer_trim_chunks( + chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE +) -> list[InferenceChunk]: + tokenizer = get_default_llm_tokenizer() + new_chunks = copy(chunks) + for ind, chunk in enumerate(new_chunks): + tokens = tokenizer.encode(chunk.content) + if len(tokens) > max_chunk_toks: + new_chunk = copy(chunk) + new_chunk.content = tokenizer.decode(tokens[:max_chunk_toks]) + new_chunks[ind] = new_chunk + return new_chunks + + def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage: if ( msg.message_type == MessageType.SYSTEM diff --git a/backend/danswer/main.py b/backend/danswer/main.py index ef98a360e22e..634e3fb2a63d 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -30,7 +30,7 @@ from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import SKIP_RERANKING from danswer.db.credentials import create_initial_public_credential -from danswer.direct_qa.llm_utils import get_default_qa_model +from danswer.direct_qa.factory import get_default_qa_model from danswer.document_index.factory import get_default_document_index from danswer.server.cc_pair.api import router as cc_pair_router from danswer.server.chat_backend import router as chat_router @@ -179,6 +179,7 @@ def get_application() -> FastAPI: logger.info("Warming up local NLP models.") warm_up_models() qa_model = get_default_qa_model() + # This is for the LLM, most LLMs will not need warming up qa_model.warm_up_model() logger.info("Verifying query preprocessing (NLTK) data is downloaded") diff --git a/backend/danswer/prompts/__init__.py b/backend/danswer/prompts/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/backend/danswer/prompts/constants.py b/backend/danswer/prompts/constants.py new file mode 100644 index 000000000000..e1ba5c47b1eb --- /dev/null +++ b/backend/danswer/prompts/constants.py @@ -0,0 +1,11 @@ +GENERAL_SEP_PAT = "-----" +CODE_BLOCK_PAT = "```\n{}\n```" +QUESTION_PAT = "Query:" +THOUGHT_PAT = "Thought:" +ANSWER_PAT = "Answer:" +ANSWERABLE_PAT = "Answerable:" +FINAL_ANSWER_PAT = "Final Answer:" +UNCERTAINTY_PAT = "?" +QUOTE_PAT = "Quote:" +QUOTES_PAT_PLURAL = "Quotes:" +INVALID_PAT = "Invalid:" diff --git a/backend/danswer/prompts/direct_qa_prompts.py b/backend/danswer/prompts/direct_qa_prompts.py new file mode 100644 index 000000000000..0a41f2ae0fb9 --- /dev/null +++ b/backend/danswer/prompts/direct_qa_prompts.py @@ -0,0 +1,111 @@ +import json + +from danswer.prompts.constants import ANSWER_PAT +from danswer.prompts.constants import GENERAL_SEP_PAT +from danswer.prompts.constants import QUESTION_PAT +from danswer.prompts.constants import QUOTE_PAT +from danswer.prompts.constants import THOUGHT_PAT +from danswer.prompts.constants import UNCERTAINTY_PAT + + +QA_HEADER = """ +You are a question answering system that is constantly learning and improving. +You can process and comprehend vast amounts of text and utilize this knowledge to provide \ +accurate and detailed answers to diverse queries. +""".strip() + + +REQUIRE_JSON = """ +You ALWAYS responds with only a json containing an answer and quotes that support the answer. +Your responses are as INFORMATIVE and DETAILED as possible. +""".strip() + + +JSON_HELPFUL_HINT = """ +Hint: Make the answer as DETAILED as possible and respond in JSON format! \ +Quotes MUST be EXACT substrings from provided documents! +""".strip() + + +# This has to be doubly escaped due to json containing { } which are also used for format strings +EMPTY_SAMPLE_JSON = { + "answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.", + "quotes": [ + "each quote must be UNEDITED and EXACTLY as shown in the context documents!", + "HINT, quotes are not shown to the user!", + ], +} + + +ANSWER_NOT_FOUND_RESPONSE = f'{{"answer": "{UNCERTAINTY_PAT}", "quotes": []}}' + + +# Default json prompt which can reference multiple docs and provide answer + quotes +JSON_PROMPT = f""" +{QA_HEADER} +{REQUIRE_JSON} +{GENERAL_SEP_PAT} +CONTEXT: +{{context_docs_str}} +{GENERAL_SEP_PAT} +SAMPLE_RESPONSE: +``` +{{{json.dumps(EMPTY_SAMPLE_JSON)}}} +``` +{QUESTION_PAT} {{user_query}} +{JSON_HELPFUL_HINT} +""".strip() + + +# Default chain-of-thought style json prompt which uses multiple docs +# This one has a section for the LLM to output some non-answer "thoughts" +# COT (chain-of-thought) flow basically +COT_PROMPT = f""" +{QA_HEADER} +{GENERAL_SEP_PAT} +CONTEXT: +{{context_docs_str}} +{GENERAL_SEP_PAT} +You MUST respond in the following format: +``` +{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer. + +{{{json.dumps(EMPTY_SAMPLE_JSON)}}} +``` + +{QUESTION_PAT} {{user_query}} +{JSON_HELPFUL_HINT} +""".strip() + + +# For weak LLM which only takes one chunk and cannot output json +WEAK_LLM_PROMPT = f""" +Respond to the user query using a reference document. +{GENERAL_SEP_PAT} +Reference Document: +{{single_reference_doc}} +{GENERAL_SEP_PAT} +Answer the user query below based on the reference document above. +Respond with an "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as needed to support the answer.' + +{QUESTION_PAT} {{user_query}} +{ANSWER_PAT} +""".strip() + + +# For weak CHAT LLM which takes one chunk and cannot output json +# The next message should have the user query +# Note, no flow/config currently uses this one +WEAK_CHAT_LLM_PROMPT = f""" +You are a question answering assistant +Respond to the user query with an "{ANSWER_PAT}" section and \ +as many "{QUOTE_PAT}" sections as needed to support the answer. +Answer the user query based on the following document: + +{{first_chunk_content}} +""".strip() + + +# User the following for easy viewing of prompts +if __name__ == "__main__": + print(JSON_PROMPT) # Default prompt used in the Danswer UI flow diff --git a/backend/danswer/prompts/secondary_llm_flows.py b/backend/danswer/prompts/secondary_llm_flows.py new file mode 100644 index 000000000000..d0abec9d0ddc --- /dev/null +++ b/backend/danswer/prompts/secondary_llm_flows.py @@ -0,0 +1,96 @@ +from danswer.prompts.constants import ANSWER_PAT +from danswer.prompts.constants import ANSWERABLE_PAT +from danswer.prompts.constants import GENERAL_SEP_PAT +from danswer.prompts.constants import QUESTION_PAT +from danswer.prompts.constants import THOUGHT_PAT + + +ANSWER_VALIDITY_PROMPT = f""" +You are an assistant to identify invalid query/answer pairs coming from a large language model. +The query/answer pair is invalid if any of the following are True: +1. Query is asking for information that varies by person or is subjective. If there is not a \ +globally true answer, the language model should not respond, therefore any answer is invalid. +2. Answer addresses a related but different query. To be helpful, the model may provide provide \ +related information about a query but it won't match what the user is asking, this is invalid. +3. Answer is just some form of "I don\'t know" or "not enough information" without significant \ +additional useful information. Explaining why it does not know or cannot answer is invalid. + +{QUESTION_PAT} {{user_query}} +{ANSWER_PAT} {{llm_answer}} + +------------------------ +You MUST answer in EXACTLY the following format: +``` +1. True or False +2. True or False +3. True or False +Final Answer: Valid or Invalid +``` + +Hint: Remember, if ANY of the conditions are True, it is Invalid. +""".strip() + + +TIME_FILTER_PROMPT = """ +You are a tool to identify time filters to apply to a user query for a downstream search \ +application. The downstream application is able to use a recency bias or apply a hard cutoff to \ +remove all documents before the cutoff. Identify the correct filters to apply for the user query. + +Always answer with ONLY a json which contains the keys "filter_type", "filter_value", \ +"value_multiple" and "date". + +The valid values for "filter_type" are "hard cutoff", "favors recent", or "not time sensitive". +The valid values for "filter_value" are "day", "week", "month", "quarter", "half", or "year". +The valid values for "value_multiple" is any number. +The valid values for "date" is a date in format MM/DD/YYYY. +""".strip() + + +ANSWERABLE_PROMPT = f""" +You are a helper tool to determine if a query is answerable using retrieval augmented generation. +The main system will try to answer the user query based on ONLY the top 5 most relevant \ +documents found from search. +Sources contain both up to date and proprietary information for the specific team. +For named or unknown entities, assume the search will find relevant and consistent knowledge \ +about the entity. +The system is not tuned for writing code. +The system is not tuned for interfacing with structured data via query languages like SQL. +If the question might not require code or query language, then assume it can be answered without \ +code or query language. +Determine if that system should attempt to answer. +"ANSWERABLE" must be exactly "True" or "False" + +{GENERAL_SEP_PAT} + +{QUESTION_PAT.upper()} What is this Slack channel about? +``` +{THOUGHT_PAT.upper()} First the system must determine which Slack channel is being referred to. \ +By fetching 5 documents related to Slack channel contents, it is not possible to determine which \ +Slack channel the user is referring to. +{ANSWERABLE_PAT.upper()} False +``` + +{QUESTION_PAT.upper()} Danswer is unreachable. +``` +{THOUGHT_PAT.upper()} The system searches documents related to Danswer being unreachable. \ +Assuming the documents from search contains situations where Danswer is not reachable and \ +contains a fix, the query may be answerable. +{ANSWERABLE_PAT.upper()} True +``` + +{QUESTION_PAT.upper()} How many customers do we have +``` +{THOUGHT_PAT.upper()} Assuming the retrieved documents contain up to date customer acquisition \ +information including a list of customers, the query can be answered. It is important to note \ +that if the information only exists in a SQL database, the system is unable to execute SQL and \ +won't find an answer. +{ANSWERABLE_PAT.upper()} True +``` + +{QUESTION_PAT.upper()} {{user_query}} +""".strip() + + +# User the following for easy viewing of prompts +if __name__ == "__main__": + print(ANSWERABLE_PROMPT) diff --git a/backend/danswer/secondary_llm_flows/answer_validation.py b/backend/danswer/secondary_llm_flows/answer_validation.py index 189dece6f5ae..4ef8e8bef318 100644 --- a/backend/danswer/secondary_llm_flows/answer_validation.py +++ b/backend/danswer/secondary_llm_flows/answer_validation.py @@ -1,8 +1,6 @@ -from danswer.configs.constants import ANSWER_PAT -from danswer.configs.constants import CODE_BLOCK_PAT -from danswer.configs.constants import QUESTION_PAT -from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt from danswer.llm.factory import get_default_llm +from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.prompts.secondary_llm_flows import ANSWER_VALIDITY_PROMPT from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time @@ -27,31 +25,11 @@ def get_answer_validity( # f"{FINAL_ANSWER_PAT} Valid or Invalid" # ) - format_demo = ( - "1. True or False\n" - "2. True or False\n" - "3. True or False\n" - "Final Answer: Valid or Invalid" - ) - messages = [ { "role": "user", - "content": ( - "You are an assistant to identify invalid query/answer pairs coming from a large language model. " - "The query/answer pair is invalid if any of the following are True:\n" - "1. Query is asking for information that varies by person or is subjective." - "If there is not a globally true answer, the language model should not respond, " - "therefore any answer is invalid.\n" - "2. Answer addresses a related but different query. Sometimes to be helpful, the model will " - "provide related information about a query but it won't match what the user is asking, " - "this is invalid.\n" - '3. Answer is just some form of "I don\'t know" or "not enough information" without significant ' - "additional useful information. Explaining why it does not know or cannot answer is invalid.\n\n" - f"{QUESTION_PAT} {query}\n{ANSWER_PAT} {answer}" - "\n\n------------------------\n" - f"You MUST answer in EXACTLY the following format:{CODE_BLOCK_PAT.format(format_demo)}\n" - "Hint: Remember, if ANY of the conditions are True, it is Invalid." + "content": ANSWER_VALIDITY_PROMPT.format( + user_query=query, llm_answer=answer ), }, ] diff --git a/backend/danswer/secondary_llm_flows/extract_filters.py b/backend/danswer/secondary_llm_flows/extract_filters.py index df4efe7ae9c6..00cd024df028 100644 --- a/backend/danswer/secondary_llm_flows/extract_filters.py +++ b/backend/danswer/secondary_llm_flows/extract_filters.py @@ -8,6 +8,7 @@ from dateutil.parser import parse from danswer.configs.app_configs import DISABLE_TIME_FILTER_EXTRACTION from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.prompts.secondary_llm_flows import TIME_FILTER_PROMPT from danswer.server.models import QuestionRequest from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time @@ -50,19 +51,7 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]: messages = [ { "role": "system", - "content": "You are a tool to identify time filters to apply to a user query for " - "a downstream search application. The downstream application is able to " - "use a recency bias or apply a hard cutoff to remove all documents " - "before the cutoff. Identify the correct filters to apply for the user " - "query.\n\n" - "Always answer with ONLY a json which contains the keys " - '"filter_type", "filter_value", "value_multiple" and "date".\n\n' - 'The valid values for "filter_type" are "hard cutoff", ' - '"favors recent", or "not time sensitive".\n' - 'The valid values for "filter_value" are "day", "week", "month", ' - '"quarter", "half", or "year".\n' - 'The valid values for "value_multiple" is any number.\n' - 'The valid values for "date" is a date in format MM/DD/YYYY.', + "content": TIME_FILTER_PROMPT, }, { "role": "user", diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 22e8ba15ddfb..0ebc9c587215 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -1,12 +1,13 @@ import re from collections.abc import Iterator -from danswer.configs.constants import CODE_BLOCK_PAT -from danswer.configs.constants import GENERAL_SEP_PAT from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import StreamingError -from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt from danswer.llm.factory import get_default_llm +from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.prompts.constants import ANSWERABLE_PAT +from danswer.prompts.constants import THOUGHT_PAT +from danswer.prompts.secondary_llm_flows import ANSWERABLE_PROMPT from danswer.server.models import QueryValidationResponse from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger @@ -14,55 +15,11 @@ from danswer.utils.logger import setup_logger logger = setup_logger() -QUERY_PAT = "QUERY: " -REASONING_PAT = "THOUGHT: " -ANSWERABLE_PAT = "ANSWERABLE: " - - def get_query_validation_messages(user_query: str) -> list[dict[str, str]]: - ambiguous_example_question = f"{QUERY_PAT}What is this Slack channel about?" - ambiguous_example_answer = ( - f"{REASONING_PAT}First the system must determine which Slack channel is " - f"being referred to. By fetching 5 documents related to Slack channel contents, " - f"it is not possible to determine which Slack channel the user is referring to.\n" - f"{ANSWERABLE_PAT}False" - ) - - debug_example_question = f"{QUERY_PAT}Danswer is unreachable." - debug_example_answer = ( - f"{REASONING_PAT}The system searches documents related to Danswer being " - f"unreachable. Assuming the documents from search contains situations where " - f"Danswer is not reachable and contains a fix, the query may be answerable.\n" - f"{ANSWERABLE_PAT}True" - ) - - up_to_date_example_question = f"{QUERY_PAT}How many customers do we have" - up_to_date_example_answer = ( - f"{REASONING_PAT}Assuming the retrieved documents contain up to date customer " - f"acquisition information including a list of customers, the query can be answered. " - f"It is important to note that if the information only exists in a database, " - f"the system is unable to execute SQL and won't find an answer." - f"\n{ANSWERABLE_PAT}True" - ) - messages = [ { "role": "user", - "content": "You are a helper tool to determine if a query is answerable using retrieval augmented " - f"generation.\nThe main system will try to answer the user query based on ONLY the top 5 most relevant " - f"documents found from search.\nSources contain both up to date and proprietary information for " - f"the specific team.\nFor named or unknown entities, assume the search will find " - f"relevant and consistent knowledge about the entity.\n" - f"The system is not tuned for writing code.\n" - f"The system is not tuned for interfacing with structured data via query languages like SQL.\n" - f"If the question might not require code or query language, " - f"then assume it can be answered without code or query language.\n" - f"Determine if that system should attempt to answer.\n" - f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"\n{GENERAL_SEP_PAT}\n' - f"{ambiguous_example_question}{CODE_BLOCK_PAT.format(ambiguous_example_answer)}\n" - f"{debug_example_question}{CODE_BLOCK_PAT.format(debug_example_answer)}\n" - f"{up_to_date_example_question}{CODE_BLOCK_PAT.format(up_to_date_example_answer)}\n" - f"{QUERY_PAT + user_query}", + "content": ANSWERABLE_PROMPT.format(user_query=user_query), }, ] @@ -71,14 +28,14 @@ def get_query_validation_messages(user_query: str) -> list[dict[str, str]]: def extract_answerability_reasoning(model_raw: str) -> str: reasoning_match = re.search( - f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_raw, re.DOTALL + f"{THOUGHT_PAT.upper()}(.*?){ANSWERABLE_PAT.upper()}", model_raw, re.DOTALL ) reasoning_text = reasoning_match.group(1).strip() if reasoning_match else "" return reasoning_text def extract_answerability_bool(model_raw: str) -> bool: - answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_raw) + answerable_match = re.search(f"{ANSWERABLE_PAT.upper()}(.+)", model_raw) answerable_text = answerable_match.group(1).strip() if answerable_match else "" answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False return answerable @@ -106,13 +63,13 @@ def stream_query_answerability(user_query: str) -> Iterator[str]: for token in tokens: model_output = model_output + token - if ANSWERABLE_PAT in model_output: + if ANSWERABLE_PAT.upper() in model_output: continue - if not reasoning_pat_found and REASONING_PAT in model_output: + if not reasoning_pat_found and THOUGHT_PAT.upper() in model_output: reasoning_pat_found = True - reason_ind = model_output.find(REASONING_PAT) - remaining = model_output[reason_ind + len(REASONING_PAT) :] + reason_ind = model_output.find(THOUGHT_PAT.upper()) + remaining = model_output[reason_ind + len(THOUGHT_PAT.upper()) :] if remaining: yield get_json_line( DanswerAnswerPiece(answer_piece=remaining).dict() @@ -121,7 +78,7 @@ def stream_query_answerability(user_query: str) -> Iterator[str]: if reasoning_pat_found: hold_answerable = hold_answerable + token - if hold_answerable == ANSWERABLE_PAT[: len(hold_answerable)]: + if hold_answerable == ANSWERABLE_PAT.upper()[: len(hold_answerable)]: continue yield get_json_line( DanswerAnswerPiece(answer_piece=hold_answerable).dict() diff --git a/backend/danswer/server/chat_backend.py b/backend/danswer/server/chat_backend.py index 2f706887b528..3586a3a09d6e 100644 --- a/backend/danswer/server/chat_backend.py +++ b/backend/danswer/server/chat_backend.py @@ -24,7 +24,7 @@ from danswer.db.feedback import create_chat_message_feedback from danswer.db.models import ChatMessage from danswer.db.models import User from danswer.direct_qa.interfaces import DanswerAnswerPiece -from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import get_default_llm_token_encode from danswer.secondary_llm_flows.chat_helpers import get_new_chat_name from danswer.server.models import ChatFeedbackRequest from danswer.server.models import ChatMessageDetail @@ -246,7 +246,7 @@ def handle_new_chat_message( parent_edit_number = chat_message.parent_edit_number user_id = user.id if user is not None else None - llm_tokenizer = get_default_llm_tokenizer() + llm_tokenizer = get_default_llm_token_encode() chat_session = fetch_chat_session_by_id(chat_session_id, db_session) persona = ( @@ -351,7 +351,7 @@ def regenerate_message_given_parent( edit_number = parent_message.edit_number user_id = user.id if user is not None else None - llm_tokenizer = get_default_llm_tokenizer() + llm_tokenizer = get_default_llm_token_encode() chat_message = fetch_chat_message( chat_session_id=chat_session_id, diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index d850c7a85a0c..2cde4faf8df6 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -23,7 +23,7 @@ from danswer.db.feedback import fetch_docs_ranked_by_boost from danswer.db.feedback import update_document_boost from danswer.db.feedback import update_document_hidden from danswer.db.models import User -from danswer.direct_qa.llm_utils import get_default_qa_model +from danswer.direct_qa.factory import get_default_qa_model from danswer.document_index.factory import get_default_document_index from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 9c4aa3d84e22..cf455ae2c05b 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -40,6 +40,8 @@ services: - ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} - ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-} - SKIP_RERANKING=${SKIP_RERANKING:-} + - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} + - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} volumes: @@ -89,6 +91,7 @@ services: - ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} - ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-} - SKIP_RERANKING=${SKIP_RERANKING:-} + - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} # Set to debug to get more fine-grained logs