From 31d3ae0e3ec098feb02666663bc732d17ad67a35 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 29 Feb 2024 13:53:37 -0800 Subject: [PATCH] Fix Slack Document Only Persona (#1150) --- backend/danswer/chat/chat_utils.py | 12 +++++++++--- backend/danswer/db/chat.py | 16 ++++++++++++++++ backend/danswer/prompts/chat_prompts.py | 2 ++ backend/danswer/prompts/token_counts.py | 7 ++++++- 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index 26533d704..b839ea3bf 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -21,6 +21,7 @@ from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_chat_messages_by_session +from danswer.db.chat import get_default_prompt from danswer.db.models import ChatMessage from danswer.db.models import Persona from danswer.db.models import Prompt @@ -30,6 +31,7 @@ from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens from danswer.llm.utils import tokenizer_trim_content +from danswer.prompts.chat_prompts import ADDITIONAL_INFO from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from danswer.prompts.chat_prompts import CHAT_USER_PROMPT from danswer.prompts.chat_prompts import CITATION_REMINDER @@ -40,6 +42,7 @@ from danswer.prompts.constants import CODE_BLOCK_PAT from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.prompts.prompt_utils import get_current_llm_day_time +from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT from danswer.prompts.token_counts import ( CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, ) @@ -129,8 +132,8 @@ def build_chat_system_message( system_prompt += no_citation_line if prompt.datetime_aware: if system_prompt: - system_prompt += ( - f"\n\nAdditional Information:\n\t- {get_current_llm_day_time()}." + system_prompt += ADDITIONAL_INFO.format( + datetime_info=get_current_llm_day_time() ) else: system_prompt = get_current_llm_day_time() @@ -566,6 +569,7 @@ def extract_citations_from_stream( def get_prompt_tokens(prompt: Prompt) -> int: + # Note: currently custom prompts do not allow datetime aware, only default prompts return ( check_number_of_tokens(prompt.system_prompt) + check_number_of_tokens(prompt.task_prompt) @@ -573,6 +577,7 @@ def get_prompt_tokens(prompt: Prompt) -> int: + CITATION_STATEMENT_TOKEN_CNT + CITATION_REMINDER_TOKEN_CNT + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) + + (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0) ) @@ -611,7 +616,8 @@ def compute_max_document_tokens( # TODO this may not always be the first prompt prompt_tokens = get_prompt_tokens(persona.prompts[0]) else: - raise RuntimeError("Persona has no prompts - this should never happen") + prompt_tokens = get_prompt_tokens(get_default_prompt()) + user_input_tokens = ( check_number_of_tokens(actual_user_input) if actual_user_input is not None diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 50c895999..acb81f534 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from functools import lru_cache from uuid import UUID from sqlalchemy import delete @@ -14,6 +15,7 @@ from sqlalchemy.orm import Session from danswer.configs.chat_configs import HARD_DELETE_CHATS from danswer.configs.constants import MessageType from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX +from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import ChatMessage from danswer.db.models import ChatSession from danswer.db.models import DocumentSet as DBDocumentSet @@ -303,6 +305,20 @@ def get_prompt_by_id( return prompt +@lru_cache() +def get_default_prompt() -> Prompt: + with Session(get_sqlalchemy_engine()) as db_session: + stmt = select(Prompt).where(Prompt.id == 0) + + result = db_session.execute(stmt) + prompt = result.scalar_one_or_none() + + if prompt is None: + raise RuntimeError("Default Prompt not found") + + return prompt + + def get_persona_by_id( persona_id: int, # if user_id is `None` assume the user is an admin or auth is disabled diff --git a/backend/danswer/prompts/chat_prompts.py b/backend/danswer/prompts/chat_prompts.py index bdb938090..d83970a37 100644 --- a/backend/danswer/prompts/chat_prompts.py +++ b/backend/danswer/prompts/chat_prompts.py @@ -14,6 +14,8 @@ CITATION_REMINDER = """ Remember to provide inline citations in the format [1], [2], [3], etc. """ +ADDITIONAL_INFO = "\n\nAdditional Information:\n\t- {datetime_info}." + DEFAULT_IGNORE_STATEMENT = " Ignore any context documents that are not relevant." diff --git a/backend/danswer/prompts/token_counts.py b/backend/danswer/prompts/token_counts.py index 35d082b8d..1cf0f80e5 100644 --- a/backend/danswer/prompts/token_counts.py +++ b/backend/danswer/prompts/token_counts.py @@ -1,10 +1,11 @@ from danswer.llm.utils import check_number_of_tokens +from danswer.prompts.chat_prompts import ADDITIONAL_INFO from danswer.prompts.chat_prompts import CHAT_USER_PROMPT from danswer.prompts.chat_prompts import CITATION_REMINDER from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT - +from danswer.prompts.prompt_utils import get_current_llm_day_time # tokens outside of the actual persona's "user_prompt" that make up the end # user message @@ -22,3 +23,7 @@ CITATION_STATEMENT_TOKEN_CNT = check_number_of_tokens(REQUIRE_CITATION_STATEMENT CITATION_REMINDER_TOKEN_CNT = check_number_of_tokens(CITATION_REMINDER) LANGUAGE_HINT_TOKEN_CNT = check_number_of_tokens(LANGUAGE_HINT) + +ADDITIONAL_INFO_TOKEN_CNT = check_number_of_tokens( + ADDITIONAL_INFO.format(datetime_info=get_current_llm_day_time()) +)