mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
DanswerBot Chat (#831)
This commit is contained in:
@@ -34,9 +34,9 @@ MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
|
||||
|
||||
|
||||
# Cross Encoder Settings
|
||||
# This following setting is for non-real-time-flows
|
||||
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
|
||||
# This one is for real-time (streaming) flows
|
||||
ENABLE_RERANKING_ASYNC_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
|
||||
)
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||
)
|
||||
|
@@ -214,7 +214,9 @@ def build_qa_response_blocks(
|
||||
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
|
||||
)
|
||||
else:
|
||||
answer_block = SectionBlock(text=remove_slack_text_interactions(answer))
|
||||
answer_processed = remove_slack_text_interactions(answer)
|
||||
answer_processed = answer_processed.encode("utf-8").decode("unicode_escape")
|
||||
answer_block = SectionBlock(text=answer_processed)
|
||||
if quotes:
|
||||
quotes_blocks = build_quotes_block(quotes)
|
||||
|
||||
|
@@ -86,8 +86,14 @@ def handle_message(
|
||||
Query thrown out by filters due to config does not count as a failure that should be notified
|
||||
Danswer failing to answer/retrieve docs does count and should be notified
|
||||
"""
|
||||
msg = message_info.msg_content
|
||||
channel = message_info.channel_to_respond
|
||||
|
||||
logger = cast(
|
||||
logging.Logger,
|
||||
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
|
||||
)
|
||||
|
||||
messages = message_info.thread_messages
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
sender_id = message_info.sender
|
||||
bipass_filters = message_info.bipass_filters
|
||||
@@ -95,11 +101,6 @@ def handle_message(
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
logger = cast(
|
||||
logging.Logger,
|
||||
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
|
||||
)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = channel_config.persona if channel_config else None
|
||||
prompt = None
|
||||
@@ -133,7 +134,7 @@ def handle_message(
|
||||
|
||||
if (
|
||||
"questionmark_prefilter" in channel_conf["answer_filters"]
|
||||
and "?" not in msg
|
||||
and "?" not in messages[-1].message
|
||||
):
|
||||
logger.info(
|
||||
"Skipping message since it does not contain a question mark"
|
||||
@@ -223,7 +224,7 @@ def handle_message(
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
query=msg,
|
||||
messages=messages,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
@@ -275,7 +276,9 @@ def handle_message(
|
||||
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(f"Unable to answer question: '{msg}' - no documents found")
|
||||
logger.error(
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
@@ -296,7 +299,7 @@ def handle_message(
|
||||
return True
|
||||
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(msg, is_bot_msg)
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
message_id=answer.chat_message_id,
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
import time
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
@@ -10,9 +9,10 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback
|
||||
@@ -22,9 +22,13 @@ from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.danswerbot.slack.utils import ChannelIdAdapter
|
||||
from danswer.danswerbot.slack.utils import decompose_block_id
|
||||
from danswer.danswerbot.slack.utils import get_channel_name_from_id
|
||||
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
|
||||
from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.search_nlp_models import warm_up_models
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -63,7 +67,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
return False
|
||||
|
||||
if event_type == "message":
|
||||
bot_tag_id = client.web_client.auth_test().get("user_id")
|
||||
bot_tag_id = get_danswer_bot_app_id(client.web_client)
|
||||
# DMs with the bot don't pick up the @DanswerBot so we have to keep the
|
||||
# caught events_api
|
||||
if bot_tag_id and bot_tag_id in msg and event.get("channel_type") != "im":
|
||||
@@ -87,8 +91,14 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
# Pick the root of the thread (if a thread exists)
|
||||
if thread_ts and message_ts != thread_ts:
|
||||
channel_specific_logger.info(
|
||||
# Can respond in thread if it's an "im" directly to Danswer or @DanswerBot is tagged
|
||||
if (
|
||||
thread_ts
|
||||
and message_ts != thread_ts
|
||||
and event_type != "app_mention"
|
||||
and event.get("channel_type") != "im"
|
||||
):
|
||||
channel_specific_logger.debug(
|
||||
"Skipping message since it is not the root of a thread"
|
||||
)
|
||||
return False
|
||||
@@ -156,18 +166,25 @@ def build_request_details(
|
||||
tagged = event.get("type") == "app_mention"
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
bot_tag_id = client.web_client.auth_test().get("user_id")
|
||||
# Might exist even if not tagged, specifically in the case of @DanswerBot
|
||||
# in DanswerBot DM channel
|
||||
msg = re.sub(rf"<@{bot_tag_id}>\s", "", msg)
|
||||
|
||||
msg = remove_danswer_bot_tag(msg, client=client.web_client)
|
||||
|
||||
if tagged:
|
||||
logger.info("User tagged DanswerBot")
|
||||
|
||||
if thread_ts != message_ts and thread_ts is not None:
|
||||
thread_messages = read_slack_thread(
|
||||
channel=channel, thread=thread_ts, client=client.web_client
|
||||
)
|
||||
else:
|
||||
thread_messages = [
|
||||
ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||
]
|
||||
|
||||
return SlackMessageInfo(
|
||||
msg_content=msg,
|
||||
thread_messages=thread_messages,
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=cast(str, thread_ts or message_ts),
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
sender=event.get("user") or None,
|
||||
bipass_filters=tagged,
|
||||
is_bot_msg=False,
|
||||
@@ -178,8 +195,10 @@ def build_request_details(
|
||||
msg = req.payload["text"]
|
||||
sender = req.payload["user_id"]
|
||||
|
||||
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||
|
||||
return SlackMessageInfo(
|
||||
msg_content=msg,
|
||||
thread_messages=[single_msg],
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=None,
|
||||
sender=sender,
|
||||
@@ -297,7 +316,7 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None:
|
||||
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
|
||||
# without issue.
|
||||
if __name__ == "__main__":
|
||||
warm_up_models(skip_cross_encoders=SKIP_RERANKING)
|
||||
warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW)
|
||||
|
||||
slack_bot_tokens: SlackBotTokens | None = None
|
||||
socket_client: SocketModeClient | None = None
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
|
||||
|
||||
class SlackMessageInfo(BaseModel):
|
||||
msg_content: str
|
||||
thread_messages: list[ThreadMessage]
|
||||
channel_to_respond: str
|
||||
msg_to_respond: str | None
|
||||
sender: str | None
|
||||
|
@@ -13,17 +13,34 @@ from slack_sdk.models.blocks import Block
|
||||
from slack_sdk.models.metadata import Metadata
|
||||
|
||||
from danswer.configs.constants import ID_SEPARATOR
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DANSWER_BOT_APP_ID: str | None = None
|
||||
|
||||
|
||||
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
|
||||
global DANSWER_BOT_APP_ID
|
||||
if DANSWER_BOT_APP_ID is None:
|
||||
DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id")
|
||||
return DANSWER_BOT_APP_ID
|
||||
|
||||
|
||||
def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
|
||||
bot_tag_id = get_danswer_bot_app_id(web_client=client)
|
||||
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
|
||||
|
||||
|
||||
class ChannelIdAdapter(logging.LoggerAdapter):
|
||||
"""This is used to add the channel ID to all log messages
|
||||
emitted in this file"""
|
||||
@@ -199,3 +216,57 @@ def fetch_userids_from_emails(user_emails: list[str], client: WebClient) -> list
|
||||
)
|
||||
|
||||
return user_ids
|
||||
|
||||
|
||||
def fetch_user_semantic_id_from_id(user_id: str, client: WebClient) -> str | None:
|
||||
response = client.users_info(user=user_id)
|
||||
if not response["ok"]:
|
||||
return None
|
||||
|
||||
user: dict = cast(dict[Any, dict], response.data).get("user", {})
|
||||
|
||||
return (
|
||||
user.get("real_name")
|
||||
or user.get("name")
|
||||
or user.get("profile", {}).get("email")
|
||||
)
|
||||
|
||||
|
||||
def read_slack_thread(
|
||||
channel: str, thread: str, client: WebClient
|
||||
) -> list[ThreadMessage]:
|
||||
thread_messages: list[ThreadMessage] = []
|
||||
response = client.conversations_replies(channel=channel, ts=thread)
|
||||
replies = cast(dict, response.data).get("messages", [])
|
||||
for reply in replies:
|
||||
if "user" in reply and "bot_id" not in reply:
|
||||
message = remove_danswer_bot_tag(reply["text"], client=client)
|
||||
user_sem_id = fetch_user_semantic_id_from_id(reply["user"], client)
|
||||
message_type = MessageType.USER
|
||||
else:
|
||||
self_app_id = get_danswer_bot_app_id(client)
|
||||
|
||||
# Only include bot messages from Danswer, other bots are not taken in as context
|
||||
if self_app_id != reply.get("user"):
|
||||
continue
|
||||
|
||||
blocks = reply["blocks"]
|
||||
if len(blocks) <= 1:
|
||||
continue
|
||||
|
||||
# The useful block is the second one after the header block that says AI Answer
|
||||
message = reply["blocks"][1]["text"]["text"]
|
||||
|
||||
if message.startswith("_Filters"):
|
||||
if len(blocks) <= 2:
|
||||
continue
|
||||
message = reply["blocks"][2]["text"]["text"]
|
||||
|
||||
user_sem_id = "Assistant"
|
||||
message_type = MessageType.ASSISTANT
|
||||
|
||||
thread_messages.append(
|
||||
ThreadMessage(message=message, sender=user_sem_id, role=message_type)
|
||||
)
|
||||
|
||||
return thread_messages
|
||||
|
@@ -47,7 +47,6 @@ from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.one_shot_answer.factory import get_default_qa_model
|
||||
from danswer.search.search_nlp_models import warm_up_models
|
||||
from danswer.server.danswer_api.ingestion import get_danswer_api_key
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
@@ -261,7 +260,6 @@ def get_application() -> FastAPI:
|
||||
|
||||
# This is for the LLM, most LLMs will not need warming up
|
||||
get_default_llm().log_model_configs()
|
||||
get_default_qa_model().warm_up_model()
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
nltk.download("stopwords", quiet=True)
|
||||
|
@@ -28,6 +28,8 @@ from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.factory import get_question_answer_model
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
@@ -35,6 +37,7 @@ from danswer.search.request_preprocessing import retrieval_preprocessing
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import full_chunk_search_generator
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -58,7 +61,8 @@ def stream_answer_objects(
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> Iterator[
|
||||
QADocsResponse
|
||||
QueryRephrase
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
@@ -73,6 +77,8 @@ def stream_answer_objects(
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
user_id = user.id if user is not None else None
|
||||
query_msg = query_req.messages[-1]
|
||||
history = query_req.messages[:-1]
|
||||
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
@@ -90,24 +96,20 @@ def stream_answer_objects(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=query_req.query,
|
||||
token_count=len(llm_tokenizer(query_req.query)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
history_str = combine_message_thread(history)
|
||||
|
||||
rephrased_query = thread_based_query_rephrase(
|
||||
user_query=query_msg.message,
|
||||
history_str=history_str,
|
||||
)
|
||||
yield QueryRephrase(rephrased_query=rephrased_query)
|
||||
|
||||
(
|
||||
retrieval_request,
|
||||
predicted_search_type,
|
||||
predicted_flow,
|
||||
) = retrieval_preprocessing(
|
||||
query=query_req.query,
|
||||
query=rephrased_query,
|
||||
retrieval_details=query_req.retrieval_options,
|
||||
persona=chat_session.persona,
|
||||
user=user,
|
||||
@@ -190,9 +192,25 @@ def stream_answer_objects(
|
||||
llm_version=llm_override,
|
||||
)
|
||||
|
||||
full_prompt_str = qa_model.build_prompt(
|
||||
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
|
||||
)
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=full_prompt_str,
|
||||
token_count=len(llm_tokenizer(full_prompt_str)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
response_packets = qa_model.answer_question_stream(
|
||||
query=query_req.query,
|
||||
context_docs=llm_chunks,
|
||||
prompt=full_prompt_str,
|
||||
llm_context_docs=llm_chunks,
|
||||
metrics_callback=llm_metrics_callback,
|
||||
)
|
||||
|
||||
@@ -272,6 +290,8 @@ def get_one_shot_answer(
|
||||
|
||||
answer = ""
|
||||
for packet in results:
|
||||
if isinstance(packet, QueryRephrase):
|
||||
qa_response.rephrase = packet.rephrased_query
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
@@ -289,6 +309,11 @@ def get_one_shot_answer(
|
||||
qa_response.answer = answer
|
||||
|
||||
if enable_reflexion:
|
||||
qa_response.answer_valid = get_answer_validity(query_req.query, answer)
|
||||
# Because follow up messages are explicitly tagged, we don't need to verify the answer
|
||||
if len(query_req.messages) == 1:
|
||||
first_query = query_req.messages[0].message
|
||||
qa_response.answer_valid = get_answer_validity(first_query, answer)
|
||||
else:
|
||||
qa_response.answer_valid = True
|
||||
|
||||
return qa_response
|
||||
|
@@ -3,98 +3,42 @@ from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.one_shot_answer.interfaces import QAModel
|
||||
from danswer.one_shot_answer.qa_block import PromptBasedQAHandler
|
||||
from danswer.one_shot_answer.qa_block import QABlock
|
||||
from danswer.one_shot_answer.qa_block import QAHandler
|
||||
from danswer.one_shot_answer.qa_block import SingleMessageQAHandler
|
||||
from danswer.one_shot_answer.qa_block import SingleMessageScratchpadHandler
|
||||
from danswer.one_shot_answer.qa_block import WeakLLMQAHandler
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_default_qa_handler(
|
||||
chain_of_thought: bool = False,
|
||||
user_selection: str | None = QA_PROMPT_OVERRIDE,
|
||||
) -> QAHandler:
|
||||
if user_selection:
|
||||
if user_selection.lower() == "default":
|
||||
return SingleMessageQAHandler()
|
||||
if user_selection.lower() == "cot":
|
||||
return SingleMessageScratchpadHandler()
|
||||
if user_selection.lower() == "weak":
|
||||
return WeakLLMQAHandler()
|
||||
|
||||
raise ValueError("Invalid Question-Answering prompt selected")
|
||||
|
||||
if chain_of_thought:
|
||||
return SingleMessageScratchpadHandler()
|
||||
|
||||
return SingleMessageQAHandler()
|
||||
|
||||
|
||||
def get_default_qa_model(
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
chain_of_thought: bool = False,
|
||||
) -> QAModel:
|
||||
llm = get_default_llm(api_key=api_key, timeout=timeout)
|
||||
qa_handler = get_default_qa_handler(chain_of_thought=chain_of_thought)
|
||||
|
||||
return QABlock(
|
||||
llm=llm,
|
||||
qa_handler=qa_handler,
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_qa_model(
|
||||
prompt: Prompt,
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
llm_version: str | None = None,
|
||||
) -> QAModel:
|
||||
return QABlock(
|
||||
llm=get_default_llm(
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
gen_ai_model_version_override=llm_version,
|
||||
),
|
||||
qa_handler=PromptBasedQAHandler(
|
||||
system_prompt=prompt.system_prompt, task_prompt=prompt.task_prompt
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_question_answer_model(
|
||||
prompt: Prompt | None,
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
chain_of_thought: bool = False,
|
||||
llm_version: str | None = None,
|
||||
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
|
||||
) -> QAModel:
|
||||
if prompt is None and llm_version is not None:
|
||||
raise RuntimeError(
|
||||
"Cannot specify llm version for QA model without providing prompt. "
|
||||
"This flow is only intended for flows with a specified Persona/Prompt."
|
||||
)
|
||||
if chain_of_thought:
|
||||
raise NotImplementedError("COT has been disabled")
|
||||
|
||||
if prompt is not None and chain_of_thought:
|
||||
raise RuntimeError(
|
||||
"Cannot choose COT prompt with a customized Prompt object. "
|
||||
"User can prompt the model to output COT themselves if they want."
|
||||
)
|
||||
system_prompt = prompt.system_prompt if prompt is not None else None
|
||||
task_prompt = prompt.task_prompt if prompt is not None else None
|
||||
|
||||
if prompt is not None:
|
||||
return get_prompt_qa_model(
|
||||
prompt=prompt,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
llm_version=llm_version,
|
||||
)
|
||||
|
||||
return get_default_qa_model(
|
||||
llm = get_default_llm(
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
chain_of_thought=chain_of_thought,
|
||||
gen_ai_model_version_override=llm_version,
|
||||
)
|
||||
|
||||
if qa_model_version == "weak":
|
||||
qa_handler: QAHandler = WeakLLMQAHandler(
|
||||
system_prompt=system_prompt, task_prompt=task_prompt
|
||||
)
|
||||
else:
|
||||
qa_handler = SingleMessageQAHandler(
|
||||
system_prompt=system_prompt, task_prompt=task_prompt
|
||||
)
|
||||
|
||||
return QABlock(llm=llm, qa_handler=qa_handler)
|
||||
|
@@ -1,37 +1,26 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
|
||||
from danswer.chat.models import AnswerQuestionReturn
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import LLMMetricsContainer
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
|
||||
|
||||
class QAModel:
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
"""Is this model protected by security features
|
||||
Does it need an api key to access the model for inference"""
|
||||
return True
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
"""This is called during server start up to load the models into memory
|
||||
pass if model is accessed via API"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def answer_question(
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionReturn:
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def answer_question_stream(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
prompt: str,
|
||||
llm_context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
raise NotImplementedError
|
||||
|
@@ -5,11 +5,22 @@ from pydantic import root_validator
|
||||
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.search.models import RetrievalDetails
|
||||
|
||||
|
||||
class QueryRephrase(BaseModel):
|
||||
rephrased_query: str
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None
|
||||
role: MessageType
|
||||
|
||||
|
||||
class DirectQARequest(BaseModel):
|
||||
query: str
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails
|
||||
@@ -35,6 +46,7 @@ class DirectQARequest(BaseModel):
|
||||
class OneShotQAResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
rephrase: str | None = None
|
||||
quotes: DanswerQuotes | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
@@ -4,11 +4,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.chat.chat_utils import build_context_str
|
||||
from danswer.chat.models import AnswerQuestionReturn
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
@@ -21,16 +17,21 @@ from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.llm.utils import tokenizer_trim_chunks
|
||||
from danswer.one_shot_answer.interfaces import QAModel
|
||||
from danswer.one_shot_answer.qa_utils import process_answer
|
||||
from danswer.one_shot_answer.qa_utils import process_model_tokens
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import COT_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.direct_qa_prompts import ONE_SHOT_SYSTEM_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import ONE_SHOT_TASK_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_SYSTEM_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_TASK_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_up_code_blocks
|
||||
from danswer.utils.text_processing import escape_newlines
|
||||
@@ -39,12 +40,6 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class QAHandler(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def build_prompt(
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_json_output(self) -> bool:
|
||||
@@ -54,12 +49,14 @@ class QAHandler(abc.ABC):
|
||||
finetuned to recognize this."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_llm_output(
|
||||
self, model_output: str, context_chunks: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||
return process_answer(
|
||||
model_output, context_chunks, is_json_prompt=self.is_json_output
|
||||
)
|
||||
@abc.abstractmethod
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def process_llm_token_stream(
|
||||
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
||||
@@ -78,70 +75,137 @@ class WeakLLMQAHandler(QAHandler):
|
||||
output format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str | None,
|
||||
task_prompt: str | None,
|
||||
) -> None:
|
||||
if not system_prompt and not task_prompt:
|
||||
self.system_prompt = WEAK_MODEL_SYSTEM_PROMPT
|
||||
self.task_prompt = WEAK_MODEL_TASK_PROMPT
|
||||
else:
|
||||
self.system_prompt = system_prompt or ""
|
||||
self.task_prompt = task_prompt or ""
|
||||
|
||||
self.task_prompt = self.task_prompt.rstrip()
|
||||
if self.task_prompt and self.task_prompt[0] != "\n":
|
||||
self.task_prompt = "\n" + self.task_prompt
|
||||
|
||||
@property
|
||||
def is_json_output(self) -> bool:
|
||||
return False
|
||||
|
||||
def build_prompt(
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[BaseMessage]:
|
||||
message = WEAK_LLM_PROMPT.format(
|
||||
user_query=query, single_reference_doc=context_chunks[0].content
|
||||
)
|
||||
self,
|
||||
query: str,
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> str:
|
||||
context_block = ""
|
||||
if context_chunks:
|
||||
context_block = CONTEXT_BLOCK.format(
|
||||
context_docs_str=context_chunks[0].content
|
||||
)
|
||||
|
||||
return [HumanMessage(content=message)]
|
||||
prompt_str = WEAK_LLM_PROMPT.format(
|
||||
system_prompt=self.system_prompt,
|
||||
context_block=context_block,
|
||||
task_prompt=self.task_prompt,
|
||||
user_query=query,
|
||||
)
|
||||
return prompt_str
|
||||
|
||||
|
||||
class SingleMessageQAHandler(QAHandler):
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str | None,
|
||||
task_prompt: str | None,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> None:
|
||||
self.use_language_hint = use_language_hint
|
||||
if not system_prompt and not task_prompt:
|
||||
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
|
||||
self.task_prompt = ONE_SHOT_TASK_PROMPT
|
||||
else:
|
||||
self.system_prompt = system_prompt or ""
|
||||
self.task_prompt = task_prompt or ""
|
||||
|
||||
self.task_prompt = self.task_prompt.rstrip()
|
||||
if self.task_prompt and self.task_prompt[0] != "\n":
|
||||
self.task_prompt = "\n" + self.task_prompt
|
||||
|
||||
@property
|
||||
def is_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> list[BaseMessage]:
|
||||
context_docs_str = build_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
||||
)
|
||||
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
|
||||
) -> str:
|
||||
context_block = ""
|
||||
if context_chunks:
|
||||
context_docs_str = build_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
||||
)
|
||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
|
||||
|
||||
single_message = JSON_PROMPT.format(
|
||||
context_docs_str=context_docs_str,
|
||||
history_block = ""
|
||||
if history_str:
|
||||
history_block = HISTORY_BLOCK.format(history_str=history_str)
|
||||
|
||||
full_prompt = JSON_PROMPT.format(
|
||||
system_prompt=self.system_prompt,
|
||||
context_block=context_block,
|
||||
history_block=history_block,
|
||||
task_prompt=self.task_prompt,
|
||||
user_query=query,
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
language_hint_or_none=LANGUAGE_HINT.strip()
|
||||
if self.use_language_hint
|
||||
else "",
|
||||
).strip()
|
||||
|
||||
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
|
||||
return prompt
|
||||
return full_prompt
|
||||
|
||||
|
||||
# This one isn't used, currently only streaming prompts are used
|
||||
class SingleMessageScratchpadHandler(QAHandler):
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str | None,
|
||||
task_prompt: str | None,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> None:
|
||||
self.use_language_hint = use_language_hint
|
||||
if not system_prompt and not task_prompt:
|
||||
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
|
||||
self.task_prompt = ONE_SHOT_TASK_PROMPT
|
||||
else:
|
||||
self.system_prompt = system_prompt or ""
|
||||
self.task_prompt = task_prompt or ""
|
||||
|
||||
self.task_prompt = self.task_prompt.rstrip()
|
||||
if self.task_prompt and self.task_prompt[0] != "\n":
|
||||
self.task_prompt = "\n" + self.task_prompt
|
||||
|
||||
@property
|
||||
def is_json_output(self) -> bool:
|
||||
# Even though the full LLM output isn't a valid json
|
||||
# only the valid json portion is kept and passed along
|
||||
# therefore it is treated as a json output
|
||||
return True
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> list[BaseMessage]:
|
||||
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
|
||||
) -> str:
|
||||
context_docs_str = build_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
||||
)
|
||||
|
||||
single_message = COT_PROMPT.format(
|
||||
# Outdated
|
||||
prompt = COT_PROMPT.format(
|
||||
context_docs_str=context_docs_str,
|
||||
user_query=query,
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
language_hint_or_none=LANGUAGE_HINT.strip()
|
||||
if self.use_language_hint
|
||||
else "",
|
||||
).strip()
|
||||
|
||||
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
|
||||
return prompt
|
||||
|
||||
def process_llm_output(
|
||||
@@ -170,68 +234,22 @@ class SingleMessageScratchpadHandler(QAHandler):
|
||||
)
|
||||
|
||||
|
||||
class PromptBasedQAHandler(QAHandler):
|
||||
def __init__(self, system_prompt: str, task_prompt: str) -> None:
|
||||
self.system_prompt = system_prompt
|
||||
self.task_prompt = task_prompt
|
||||
|
||||
@property
|
||||
def is_json_output(self) -> bool:
|
||||
return False
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> list[BaseMessage]:
|
||||
context_docs_str = build_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
||||
)
|
||||
|
||||
if not context_chunks:
|
||||
single_message = PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query=query,
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
).strip()
|
||||
else:
|
||||
single_message = PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str=context_docs_str,
|
||||
user_query=query,
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
).strip()
|
||||
|
||||
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
|
||||
return prompt
|
||||
|
||||
def build_dummy_prompt(self, retrieval_disabled: bool) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
def process_llm_output(
|
||||
self, model_output: str, context_chunks: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||
return DanswerAnswer(answer=model_output), DanswerQuotes(quotes=[])
|
||||
|
||||
def process_llm_token_stream(
|
||||
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
for token in tokens:
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
|
||||
yield DanswerQuotes(quotes=[])
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
class QABlock(QAModel):
|
||||
@@ -239,64 +257,30 @@ class QABlock(QAModel):
|
||||
self._llm = llm
|
||||
self._qa_handler = qa_handler
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return self._llm.requires_api_key
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
"""This is called during server start up to load the models into memory
|
||||
in case the chosen LLM is not accessed via API"""
|
||||
if self._llm.requires_warm_up:
|
||||
logger.info("Warming up LLM with a first inference")
|
||||
self._llm.invoke("Ignore this!")
|
||||
|
||||
def answer_question(
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionReturn:
|
||||
trimmed_context_docs = tokenizer_trim_chunks(context_docs)
|
||||
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
||||
model_out = self._llm.invoke(prompt)
|
||||
|
||||
if metrics_callback is not None:
|
||||
prompt_tokens = sum(
|
||||
[
|
||||
check_number_of_tokens(
|
||||
text=str(p.content), encode_fn=get_default_llm_token_encode()
|
||||
)
|
||||
for p in prompt
|
||||
]
|
||||
)
|
||||
|
||||
response_tokens = check_number_of_tokens(
|
||||
text=model_out, encode_fn=get_default_llm_token_encode()
|
||||
)
|
||||
|
||||
metrics_callback(
|
||||
LLMMetricsContainer(
|
||||
prompt_tokens=prompt_tokens, response_tokens=response_tokens
|
||||
)
|
||||
)
|
||||
|
||||
return self._qa_handler.process_llm_output(model_out, trimmed_context_docs)
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> str:
|
||||
prompt = self._qa_handler.build_prompt(
|
||||
query=query, history_str=history_str, context_chunks=context_chunks
|
||||
)
|
||||
return prompt
|
||||
|
||||
def answer_question_stream(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
prompt: str,
|
||||
llm_context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
trimmed_context_docs = tokenizer_trim_chunks(context_docs)
|
||||
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
||||
tokens_stream = self._llm.stream(prompt)
|
||||
|
||||
captured_tokens = []
|
||||
|
||||
try:
|
||||
for answer_piece in self._qa_handler.process_llm_token_stream(
|
||||
iter(tokens_stream), trimmed_context_docs
|
||||
iter(tokens_stream), llm_context_docs
|
||||
):
|
||||
if (
|
||||
isinstance(answer_piece, DanswerAnswerPiece)
|
||||
@@ -309,13 +293,8 @@ class QABlock(QAModel):
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
if metrics_callback is not None:
|
||||
prompt_tokens = sum(
|
||||
[
|
||||
check_number_of_tokens(
|
||||
text=str(p.content), encode_fn=get_default_llm_token_encode()
|
||||
)
|
||||
for p in prompt
|
||||
]
|
||||
prompt_tokens = check_number_of_tokens(
|
||||
text=str(prompt), encode_fn=get_default_llm_token_encode()
|
||||
)
|
||||
|
||||
response_tokens = check_number_of_tokens(
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from json.decoder import JSONDecodeError
|
||||
@@ -13,7 +14,11 @@ from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.prompts.constants import ANSWER_PAT
|
||||
from danswer.prompts.constants import QUOTE_PAT
|
||||
from danswer.prompts.constants import UNCERTAINTY_PAT
|
||||
@@ -270,3 +275,41 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
||||
"""Mock streaming by generating the passed in model output, character by character"""
|
||||
for token in model_out:
|
||||
yield token
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
token_limit: int | None = GEN_AI_HISTORY_CUTOFF,
|
||||
llm_tokenizer: Callable | None = None,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
if llm_tokenizer is None:
|
||||
llm_tokenizer = get_default_llm_token_encode()
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer(msg_str))
|
||||
|
||||
if (
|
||||
token_limit is not None
|
||||
and total_token_count + message_token_count > token_limit
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator
|
||||
CODE_BLOCK_PAT = "```\n{}\n```"
|
||||
QUESTION_PAT = "Query:"
|
||||
FINAL_QUERY_PAT = "Final Query:"
|
||||
THOUGHT_PAT = "Thought:"
|
||||
ANSWER_PAT = "Answer:"
|
||||
ANSWERABLE_PAT = "Answerable:"
|
||||
|
@@ -2,24 +2,36 @@
|
||||
# It is used also for the one shot direct QA flow
|
||||
import json
|
||||
|
||||
from danswer.prompts.constants import ANSWER_PAT
|
||||
from danswer.prompts.constants import FINAL_QUERY_PAT
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.prompts.constants import QUESTION_PAT
|
||||
from danswer.prompts.constants import QUOTE_PAT
|
||||
from danswer.prompts.constants import THOUGHT_PAT
|
||||
from danswer.prompts.constants import UNCERTAINTY_PAT
|
||||
|
||||
|
||||
QA_HEADER = """
|
||||
ONE_SHOT_SYSTEM_PROMPT = """
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide \
|
||||
accurate and detailed answers to diverse queries.
|
||||
""".strip()
|
||||
|
||||
ONE_SHOT_TASK_PROMPT = """
|
||||
Answer the final query below taking into account the context above where relevant. \
|
||||
Ignore any provided context that is not relevant to the query.
|
||||
""".strip()
|
||||
|
||||
|
||||
WEAK_MODEL_SYSTEM_PROMPT = """
|
||||
Respond to the user query using the following reference document.
|
||||
""".lstrip()
|
||||
|
||||
WEAK_MODEL_TASK_PROMPT = """
|
||||
Answer the user query below based on the reference document above.
|
||||
"""
|
||||
|
||||
|
||||
REQUIRE_JSON = """
|
||||
You ALWAYS responds with only a json containing an answer and quotes that support the answer.
|
||||
Your responses are as INFORMATIVE and DETAILED as possible.
|
||||
You ALWAYS responds with ONLY a JSON containing an answer and quotes that support the answer.
|
||||
""".strip()
|
||||
|
||||
|
||||
@@ -34,6 +46,21 @@ IMPORTANT: Respond in the same language as my query!
|
||||
"""
|
||||
|
||||
|
||||
CONTEXT_BLOCK = f"""
|
||||
REFERENCE DOCUMENTS:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{context_docs_str}}
|
||||
{GENERAL_SEP_PAT}
|
||||
"""
|
||||
|
||||
HISTORY_BLOCK = f"""
|
||||
CONVERSATION HISTORY:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{history_str}}
|
||||
{GENERAL_SEP_PAT}
|
||||
"""
|
||||
|
||||
|
||||
# This has to be doubly escaped due to json containing { } which are also used for format strings
|
||||
EMPTY_SAMPLE_JSON = {
|
||||
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
|
||||
@@ -48,44 +75,22 @@ ANSWER_NOT_FOUND_RESPONSE = f'{{"answer": "{UNCERTAINTY_PAT}", "quotes": []}}'
|
||||
|
||||
|
||||
# Default json prompt which can reference multiple docs and provide answer + quotes
|
||||
# system_like_header is similar to system message, can be user provided or defaults to QA_HEADER
|
||||
# context/history blocks are for context documents and conversation history, they can be blank
|
||||
# task prompt is the task message of the prompt, can be blank, there is no default
|
||||
JSON_PROMPT = f"""
|
||||
{QA_HEADER}
|
||||
{{system_prompt}}
|
||||
{REQUIRE_JSON}
|
||||
{{context_block}}{{history_block}}{{task_prompt}}
|
||||
|
||||
CONTEXT:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{context_docs_str}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
SAMPLE_RESPONSE:
|
||||
SAMPLE RESPONSE:
|
||||
```
|
||||
{{{json.dumps(EMPTY_SAMPLE_JSON)}}}
|
||||
```
|
||||
{QUESTION_PAT.upper()} {{user_query}}
|
||||
{JSON_HELPFUL_HINT}
|
||||
{{language_hint_or_none}}
|
||||
""".strip()
|
||||
|
||||
{FINAL_QUERY_PAT.upper()}
|
||||
{{user_query}}
|
||||
|
||||
# Default chain-of-thought style json prompt which uses multiple docs
|
||||
# This one has a section for the LLM to output some non-answer "thoughts"
|
||||
# COT (chain-of-thought) flow basically
|
||||
COT_PROMPT = f"""
|
||||
{QA_HEADER}
|
||||
|
||||
CONTEXT:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{context_docs_str}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
You MUST respond in the following format:
|
||||
```
|
||||
{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.
|
||||
|
||||
{{{json.dumps(EMPTY_SAMPLE_JSON)}}}
|
||||
```
|
||||
|
||||
{QUESTION_PAT.upper()} {{user_query}}
|
||||
{JSON_HELPFUL_HINT}
|
||||
{{language_hint_or_none}}
|
||||
""".strip()
|
||||
@@ -94,35 +99,17 @@ You MUST respond in the following format:
|
||||
# For weak LLM which only takes one chunk and cannot output json
|
||||
# Also not requiring quotes as it tends to not work
|
||||
WEAK_LLM_PROMPT = f"""
|
||||
Respond to the user query using the following reference document.
|
||||
|
||||
Reference Document:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{single_reference_doc}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
Answer the user query below based on the reference document above.
|
||||
{{system_prompt}}
|
||||
{{context_block}}
|
||||
{{task_prompt}}
|
||||
|
||||
{QUESTION_PAT.upper()}
|
||||
{{user_query}}
|
||||
""".strip()
|
||||
|
||||
|
||||
# For weak CHAT LLM which takes one chunk and cannot output json
|
||||
# The next message should have the user query
|
||||
# Note, no flow/config currently uses this one
|
||||
WEAK_CHAT_LLM_PROMPT = f"""
|
||||
You are a question answering assistant
|
||||
Respond to the user query with an "{ANSWER_PAT}" section and \
|
||||
as many "{QUOTE_PAT}" sections as needed to support the answer.
|
||||
Answer the user query based on the following document:
|
||||
|
||||
{{first_chunk_content}}
|
||||
""".strip()
|
||||
|
||||
|
||||
# Parameterized prompt which allows the user to specify their
|
||||
# own system / task prompt
|
||||
# This is only for visualization for the users to specify their own prompts
|
||||
# The actual flow does not work like this
|
||||
PARAMATERIZED_PROMPT = f"""
|
||||
{{system_prompt}}
|
||||
|
||||
@@ -147,6 +134,31 @@ RESPONSE:
|
||||
""".strip()
|
||||
|
||||
|
||||
# CURRENTLY DISABLED, CANNOT USE THIS ONE
|
||||
# Default chain-of-thought style json prompt which uses multiple docs
|
||||
# This one has a section for the LLM to output some non-answer "thoughts"
|
||||
# COT (chain-of-thought) flow basically
|
||||
COT_PROMPT = f"""
|
||||
{ONE_SHOT_SYSTEM_PROMPT}
|
||||
|
||||
CONTEXT:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{context_docs_str}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
You MUST respond in the following format:
|
||||
```
|
||||
{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.
|
||||
|
||||
{{{json.dumps(EMPTY_SAMPLE_JSON)}}}
|
||||
```
|
||||
|
||||
{QUESTION_PAT.upper()} {{user_query}}
|
||||
{JSON_HELPFUL_HINT}
|
||||
{{language_hint_or_none}}
|
||||
""".strip()
|
||||
|
||||
|
||||
# User the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(JSON_PROMPT) # Default prompt used in the Danswer UI flow
|
||||
|
@@ -3,8 +3,8 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
@@ -31,7 +31,7 @@ def retrieval_preprocessing(
|
||||
bypass_acl: bool = False,
|
||||
include_query_intent: bool = True,
|
||||
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||
skip_rerank_non_realtime: bool = SKIP_RERANKING,
|
||||
skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW,
|
||||
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
||||
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
||||
|
@@ -59,6 +59,22 @@ def multilingual_query_expansion(
|
||||
return query_rephrases
|
||||
|
||||
|
||||
def get_contextual_rephrase_messages(
|
||||
question: str,
|
||||
history_str: str,
|
||||
) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": HISTORY_QUERY_REPHRASE.format(
|
||||
question=question, chat_history=history_str
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def history_based_query_rephrase(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
@@ -66,21 +82,6 @@ def history_based_query_rephrase(
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
) -> str:
|
||||
def _get_history_rephrase_messages(
|
||||
question: str,
|
||||
history_str: str,
|
||||
) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": HISTORY_QUERY_REPHRASE.format(
|
||||
question=question, chat_history=history_str
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
user_query = cast(str, query_message.message)
|
||||
|
||||
if not user_query:
|
||||
@@ -98,7 +99,38 @@ def history_based_query_rephrase(
|
||||
|
||||
history_str = combine_message_chain(history)
|
||||
|
||||
prompt_msgs = _get_history_rephrase_messages(
|
||||
prompt_msgs = get_contextual_rephrase_messages(
|
||||
question=user_query, history_str=history_str
|
||||
)
|
||||
|
||||
if llm is None:
|
||||
llm = get_default_llm()
|
||||
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
||||
rephrased_query = llm.invoke(filled_llm_prompt)
|
||||
|
||||
logger.debug(f"Rephrased combined query: {rephrased_query}")
|
||||
|
||||
return rephrased_query
|
||||
|
||||
|
||||
def thread_based_query_rephrase(
|
||||
user_query: str,
|
||||
history_str: str,
|
||||
llm: LLM | None = None,
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
) -> str:
|
||||
if not history_str:
|
||||
return user_query
|
||||
|
||||
if len(user_query) >= size_heuristic:
|
||||
return user_query
|
||||
|
||||
if count_punctuation(user_query) >= punctuation_heuristic:
|
||||
return user_query
|
||||
|
||||
prompt_msgs = get_contextual_rephrase_messages(
|
||||
question=user_query, history_str=history_str
|
||||
)
|
||||
|
||||
|
@@ -15,7 +15,7 @@ from danswer.db.chat import upsert_persona
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.one_shot_answer.qa_block import PromptBasedQAHandler
|
||||
from danswer.one_shot_answer.qa_block import build_dummy_prompt
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.features.persona.models import PromptTemplateResponse
|
||||
@@ -149,9 +149,11 @@ def build_final_template_prompt(
|
||||
_: User | None = Depends(current_user),
|
||||
) -> PromptTemplateResponse:
|
||||
return PromptTemplateResponse(
|
||||
final_prompt_template=PromptBasedQAHandler(
|
||||
system_prompt=system_prompt, task_prompt=task_prompt
|
||||
).build_dummy_prompt(retrieval_disabled=retrieval_disabled)
|
||||
final_prompt_template=build_dummy_prompt(
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
retrieval_disabled=retrieval_disabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@@ -163,9 +163,8 @@ def get_answer_with_quote(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
logger.info(
|
||||
f"Received query for one shot answer with quotes: {query_request.query}"
|
||||
)
|
||||
query = query_request.messages[0].message
|
||||
logger.info(f"Received query for one shot answer with quotes: {query}")
|
||||
packets = stream_one_shot_answer(
|
||||
query_req=query_request, user=user, db_session=db_session
|
||||
)
|
||||
|
@@ -9,9 +9,11 @@ import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import LLMMetricsContainer
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.one_shot_answer.answer_question import get_one_shot_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
@@ -84,8 +86,10 @@ def get_answer_for_question(
|
||||
access_control_list=None,
|
||||
)
|
||||
|
||||
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
||||
|
||||
new_message_request = DirectQARequest(
|
||||
query=query,
|
||||
messages=messages,
|
||||
prompt_id=0,
|
||||
persona_id=0,
|
||||
retrieval_options=RetrievalDetails(
|
||||
|
Reference in New Issue
Block a user