From 64ee5ffff555fd2d3991994a4867e4ffd6bc5fa9 Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 10 Jun 2024 11:28:05 -0700 Subject: [PATCH] Fix slack bot creation with document sets --- backend/danswer/db/persona.py | 31 +++++++++++++------ backend/danswer/db/slack_bot_config.py | 2 +- .../llm/answering/prompts/citations_prompt.py | 4 +-- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 88a72cc93..538b8441b 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -506,18 +506,29 @@ def get_prompt_by_id( return prompt +def _get_default_prompt(db_session: Session) -> Prompt: + stmt = select(Prompt).where(Prompt.id == 0) + result = db_session.execute(stmt) + prompt = result.scalar_one_or_none() + + if prompt is None: + raise RuntimeError("Default Prompt not found") + + return prompt + + +def get_default_prompt(db_session: Session) -> Prompt: + return _get_default_prompt(db_session) + + @lru_cache() -def get_default_prompt() -> Prompt: +def get_default_prompt__read_only() -> Prompt: + """Due to the way lru_cache / SQLAlchemy works, this can cause issues + when trying to attach the returned `Prompt` object to a `Persona`. If you are + doing anything other than reading, you should use the `get_default_prompt` + method instead.""" with Session(get_sqlalchemy_engine()) as db_session: - stmt = select(Prompt).where(Prompt.id == 0) - - result = db_session.execute(stmt) - prompt = result.scalar_one_or_none() - - if prompt is None: - raise RuntimeError("Default Prompt not found") - - return prompt + return _get_default_prompt(db_session) def get_persona_by_id( diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 43418f621..8683b3222 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -51,7 +51,7 @@ def create_slack_bot_persona( # create/update persona associated with the slack bot persona_name = _build_persona_name(channel_names) - default_prompt = get_default_prompt() + default_prompt = get_default_prompt(db_session) persona = upsert_persona( user=None, # Slack Bot Personas are not attached to users persona_id=existing_persona_id, diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index 0a6e2c75e..6c4ebfc40 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -5,7 +5,7 @@ from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.models import Persona -from danswer.db.persona import get_default_prompt +from danswer.db.persona import get_default_prompt__read_only from danswer.file_store.utils import InMemoryChatFile from danswer.llm.answering.models import PromptConfig from danswer.llm.factory import get_llm_for_persona @@ -96,7 +96,7 @@ def compute_max_document_tokens_for_persona( actual_user_input: str | None = None, max_llm_token_override: int | None = None, ) -> int: - prompt = persona.prompts[0] if persona.prompts else get_default_prompt() + prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only() return compute_max_document_tokens( prompt_config=PromptConfig.from_model(prompt), llm_config=get_llm_for_persona(persona).config,