DanswerBot Chat (#831)

This commit is contained in:
Yuhong Sun
2023-12-17 18:18:48 -08:00
committed by GitHub
parent c7a91b1819
commit 5957b888a5
23 changed files with 526 additions and 385 deletions

View File

@@ -34,9 +34,9 @@ MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
# Cross Encoder Settings
# This following setting is for non-real-time-flows
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
# This one is for real-time (streaming) flows
ENABLE_RERANKING_ASYNC_FLOW = (
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
)
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)

View File

@@ -214,7 +214,9 @@ def build_qa_response_blocks(
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
)
else:
answer_block = SectionBlock(text=remove_slack_text_interactions(answer))
answer_processed = remove_slack_text_interactions(answer)
answer_processed = answer_processed.encode("utf-8").decode("unicode_escape")
answer_block = SectionBlock(text=answer_processed)
if quotes:
quotes_blocks = build_quotes_block(quotes)

View File

@@ -86,8 +86,14 @@ def handle_message(
Query thrown out by filters due to config does not count as a failure that should be notified
Danswer failing to answer/retrieve docs does count and should be notified
"""
msg = message_info.msg_content
channel = message_info.channel_to_respond
logger = cast(
logging.Logger,
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
)
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
sender_id = message_info.sender
bipass_filters = message_info.bipass_filters
@@ -95,11 +101,6 @@ def handle_message(
engine = get_sqlalchemy_engine()
logger = cast(
logging.Logger,
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
)
document_set_names: list[str] | None = None
persona = channel_config.persona if channel_config else None
prompt = None
@@ -133,7 +134,7 @@ def handle_message(
if (
"questionmark_prefilter" in channel_conf["answer_filters"]
and "?" not in msg
and "?" not in messages[-1].message
):
logger.info(
"Skipping message since it does not contain a question mark"
@@ -223,7 +224,7 @@ def handle_message(
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
query=msg,
messages=messages,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
@@ -275,7 +276,9 @@ def handle_message(
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(f"Unable to answer question: '{msg}' - no documents found")
logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
@@ -296,7 +299,7 @@ def handle_message(
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(msg, is_bot_msg)
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,

View File

@@ -1,4 +1,3 @@
import re
import time
from threading import Event
from typing import Any
@@ -10,9 +9,10 @@ from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback
@@ -22,9 +22,13 @@ from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import decompose_block_id
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.search_nlp_models import warm_up_models
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
@@ -63,7 +67,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
return False
if event_type == "message":
bot_tag_id = client.web_client.auth_test().get("user_id")
bot_tag_id = get_danswer_bot_app_id(client.web_client)
# DMs with the bot don't pick up the @DanswerBot so we have to keep the
# caught events_api
if bot_tag_id and bot_tag_id in msg and event.get("channel_type") != "im":
@@ -87,8 +91,14 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
# Pick the root of the thread (if a thread exists)
if thread_ts and message_ts != thread_ts:
channel_specific_logger.info(
# Can respond in thread if it's an "im" directly to Danswer or @DanswerBot is tagged
if (
thread_ts
and message_ts != thread_ts
and event_type != "app_mention"
and event.get("channel_type") != "im"
):
channel_specific_logger.debug(
"Skipping message since it is not the root of a thread"
)
return False
@@ -156,18 +166,25 @@ def build_request_details(
tagged = event.get("type") == "app_mention"
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
bot_tag_id = client.web_client.auth_test().get("user_id")
# Might exist even if not tagged, specifically in the case of @DanswerBot
# in DanswerBot DM channel
msg = re.sub(rf"<@{bot_tag_id}>\s", "", msg)
msg = remove_danswer_bot_tag(msg, client=client.web_client)
if tagged:
logger.info("User tagged DanswerBot")
if thread_ts != message_ts and thread_ts is not None:
thread_messages = read_slack_thread(
channel=channel, thread=thread_ts, client=client.web_client
)
else:
thread_messages = [
ThreadMessage(message=msg, sender=None, role=MessageType.USER)
]
return SlackMessageInfo(
msg_content=msg,
thread_messages=thread_messages,
channel_to_respond=channel,
msg_to_respond=cast(str, thread_ts or message_ts),
msg_to_respond=cast(str, message_ts or thread_ts),
sender=event.get("user") or None,
bipass_filters=tagged,
is_bot_msg=False,
@@ -178,8 +195,10 @@ def build_request_details(
msg = req.payload["text"]
sender = req.payload["user_id"]
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
return SlackMessageInfo(
msg_content=msg,
thread_messages=[single_msg],
channel_to_respond=channel,
msg_to_respond=None,
sender=sender,
@@ -297,7 +316,7 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None:
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
warm_up_models(skip_cross_encoders=SKIP_RERANKING)
warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW)
slack_bot_tokens: SlackBotTokens | None = None
socket_client: SocketModeClient | None = None

View File

@@ -1,8 +1,10 @@
from pydantic import BaseModel
from danswer.one_shot_answer.models import ThreadMessage
class SlackMessageInfo(BaseModel):
msg_content: str
thread_messages: list[ThreadMessage]
channel_to_respond: str
msg_to_respond: str | None
sender: str | None

View File

@@ -13,17 +13,34 @@ from slack_sdk.models.blocks import Block
from slack_sdk.models.metadata import Metadata
from danswer.configs.constants import ID_SEPARATOR
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.one_shot_answer.models import ThreadMessage
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import replace_whitespaces_w_space
logger = setup_logger()
DANSWER_BOT_APP_ID: str | None = None
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
global DANSWER_BOT_APP_ID
if DANSWER_BOT_APP_ID is None:
DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id")
return DANSWER_BOT_APP_ID
def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
bot_tag_id = get_danswer_bot_app_id(web_client=client)
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
class ChannelIdAdapter(logging.LoggerAdapter):
"""This is used to add the channel ID to all log messages
emitted in this file"""
@@ -199,3 +216,57 @@ def fetch_userids_from_emails(user_emails: list[str], client: WebClient) -> list
)
return user_ids
def fetch_user_semantic_id_from_id(user_id: str, client: WebClient) -> str | None:
response = client.users_info(user=user_id)
if not response["ok"]:
return None
user: dict = cast(dict[Any, dict], response.data).get("user", {})
return (
user.get("real_name")
or user.get("name")
or user.get("profile", {}).get("email")
)
def read_slack_thread(
channel: str, thread: str, client: WebClient
) -> list[ThreadMessage]:
thread_messages: list[ThreadMessage] = []
response = client.conversations_replies(channel=channel, ts=thread)
replies = cast(dict, response.data).get("messages", [])
for reply in replies:
if "user" in reply and "bot_id" not in reply:
message = remove_danswer_bot_tag(reply["text"], client=client)
user_sem_id = fetch_user_semantic_id_from_id(reply["user"], client)
message_type = MessageType.USER
else:
self_app_id = get_danswer_bot_app_id(client)
# Only include bot messages from Danswer, other bots are not taken in as context
if self_app_id != reply.get("user"):
continue
blocks = reply["blocks"]
if len(blocks) <= 1:
continue
# The useful block is the second one after the header block that says AI Answer
message = reply["blocks"][1]["text"]["text"]
if message.startswith("_Filters"):
if len(blocks) <= 2:
continue
message = reply["blocks"][2]["text"]["text"]
user_sem_id = "Assistant"
message_type = MessageType.ASSISTANT
thread_messages.append(
ThreadMessage(message=message, sender=user_sem_id, role=message_type)
)
return thread_messages

View File

@@ -47,7 +47,6 @@ from danswer.db.credentials import create_initial_public_credential
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.factory import get_default_document_index
from danswer.llm.factory import get_default_llm
from danswer.one_shot_answer.factory import get_default_qa_model
from danswer.search.search_nlp_models import warm_up_models
from danswer.server.danswer_api.ingestion import get_danswer_api_key
from danswer.server.danswer_api.ingestion import router as danswer_api_router
@@ -261,7 +260,6 @@ def get_application() -> FastAPI:
# This is for the LLM, most LLMs will not need warming up
get_default_llm().log_model_configs()
get_default_qa_model().warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)

View File

@@ -28,6 +28,8 @@ from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.factory import get_question_answer_model
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SavedSearchDoc
@@ -35,6 +37,7 @@ from danswer.search.request_preprocessing import retrieval_preprocessing
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import full_chunk_search_generator
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
@@ -58,7 +61,8 @@ def stream_answer_objects(
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> Iterator[
QADocsResponse
QueryRephrase
| QADocsResponse
| LLMRelevanceFilterResponse
| DanswerAnswerPiece
| DanswerQuotes
@@ -73,6 +77,8 @@ def stream_answer_objects(
4. [always] Details on the final AI response message that is created
"""
user_id = user.id if user is not None else None
query_msg = query_req.messages[-1]
history = query_req.messages[:-1]
chat_session = create_chat_session(
db_session=db_session,
@@ -90,24 +96,20 @@ def stream_answer_objects(
chat_session_id=chat_session.id, db_session=db_session
)
# Create the first User query message
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=query_req.prompt_id,
message=query_req.query,
token_count=len(llm_tokenizer(query_req.query)),
message_type=MessageType.USER,
db_session=db_session,
commit=True,
history_str = combine_message_thread(history)
rephrased_query = thread_based_query_rephrase(
user_query=query_msg.message,
history_str=history_str,
)
yield QueryRephrase(rephrased_query=rephrased_query)
(
retrieval_request,
predicted_search_type,
predicted_flow,
) = retrieval_preprocessing(
query=query_req.query,
query=rephrased_query,
retrieval_details=query_req.retrieval_options,
persona=chat_session.persona,
user=user,
@@ -190,9 +192,25 @@ def stream_answer_objects(
llm_version=llm_override,
)
full_prompt_str = qa_model.build_prompt(
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
)
# Create the first User query message
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=query_req.prompt_id,
message=full_prompt_str,
token_count=len(llm_tokenizer(full_prompt_str)),
message_type=MessageType.USER,
db_session=db_session,
commit=True,
)
response_packets = qa_model.answer_question_stream(
query=query_req.query,
context_docs=llm_chunks,
prompt=full_prompt_str,
llm_context_docs=llm_chunks,
metrics_callback=llm_metrics_callback,
)
@@ -272,6 +290,8 @@ def get_one_shot_answer(
answer = ""
for packet in results:
if isinstance(packet, QueryRephrase):
qa_response.rephrase = packet.rephrased_query
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
@@ -289,6 +309,11 @@ def get_one_shot_answer(
qa_response.answer = answer
if enable_reflexion:
qa_response.answer_valid = get_answer_validity(query_req.query, answer)
# Because follow up messages are explicitly tagged, we don't need to verify the answer
if len(query_req.messages) == 1:
first_query = query_req.messages[0].message
qa_response.answer_valid = get_answer_validity(first_query, answer)
else:
qa_response.answer_valid = True
return qa_response

View File

@@ -3,98 +3,42 @@ from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.db.models import Prompt
from danswer.llm.factory import get_default_llm
from danswer.one_shot_answer.interfaces import QAModel
from danswer.one_shot_answer.qa_block import PromptBasedQAHandler
from danswer.one_shot_answer.qa_block import QABlock
from danswer.one_shot_answer.qa_block import QAHandler
from danswer.one_shot_answer.qa_block import SingleMessageQAHandler
from danswer.one_shot_answer.qa_block import SingleMessageScratchpadHandler
from danswer.one_shot_answer.qa_block import WeakLLMQAHandler
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_default_qa_handler(
chain_of_thought: bool = False,
user_selection: str | None = QA_PROMPT_OVERRIDE,
) -> QAHandler:
if user_selection:
if user_selection.lower() == "default":
return SingleMessageQAHandler()
if user_selection.lower() == "cot":
return SingleMessageScratchpadHandler()
if user_selection.lower() == "weak":
return WeakLLMQAHandler()
raise ValueError("Invalid Question-Answering prompt selected")
if chain_of_thought:
return SingleMessageScratchpadHandler()
return SingleMessageQAHandler()
def get_default_qa_model(
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
chain_of_thought: bool = False,
) -> QAModel:
llm = get_default_llm(api_key=api_key, timeout=timeout)
qa_handler = get_default_qa_handler(chain_of_thought=chain_of_thought)
return QABlock(
llm=llm,
qa_handler=qa_handler,
)
def get_prompt_qa_model(
prompt: Prompt,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
llm_version: str | None = None,
) -> QAModel:
return QABlock(
llm=get_default_llm(
api_key=api_key,
timeout=timeout,
gen_ai_model_version_override=llm_version,
),
qa_handler=PromptBasedQAHandler(
system_prompt=prompt.system_prompt, task_prompt=prompt.task_prompt
),
)
def get_question_answer_model(
prompt: Prompt | None,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
chain_of_thought: bool = False,
llm_version: str | None = None,
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
) -> QAModel:
if prompt is None and llm_version is not None:
raise RuntimeError(
"Cannot specify llm version for QA model without providing prompt. "
"This flow is only intended for flows with a specified Persona/Prompt."
)
if chain_of_thought:
raise NotImplementedError("COT has been disabled")
if prompt is not None and chain_of_thought:
raise RuntimeError(
"Cannot choose COT prompt with a customized Prompt object. "
"User can prompt the model to output COT themselves if they want."
)
system_prompt = prompt.system_prompt if prompt is not None else None
task_prompt = prompt.task_prompt if prompt is not None else None
if prompt is not None:
return get_prompt_qa_model(
prompt=prompt,
api_key=api_key,
timeout=timeout,
llm_version=llm_version,
)
return get_default_qa_model(
llm = get_default_llm(
api_key=api_key,
timeout=timeout,
chain_of_thought=chain_of_thought,
gen_ai_model_version_override=llm_version,
)
if qa_model_version == "weak":
qa_handler: QAHandler = WeakLLMQAHandler(
system_prompt=system_prompt, task_prompt=task_prompt
)
else:
qa_handler = SingleMessageQAHandler(
system_prompt=system_prompt, task_prompt=task_prompt
)
return QABlock(llm=llm, qa_handler=qa_handler)

View File

@@ -1,37 +1,26 @@
import abc
from collections.abc import Callable
from danswer.chat.models import AnswerQuestionReturn
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import LLMMetricsContainer
from danswer.indexing.models import InferenceChunk
class QAModel:
@property
def requires_api_key(self) -> bool:
"""Is this model protected by security features
Does it need an api key to access the model for inference"""
return True
def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory
pass if model is accessed via API"""
@abc.abstractmethod
def answer_question(
def build_prompt(
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionReturn:
history_str: str,
context_chunks: list[InferenceChunk],
) -> str:
raise NotImplementedError
@abc.abstractmethod
def answer_question_stream(
self,
query: str,
context_docs: list[InferenceChunk],
prompt: str,
llm_context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionStreamReturn:
raise NotImplementedError

View File

@@ -5,11 +5,22 @@ from pydantic import root_validator
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.models import RetrievalDetails
class QueryRephrase(BaseModel):
rephrased_query: str
class ThreadMessage(BaseModel):
message: str
sender: str | None
role: MessageType
class DirectQARequest(BaseModel):
query: str
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
retrieval_options: RetrievalDetails
@@ -35,6 +46,7 @@ class DirectQARequest(BaseModel):
class OneShotQAResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
rephrase: str | None = None
quotes: DanswerQuotes | None = None
docs: QADocsResponse | None = None
llm_chunks_indices: list[int] | None = None

View File

@@ -4,11 +4,7 @@ from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from danswer.chat.chat_utils import build_context_str
from danswer.chat.models import AnswerQuestionReturn
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
@@ -21,16 +17,21 @@ from danswer.indexing.models import InferenceChunk
from danswer.llm.interfaces import LLM
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_token_encode
from danswer.llm.utils import tokenizer_trim_chunks
from danswer.one_shot_answer.interfaces import QAModel
from danswer.one_shot_answer.qa_utils import process_answer
from danswer.one_shot_answer.qa_utils import process_model_tokens
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import COT_PROMPT
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import ONE_SHOT_SYSTEM_PROMPT
from danswer.prompts.direct_qa_prompts import ONE_SHOT_TASK_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_SYSTEM_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_TASK_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_up_code_blocks
from danswer.utils.text_processing import escape_newlines
@@ -39,12 +40,6 @@ logger = setup_logger()
class QAHandler(abc.ABC):
@abc.abstractmethod
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
raise NotImplementedError
@property
@abc.abstractmethod
def is_json_output(self) -> bool:
@@ -54,12 +49,14 @@ class QAHandler(abc.ABC):
finetuned to recognize this."""
raise NotImplementedError
def process_llm_output(
self, model_output: str, context_chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, DanswerQuotes]:
return process_answer(
model_output, context_chunks, is_json_prompt=self.is_json_output
)
@abc.abstractmethod
def build_prompt(
self,
query: str,
history_str: str,
context_chunks: list[InferenceChunk],
) -> str:
raise NotImplementedError
def process_llm_token_stream(
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
@@ -78,70 +75,137 @@ class WeakLLMQAHandler(QAHandler):
output format.
"""
def __init__(
self,
system_prompt: str | None,
task_prompt: str | None,
) -> None:
if not system_prompt and not task_prompt:
self.system_prompt = WEAK_MODEL_SYSTEM_PROMPT
self.task_prompt = WEAK_MODEL_TASK_PROMPT
else:
self.system_prompt = system_prompt or ""
self.task_prompt = task_prompt or ""
self.task_prompt = self.task_prompt.rstrip()
if self.task_prompt and self.task_prompt[0] != "\n":
self.task_prompt = "\n" + self.task_prompt
@property
def is_json_output(self) -> bool:
return False
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
message = WEAK_LLM_PROMPT.format(
user_query=query, single_reference_doc=context_chunks[0].content
)
self,
query: str,
history_str: str,
context_chunks: list[InferenceChunk],
) -> str:
context_block = ""
if context_chunks:
context_block = CONTEXT_BLOCK.format(
context_docs_str=context_chunks[0].content
)
return [HumanMessage(content=message)]
prompt_str = WEAK_LLM_PROMPT.format(
system_prompt=self.system_prompt,
context_block=context_block,
task_prompt=self.task_prompt,
user_query=query,
)
return prompt_str
class SingleMessageQAHandler(QAHandler):
def __init__(
self,
system_prompt: str | None,
task_prompt: str | None,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> None:
self.use_language_hint = use_language_hint
if not system_prompt and not task_prompt:
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
self.task_prompt = ONE_SHOT_TASK_PROMPT
else:
self.system_prompt = system_prompt or ""
self.task_prompt = task_prompt or ""
self.task_prompt = self.task_prompt.rstrip()
if self.task_prompt and self.task_prompt[0] != "\n":
self.task_prompt = "\n" + self.task_prompt
@property
def is_json_output(self) -> bool:
return True
def build_prompt(
self,
query: str,
context_chunks: list[InferenceChunk],
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> list[BaseMessage]:
context_docs_str = build_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks)
)
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
) -> str:
context_block = ""
if context_chunks:
context_docs_str = build_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks)
)
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
single_message = JSON_PROMPT.format(
context_docs_str=context_docs_str,
history_block = ""
if history_str:
history_block = HISTORY_BLOCK.format(history_str=history_str)
full_prompt = JSON_PROMPT.format(
system_prompt=self.system_prompt,
context_block=context_block,
history_block=history_block,
task_prompt=self.task_prompt,
user_query=query,
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
language_hint_or_none=LANGUAGE_HINT.strip()
if self.use_language_hint
else "",
).strip()
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
return full_prompt
# This one isn't used, currently only streaming prompts are used
class SingleMessageScratchpadHandler(QAHandler):
def __init__(
self,
system_prompt: str | None,
task_prompt: str | None,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> None:
self.use_language_hint = use_language_hint
if not system_prompt and not task_prompt:
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
self.task_prompt = ONE_SHOT_TASK_PROMPT
else:
self.system_prompt = system_prompt or ""
self.task_prompt = task_prompt or ""
self.task_prompt = self.task_prompt.rstrip()
if self.task_prompt and self.task_prompt[0] != "\n":
self.task_prompt = "\n" + self.task_prompt
@property
def is_json_output(self) -> bool:
# Even though the full LLM output isn't a valid json
# only the valid json portion is kept and passed along
# therefore it is treated as a json output
return True
def build_prompt(
self,
query: str,
context_chunks: list[InferenceChunk],
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> list[BaseMessage]:
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
) -> str:
context_docs_str = build_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks)
)
single_message = COT_PROMPT.format(
# Outdated
prompt = COT_PROMPT.format(
context_docs_str=context_docs_str,
user_query=query,
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
language_hint_or_none=LANGUAGE_HINT.strip()
if self.use_language_hint
else "",
).strip()
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
def process_llm_output(
@@ -170,68 +234,22 @@ class SingleMessageScratchpadHandler(QAHandler):
)
class PromptBasedQAHandler(QAHandler):
def __init__(self, system_prompt: str, task_prompt: str) -> None:
self.system_prompt = system_prompt
self.task_prompt = task_prompt
@property
def is_json_output(self) -> bool:
return False
def build_prompt(
self,
query: str,
context_chunks: list[InferenceChunk],
) -> list[BaseMessage]:
context_docs_str = build_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks)
)
if not context_chunks:
single_message = PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query=query,
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
else:
single_message = PARAMATERIZED_PROMPT.format(
context_docs_str=context_docs_str,
user_query=query,
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
def build_dummy_prompt(self, retrieval_disabled: bool) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def process_llm_output(
self, model_output: str, context_chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, DanswerQuotes]:
return DanswerAnswer(answer=model_output), DanswerQuotes(quotes=[])
def process_llm_token_stream(
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
for token in tokens:
yield DanswerAnswerPiece(answer_piece=token)
yield DanswerQuotes(quotes=[])
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
class QABlock(QAModel):
@@ -239,64 +257,30 @@ class QABlock(QAModel):
self._llm = llm
self._qa_handler = qa_handler
@property
def requires_api_key(self) -> bool:
return self._llm.requires_api_key
def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory
in case the chosen LLM is not accessed via API"""
if self._llm.requires_warm_up:
logger.info("Warming up LLM with a first inference")
self._llm.invoke("Ignore this!")
def answer_question(
def build_prompt(
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionReturn:
trimmed_context_docs = tokenizer_trim_chunks(context_docs)
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
model_out = self._llm.invoke(prompt)
if metrics_callback is not None:
prompt_tokens = sum(
[
check_number_of_tokens(
text=str(p.content), encode_fn=get_default_llm_token_encode()
)
for p in prompt
]
)
response_tokens = check_number_of_tokens(
text=model_out, encode_fn=get_default_llm_token_encode()
)
metrics_callback(
LLMMetricsContainer(
prompt_tokens=prompt_tokens, response_tokens=response_tokens
)
)
return self._qa_handler.process_llm_output(model_out, trimmed_context_docs)
history_str: str,
context_chunks: list[InferenceChunk],
) -> str:
prompt = self._qa_handler.build_prompt(
query=query, history_str=history_str, context_chunks=context_chunks
)
return prompt
def answer_question_stream(
self,
query: str,
context_docs: list[InferenceChunk],
prompt: str,
llm_context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionStreamReturn:
trimmed_context_docs = tokenizer_trim_chunks(context_docs)
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
tokens_stream = self._llm.stream(prompt)
captured_tokens = []
try:
for answer_piece in self._qa_handler.process_llm_token_stream(
iter(tokens_stream), trimmed_context_docs
iter(tokens_stream), llm_context_docs
):
if (
isinstance(answer_piece, DanswerAnswerPiece)
@@ -309,13 +293,8 @@ class QABlock(QAModel):
yield StreamingError(error=str(e))
if metrics_callback is not None:
prompt_tokens = sum(
[
check_number_of_tokens(
text=str(p.content), encode_fn=get_default_llm_token_encode()
)
for p in prompt
]
prompt_tokens = check_number_of_tokens(
text=str(prompt), encode_fn=get_default_llm_token_encode()
)
response_tokens = check_number_of_tokens(

View File

@@ -1,5 +1,6 @@
import math
import re
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from json.decoder import JSONDecodeError
@@ -13,7 +14,11 @@ from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import ThreadMessage
from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.prompts.constants import UNCERTAINTY_PAT
@@ -270,3 +275,41 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
"""Mock streaming by generating the passed in model output, character by character"""
for token in model_out:
yield token
def combine_message_thread(
messages: list[ThreadMessage],
token_limit: int | None = GEN_AI_HISTORY_CUTOFF,
llm_tokenizer: Callable | None = None,
) -> str:
"""Used to create a single combined message context from threads"""
message_strs: list[str] = []
total_token_count = 0
if llm_tokenizer is None:
llm_tokenizer = get_default_llm_token_encode()
for message in reversed(messages):
if message.role == MessageType.USER:
role_str = message.role.value.upper()
if message.sender:
role_str += " " + message.sender
else:
# Since other messages might have the user identifying information
# better to use Unknown for symmetry
role_str += " Unknown"
else:
role_str = message.role.value.upper()
msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer(msg_str))
if (
token_limit is not None
and total_token_count + message_token_count > token_limit
):
break
message_strs.insert(0, msg_str)
total_token_count += message_token_count
return "\n\n".join(message_strs)

View File

@@ -1,6 +1,7 @@
GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator
CODE_BLOCK_PAT = "```\n{}\n```"
QUESTION_PAT = "Query:"
FINAL_QUERY_PAT = "Final Query:"
THOUGHT_PAT = "Thought:"
ANSWER_PAT = "Answer:"
ANSWERABLE_PAT = "Answerable:"

View File

@@ -2,24 +2,36 @@
# It is used also for the one shot direct QA flow
import json
from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import FINAL_QUERY_PAT
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.prompts.constants import QUESTION_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.prompts.constants import THOUGHT_PAT
from danswer.prompts.constants import UNCERTAINTY_PAT
QA_HEADER = """
ONE_SHOT_SYSTEM_PROMPT = """
You are a question answering system that is constantly learning and improving.
You can process and comprehend vast amounts of text and utilize this knowledge to provide \
accurate and detailed answers to diverse queries.
""".strip()
ONE_SHOT_TASK_PROMPT = """
Answer the final query below taking into account the context above where relevant. \
Ignore any provided context that is not relevant to the query.
""".strip()
WEAK_MODEL_SYSTEM_PROMPT = """
Respond to the user query using the following reference document.
""".lstrip()
WEAK_MODEL_TASK_PROMPT = """
Answer the user query below based on the reference document above.
"""
REQUIRE_JSON = """
You ALWAYS responds with only a json containing an answer and quotes that support the answer.
Your responses are as INFORMATIVE and DETAILED as possible.
You ALWAYS responds with ONLY a JSON containing an answer and quotes that support the answer.
""".strip()
@@ -34,6 +46,21 @@ IMPORTANT: Respond in the same language as my query!
"""
CONTEXT_BLOCK = f"""
REFERENCE DOCUMENTS:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
"""
HISTORY_BLOCK = f"""
CONVERSATION HISTORY:
{GENERAL_SEP_PAT}
{{history_str}}
{GENERAL_SEP_PAT}
"""
# This has to be doubly escaped due to json containing { } which are also used for format strings
EMPTY_SAMPLE_JSON = {
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
@@ -48,44 +75,22 @@ ANSWER_NOT_FOUND_RESPONSE = f'{{"answer": "{UNCERTAINTY_PAT}", "quotes": []}}'
# Default json prompt which can reference multiple docs and provide answer + quotes
# system_like_header is similar to system message, can be user provided or defaults to QA_HEADER
# context/history blocks are for context documents and conversation history, they can be blank
# task prompt is the task message of the prompt, can be blank, there is no default
JSON_PROMPT = f"""
{QA_HEADER}
{{system_prompt}}
{REQUIRE_JSON}
{{context_block}}{{history_block}}{{task_prompt}}
CONTEXT:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
SAMPLE_RESPONSE:
SAMPLE RESPONSE:
```
{{{json.dumps(EMPTY_SAMPLE_JSON)}}}
```
{QUESTION_PAT.upper()} {{user_query}}
{JSON_HELPFUL_HINT}
{{language_hint_or_none}}
""".strip()
{FINAL_QUERY_PAT.upper()}
{{user_query}}
# Default chain-of-thought style json prompt which uses multiple docs
# This one has a section for the LLM to output some non-answer "thoughts"
# COT (chain-of-thought) flow basically
COT_PROMPT = f"""
{QA_HEADER}
CONTEXT:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
You MUST respond in the following format:
```
{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.
{{{json.dumps(EMPTY_SAMPLE_JSON)}}}
```
{QUESTION_PAT.upper()} {{user_query}}
{JSON_HELPFUL_HINT}
{{language_hint_or_none}}
""".strip()
@@ -94,35 +99,17 @@ You MUST respond in the following format:
# For weak LLM which only takes one chunk and cannot output json
# Also not requiring quotes as it tends to not work
WEAK_LLM_PROMPT = f"""
Respond to the user query using the following reference document.
Reference Document:
{GENERAL_SEP_PAT}
{{single_reference_doc}}
{GENERAL_SEP_PAT}
Answer the user query below based on the reference document above.
{{system_prompt}}
{{context_block}}
{{task_prompt}}
{QUESTION_PAT.upper()}
{{user_query}}
""".strip()
# For weak CHAT LLM which takes one chunk and cannot output json
# The next message should have the user query
# Note, no flow/config currently uses this one
WEAK_CHAT_LLM_PROMPT = f"""
You are a question answering assistant
Respond to the user query with an "{ANSWER_PAT}" section and \
as many "{QUOTE_PAT}" sections as needed to support the answer.
Answer the user query based on the following document:
{{first_chunk_content}}
""".strip()
# Parameterized prompt which allows the user to specify their
# own system / task prompt
# This is only for visualization for the users to specify their own prompts
# The actual flow does not work like this
PARAMATERIZED_PROMPT = f"""
{{system_prompt}}
@@ -147,6 +134,31 @@ RESPONSE:
""".strip()
# CURRENTLY DISABLED, CANNOT USE THIS ONE
# Default chain-of-thought style json prompt which uses multiple docs
# This one has a section for the LLM to output some non-answer "thoughts"
# COT (chain-of-thought) flow basically
COT_PROMPT = f"""
{ONE_SHOT_SYSTEM_PROMPT}
CONTEXT:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
You MUST respond in the following format:
```
{THOUGHT_PAT} Use this section as a scratchpad to reason through the answer.
{{{json.dumps(EMPTY_SAMPLE_JSON)}}}
```
{QUESTION_PAT.upper()} {{user_query}}
{JSON_HELPFUL_HINT}
{{language_hint_or_none}}
""".strip()
# User the following for easy viewing of prompts
if __name__ == "__main__":
print(JSON_PROMPT) # Default prompt used in the Danswer UI flow

View File

@@ -3,8 +3,8 @@ from sqlalchemy.orm import Session
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.search.access_filters import build_access_filters_for_user
@@ -31,7 +31,7 @@ def retrieval_preprocessing(
bypass_acl: bool = False,
include_query_intent: bool = True,
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
skip_rerank_non_realtime: bool = SKIP_RERANKING,
skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW,
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,

View File

@@ -59,6 +59,22 @@ def multilingual_query_expansion(
return query_rephrases
def get_contextual_rephrase_messages(
question: str,
history_str: str,
) -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": HISTORY_QUERY_REPHRASE.format(
question=question, chat_history=history_str
),
},
]
return messages
def history_based_query_rephrase(
query_message: ChatMessage,
history: list[ChatMessage],
@@ -66,21 +82,6 @@ def history_based_query_rephrase(
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
) -> str:
def _get_history_rephrase_messages(
question: str,
history_str: str,
) -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": HISTORY_QUERY_REPHRASE.format(
question=question, chat_history=history_str
),
},
]
return messages
user_query = cast(str, query_message.message)
if not user_query:
@@ -98,7 +99,38 @@ def history_based_query_rephrase(
history_str = combine_message_chain(history)
prompt_msgs = _get_history_rephrase_messages(
prompt_msgs = get_contextual_rephrase_messages(
question=user_query, history_str=history_str
)
if llm is None:
llm = get_default_llm()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
rephrased_query = llm.invoke(filled_llm_prompt)
logger.debug(f"Rephrased combined query: {rephrased_query}")
return rephrased_query
def thread_based_query_rephrase(
user_query: str,
history_str: str,
llm: LLM | None = None,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
) -> str:
if not history_str:
return user_query
if len(user_query) >= size_heuristic:
return user_query
if count_punctuation(user_query) >= punctuation_heuristic:
return user_query
prompt_msgs = get_contextual_rephrase_messages(
question=user_query, history_str=history_str
)

View File

@@ -15,7 +15,7 @@ from danswer.db.chat import upsert_persona
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.one_shot_answer.qa_block import PromptBasedQAHandler
from danswer.one_shot_answer.qa_block import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse
@@ -149,9 +149,11 @@ def build_final_template_prompt(
_: User | None = Depends(current_user),
) -> PromptTemplateResponse:
return PromptTemplateResponse(
final_prompt_template=PromptBasedQAHandler(
system_prompt=system_prompt, task_prompt=task_prompt
).build_dummy_prompt(retrieval_disabled=retrieval_disabled)
final_prompt_template=build_dummy_prompt(
system_prompt=system_prompt,
task_prompt=task_prompt,
retrieval_disabled=retrieval_disabled,
)
)

View File

@@ -163,9 +163,8 @@ def get_answer_with_quote(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
logger.info(
f"Received query for one shot answer with quotes: {query_request.query}"
)
query = query_request.messages[0].message
logger.info(f"Received query for one shot answer with quotes: {query}")
packets = stream_one_shot_answer(
query_req=query_request, user=user, db_session=db_session
)

View File

@@ -9,9 +9,11 @@ import yaml
from sqlalchemy.orm import Session
from danswer.chat.models import LLMMetricsContainer
from danswer.configs.constants import MessageType
from danswer.db.engine import get_sqlalchemy_engine
from danswer.one_shot_answer.answer_question import get_one_shot_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.models import IndexFilters
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RerankMetricsContainer
@@ -84,8 +86,10 @@ def get_answer_for_question(
access_control_list=None,
)
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
new_message_request = DirectQARequest(
query=query,
messages=messages,
prompt_id=0,
persona_id=0,
retrieval_options=RetrievalDetails(