DanswerBot Chat (#831)

This commit is contained in:
Yuhong Sun
2023-12-17 18:18:48 -08:00
committed by GitHub
parent c7a91b1819
commit 5957b888a5
23 changed files with 526 additions and 385 deletions

View File

@@ -34,9 +34,9 @@ MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
# Cross Encoder Settings # Cross Encoder Settings
# This following setting is for non-real-time-flows ENABLE_RERANKING_ASYNC_FLOW = (
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true" os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
# This one is for real-time (streaming) flows )
ENABLE_RERANKING_REAL_TIME_FLOW = ( ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true" os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
) )

View File

@@ -214,7 +214,9 @@ def build_qa_response_blocks(
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓" text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
) )
else: else:
answer_block = SectionBlock(text=remove_slack_text_interactions(answer)) answer_processed = remove_slack_text_interactions(answer)
answer_processed = answer_processed.encode("utf-8").decode("unicode_escape")
answer_block = SectionBlock(text=answer_processed)
if quotes: if quotes:
quotes_blocks = build_quotes_block(quotes) quotes_blocks = build_quotes_block(quotes)

View File

@@ -86,8 +86,14 @@ def handle_message(
Query thrown out by filters due to config does not count as a failure that should be notified Query thrown out by filters due to config does not count as a failure that should be notified
Danswer failing to answer/retrieve docs does count and should be notified Danswer failing to answer/retrieve docs does count and should be notified
""" """
msg = message_info.msg_content
channel = message_info.channel_to_respond channel = message_info.channel_to_respond
logger = cast(
logging.Logger,
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
)
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond message_ts_to_respond_to = message_info.msg_to_respond
sender_id = message_info.sender sender_id = message_info.sender
bipass_filters = message_info.bipass_filters bipass_filters = message_info.bipass_filters
@@ -95,11 +101,6 @@ def handle_message(
engine = get_sqlalchemy_engine() engine = get_sqlalchemy_engine()
logger = cast(
logging.Logger,
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
)
document_set_names: list[str] | None = None document_set_names: list[str] | None = None
persona = channel_config.persona if channel_config else None persona = channel_config.persona if channel_config else None
prompt = None prompt = None
@@ -133,7 +134,7 @@ def handle_message(
if ( if (
"questionmark_prefilter" in channel_conf["answer_filters"] "questionmark_prefilter" in channel_conf["answer_filters"]
and "?" not in msg and "?" not in messages[-1].message
): ):
logger.info( logger.info(
"Skipping message since it does not contain a question mark" "Skipping message since it does not contain a question mark"
@@ -223,7 +224,7 @@ def handle_message(
# This includes throwing out answer via reflexion # This includes throwing out answer via reflexion
answer = _get_answer( answer = _get_answer(
DirectQARequest( DirectQARequest(
query=msg, messages=messages,
prompt_id=prompt.id if prompt else None, prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0, persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details, retrieval_options=retrieval_details,
@@ -275,7 +276,9 @@ def handle_message(
top_docs = retrieval_info.top_documents top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs: if not top_docs and not should_respond_even_with_no_docs:
logger.error(f"Unable to answer question: '{msg}' - no documents found") logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message # Optionally, respond in thread with the error message
# Used primarily for debugging purposes # Used primarily for debugging purposes
if should_respond_with_error_msgs: if should_respond_with_error_msgs:
@@ -296,7 +299,7 @@ def handle_message(
return True return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it # If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(msg, is_bot_msg) restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
answer_blocks = build_qa_response_blocks( answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id, message_id=answer.chat_message_id,

View File

@@ -1,4 +1,3 @@
import re
import time import time
from threading import Event from threading import Event
from typing import Any from typing import Any
@@ -10,9 +9,10 @@ from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.configs.model_configs import SKIP_RERANKING from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback
@@ -22,9 +22,13 @@ from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.danswerbot.slack.utils import ChannelIdAdapter from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import decompose_block_id from danswer.danswerbot.slack.utils import decompose_block_id
from danswer.danswerbot.slack.utils import get_channel_name_from_id from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.search_nlp_models import warm_up_models from danswer.search.search_nlp_models import warm_up_models
from danswer.server.manage.models import SlackBotTokens from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@@ -63,7 +67,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
return False return False
if event_type == "message": if event_type == "message":
bot_tag_id = client.web_client.auth_test().get("user_id") bot_tag_id = get_danswer_bot_app_id(client.web_client)
# DMs with the bot don't pick up the @DanswerBot so we have to keep the # DMs with the bot don't pick up the @DanswerBot so we have to keep the
# caught events_api # caught events_api
if bot_tag_id and bot_tag_id in msg and event.get("channel_type") != "im": if bot_tag_id and bot_tag_id in msg and event.get("channel_type") != "im":
@@ -87,8 +91,14 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
message_ts = event.get("ts") message_ts = event.get("ts")
thread_ts = event.get("thread_ts") thread_ts = event.get("thread_ts")
# Pick the root of the thread (if a thread exists) # Pick the root of the thread (if a thread exists)
if thread_ts and message_ts != thread_ts: # Can respond in thread if it's an "im" directly to Danswer or @DanswerBot is tagged
channel_specific_logger.info( if (
thread_ts
and message_ts != thread_ts
and event_type != "app_mention"
and event.get("channel_type") != "im"
):
channel_specific_logger.debug(
"Skipping message since it is not the root of a thread" "Skipping message since it is not the root of a thread"
) )
return False return False
@@ -156,18 +166,25 @@ def build_request_details(
tagged = event.get("type") == "app_mention" tagged = event.get("type") == "app_mention"
message_ts = event.get("ts") message_ts = event.get("ts")
thread_ts = event.get("thread_ts") thread_ts = event.get("thread_ts")
bot_tag_id = client.web_client.auth_test().get("user_id")
# Might exist even if not tagged, specifically in the case of @DanswerBot msg = remove_danswer_bot_tag(msg, client=client.web_client)
# in DanswerBot DM channel
msg = re.sub(rf"<@{bot_tag_id}>\s", "", msg)
if tagged: if tagged:
logger.info("User tagged DanswerBot") logger.info("User tagged DanswerBot")
if thread_ts != message_ts and thread_ts is not None:
thread_messages = read_slack_thread(
channel=channel, thread=thread_ts, client=client.web_client
)
else:
thread_messages = [
ThreadMessage(message=msg, sender=None, role=MessageType.USER)
]
return SlackMessageInfo( return SlackMessageInfo(
msg_content=msg, thread_messages=thread_messages,
channel_to_respond=channel, channel_to_respond=channel,
msg_to_respond=cast(str, thread_ts or message_ts), msg_to_respond=cast(str, message_ts or thread_ts),
sender=event.get("user") or None, sender=event.get("user") or None,
bipass_filters=tagged, bipass_filters=tagged,
is_bot_msg=False, is_bot_msg=False,
@@ -178,8 +195,10 @@ def build_request_details(
msg = req.payload["text"] msg = req.payload["text"]
sender = req.payload["user_id"] sender = req.payload["user_id"]
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
return SlackMessageInfo( return SlackMessageInfo(
msg_content=msg, thread_messages=[single_msg],
channel_to_respond=channel, channel_to_respond=channel,
msg_to_respond=None, msg_to_respond=None,
sender=sender, sender=sender,
@@ -297,7 +316,7 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None:
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC # NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue. # without issue.
if __name__ == "__main__": if __name__ == "__main__":
warm_up_models(skip_cross_encoders=SKIP_RERANKING) warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW)
slack_bot_tokens: SlackBotTokens | None = None slack_bot_tokens: SlackBotTokens | None = None
socket_client: SocketModeClient | None = None socket_client: SocketModeClient | None = None

View File

@@ -1,8 +1,10 @@
from pydantic import BaseModel from pydantic import BaseModel
from danswer.one_shot_answer.models import ThreadMessage
class SlackMessageInfo(BaseModel): class SlackMessageInfo(BaseModel):
msg_content: str thread_messages: list[ThreadMessage]
channel_to_respond: str channel_to_respond: str
msg_to_respond: str | None msg_to_respond: str | None
sender: str | None sender: str | None

View File

@@ -13,17 +13,34 @@ from slack_sdk.models.blocks import Block
from slack_sdk.models.metadata import Metadata from slack_sdk.models.metadata import Metadata
from danswer.configs.constants import ID_SEPARATOR from danswer.configs.constants import ID_SEPARATOR
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.tokens import fetch_tokens from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.one_shot_answer.models import ThreadMessage
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import replace_whitespaces_w_space from danswer.utils.text_processing import replace_whitespaces_w_space
logger = setup_logger() logger = setup_logger()
DANSWER_BOT_APP_ID: str | None = None
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
global DANSWER_BOT_APP_ID
if DANSWER_BOT_APP_ID is None:
DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id")
return DANSWER_BOT_APP_ID
def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
bot_tag_id = get_danswer_bot_app_id(web_client=client)
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
class ChannelIdAdapter(logging.LoggerAdapter): class ChannelIdAdapter(logging.LoggerAdapter):
"""This is used to add the channel ID to all log messages """This is used to add the channel ID to all log messages
emitted in this file""" emitted in this file"""
@@ -199,3 +216,57 @@ def fetch_userids_from_emails(user_emails: list[str], client: WebClient) -> list
) )
return user_ids return user_ids
def fetch_user_semantic_id_from_id(user_id: str, client: WebClient) -> str | None:
response = client.users_info(user=user_id)
if not response["ok"]:
return None
user: dict = cast(dict[Any, dict], response.data).get("user", {})
return (
user.get("real_name")
or user.get("name")
or user.get("profile", {}).get("email")
)
def read_slack_thread(
channel: str, thread: str, client: WebClient
) -> list[ThreadMessage]:
thread_messages: list[ThreadMessage] = []
response = client.conversations_replies(channel=channel, ts=thread)
replies = cast(dict, response.data).get("messages", [])
for reply in replies:
if "user" in reply and "bot_id" not in reply:
message = remove_danswer_bot_tag(reply["text"], client=client)
user_sem_id = fetch_user_semantic_id_from_id(reply["user"], client)
message_type = MessageType.USER
else:
self_app_id = get_danswer_bot_app_id(client)
# Only include bot messages from Danswer, other bots are not taken in as context
if self_app_id != reply.get("user"):
continue
blocks = reply["blocks"]
if len(blocks) <= 1:
continue
# The useful block is the second one after the header block that says AI Answer
message = reply["blocks"][1]["text"]["text"]
if message.startswith("_Filters"):
if len(blocks) <= 2:
continue
message = reply["blocks"][2]["text"]["text"]
user_sem_id = "Assistant"
message_type = MessageType.ASSISTANT
thread_messages.append(
ThreadMessage(message=message, sender=user_sem_id, role=message_type)
)
return thread_messages

View File

@@ -47,7 +47,6 @@ from danswer.db.credentials import create_initial_public_credential
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.one_shot_answer.factory import get_default_qa_model
from danswer.search.search_nlp_models import warm_up_models from danswer.search.search_nlp_models import warm_up_models
from danswer.server.danswer_api.ingestion import get_danswer_api_key from danswer.server.danswer_api.ingestion import get_danswer_api_key
from danswer.server.danswer_api.ingestion import router as danswer_api_router from danswer.server.danswer_api.ingestion import router as danswer_api_router
@@ -261,7 +260,6 @@ def get_application() -> FastAPI:
# This is for the LLM, most LLMs will not need warming up # This is for the LLM, most LLMs will not need warming up
get_default_llm().log_model_configs() get_default_llm().log_model_configs()
get_default_qa_model().warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded") logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True) nltk.download("stopwords", quiet=True)

View File

@@ -28,6 +28,8 @@ from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.factory import get_question_answer_model from danswer.one_shot_answer.factory import get_question_answer_model
from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import RerankMetricsContainer from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SavedSearchDoc from danswer.search.models import SavedSearchDoc
@@ -35,6 +37,7 @@ from danswer.search.request_preprocessing import retrieval_preprocessing
from danswer.search.search_runner import chunks_to_search_docs from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import full_chunk_search_generator from danswer.search.search_runner import full_chunk_search_generator
from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.utils import get_json_line from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@@ -58,7 +61,8 @@ def stream_answer_objects(
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> Iterator[ ) -> Iterator[
QADocsResponse QueryRephrase
| QADocsResponse
| LLMRelevanceFilterResponse | LLMRelevanceFilterResponse
| DanswerAnswerPiece | DanswerAnswerPiece
| DanswerQuotes | DanswerQuotes
@@ -73,6 +77,8 @@ def stream_answer_objects(
4. [always] Details on the final AI response message that is created 4. [always] Details on the final AI response message that is created
""" """
user_id = user.id if user is not None else None user_id = user.id if user is not None else None
query_msg = query_req.messages[-1]
history = query_req.messages[:-1]
chat_session = create_chat_session( chat_session = create_chat_session(
db_session=db_session, db_session=db_session,
@@ -90,24 +96,20 @@ def stream_answer_objects(
chat_session_id=chat_session.id, db_session=db_session chat_session_id=chat_session.id, db_session=db_session
) )
# Create the first User query message history_str = combine_message_thread(history)
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id, rephrased_query = thread_based_query_rephrase(
parent_message=root_message, user_query=query_msg.message,
prompt_id=query_req.prompt_id, history_str=history_str,
message=query_req.query,
token_count=len(llm_tokenizer(query_req.query)),
message_type=MessageType.USER,
db_session=db_session,
commit=True,
) )
yield QueryRephrase(rephrased_query=rephrased_query)
( (
retrieval_request, retrieval_request,
predicted_search_type, predicted_search_type,
predicted_flow, predicted_flow,
) = retrieval_preprocessing( ) = retrieval_preprocessing(
query=query_req.query, query=rephrased_query,
retrieval_details=query_req.retrieval_options, retrieval_details=query_req.retrieval_options,
persona=chat_session.persona, persona=chat_session.persona,
user=user, user=user,
@@ -190,9 +192,25 @@ def stream_answer_objects(
llm_version=llm_override, llm_version=llm_override,
) )
full_prompt_str = qa_model.build_prompt(
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
)
# Create the first User query message
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=query_req.prompt_id,
message=full_prompt_str,
token_count=len(llm_tokenizer(full_prompt_str)),
message_type=MessageType.USER,
db_session=db_session,
commit=True,
)
response_packets = qa_model.answer_question_stream( response_packets = qa_model.answer_question_stream(
query=query_req.query, prompt=full_prompt_str,
context_docs=llm_chunks, llm_context_docs=llm_chunks,
metrics_callback=llm_metrics_callback, metrics_callback=llm_metrics_callback,
) )
@@ -272,6 +290,8 @@ def get_one_shot_answer(
answer = "" answer = ""
for packet in results: for packet in results:
if isinstance(packet, QueryRephrase):
qa_response.rephrase = packet.rephrased_query
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece answer += packet.answer_piece
elif isinstance(packet, QADocsResponse): elif isinstance(packet, QADocsResponse):
@@ -289,6 +309,11 @@ def get_one_shot_answer(
qa_response.answer = answer qa_response.answer = answer
if enable_reflexion: if enable_reflexion:
qa_response.answer_valid = get_answer_validity(query_req.query, answer) # Because follow up messages are explicitly tagged, we don't need to verify the answer
if len(query_req.messages) == 1:
first_query = query_req.messages[0].message
qa_response.answer_valid = get_answer_validity(first_query, answer)
else:
qa_response.answer_valid = True
return qa_response return qa_response

View File

@@ -3,98 +3,42 @@ from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.db.models import Prompt from danswer.db.models import Prompt
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.one_shot_answer.interfaces import QAModel from danswer.one_shot_answer.interfaces import QAModel
from danswer.one_shot_answer.qa_block import PromptBasedQAHandler
from danswer.one_shot_answer.qa_block import QABlock from danswer.one_shot_answer.qa_block import QABlock
from danswer.one_shot_answer.qa_block import QAHandler from danswer.one_shot_answer.qa_block import QAHandler
from danswer.one_shot_answer.qa_block import SingleMessageQAHandler from danswer.one_shot_answer.qa_block import SingleMessageQAHandler
from danswer.one_shot_answer.qa_block import SingleMessageScratchpadHandler
from danswer.one_shot_answer.qa_block import WeakLLMQAHandler from danswer.one_shot_answer.qa_block import WeakLLMQAHandler
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
def get_default_qa_handler(
chain_of_thought: bool = False,
user_selection: str | None = QA_PROMPT_OVERRIDE,
) -> QAHandler:
if user_selection:
if user_selection.lower() == "default":
return SingleMessageQAHandler()
if user_selection.lower() == "cot":
return SingleMessageScratchpadHandler()
if user_selection.lower() == "weak":
return WeakLLMQAHandler()
raise ValueError("Invalid Question-Answering prompt selected")
if chain_of_thought:
return SingleMessageScratchpadHandler()
return SingleMessageQAHandler()
def get_default_qa_model(
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
chain_of_thought: bool = False,
) -> QAModel:
llm = get_default_llm(api_key=api_key, timeout=timeout)
qa_handler = get_default_qa_handler(chain_of_thought=chain_of_thought)
return QABlock(
llm=llm,
qa_handler=qa_handler,
)
def get_prompt_qa_model(
prompt: Prompt,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
llm_version: str | None = None,
) -> QAModel:
return QABlock(
llm=get_default_llm(
api_key=api_key,
timeout=timeout,
gen_ai_model_version_override=llm_version,
),
qa_handler=PromptBasedQAHandler(
system_prompt=prompt.system_prompt, task_prompt=prompt.task_prompt
),
)
def get_question_answer_model( def get_question_answer_model(
prompt: Prompt | None, prompt: Prompt | None,
api_key: str | None = None, api_key: str | None = None,
timeout: int = QA_TIMEOUT, timeout: int = QA_TIMEOUT,
chain_of_thought: bool = False, chain_of_thought: bool = False,
llm_version: str | None = None, llm_version: str | None = None,
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
) -> QAModel: ) -> QAModel:
if prompt is None and llm_version is not None: if chain_of_thought:
raise RuntimeError( raise NotImplementedError("COT has been disabled")
"Cannot specify llm version for QA model without providing prompt. "
"This flow is only intended for flows with a specified Persona/Prompt."
)
if prompt is not None and chain_of_thought: system_prompt = prompt.system_prompt if prompt is not None else None
raise RuntimeError( task_prompt = prompt.task_prompt if prompt is not None else None
"Cannot choose COT prompt with a customized Prompt object. "
"User can prompt the model to output COT themselves if they want."
)
if prompt is not None: llm = get_default_llm(
return get_prompt_qa_model(
prompt=prompt,
api_key=api_key, api_key=api_key,
timeout=timeout, timeout=timeout,
llm_version=llm_version, gen_ai_model_version_override=llm_version,
) )
return get_default_qa_model( if qa_model_version == "weak":
api_key=api_key, qa_handler: QAHandler = WeakLLMQAHandler(
timeout=timeout, system_prompt=system_prompt, task_prompt=task_prompt
chain_of_thought=chain_of_thought,
) )
else:
qa_handler = SingleMessageQAHandler(
system_prompt=system_prompt, task_prompt=task_prompt
)
return QABlock(llm=llm, qa_handler=qa_handler)

View File

@@ -1,37 +1,26 @@
import abc import abc
from collections.abc import Callable from collections.abc import Callable
from danswer.chat.models import AnswerQuestionReturn
from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import LLMMetricsContainer
from danswer.indexing.models import InferenceChunk from danswer.indexing.models import InferenceChunk
class QAModel: class QAModel:
@property
def requires_api_key(self) -> bool:
"""Is this model protected by security features
Does it need an api key to access the model for inference"""
return True
def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory
pass if model is accessed via API"""
@abc.abstractmethod @abc.abstractmethod
def answer_question( def build_prompt(
self, self,
query: str, query: str,
context_docs: list[InferenceChunk], history_str: str,
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, context_chunks: list[InferenceChunk],
) -> AnswerQuestionReturn: ) -> str:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def answer_question_stream( def answer_question_stream(
self, self,
query: str, prompt: str,
context_docs: list[InferenceChunk], llm_context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionStreamReturn: ) -> AnswerQuestionStreamReturn:
raise NotImplementedError raise NotImplementedError

View File

@@ -5,11 +5,22 @@ from pydantic import root_validator
from danswer.chat.models import DanswerQuotes from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.models import RetrievalDetails from danswer.search.models import RetrievalDetails
class QueryRephrase(BaseModel):
rephrased_query: str
class ThreadMessage(BaseModel):
message: str
sender: str | None
role: MessageType
class DirectQARequest(BaseModel): class DirectQARequest(BaseModel):
query: str messages: list[ThreadMessage]
prompt_id: int | None prompt_id: int | None
persona_id: int persona_id: int
retrieval_options: RetrievalDetails retrieval_options: RetrievalDetails
@@ -35,6 +46,7 @@ class DirectQARequest(BaseModel):
class OneShotQAResponse(BaseModel): class OneShotQAResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break # This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None answer: str | None = None
rephrase: str | None = None
quotes: DanswerQuotes | None = None quotes: DanswerQuotes | None = None
docs: QADocsResponse | None = None docs: QADocsResponse | None = None
llm_chunks_indices: list[int] | None = None llm_chunks_indices: list[int] | None = None

View File

@@ -4,11 +4,7 @@ from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from typing import cast from typing import cast
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from danswer.chat.chat_utils import build_context_str from danswer.chat.chat_utils import build_context_str
from danswer.chat.models import AnswerQuestionReturn
from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import DanswerAnswer from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import DanswerAnswerPiece
@@ -21,16 +17,21 @@ from danswer.indexing.models import InferenceChunk
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_token_encode from danswer.llm.utils import get_default_llm_token_encode
from danswer.llm.utils import tokenizer_trim_chunks
from danswer.one_shot_answer.interfaces import QAModel from danswer.one_shot_answer.interfaces import QAModel
from danswer.one_shot_answer.qa_utils import process_answer from danswer.one_shot_answer.qa_utils import process_answer
from danswer.one_shot_answer.qa_utils import process_model_tokens from danswer.one_shot_answer.qa_utils import process_model_tokens
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import COT_PROMPT from danswer.prompts.direct_qa_prompts import COT_PROMPT
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import ONE_SHOT_SYSTEM_PROMPT
from danswer.prompts.direct_qa_prompts import ONE_SHOT_TASK_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_SYSTEM_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_TASK_PROMPT
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_up_code_blocks from danswer.utils.text_processing import clean_up_code_blocks
from danswer.utils.text_processing import escape_newlines from danswer.utils.text_processing import escape_newlines
@@ -39,12 +40,6 @@ logger = setup_logger()
class QAHandler(abc.ABC): class QAHandler(abc.ABC):
@abc.abstractmethod
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def is_json_output(self) -> bool: def is_json_output(self) -> bool:
@@ -54,12 +49,14 @@ class QAHandler(abc.ABC):
finetuned to recognize this.""" finetuned to recognize this."""
raise NotImplementedError raise NotImplementedError
def process_llm_output( @abc.abstractmethod
self, model_output: str, context_chunks: list[InferenceChunk] def build_prompt(
) -> tuple[DanswerAnswer, DanswerQuotes]: self,
return process_answer( query: str,
model_output, context_chunks, is_json_prompt=self.is_json_output history_str: str,
) context_chunks: list[InferenceChunk],
) -> str:
raise NotImplementedError
def process_llm_token_stream( def process_llm_token_stream(
self, tokens: Iterator[str], context_chunks: list[InferenceChunk] self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
@@ -78,70 +75,137 @@ class WeakLLMQAHandler(QAHandler):
output format. output format.
""" """
def __init__(
self,
system_prompt: str | None,
task_prompt: str | None,
) -> None:
if not system_prompt and not task_prompt:
self.system_prompt = WEAK_MODEL_SYSTEM_PROMPT
self.task_prompt = WEAK_MODEL_TASK_PROMPT
else:
self.system_prompt = system_prompt or ""
self.task_prompt = task_prompt or ""
self.task_prompt = self.task_prompt.rstrip()
if self.task_prompt and self.task_prompt[0] != "\n":
self.task_prompt = "\n" + self.task_prompt
@property @property
def is_json_output(self) -> bool: def is_json_output(self) -> bool:
return False return False
def build_prompt( def build_prompt(
self, query: str, context_chunks: list[InferenceChunk] self,
) -> list[BaseMessage]: query: str,
message = WEAK_LLM_PROMPT.format( history_str: str,
user_query=query, single_reference_doc=context_chunks[0].content context_chunks: list[InferenceChunk],
) -> str:
context_block = ""
if context_chunks:
context_block = CONTEXT_BLOCK.format(
context_docs_str=context_chunks[0].content
) )
return [HumanMessage(content=message)] prompt_str = WEAK_LLM_PROMPT.format(
system_prompt=self.system_prompt,
context_block=context_block,
task_prompt=self.task_prompt,
user_query=query,
)
return prompt_str
class SingleMessageQAHandler(QAHandler): class SingleMessageQAHandler(QAHandler):
def __init__(
self,
system_prompt: str | None,
task_prompt: str | None,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> None:
self.use_language_hint = use_language_hint
if not system_prompt and not task_prompt:
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
self.task_prompt = ONE_SHOT_TASK_PROMPT
else:
self.system_prompt = system_prompt or ""
self.task_prompt = task_prompt or ""
self.task_prompt = self.task_prompt.rstrip()
if self.task_prompt and self.task_prompt[0] != "\n":
self.task_prompt = "\n" + self.task_prompt
@property @property
def is_json_output(self) -> bool: def is_json_output(self) -> bool:
return True return True
def build_prompt( def build_prompt(
self, self, query: str, history_str: str, context_chunks: list[InferenceChunk]
query: str, ) -> str:
context_chunks: list[InferenceChunk], context_block = ""
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), if context_chunks:
) -> list[BaseMessage]:
context_docs_str = build_context_str( context_docs_str = build_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks) cast(list[LlmDoc | InferenceChunk], context_chunks)
) )
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
single_message = JSON_PROMPT.format( history_block = ""
context_docs_str=context_docs_str, if history_str:
history_block = HISTORY_BLOCK.format(history_str=history_str)
full_prompt = JSON_PROMPT.format(
system_prompt=self.system_prompt,
context_block=context_block,
history_block=history_block,
task_prompt=self.task_prompt,
user_query=query, user_query=query,
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", language_hint_or_none=LANGUAGE_HINT.strip()
if self.use_language_hint
else "",
).strip() ).strip()
return full_prompt
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
# This one isn't used, currently only streaming prompts are used
class SingleMessageScratchpadHandler(QAHandler): class SingleMessageScratchpadHandler(QAHandler):
def __init__(
self,
system_prompt: str | None,
task_prompt: str | None,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> None:
self.use_language_hint = use_language_hint
if not system_prompt and not task_prompt:
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
self.task_prompt = ONE_SHOT_TASK_PROMPT
else:
self.system_prompt = system_prompt or ""
self.task_prompt = task_prompt or ""
self.task_prompt = self.task_prompt.rstrip()
if self.task_prompt and self.task_prompt[0] != "\n":
self.task_prompt = "\n" + self.task_prompt
@property @property
def is_json_output(self) -> bool: def is_json_output(self) -> bool:
# Even though the full LLM output isn't a valid json
# only the valid json portion is kept and passed along
# therefore it is treated as a json output
return True return True
def build_prompt( def build_prompt(
self, self, query: str, history_str: str, context_chunks: list[InferenceChunk]
query: str, ) -> str:
context_chunks: list[InferenceChunk],
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> list[BaseMessage]:
context_docs_str = build_context_str( context_docs_str = build_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks) cast(list[LlmDoc | InferenceChunk], context_chunks)
) )
single_message = COT_PROMPT.format( # Outdated
prompt = COT_PROMPT.format(
context_docs_str=context_docs_str, context_docs_str=context_docs_str,
user_query=query, user_query=query,
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", language_hint_or_none=LANGUAGE_HINT.strip()
if self.use_language_hint
else "",
).strip() ).strip()
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt return prompt
def process_llm_output( def process_llm_output(
@@ -170,133 +234,53 @@ class SingleMessageScratchpadHandler(QAHandler):
) )
class PromptBasedQAHandler(QAHandler): def build_dummy_prompt(
def __init__(self, system_prompt: str, task_prompt: str) -> None: system_prompt: str, task_prompt: str, retrieval_disabled: bool
self.system_prompt = system_prompt ) -> str:
self.task_prompt = task_prompt
@property
def is_json_output(self) -> bool:
return False
def build_prompt(
self,
query: str,
context_chunks: list[InferenceChunk],
) -> list[BaseMessage]:
context_docs_str = build_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks)
)
if not context_chunks:
single_message = PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query=query,
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
else:
single_message = PARAMATERIZED_PROMPT.format(
context_docs_str=context_docs_str,
user_query=query,
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
def build_dummy_prompt(self, retrieval_disabled: bool) -> str:
if retrieval_disabled: if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>", user_query="<USER_QUERY>",
system_prompt=self.system_prompt, system_prompt=system_prompt,
task_prompt=self.task_prompt, task_prompt=task_prompt,
).strip() ).strip()
return PARAMATERIZED_PROMPT.format( return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>", context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>", user_query="<USER_QUERY>",
system_prompt=self.system_prompt, system_prompt=system_prompt,
task_prompt=self.task_prompt, task_prompt=task_prompt,
).strip() ).strip()
def process_llm_output(
self, model_output: str, context_chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, DanswerQuotes]:
return DanswerAnswer(answer=model_output), DanswerQuotes(quotes=[])
def process_llm_token_stream(
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
for token in tokens:
yield DanswerAnswerPiece(answer_piece=token)
yield DanswerQuotes(quotes=[])
class QABlock(QAModel): class QABlock(QAModel):
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None: def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
self._llm = llm self._llm = llm
self._qa_handler = qa_handler self._qa_handler = qa_handler
@property def build_prompt(
def requires_api_key(self) -> bool:
return self._llm.requires_api_key
def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory
in case the chosen LLM is not accessed via API"""
if self._llm.requires_warm_up:
logger.info("Warming up LLM with a first inference")
self._llm.invoke("Ignore this!")
def answer_question(
self, self,
query: str, query: str,
context_docs: list[InferenceChunk], history_str: str,
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, context_chunks: list[InferenceChunk],
) -> AnswerQuestionReturn: ) -> str:
trimmed_context_docs = tokenizer_trim_chunks(context_docs) prompt = self._qa_handler.build_prompt(
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs) query=query, history_str=history_str, context_chunks=context_chunks
model_out = self._llm.invoke(prompt)
if metrics_callback is not None:
prompt_tokens = sum(
[
check_number_of_tokens(
text=str(p.content), encode_fn=get_default_llm_token_encode()
) )
for p in prompt return prompt
]
)
response_tokens = check_number_of_tokens(
text=model_out, encode_fn=get_default_llm_token_encode()
)
metrics_callback(
LLMMetricsContainer(
prompt_tokens=prompt_tokens, response_tokens=response_tokens
)
)
return self._qa_handler.process_llm_output(model_out, trimmed_context_docs)
def answer_question_stream( def answer_question_stream(
self, self,
query: str, prompt: str,
context_docs: list[InferenceChunk], llm_context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionStreamReturn: ) -> AnswerQuestionStreamReturn:
trimmed_context_docs = tokenizer_trim_chunks(context_docs)
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
tokens_stream = self._llm.stream(prompt) tokens_stream = self._llm.stream(prompt)
captured_tokens = [] captured_tokens = []
try: try:
for answer_piece in self._qa_handler.process_llm_token_stream( for answer_piece in self._qa_handler.process_llm_token_stream(
iter(tokens_stream), trimmed_context_docs iter(tokens_stream), llm_context_docs
): ):
if ( if (
isinstance(answer_piece, DanswerAnswerPiece) isinstance(answer_piece, DanswerAnswerPiece)
@@ -309,13 +293,8 @@ class QABlock(QAModel):
yield StreamingError(error=str(e)) yield StreamingError(error=str(e))
if metrics_callback is not None: if metrics_callback is not None:
prompt_tokens = sum( prompt_tokens = check_number_of_tokens(
[ text=str(prompt), encode_fn=get_default_llm_token_encode()
check_number_of_tokens(
text=str(p.content), encode_fn=get_default_llm_token_encode()
)
for p in prompt
]
) )
response_tokens = check_number_of_tokens( response_tokens = check_number_of_tokens(

View File

@@ -1,5 +1,6 @@
import math import math
import re import re
from collections.abc import Callable
from collections.abc import Generator from collections.abc import Generator
from collections.abc import Iterator from collections.abc import Iterator
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
@@ -13,7 +14,11 @@ from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes from danswer.chat.models import DanswerQuotes
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.indexing.models import InferenceChunk from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import ThreadMessage
from danswer.prompts.constants import ANSWER_PAT from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import QUOTE_PAT from danswer.prompts.constants import QUOTE_PAT
from danswer.prompts.constants import UNCERTAINTY_PAT from danswer.prompts.constants import UNCERTAINTY_PAT
@@ -270,3 +275,41 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
"""Mock streaming by generating the passed in model output, character by character""" """Mock streaming by generating the passed in model output, character by character"""
for token in model_out: for token in model_out:
yield token yield token
def combine_message_thread(
messages: list[ThreadMessage],
token_limit: int | None = GEN_AI_HISTORY_CUTOFF,
llm_tokenizer: Callable | None = None,
) -> str:
"""Used to create a single combined message context from threads"""
message_strs: list[str] = []
total_token_count = 0
if llm_tokenizer is None:
llm_tokenizer = get_default_llm_token_encode()
for message in reversed(messages):
if message.role == MessageType.USER:
role_str = message.role.value.upper()
if message.sender:
role_str += " " + message.sender
else:
# Since other messages might have the user identifying information
# better to use Unknown for symmetry
role_str += " Unknown"
else:
role_str = message.role.value.upper()
msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer(msg_str))
if (
token_limit is not None
and total_token_count + message_token_count > token_limit
):
break
message_strs.insert(0, msg_str)
total_token_count += message_token_count
return "\n\n".join(message_strs)

View File

@@ -1,6 +1,7 @@
GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator
CODE_BLOCK_PAT = "```\n{}\n```" CODE_BLOCK_PAT = "```\n{}\n```"
QUESTION_PAT = "Query:" QUESTION_PAT = "Query:"
FINAL_QUERY_PAT = "Final Query:"
THOUGHT_PAT = "Thought:" THOUGHT_PAT = "Thought:"
ANSWER_PAT = "Answer:" ANSWER_PAT = "Answer:"
ANSWERABLE_PAT = "Answerable:" ANSWERABLE_PAT = "Answerable:"

View File

@@ -2,24 +2,36 @@
# It is used also for the one shot direct QA flow # It is used also for the one shot direct QA flow
import json import json
from danswer.prompts.constants import ANSWER_PAT from danswer.prompts.constants import FINAL_QUERY_PAT
from danswer.prompts.constants import GENERAL_SEP_PAT from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.prompts.constants import QUESTION_PAT from danswer.prompts.constants import QUESTION_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.prompts.constants import THOUGHT_PAT from danswer.prompts.constants import THOUGHT_PAT
from danswer.prompts.constants import UNCERTAINTY_PAT from danswer.prompts.constants import UNCERTAINTY_PAT
QA_HEADER = """ ONE_SHOT_SYSTEM_PROMPT = """
You are a question answering system that is constantly learning and improving. You are a question answering system that is constantly learning and improving.
You can process and comprehend vast amounts of text and utilize this knowledge to provide \ You can process and comprehend vast amounts of text and utilize this knowledge to provide \
accurate and detailed answers to diverse queries. accurate and detailed answers to diverse queries.
""".strip() """.strip()
ONE_SHOT_TASK_PROMPT = """
Answer the final query below taking into account the context above where relevant. \
Ignore any provided context that is not relevant to the query.
""".strip()
WEAK_MODEL_SYSTEM_PROMPT = """
Respond to the user query using the following reference document.
""".lstrip()
WEAK_MODEL_TASK_PROMPT = """
Answer the user query below based on the reference document above.
"""
REQUIRE_JSON = """ REQUIRE_JSON = """
You ALWAYS responds with only a json containing an answer and quotes that support the answer. You ALWAYS responds with ONLY a JSON containing an answer and quotes that support the answer.
Your responses are as INFORMATIVE and DETAILED as possible.
""".strip() """.strip()
@@ -34,6 +46,21 @@ IMPORTANT: Respond in the same language as my query!
""" """
CONTEXT_BLOCK = f"""
REFERENCE DOCUMENTS:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
"""
HISTORY_BLOCK = f"""
CONVERSATION HISTORY:
{GENERAL_SEP_PAT}
{{history_str}}
{GENERAL_SEP_PAT}
"""
# This has to be doubly escaped due to json containing { } which are also used for format strings # This has to be doubly escaped due to json containing { } which are also used for format strings
EMPTY_SAMPLE_JSON = { EMPTY_SAMPLE_JSON = {
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.", "answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
@@ -48,44 +75,22 @@ ANSWER_NOT_FOUND_RESPONSE = f'{{"answer": "{UNCERTAINTY_PAT}", "quotes": []}}'
# Default json prompt which can reference multiple docs and provide answer + quotes # Default json prompt which can reference multiple docs and provide answer + quotes
# system_like_header is similar to system message, can be user provided or defaults to QA_HEADER
# context/history blocks are for context documents and conversation history, they can be blank
# task prompt is the task message of the prompt, can be blank, there is no default
JSON_PROMPT = f""" JSON_PROMPT = f"""
{QA_HEADER} {{system_prompt}}
{REQUIRE_JSON} {REQUIRE_JSON}
{{context_block}}{{history_block}}{{task_prompt}}
CONTEXT: SAMPLE RESPONSE:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
SAMPLE_RESPONSE:
``` ```
{{{json.dumps(EMPTY_SAMPLE_JSON)}}} {{{json.dumps(EMPTY_SAMPLE_JSON)}}}
``` ```
{QUESTION_PAT.upper()} {{user_query}}
{JSON_HELPFUL_HINT}
{{language_hint_or_none}}
""".strip()
{FINAL_QUERY_PAT.upper()}
{{user_query}}
# Default chain-of-thought style json prompt which uses multiple docs
# This one has a section for the LLM to output some non-answer "thoughts"
# COT (chain-of-thought) flow basically
COT_PROMPT = f"""
{QA_HEADER}
CONTEXT:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
You MUST respond in the following format:
```
{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.
{{{json.dumps(EMPTY_SAMPLE_JSON)}}}
```
{QUESTION_PAT.upper()} {{user_query}}
{JSON_HELPFUL_HINT} {JSON_HELPFUL_HINT}
{{language_hint_or_none}} {{language_hint_or_none}}
""".strip() """.strip()
@@ -94,35 +99,17 @@ You MUST respond in the following format:
# For weak LLM which only takes one chunk and cannot output json # For weak LLM which only takes one chunk and cannot output json
# Also not requiring quotes as it tends to not work # Also not requiring quotes as it tends to not work
WEAK_LLM_PROMPT = f""" WEAK_LLM_PROMPT = f"""
Respond to the user query using the following reference document. {{system_prompt}}
{{context_block}}
Reference Document: {{task_prompt}}
{GENERAL_SEP_PAT}
{{single_reference_doc}}
{GENERAL_SEP_PAT}
Answer the user query below based on the reference document above.
{QUESTION_PAT.upper()} {QUESTION_PAT.upper()}
{{user_query}} {{user_query}}
""".strip() """.strip()
# For weak CHAT LLM which takes one chunk and cannot output json # This is only for visualization for the users to specify their own prompts
# The next message should have the user query # The actual flow does not work like this
# Note, no flow/config currently uses this one
WEAK_CHAT_LLM_PROMPT = f"""
You are a question answering assistant
Respond to the user query with an "{ANSWER_PAT}" section and \
as many "{QUOTE_PAT}" sections as needed to support the answer.
Answer the user query based on the following document:
{{first_chunk_content}}
""".strip()
# Parameterized prompt which allows the user to specify their
# own system / task prompt
PARAMATERIZED_PROMPT = f""" PARAMATERIZED_PROMPT = f"""
{{system_prompt}} {{system_prompt}}
@@ -147,6 +134,31 @@ RESPONSE:
""".strip() """.strip()
# CURRENTLY DISABLED, CANNOT USE THIS ONE
# Default chain-of-thought style json prompt which uses multiple docs
# This one has a section for the LLM to output some non-answer "thoughts"
# COT (chain-of-thought) flow basically
COT_PROMPT = f"""
{ONE_SHOT_SYSTEM_PROMPT}
CONTEXT:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
You MUST respond in the following format:
```
{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.
{{{json.dumps(EMPTY_SAMPLE_JSON)}}}
```
{QUESTION_PAT.upper()} {{user_query}}
{JSON_HELPFUL_HINT}
{{language_hint_or_none}}
""".strip()
# User the following for easy viewing of prompts # User the following for easy viewing of prompts
if __name__ == "__main__": if __name__ == "__main__":
print(JSON_PROMPT) # Default prompt used in the Danswer UI flow print(JSON_PROMPT) # Default prompt used in the Danswer UI flow

View File

@@ -3,8 +3,8 @@ from sqlalchemy.orm import Session
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.db.models import User from danswer.db.models import User
from danswer.search.access_filters import build_access_filters_for_user from danswer.search.access_filters import build_access_filters_for_user
@@ -31,7 +31,7 @@ def retrieval_preprocessing(
bypass_acl: bool = False, bypass_acl: bool = False,
include_query_intent: bool = True, include_query_intent: bool = True,
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW, skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
skip_rerank_non_realtime: bool = SKIP_RERANKING, skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW,
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION, disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER, favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,

View File

@@ -59,14 +59,7 @@ def multilingual_query_expansion(
return query_rephrases return query_rephrases
def history_based_query_rephrase( def get_contextual_rephrase_messages(
query_message: ChatMessage,
history: list[ChatMessage],
llm: LLM | None = None,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
) -> str:
def _get_history_rephrase_messages(
question: str, question: str,
history_str: str, history_str: str,
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
@@ -81,6 +74,14 @@ def history_based_query_rephrase(
return messages return messages
def history_based_query_rephrase(
query_message: ChatMessage,
history: list[ChatMessage],
llm: LLM | None = None,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
) -> str:
user_query = cast(str, query_message.message) user_query = cast(str, query_message.message)
if not user_query: if not user_query:
@@ -98,7 +99,38 @@ def history_based_query_rephrase(
history_str = combine_message_chain(history) history_str = combine_message_chain(history)
prompt_msgs = _get_history_rephrase_messages( prompt_msgs = get_contextual_rephrase_messages(
question=user_query, history_str=history_str
)
if llm is None:
llm = get_default_llm()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
rephrased_query = llm.invoke(filled_llm_prompt)
logger.debug(f"Rephrased combined query: {rephrased_query}")
return rephrased_query
def thread_based_query_rephrase(
user_query: str,
history_str: str,
llm: LLM | None = None,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
) -> str:
if not history_str:
return user_query
if len(user_query) >= size_heuristic:
return user_query
if count_punctuation(user_query) >= punctuation_heuristic:
return user_query
prompt_msgs = get_contextual_rephrase_messages(
question=user_query, history_str=history_str question=user_query, history_str=history_str
) )

View File

@@ -15,7 +15,7 @@ from danswer.db.chat import upsert_persona
from danswer.db.document_set import get_document_sets_by_ids from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.models import User from danswer.db.models import User
from danswer.one_shot_answer.qa_block import PromptBasedQAHandler from danswer.one_shot_answer.qa_block import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse from danswer.server.features.persona.models import PromptTemplateResponse
@@ -149,9 +149,11 @@ def build_final_template_prompt(
_: User | None = Depends(current_user), _: User | None = Depends(current_user),
) -> PromptTemplateResponse: ) -> PromptTemplateResponse:
return PromptTemplateResponse( return PromptTemplateResponse(
final_prompt_template=PromptBasedQAHandler( final_prompt_template=build_dummy_prompt(
system_prompt=system_prompt, task_prompt=task_prompt system_prompt=system_prompt,
).build_dummy_prompt(retrieval_disabled=retrieval_disabled) task_prompt=task_prompt,
retrieval_disabled=retrieval_disabled,
)
) )

View File

@@ -163,9 +163,8 @@ def get_answer_with_quote(
user: User = Depends(current_user), user: User = Depends(current_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> StreamingResponse: ) -> StreamingResponse:
logger.info( query = query_request.messages[0].message
f"Received query for one shot answer with quotes: {query_request.query}" logger.info(f"Received query for one shot answer with quotes: {query}")
)
packets = stream_one_shot_answer( packets = stream_one_shot_answer(
query_req=query_request, user=user, db_session=db_session query_req=query_request, user=user, db_session=db_session
) )

View File

@@ -9,9 +9,11 @@ import yaml
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import LLMMetricsContainer
from danswer.configs.constants import MessageType
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.one_shot_answer.answer_question import get_one_shot_answer from danswer.one_shot_answer.answer_question import get_one_shot_answer
from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.models import IndexFilters from danswer.search.models import IndexFilters
from danswer.search.models import OptionalSearchSetting from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RerankMetricsContainer from danswer.search.models import RerankMetricsContainer
@@ -84,8 +86,10 @@ def get_answer_for_question(
access_control_list=None, access_control_list=None,
) )
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
new_message_request = DirectQARequest( new_message_request = DirectQARequest(
query=query, messages=messages,
prompt_id=0, prompt_id=0,
persona_id=0, persona_id=0,
retrieval_options=RetrievalDetails( retrieval_options=RetrievalDetails(

View File

@@ -121,7 +121,7 @@ services:
- SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-} - SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} - ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-} - ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- SKIP_RERANKING=${SKIP_RERANKING:-} - ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-} - MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}

View File

@@ -19,9 +19,6 @@ NORMALIZE_EMBEDDINGS="True"
SIM_SCORE_RANGE_LOW="0.6" SIM_SCORE_RANGE_LOW="0.6"
SIM_SCORE_RANGE_LOW="0.8" SIM_SCORE_RANGE_LOW="0.8"
# No recent multilingual reranking models small enough to run on CPU, so turning it off
SKIP_RERANKING="True"
# Use LLM to determine if chunks are relevant to the query # Use LLM to determine if chunks are relevant to the query
# may not work well for languages that do not have much training data in the LLM training set # may not work well for languages that do not have much training data in the LLM training set
DISABLE_LLM_CHUNK_FILTER="True" DISABLE_LLM_CHUNK_FILTER="True"

View File

@@ -32,10 +32,17 @@ export const searchRequestStreamed = async ({
let relevantDocuments: DanswerDocument[] | null = null; let relevantDocuments: DanswerDocument[] | null = null;
try { try {
const filters = buildFilters(sources, documentSets, timeRange); const filters = buildFilters(sources, documentSets, timeRange);
const threadMessage = {
message: query,
sender: null,
role: "user"
};
const response = await fetch("/api/query/stream-answer-with-quote", { const response = await fetch("/api/query/stream-answer-with-quote", {
method: "POST", method: "POST",
body: JSON.stringify({ body: JSON.stringify({
query, messages: [threadMessage],
persona_id: persona.id, persona_id: persona.id,
prompt_id: persona.id === 0 ? null : persona.prompts[0]?.id, prompt_id: persona.id === 0 ? null : persona.prompts[0]?.id,
retrieval_options: { retrieval_options: {