diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index acfb976cdf5b..2f23ef35c048 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -34,9 +34,9 @@ MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1) # Cross Encoder Settings -# This following setting is for non-real-time-flows -SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true" -# This one is for real-time (streaming) flows +ENABLE_RERANKING_ASYNC_FLOW = ( + os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true" +) ENABLE_RERANKING_REAL_TIME_FLOW = ( os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true" ) diff --git a/backend/danswer/danswerbot/slack/blocks.py b/backend/danswer/danswerbot/slack/blocks.py index d26703960fff..191d90b0ce2b 100644 --- a/backend/danswer/danswerbot/slack/blocks.py +++ b/backend/danswer/danswerbot/slack/blocks.py @@ -214,7 +214,9 @@ def build_qa_response_blocks( text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓" ) else: - answer_block = SectionBlock(text=remove_slack_text_interactions(answer)) + answer_processed = remove_slack_text_interactions(answer) + answer_processed = answer_processed.encode("utf-8").decode("unicode_escape") + answer_block = SectionBlock(text=answer_processed) if quotes: quotes_blocks = build_quotes_block(quotes) diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index ba61edafcf43..b8e77e670446 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -86,8 +86,14 @@ def handle_message( Query thrown out by filters due to config does not count as a failure that should be notified Danswer failing to answer/retrieve docs does count and should be notified """ - msg = message_info.msg_content channel = message_info.channel_to_respond + + logger = cast( + logging.Logger, + ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}), + ) + + messages = message_info.thread_messages message_ts_to_respond_to = message_info.msg_to_respond sender_id = message_info.sender bipass_filters = message_info.bipass_filters @@ -95,11 +101,6 @@ def handle_message( engine = get_sqlalchemy_engine() - logger = cast( - logging.Logger, - ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}), - ) - document_set_names: list[str] | None = None persona = channel_config.persona if channel_config else None prompt = None @@ -133,7 +134,7 @@ def handle_message( if ( "questionmark_prefilter" in channel_conf["answer_filters"] - and "?" not in msg + and "?" not in messages[-1].message ): logger.info( "Skipping message since it does not contain a question mark" @@ -223,7 +224,7 @@ def handle_message( # This includes throwing out answer via reflexion answer = _get_answer( DirectQARequest( - query=msg, + messages=messages, prompt_id=prompt.id if prompt else None, persona_id=persona.id if persona is not None else 0, retrieval_options=retrieval_details, @@ -275,7 +276,9 @@ def handle_message( top_docs = retrieval_info.top_documents if not top_docs and not should_respond_even_with_no_docs: - logger.error(f"Unable to answer question: '{msg}' - no documents found") + logger.error( + f"Unable to answer question: '{answer.rephrase}' - no documents found" + ) # Optionally, respond in thread with the error message # Used primarily for debugging purposes if should_respond_with_error_msgs: @@ -296,7 +299,7 @@ def handle_message( return True # If called with the DanswerBot slash command, the question is lost so we have to reshow it - restate_question_block = get_restate_blocks(msg, is_bot_msg) + restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg) answer_blocks = build_qa_response_blocks( message_id=answer.chat_message_id, diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index 8f51c2934716..b4f45d0a2961 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -1,4 +1,3 @@ -import re import time from threading import Event from typing import Any @@ -10,9 +9,10 @@ from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse from sqlalchemy.orm import Session +from danswer.configs.constants import MessageType from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER -from danswer.configs.model_configs import SKIP_RERANKING +from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback @@ -22,9 +22,13 @@ from danswer.danswerbot.slack.tokens import fetch_tokens from danswer.danswerbot.slack.utils import ChannelIdAdapter from danswer.danswerbot.slack.utils import decompose_block_id from danswer.danswerbot.slack.utils import get_channel_name_from_id +from danswer.danswerbot.slack.utils import get_danswer_bot_app_id +from danswer.danswerbot.slack.utils import read_slack_thread +from danswer.danswerbot.slack.utils import remove_danswer_bot_tag from danswer.danswerbot.slack.utils import respond_in_thread from danswer.db.engine import get_sqlalchemy_engine from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.one_shot_answer.models import ThreadMessage from danswer.search.search_nlp_models import warm_up_models from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger @@ -63,7 +67,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool return False if event_type == "message": - bot_tag_id = client.web_client.auth_test().get("user_id") + bot_tag_id = get_danswer_bot_app_id(client.web_client) # DMs with the bot don't pick up the @DanswerBot so we have to keep the # caught events_api if bot_tag_id and bot_tag_id in msg and event.get("channel_type") != "im": @@ -87,8 +91,14 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool message_ts = event.get("ts") thread_ts = event.get("thread_ts") # Pick the root of the thread (if a thread exists) - if thread_ts and message_ts != thread_ts: - channel_specific_logger.info( + # Can respond in thread if it's an "im" directly to Danswer or @DanswerBot is tagged + if ( + thread_ts + and message_ts != thread_ts + and event_type != "app_mention" + and event.get("channel_type") != "im" + ): + channel_specific_logger.debug( "Skipping message since it is not the root of a thread" ) return False @@ -156,18 +166,25 @@ def build_request_details( tagged = event.get("type") == "app_mention" message_ts = event.get("ts") thread_ts = event.get("thread_ts") - bot_tag_id = client.web_client.auth_test().get("user_id") - # Might exist even if not tagged, specifically in the case of @DanswerBot - # in DanswerBot DM channel - msg = re.sub(rf"<@{bot_tag_id}>\s", "", msg) + + msg = remove_danswer_bot_tag(msg, client=client.web_client) if tagged: logger.info("User tagged DanswerBot") + if thread_ts != message_ts and thread_ts is not None: + thread_messages = read_slack_thread( + channel=channel, thread=thread_ts, client=client.web_client + ) + else: + thread_messages = [ + ThreadMessage(message=msg, sender=None, role=MessageType.USER) + ] + return SlackMessageInfo( - msg_content=msg, + thread_messages=thread_messages, channel_to_respond=channel, - msg_to_respond=cast(str, thread_ts or message_ts), + msg_to_respond=cast(str, message_ts or thread_ts), sender=event.get("user") or None, bipass_filters=tagged, is_bot_msg=False, @@ -178,8 +195,10 @@ def build_request_details( msg = req.payload["text"] sender = req.payload["user_id"] + single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER) + return SlackMessageInfo( - msg_content=msg, + thread_messages=[single_msg], channel_to_respond=channel, msg_to_respond=None, sender=sender, @@ -297,7 +316,7 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None: # NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC # without issue. if __name__ == "__main__": - warm_up_models(skip_cross_encoders=SKIP_RERANKING) + warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW) slack_bot_tokens: SlackBotTokens | None = None socket_client: SocketModeClient | None = None diff --git a/backend/danswer/danswerbot/slack/models.py b/backend/danswer/danswerbot/slack/models.py index d6483c7371cc..95cf69ac148f 100644 --- a/backend/danswer/danswerbot/slack/models.py +++ b/backend/danswer/danswerbot/slack/models.py @@ -1,8 +1,10 @@ from pydantic import BaseModel +from danswer.one_shot_answer.models import ThreadMessage + class SlackMessageInfo(BaseModel): - msg_content: str + thread_messages: list[ThreadMessage] channel_to_respond: str msg_to_respond: str | None sender: str | None diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index d0fa9d9cc85f..bdba7bd1534b 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -13,17 +13,34 @@ from slack_sdk.models.blocks import Block from slack_sdk.models.metadata import Metadata from danswer.configs.constants import ID_SEPARATOR +from danswer.configs.constants import MessageType from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import SlackTextCleaner from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID from danswer.danswerbot.slack.tokens import fetch_tokens +from danswer.one_shot_answer.models import ThreadMessage from danswer.utils.logger import setup_logger from danswer.utils.text_processing import replace_whitespaces_w_space logger = setup_logger() +DANSWER_BOT_APP_ID: str | None = None + + +def get_danswer_bot_app_id(web_client: WebClient) -> Any: + global DANSWER_BOT_APP_ID + if DANSWER_BOT_APP_ID is None: + DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id") + return DANSWER_BOT_APP_ID + + +def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str: + bot_tag_id = get_danswer_bot_app_id(web_client=client) + return re.sub(rf"<@{bot_tag_id}>\s", "", message_str) + + class ChannelIdAdapter(logging.LoggerAdapter): """This is used to add the channel ID to all log messages emitted in this file""" @@ -199,3 +216,57 @@ def fetch_userids_from_emails(user_emails: list[str], client: WebClient) -> list ) return user_ids + + +def fetch_user_semantic_id_from_id(user_id: str, client: WebClient) -> str | None: + response = client.users_info(user=user_id) + if not response["ok"]: + return None + + user: dict = cast(dict[Any, dict], response.data).get("user", {}) + + return ( + user.get("real_name") + or user.get("name") + or user.get("profile", {}).get("email") + ) + + +def read_slack_thread( + channel: str, thread: str, client: WebClient +) -> list[ThreadMessage]: + thread_messages: list[ThreadMessage] = [] + response = client.conversations_replies(channel=channel, ts=thread) + replies = cast(dict, response.data).get("messages", []) + for reply in replies: + if "user" in reply and "bot_id" not in reply: + message = remove_danswer_bot_tag(reply["text"], client=client) + user_sem_id = fetch_user_semantic_id_from_id(reply["user"], client) + message_type = MessageType.USER + else: + self_app_id = get_danswer_bot_app_id(client) + + # Only include bot messages from Danswer, other bots are not taken in as context + if self_app_id != reply.get("user"): + continue + + blocks = reply["blocks"] + if len(blocks) <= 1: + continue + + # The useful block is the second one after the header block that says AI Answer + message = reply["blocks"][1]["text"]["text"] + + if message.startswith("_Filters"): + if len(blocks) <= 2: + continue + message = reply["blocks"][2]["text"]["text"] + + user_sem_id = "Assistant" + message_type = MessageType.ASSISTANT + + thread_messages.append( + ThreadMessage(message=message, sender=user_sem_id, role=message_type) + ) + + return thread_messages diff --git a/backend/danswer/main.py b/backend/danswer/main.py index cbda09eb34bb..aeaa028b2249 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -47,7 +47,6 @@ from danswer.db.credentials import create_initial_public_credential from danswer.db.engine import get_sqlalchemy_engine from danswer.document_index.factory import get_default_document_index from danswer.llm.factory import get_default_llm -from danswer.one_shot_answer.factory import get_default_qa_model from danswer.search.search_nlp_models import warm_up_models from danswer.server.danswer_api.ingestion import get_danswer_api_key from danswer.server.danswer_api.ingestion import router as danswer_api_router @@ -261,7 +260,6 @@ def get_application() -> FastAPI: # This is for the LLM, most LLMs will not need warming up get_default_llm().log_model_configs() - get_default_qa_model().warm_up_model() logger.info("Verifying query preprocessing (NLTK) data is downloaded") nltk.download("stopwords", quiet=True) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 740b59299538..c5a4120ae65d 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -28,6 +28,8 @@ from danswer.llm.utils import get_default_llm_token_encode from danswer.one_shot_answer.factory import get_question_answer_model from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse +from danswer.one_shot_answer.models import QueryRephrase +from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SavedSearchDoc @@ -35,6 +37,7 @@ from danswer.search.request_preprocessing import retrieval_preprocessing from danswer.search.search_runner import chunks_to_search_docs from danswer.search.search_runner import full_chunk_search_generator from danswer.secondary_llm_flows.answer_validation import get_answer_validity +from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger @@ -58,7 +61,8 @@ def stream_answer_objects( rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> Iterator[ - QADocsResponse + QueryRephrase + | QADocsResponse | LLMRelevanceFilterResponse | DanswerAnswerPiece | DanswerQuotes @@ -73,6 +77,8 @@ def stream_answer_objects( 4. [always] Details on the final AI response message that is created """ user_id = user.id if user is not None else None + query_msg = query_req.messages[-1] + history = query_req.messages[:-1] chat_session = create_chat_session( db_session=db_session, @@ -90,24 +96,20 @@ def stream_answer_objects( chat_session_id=chat_session.id, db_session=db_session ) - # Create the first User query message - new_user_message = create_new_chat_message( - chat_session_id=chat_session.id, - parent_message=root_message, - prompt_id=query_req.prompt_id, - message=query_req.query, - token_count=len(llm_tokenizer(query_req.query)), - message_type=MessageType.USER, - db_session=db_session, - commit=True, + history_str = combine_message_thread(history) + + rephrased_query = thread_based_query_rephrase( + user_query=query_msg.message, + history_str=history_str, ) + yield QueryRephrase(rephrased_query=rephrased_query) ( retrieval_request, predicted_search_type, predicted_flow, ) = retrieval_preprocessing( - query=query_req.query, + query=rephrased_query, retrieval_details=query_req.retrieval_options, persona=chat_session.persona, user=user, @@ -190,9 +192,25 @@ def stream_answer_objects( llm_version=llm_override, ) + full_prompt_str = qa_model.build_prompt( + query=query_msg.message, history_str=history_str, context_chunks=llm_chunks + ) + + # Create the first User query message + new_user_message = create_new_chat_message( + chat_session_id=chat_session.id, + parent_message=root_message, + prompt_id=query_req.prompt_id, + message=full_prompt_str, + token_count=len(llm_tokenizer(full_prompt_str)), + message_type=MessageType.USER, + db_session=db_session, + commit=True, + ) + response_packets = qa_model.answer_question_stream( - query=query_req.query, - context_docs=llm_chunks, + prompt=full_prompt_str, + llm_context_docs=llm_chunks, metrics_callback=llm_metrics_callback, ) @@ -272,6 +290,8 @@ def get_one_shot_answer( answer = "" for packet in results: + if isinstance(packet, QueryRephrase): + qa_response.rephrase = packet.rephrased_query if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: answer += packet.answer_piece elif isinstance(packet, QADocsResponse): @@ -289,6 +309,11 @@ def get_one_shot_answer( qa_response.answer = answer if enable_reflexion: - qa_response.answer_valid = get_answer_validity(query_req.query, answer) + # Because follow up messages are explicitly tagged, we don't need to verify the answer + if len(query_req.messages) == 1: + first_query = query_req.messages[0].message + qa_response.answer_valid = get_answer_validity(first_query, answer) + else: + qa_response.answer_valid = True return qa_response diff --git a/backend/danswer/one_shot_answer/factory.py b/backend/danswer/one_shot_answer/factory.py index 3dd5f020b080..47be1fd25e36 100644 --- a/backend/danswer/one_shot_answer/factory.py +++ b/backend/danswer/one_shot_answer/factory.py @@ -3,98 +3,42 @@ from danswer.configs.chat_configs import QA_TIMEOUT from danswer.db.models import Prompt from danswer.llm.factory import get_default_llm from danswer.one_shot_answer.interfaces import QAModel -from danswer.one_shot_answer.qa_block import PromptBasedQAHandler from danswer.one_shot_answer.qa_block import QABlock from danswer.one_shot_answer.qa_block import QAHandler from danswer.one_shot_answer.qa_block import SingleMessageQAHandler -from danswer.one_shot_answer.qa_block import SingleMessageScratchpadHandler from danswer.one_shot_answer.qa_block import WeakLLMQAHandler from danswer.utils.logger import setup_logger logger = setup_logger() -def get_default_qa_handler( - chain_of_thought: bool = False, - user_selection: str | None = QA_PROMPT_OVERRIDE, -) -> QAHandler: - if user_selection: - if user_selection.lower() == "default": - return SingleMessageQAHandler() - if user_selection.lower() == "cot": - return SingleMessageScratchpadHandler() - if user_selection.lower() == "weak": - return WeakLLMQAHandler() - - raise ValueError("Invalid Question-Answering prompt selected") - - if chain_of_thought: - return SingleMessageScratchpadHandler() - - return SingleMessageQAHandler() - - -def get_default_qa_model( - api_key: str | None = None, - timeout: int = QA_TIMEOUT, - chain_of_thought: bool = False, -) -> QAModel: - llm = get_default_llm(api_key=api_key, timeout=timeout) - qa_handler = get_default_qa_handler(chain_of_thought=chain_of_thought) - - return QABlock( - llm=llm, - qa_handler=qa_handler, - ) - - -def get_prompt_qa_model( - prompt: Prompt, - api_key: str | None = None, - timeout: int = QA_TIMEOUT, - llm_version: str | None = None, -) -> QAModel: - return QABlock( - llm=get_default_llm( - api_key=api_key, - timeout=timeout, - gen_ai_model_version_override=llm_version, - ), - qa_handler=PromptBasedQAHandler( - system_prompt=prompt.system_prompt, task_prompt=prompt.task_prompt - ), - ) - - def get_question_answer_model( prompt: Prompt | None, api_key: str | None = None, timeout: int = QA_TIMEOUT, chain_of_thought: bool = False, llm_version: str | None = None, + qa_model_version: str | None = QA_PROMPT_OVERRIDE, ) -> QAModel: - if prompt is None and llm_version is not None: - raise RuntimeError( - "Cannot specify llm version for QA model without providing prompt. " - "This flow is only intended for flows with a specified Persona/Prompt." - ) + if chain_of_thought: + raise NotImplementedError("COT has been disabled") - if prompt is not None and chain_of_thought: - raise RuntimeError( - "Cannot choose COT prompt with a customized Prompt object. " - "User can prompt the model to output COT themselves if they want." - ) + system_prompt = prompt.system_prompt if prompt is not None else None + task_prompt = prompt.task_prompt if prompt is not None else None - if prompt is not None: - return get_prompt_qa_model( - prompt=prompt, - api_key=api_key, - timeout=timeout, - llm_version=llm_version, - ) - - return get_default_qa_model( + llm = get_default_llm( api_key=api_key, timeout=timeout, - chain_of_thought=chain_of_thought, + gen_ai_model_version_override=llm_version, ) + + if qa_model_version == "weak": + qa_handler: QAHandler = WeakLLMQAHandler( + system_prompt=system_prompt, task_prompt=task_prompt + ) + else: + qa_handler = SingleMessageQAHandler( + system_prompt=system_prompt, task_prompt=task_prompt + ) + + return QABlock(llm=llm, qa_handler=qa_handler) diff --git a/backend/danswer/one_shot_answer/interfaces.py b/backend/danswer/one_shot_answer/interfaces.py index 6993384a40ae..ca916d699df3 100644 --- a/backend/danswer/one_shot_answer/interfaces.py +++ b/backend/danswer/one_shot_answer/interfaces.py @@ -1,37 +1,26 @@ import abc from collections.abc import Callable -from danswer.chat.models import AnswerQuestionReturn from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import LLMMetricsContainer from danswer.indexing.models import InferenceChunk class QAModel: - @property - def requires_api_key(self) -> bool: - """Is this model protected by security features - Does it need an api key to access the model for inference""" - return True - - def warm_up_model(self) -> None: - """This is called during server start up to load the models into memory - pass if model is accessed via API""" - @abc.abstractmethod - def answer_question( + def build_prompt( self, query: str, - context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionReturn: + history_str: str, + context_chunks: list[InferenceChunk], + ) -> str: raise NotImplementedError @abc.abstractmethod def answer_question_stream( self, - query: str, - context_docs: list[InferenceChunk], + prompt: str, + llm_context_docs: list[InferenceChunk], metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> AnswerQuestionStreamReturn: raise NotImplementedError diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py index 0d25bcbd122d..1e5d94d27c77 100644 --- a/backend/danswer/one_shot_answer/models.py +++ b/backend/danswer/one_shot_answer/models.py @@ -5,11 +5,22 @@ from pydantic import root_validator from danswer.chat.models import DanswerQuotes from danswer.chat.models import QADocsResponse +from danswer.configs.constants import MessageType from danswer.search.models import RetrievalDetails +class QueryRephrase(BaseModel): + rephrased_query: str + + +class ThreadMessage(BaseModel): + message: str + sender: str | None + role: MessageType + + class DirectQARequest(BaseModel): - query: str + messages: list[ThreadMessage] prompt_id: int | None persona_id: int retrieval_options: RetrievalDetails @@ -35,6 +46,7 @@ class DirectQARequest(BaseModel): class OneShotQAResponse(BaseModel): # This is built piece by piece, any of these can be None as the flow could break answer: str | None = None + rephrase: str | None = None quotes: DanswerQuotes | None = None docs: QADocsResponse | None = None llm_chunks_indices: list[int] | None = None diff --git a/backend/danswer/one_shot_answer/qa_block.py b/backend/danswer/one_shot_answer/qa_block.py index a3e0a03c4d36..455b23cb126b 100644 --- a/backend/danswer/one_shot_answer/qa_block.py +++ b/backend/danswer/one_shot_answer/qa_block.py @@ -4,11 +4,7 @@ from collections.abc import Callable from collections.abc import Iterator from typing import cast -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage - from danswer.chat.chat_utils import build_context_str -from danswer.chat.models import AnswerQuestionReturn from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import DanswerAnswer from danswer.chat.models import DanswerAnswerPiece @@ -21,16 +17,21 @@ from danswer.indexing.models import InferenceChunk from danswer.llm.interfaces import LLM from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_default_llm_token_encode -from danswer.llm.utils import tokenizer_trim_chunks from danswer.one_shot_answer.interfaces import QAModel from danswer.one_shot_answer.qa_utils import process_answer from danswer.one_shot_answer.qa_utils import process_model_tokens +from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK from danswer.prompts.direct_qa_prompts import COT_PROMPT +from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK from danswer.prompts.direct_qa_prompts import JSON_PROMPT from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT +from danswer.prompts.direct_qa_prompts import ONE_SHOT_SYSTEM_PROMPT +from danswer.prompts.direct_qa_prompts import ONE_SHOT_TASK_PROMPT from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT +from danswer.prompts.direct_qa_prompts import WEAK_MODEL_SYSTEM_PROMPT +from danswer.prompts.direct_qa_prompts import WEAK_MODEL_TASK_PROMPT from danswer.utils.logger import setup_logger from danswer.utils.text_processing import clean_up_code_blocks from danswer.utils.text_processing import escape_newlines @@ -39,12 +40,6 @@ logger = setup_logger() class QAHandler(abc.ABC): - @abc.abstractmethod - def build_prompt( - self, query: str, context_chunks: list[InferenceChunk] - ) -> list[BaseMessage]: - raise NotImplementedError - @property @abc.abstractmethod def is_json_output(self) -> bool: @@ -54,12 +49,14 @@ class QAHandler(abc.ABC): finetuned to recognize this.""" raise NotImplementedError - def process_llm_output( - self, model_output: str, context_chunks: list[InferenceChunk] - ) -> tuple[DanswerAnswer, DanswerQuotes]: - return process_answer( - model_output, context_chunks, is_json_prompt=self.is_json_output - ) + @abc.abstractmethod + def build_prompt( + self, + query: str, + history_str: str, + context_chunks: list[InferenceChunk], + ) -> str: + raise NotImplementedError def process_llm_token_stream( self, tokens: Iterator[str], context_chunks: list[InferenceChunk] @@ -78,70 +75,137 @@ class WeakLLMQAHandler(QAHandler): output format. """ + def __init__( + self, + system_prompt: str | None, + task_prompt: str | None, + ) -> None: + if not system_prompt and not task_prompt: + self.system_prompt = WEAK_MODEL_SYSTEM_PROMPT + self.task_prompt = WEAK_MODEL_TASK_PROMPT + else: + self.system_prompt = system_prompt or "" + self.task_prompt = task_prompt or "" + + self.task_prompt = self.task_prompt.rstrip() + if self.task_prompt and self.task_prompt[0] != "\n": + self.task_prompt = "\n" + self.task_prompt + @property def is_json_output(self) -> bool: return False def build_prompt( - self, query: str, context_chunks: list[InferenceChunk] - ) -> list[BaseMessage]: - message = WEAK_LLM_PROMPT.format( - user_query=query, single_reference_doc=context_chunks[0].content - ) + self, + query: str, + history_str: str, + context_chunks: list[InferenceChunk], + ) -> str: + context_block = "" + if context_chunks: + context_block = CONTEXT_BLOCK.format( + context_docs_str=context_chunks[0].content + ) - return [HumanMessage(content=message)] + prompt_str = WEAK_LLM_PROMPT.format( + system_prompt=self.system_prompt, + context_block=context_block, + task_prompt=self.task_prompt, + user_query=query, + ) + return prompt_str class SingleMessageQAHandler(QAHandler): + def __init__( + self, + system_prompt: str | None, + task_prompt: str | None, + use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), + ) -> None: + self.use_language_hint = use_language_hint + if not system_prompt and not task_prompt: + self.system_prompt = ONE_SHOT_SYSTEM_PROMPT + self.task_prompt = ONE_SHOT_TASK_PROMPT + else: + self.system_prompt = system_prompt or "" + self.task_prompt = task_prompt or "" + + self.task_prompt = self.task_prompt.rstrip() + if self.task_prompt and self.task_prompt[0] != "\n": + self.task_prompt = "\n" + self.task_prompt + @property def is_json_output(self) -> bool: return True def build_prompt( - self, - query: str, - context_chunks: list[InferenceChunk], - use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), - ) -> list[BaseMessage]: - context_docs_str = build_context_str( - cast(list[LlmDoc | InferenceChunk], context_chunks) - ) + self, query: str, history_str: str, context_chunks: list[InferenceChunk] + ) -> str: + context_block = "" + if context_chunks: + context_docs_str = build_context_str( + cast(list[LlmDoc | InferenceChunk], context_chunks) + ) + context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str) - single_message = JSON_PROMPT.format( - context_docs_str=context_docs_str, + history_block = "" + if history_str: + history_block = HISTORY_BLOCK.format(history_str=history_str) + + full_prompt = JSON_PROMPT.format( + system_prompt=self.system_prompt, + context_block=context_block, + history_block=history_block, + task_prompt=self.task_prompt, user_query=query, - language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", + language_hint_or_none=LANGUAGE_HINT.strip() + if self.use_language_hint + else "", ).strip() - - prompt: list[BaseMessage] = [HumanMessage(content=single_message)] - return prompt + return full_prompt +# This one isn't used, currently only streaming prompts are used class SingleMessageScratchpadHandler(QAHandler): + def __init__( + self, + system_prompt: str | None, + task_prompt: str | None, + use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), + ) -> None: + self.use_language_hint = use_language_hint + if not system_prompt and not task_prompt: + self.system_prompt = ONE_SHOT_SYSTEM_PROMPT + self.task_prompt = ONE_SHOT_TASK_PROMPT + else: + self.system_prompt = system_prompt or "" + self.task_prompt = task_prompt or "" + + self.task_prompt = self.task_prompt.rstrip() + if self.task_prompt and self.task_prompt[0] != "\n": + self.task_prompt = "\n" + self.task_prompt + @property def is_json_output(self) -> bool: - # Even though the full LLM output isn't a valid json - # only the valid json portion is kept and passed along - # therefore it is treated as a json output return True def build_prompt( - self, - query: str, - context_chunks: list[InferenceChunk], - use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), - ) -> list[BaseMessage]: + self, query: str, history_str: str, context_chunks: list[InferenceChunk] + ) -> str: context_docs_str = build_context_str( cast(list[LlmDoc | InferenceChunk], context_chunks) ) - single_message = COT_PROMPT.format( + # Outdated + prompt = COT_PROMPT.format( context_docs_str=context_docs_str, user_query=query, - language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", + language_hint_or_none=LANGUAGE_HINT.strip() + if self.use_language_hint + else "", ).strip() - prompt: list[BaseMessage] = [HumanMessage(content=single_message)] return prompt def process_llm_output( @@ -170,68 +234,22 @@ class SingleMessageScratchpadHandler(QAHandler): ) -class PromptBasedQAHandler(QAHandler): - def __init__(self, system_prompt: str, task_prompt: str) -> None: - self.system_prompt = system_prompt - self.task_prompt = task_prompt - - @property - def is_json_output(self) -> bool: - return False - - def build_prompt( - self, - query: str, - context_chunks: list[InferenceChunk], - ) -> list[BaseMessage]: - context_docs_str = build_context_str( - cast(list[LlmDoc | InferenceChunk], context_chunks) - ) - - if not context_chunks: - single_message = PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( - user_query=query, - system_prompt=self.system_prompt, - task_prompt=self.task_prompt, - ).strip() - else: - single_message = PARAMATERIZED_PROMPT.format( - context_docs_str=context_docs_str, - user_query=query, - system_prompt=self.system_prompt, - task_prompt=self.task_prompt, - ).strip() - - prompt: list[BaseMessage] = [HumanMessage(content=single_message)] - return prompt - - def build_dummy_prompt(self, retrieval_disabled: bool) -> str: - if retrieval_disabled: - return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( - user_query="", - system_prompt=self.system_prompt, - task_prompt=self.task_prompt, - ).strip() - - return PARAMATERIZED_PROMPT.format( - context_docs_str="", +def build_dummy_prompt( + system_prompt: str, task_prompt: str, retrieval_disabled: bool +) -> str: + if retrieval_disabled: + return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( user_query="", - system_prompt=self.system_prompt, - task_prompt=self.task_prompt, + system_prompt=system_prompt, + task_prompt=task_prompt, ).strip() - def process_llm_output( - self, model_output: str, context_chunks: list[InferenceChunk] - ) -> tuple[DanswerAnswer, DanswerQuotes]: - return DanswerAnswer(answer=model_output), DanswerQuotes(quotes=[]) - - def process_llm_token_stream( - self, tokens: Iterator[str], context_chunks: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - for token in tokens: - yield DanswerAnswerPiece(answer_piece=token) - - yield DanswerQuotes(quotes=[]) + return PARAMATERIZED_PROMPT.format( + context_docs_str="", + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() class QABlock(QAModel): @@ -239,64 +257,30 @@ class QABlock(QAModel): self._llm = llm self._qa_handler = qa_handler - @property - def requires_api_key(self) -> bool: - return self._llm.requires_api_key - - def warm_up_model(self) -> None: - """This is called during server start up to load the models into memory - in case the chosen LLM is not accessed via API""" - if self._llm.requires_warm_up: - logger.info("Warming up LLM with a first inference") - self._llm.invoke("Ignore this!") - - def answer_question( + def build_prompt( self, query: str, - context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionReturn: - trimmed_context_docs = tokenizer_trim_chunks(context_docs) - prompt = self._qa_handler.build_prompt(query, trimmed_context_docs) - model_out = self._llm.invoke(prompt) - - if metrics_callback is not None: - prompt_tokens = sum( - [ - check_number_of_tokens( - text=str(p.content), encode_fn=get_default_llm_token_encode() - ) - for p in prompt - ] - ) - - response_tokens = check_number_of_tokens( - text=model_out, encode_fn=get_default_llm_token_encode() - ) - - metrics_callback( - LLMMetricsContainer( - prompt_tokens=prompt_tokens, response_tokens=response_tokens - ) - ) - - return self._qa_handler.process_llm_output(model_out, trimmed_context_docs) + history_str: str, + context_chunks: list[InferenceChunk], + ) -> str: + prompt = self._qa_handler.build_prompt( + query=query, history_str=history_str, context_chunks=context_chunks + ) + return prompt def answer_question_stream( self, - query: str, - context_docs: list[InferenceChunk], + prompt: str, + llm_context_docs: list[InferenceChunk], metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> AnswerQuestionStreamReturn: - trimmed_context_docs = tokenizer_trim_chunks(context_docs) - prompt = self._qa_handler.build_prompt(query, trimmed_context_docs) tokens_stream = self._llm.stream(prompt) captured_tokens = [] try: for answer_piece in self._qa_handler.process_llm_token_stream( - iter(tokens_stream), trimmed_context_docs + iter(tokens_stream), llm_context_docs ): if ( isinstance(answer_piece, DanswerAnswerPiece) @@ -309,13 +293,8 @@ class QABlock(QAModel): yield StreamingError(error=str(e)) if metrics_callback is not None: - prompt_tokens = sum( - [ - check_number_of_tokens( - text=str(p.content), encode_fn=get_default_llm_token_encode() - ) - for p in prompt - ] + prompt_tokens = check_number_of_tokens( + text=str(prompt), encode_fn=get_default_llm_token_encode() ) response_tokens = check_number_of_tokens( diff --git a/backend/danswer/one_shot_answer/qa_utils.py b/backend/danswer/one_shot_answer/qa_utils.py index ce40f176b2f7..a26cf03ef4f4 100644 --- a/backend/danswer/one_shot_answer/qa_utils.py +++ b/backend/danswer/one_shot_answer/qa_utils.py @@ -1,5 +1,6 @@ import math import re +from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterator from json.decoder import JSONDecodeError @@ -13,7 +14,11 @@ from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import DanswerQuote from danswer.chat.models import DanswerQuotes from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT +from danswer.configs.constants import MessageType +from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.indexing.models import InferenceChunk +from danswer.llm.utils import get_default_llm_token_encode +from danswer.one_shot_answer.models import ThreadMessage from danswer.prompts.constants import ANSWER_PAT from danswer.prompts.constants import QUOTE_PAT from danswer.prompts.constants import UNCERTAINTY_PAT @@ -270,3 +275,41 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: """Mock streaming by generating the passed in model output, character by character""" for token in model_out: yield token + + +def combine_message_thread( + messages: list[ThreadMessage], + token_limit: int | None = GEN_AI_HISTORY_CUTOFF, + llm_tokenizer: Callable | None = None, +) -> str: + """Used to create a single combined message context from threads""" + message_strs: list[str] = [] + total_token_count = 0 + if llm_tokenizer is None: + llm_tokenizer = get_default_llm_token_encode() + + for message in reversed(messages): + if message.role == MessageType.USER: + role_str = message.role.value.upper() + if message.sender: + role_str += " " + message.sender + else: + # Since other messages might have the user identifying information + # better to use Unknown for symmetry + role_str += " Unknown" + else: + role_str = message.role.value.upper() + + msg_str = f"{role_str}:\n{message.message}" + message_token_count = len(llm_tokenizer(msg_str)) + + if ( + token_limit is not None + and total_token_count + message_token_count > token_limit + ): + break + + message_strs.insert(0, msg_str) + total_token_count += message_token_count + + return "\n\n".join(message_strs) diff --git a/backend/danswer/prompts/constants.py b/backend/danswer/prompts/constants.py index 089f17905bbb..74e488aeb200 100644 --- a/backend/danswer/prompts/constants.py +++ b/backend/danswer/prompts/constants.py @@ -1,6 +1,7 @@ GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator CODE_BLOCK_PAT = "```\n{}\n```" QUESTION_PAT = "Query:" +FINAL_QUERY_PAT = "Final Query:" THOUGHT_PAT = "Thought:" ANSWER_PAT = "Answer:" ANSWERABLE_PAT = "Answerable:" diff --git a/backend/danswer/prompts/direct_qa_prompts.py b/backend/danswer/prompts/direct_qa_prompts.py index a52680c9f382..ddfdf2e08975 100644 --- a/backend/danswer/prompts/direct_qa_prompts.py +++ b/backend/danswer/prompts/direct_qa_prompts.py @@ -2,24 +2,36 @@ # It is used also for the one shot direct QA flow import json -from danswer.prompts.constants import ANSWER_PAT +from danswer.prompts.constants import FINAL_QUERY_PAT from danswer.prompts.constants import GENERAL_SEP_PAT from danswer.prompts.constants import QUESTION_PAT -from danswer.prompts.constants import QUOTE_PAT from danswer.prompts.constants import THOUGHT_PAT from danswer.prompts.constants import UNCERTAINTY_PAT -QA_HEADER = """ +ONE_SHOT_SYSTEM_PROMPT = """ You are a question answering system that is constantly learning and improving. You can process and comprehend vast amounts of text and utilize this knowledge to provide \ accurate and detailed answers to diverse queries. """.strip() +ONE_SHOT_TASK_PROMPT = """ +Answer the final query below taking into account the context above where relevant. \ +Ignore any provided context that is not relevant to the query. +""".strip() + + +WEAK_MODEL_SYSTEM_PROMPT = """ +Respond to the user query using the following reference document. +""".lstrip() + +WEAK_MODEL_TASK_PROMPT = """ +Answer the user query below based on the reference document above. +""" + REQUIRE_JSON = """ -You ALWAYS responds with only a json containing an answer and quotes that support the answer. -Your responses are as INFORMATIVE and DETAILED as possible. +You ALWAYS responds with ONLY a JSON containing an answer and quotes that support the answer. """.strip() @@ -34,6 +46,21 @@ IMPORTANT: Respond in the same language as my query! """ +CONTEXT_BLOCK = f""" +REFERENCE DOCUMENTS: +{GENERAL_SEP_PAT} +{{context_docs_str}} +{GENERAL_SEP_PAT} +""" + +HISTORY_BLOCK = f""" +CONVERSATION HISTORY: +{GENERAL_SEP_PAT} +{{history_str}} +{GENERAL_SEP_PAT} +""" + + # This has to be doubly escaped due to json containing { } which are also used for format strings EMPTY_SAMPLE_JSON = { "answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.", @@ -48,44 +75,22 @@ ANSWER_NOT_FOUND_RESPONSE = f'{{"answer": "{UNCERTAINTY_PAT}", "quotes": []}}' # Default json prompt which can reference multiple docs and provide answer + quotes +# system_like_header is similar to system message, can be user provided or defaults to QA_HEADER +# context/history blocks are for context documents and conversation history, they can be blank +# task prompt is the task message of the prompt, can be blank, there is no default JSON_PROMPT = f""" -{QA_HEADER} +{{system_prompt}} {REQUIRE_JSON} +{{context_block}}{{history_block}}{{task_prompt}} -CONTEXT: -{GENERAL_SEP_PAT} -{{context_docs_str}} -{GENERAL_SEP_PAT} - -SAMPLE_RESPONSE: +SAMPLE RESPONSE: ``` {{{json.dumps(EMPTY_SAMPLE_JSON)}}} ``` -{QUESTION_PAT.upper()} {{user_query}} -{JSON_HELPFUL_HINT} -{{language_hint_or_none}} -""".strip() +{FINAL_QUERY_PAT.upper()} +{{user_query}} -# Default chain-of-thought style json prompt which uses multiple docs -# This one has a section for the LLM to output some non-answer "thoughts" -# COT (chain-of-thought) flow basically -COT_PROMPT = f""" -{QA_HEADER} - -CONTEXT: -{GENERAL_SEP_PAT} -{{context_docs_str}} -{GENERAL_SEP_PAT} - -You MUST respond in the following format: -``` -{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer. - -{{{json.dumps(EMPTY_SAMPLE_JSON)}}} -``` - -{QUESTION_PAT.upper()} {{user_query}} {JSON_HELPFUL_HINT} {{language_hint_or_none}} """.strip() @@ -94,35 +99,17 @@ You MUST respond in the following format: # For weak LLM which only takes one chunk and cannot output json # Also not requiring quotes as it tends to not work WEAK_LLM_PROMPT = f""" -Respond to the user query using the following reference document. - -Reference Document: -{GENERAL_SEP_PAT} -{{single_reference_doc}} -{GENERAL_SEP_PAT} - -Answer the user query below based on the reference document above. +{{system_prompt}} +{{context_block}} +{{task_prompt}} {QUESTION_PAT.upper()} {{user_query}} """.strip() -# For weak CHAT LLM which takes one chunk and cannot output json -# The next message should have the user query -# Note, no flow/config currently uses this one -WEAK_CHAT_LLM_PROMPT = f""" -You are a question answering assistant -Respond to the user query with an "{ANSWER_PAT}" section and \ -as many "{QUOTE_PAT}" sections as needed to support the answer. -Answer the user query based on the following document: - -{{first_chunk_content}} -""".strip() - - -# Parameterized prompt which allows the user to specify their -# own system / task prompt +# This is only for visualization for the users to specify their own prompts +# The actual flow does not work like this PARAMATERIZED_PROMPT = f""" {{system_prompt}} @@ -147,6 +134,31 @@ RESPONSE: """.strip() +# CURRENTLY DISABLED, CANNOT USE THIS ONE +# Default chain-of-thought style json prompt which uses multiple docs +# This one has a section for the LLM to output some non-answer "thoughts" +# COT (chain-of-thought) flow basically +COT_PROMPT = f""" +{ONE_SHOT_SYSTEM_PROMPT} + +CONTEXT: +{GENERAL_SEP_PAT} +{{context_docs_str}} +{GENERAL_SEP_PAT} + +You MUST respond in the following format: +``` +{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer. + +{{{json.dumps(EMPTY_SAMPLE_JSON)}}} +``` + +{QUESTION_PAT.upper()} {{user_query}} +{JSON_HELPFUL_HINT} +{{language_hint_or_none}} +""".strip() + + # User the following for easy viewing of prompts if __name__ == "__main__": print(JSON_PROMPT) # Default prompt used in the Danswer UI flow diff --git a/backend/danswer/search/request_preprocessing.py b/backend/danswer/search/request_preprocessing.py index 9af5da12451e..ee4c1353e0d8 100644 --- a/backend/danswer/search/request_preprocessing.py +++ b/backend/danswer/search/request_preprocessing.py @@ -3,8 +3,8 @@ from sqlalchemy.orm import Session from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER +from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW -from danswer.configs.model_configs import SKIP_RERANKING from danswer.db.models import Persona from danswer.db.models import User from danswer.search.access_filters import build_access_filters_for_user @@ -31,7 +31,7 @@ def retrieval_preprocessing( bypass_acl: bool = False, include_query_intent: bool = True, skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW, - skip_rerank_non_realtime: bool = SKIP_RERANKING, + skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW, disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION, disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER, diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index ee4b5d8cf9bd..1be27d545478 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -59,6 +59,22 @@ def multilingual_query_expansion( return query_rephrases +def get_contextual_rephrase_messages( + question: str, + history_str: str, +) -> list[dict[str, str]]: + messages = [ + { + "role": "user", + "content": HISTORY_QUERY_REPHRASE.format( + question=question, chat_history=history_str + ), + }, + ] + + return messages + + def history_based_query_rephrase( query_message: ChatMessage, history: list[ChatMessage], @@ -66,21 +82,6 @@ def history_based_query_rephrase( size_heuristic: int = 200, punctuation_heuristic: int = 10, ) -> str: - def _get_history_rephrase_messages( - question: str, - history_str: str, - ) -> list[dict[str, str]]: - messages = [ - { - "role": "user", - "content": HISTORY_QUERY_REPHRASE.format( - question=question, chat_history=history_str - ), - }, - ] - - return messages - user_query = cast(str, query_message.message) if not user_query: @@ -98,7 +99,38 @@ def history_based_query_rephrase( history_str = combine_message_chain(history) - prompt_msgs = _get_history_rephrase_messages( + prompt_msgs = get_contextual_rephrase_messages( + question=user_query, history_str=history_str + ) + + if llm is None: + llm = get_default_llm() + + filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) + rephrased_query = llm.invoke(filled_llm_prompt) + + logger.debug(f"Rephrased combined query: {rephrased_query}") + + return rephrased_query + + +def thread_based_query_rephrase( + user_query: str, + history_str: str, + llm: LLM | None = None, + size_heuristic: int = 200, + punctuation_heuristic: int = 10, +) -> str: + if not history_str: + return user_query + + if len(user_query) >= size_heuristic: + return user_query + + if count_punctuation(user_query) >= punctuation_heuristic: + return user_query + + prompt_msgs = get_contextual_rephrase_messages( question=user_query, history_str=history_str ) diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index b6879bd1e453..fc3d2ae70630 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -15,7 +15,7 @@ from danswer.db.chat import upsert_persona from danswer.db.document_set import get_document_sets_by_ids from danswer.db.engine import get_session from danswer.db.models import User -from danswer.one_shot_answer.qa_block import PromptBasedQAHandler +from danswer.one_shot_answer.qa_block import build_dummy_prompt from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PromptTemplateResponse @@ -149,9 +149,11 @@ def build_final_template_prompt( _: User | None = Depends(current_user), ) -> PromptTemplateResponse: return PromptTemplateResponse( - final_prompt_template=PromptBasedQAHandler( - system_prompt=system_prompt, task_prompt=task_prompt - ).build_dummy_prompt(retrieval_disabled=retrieval_disabled) + final_prompt_template=build_dummy_prompt( + system_prompt=system_prompt, + task_prompt=task_prompt, + retrieval_disabled=retrieval_disabled, + ) ) diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index baf37df7e576..ad5dd6f6e964 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -163,9 +163,8 @@ def get_answer_with_quote( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StreamingResponse: - logger.info( - f"Received query for one shot answer with quotes: {query_request.query}" - ) + query = query_request.messages[0].message + logger.info(f"Received query for one shot answer with quotes: {query}") packets = stream_one_shot_answer( query_req=query_request, user=user, db_session=db_session ) diff --git a/backend/tests/regression/answer_quality/eval_direct_qa.py b/backend/tests/regression/answer_quality/eval_direct_qa.py index 7ea908889035..963676e078ea 100644 --- a/backend/tests/regression/answer_quality/eval_direct_qa.py +++ b/backend/tests/regression/answer_quality/eval_direct_qa.py @@ -9,9 +9,11 @@ import yaml from sqlalchemy.orm import Session from danswer.chat.models import LLMMetricsContainer +from danswer.configs.constants import MessageType from danswer.db.engine import get_sqlalchemy_engine from danswer.one_shot_answer.answer_question import get_one_shot_answer from danswer.one_shot_answer.models import DirectQARequest +from danswer.one_shot_answer.models import ThreadMessage from danswer.search.models import IndexFilters from danswer.search.models import OptionalSearchSetting from danswer.search.models import RerankMetricsContainer @@ -84,8 +86,10 @@ def get_answer_for_question( access_control_list=None, ) + messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)] + new_message_request = DirectQARequest( - query=query, + messages=messages, prompt_id=0, persona_id=0, retrieval_options=RetrievalDetails( diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 32e4ddd1566d..263253fcbf13 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -121,7 +121,7 @@ services: - SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-} - ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} - ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-} - - SKIP_RERANKING=${SKIP_RERANKING:-} + - ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-} diff --git a/deployment/docker_compose/env.multilingual.template b/deployment/docker_compose/env.multilingual.template index 2083bddd29cc..5b310b742900 100644 --- a/deployment/docker_compose/env.multilingual.template +++ b/deployment/docker_compose/env.multilingual.template @@ -19,9 +19,6 @@ NORMALIZE_EMBEDDINGS="True" SIM_SCORE_RANGE_LOW="0.6" SIM_SCORE_RANGE_LOW="0.8" -# No recent multilingual reranking models small enough to run on CPU, so turning it off -SKIP_RERANKING="True" - # Use LLM to determine if chunks are relevant to the query # may not work well for languages that do not have much training data in the LLM training set DISABLE_LLM_CHUNK_FILTER="True" diff --git a/web/src/lib/search/streamingQa.ts b/web/src/lib/search/streamingQa.ts index 18ee2d7df494..f51a226523cb 100644 --- a/web/src/lib/search/streamingQa.ts +++ b/web/src/lib/search/streamingQa.ts @@ -32,10 +32,17 @@ export const searchRequestStreamed = async ({ let relevantDocuments: DanswerDocument[] | null = null; try { const filters = buildFilters(sources, documentSets, timeRange); + + const threadMessage = { + message: query, + sender: null, + role: "user" + }; + const response = await fetch("/api/query/stream-answer-with-quote", { method: "POST", body: JSON.stringify({ - query, + messages: [threadMessage], persona_id: persona.id, prompt_id: persona.id === 0 ? null : persona.prompts[0]?.id, retrieval_options: {