Fix Slack Document Only Persona (#1150)

This commit is contained in:
Yuhong Sun 2024-02-29 13:53:37 -08:00 committed by GitHub
parent 10cb4ab1d2
commit 31d3ae0e3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 4 deletions

View File

@ -21,6 +21,7 @@ from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_default_prompt
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
@ -30,6 +31,7 @@ from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import tokenizer_trim_content
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER
@ -40,6 +42,7 @@ from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.prompt_utils import get_current_llm_day_time
from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
from danswer.prompts.token_counts import (
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
)
@ -129,8 +132,8 @@ def build_chat_system_message(
system_prompt += no_citation_line
if prompt.datetime_aware:
if system_prompt:
system_prompt += (
f"\n\nAdditional Information:\n\t- {get_current_llm_day_time()}."
system_prompt += ADDITIONAL_INFO.format(
datetime_info=get_current_llm_day_time()
)
else:
system_prompt = get_current_llm_day_time()
@ -566,6 +569,7 @@ def extract_citations_from_stream(
def get_prompt_tokens(prompt: Prompt) -> int:
# Note: currently custom prompts do not allow datetime aware, only default prompts
return (
check_number_of_tokens(prompt.system_prompt)
+ check_number_of_tokens(prompt.task_prompt)
@ -573,6 +577,7 @@ def get_prompt_tokens(prompt: Prompt) -> int:
+ CITATION_STATEMENT_TOKEN_CNT
+ CITATION_REMINDER_TOKEN_CNT
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0)
)
@ -611,7 +616,8 @@ def compute_max_document_tokens(
# TODO this may not always be the first prompt
prompt_tokens = get_prompt_tokens(persona.prompts[0])
else:
raise RuntimeError("Persona has no prompts - this should never happen")
prompt_tokens = get_prompt_tokens(get_default_prompt())
user_input_tokens = (
check_number_of_tokens(actual_user_input)
if actual_user_input is not None

View File

@ -1,4 +1,5 @@
from collections.abc import Sequence
from functools import lru_cache
from uuid import UUID
from sqlalchemy import delete
@ -14,6 +15,7 @@ from sqlalchemy.orm import Session
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import DocumentSet as DBDocumentSet
@ -303,6 +305,20 @@ def get_prompt_by_id(
return prompt
@lru_cache()
def get_default_prompt() -> Prompt:
with Session(get_sqlalchemy_engine()) as db_session:
stmt = select(Prompt).where(Prompt.id == 0)
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise RuntimeError("Default Prompt not found")
return prompt
def get_persona_by_id(
persona_id: int,
# if user_id is `None` assume the user is an admin or auth is disabled

View File

@ -14,6 +14,8 @@ CITATION_REMINDER = """
Remember to provide inline citations in the format [1], [2], [3], etc.
"""
ADDITIONAL_INFO = "\n\nAdditional Information:\n\t- {datetime_info}."
DEFAULT_IGNORE_STATEMENT = " Ignore any context documents that are not relevant."

View File

@ -1,10 +1,11 @@
from danswer.llm.utils import check_number_of_tokens
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.prompt_utils import get_current_llm_day_time
# tokens outside of the actual persona's "user_prompt" that make up the end
# user message
@ -22,3 +23,7 @@ CITATION_STATEMENT_TOKEN_CNT = check_number_of_tokens(REQUIRE_CITATION_STATEMENT
CITATION_REMINDER_TOKEN_CNT = check_number_of_tokens(CITATION_REMINDER)
LANGUAGE_HINT_TOKEN_CNT = check_number_of_tokens(LANGUAGE_HINT)
ADDITIONAL_INFO_TOKEN_CNT = check_number_of_tokens(
ADDITIONAL_INFO.format(datetime_info=get_current_llm_day_time())
)