diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index 4036730a9..342802e6c 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -1,5 +1,3 @@ -from typing import cast - import yaml from sqlalchemy.orm import Session @@ -50,7 +48,7 @@ def load_personas_from_yaml( with Session(get_sqlalchemy_engine()) as db_session: for persona in all_personas: doc_set_names = persona["document_sets"] - doc_sets: list[DocumentSetDBModel] | None = [ + doc_sets: list[DocumentSetDBModel] = [ get_or_create_document_set_by_name(db_session, name) for name in doc_set_names ] @@ -58,22 +56,24 @@ def load_personas_from_yaml( # Assume if user hasn't set any document sets for the persona, the user may want # to later attach document sets to the persona manually, therefore, don't overwrite/reset # the document sets for the persona - if not doc_sets: - doc_sets = None - - prompt_set_names = persona["prompts"] - if not prompt_set_names: - prompts: list[PromptDBModel | None] | None = None + doc_set_ids: list[int] | None = None + if doc_sets: + doc_set_ids = [doc_set.id for doc_set in doc_sets] else: - prompts = [ + doc_set_ids = None + + prompt_ids: list[int] | None = None + prompt_set_names = persona["prompts"] + if prompt_set_names: + prompts: list[PromptDBModel | None] = [ get_prompt_by_name(prompt_name, user=None, db_session=db_session) for prompt_name in prompt_set_names ] if any([prompt is None for prompt in prompts]): raise ValueError("Invalid Persona configs, not all prompts exist") - if not prompts: - prompts = None + if prompts: + prompt_ids = [prompt.id for prompt in prompts if prompt is not None] p_id = persona.get("id") upsert_persona( @@ -91,8 +91,8 @@ def load_personas_from_yaml( llm_model_provider_override=None, llm_model_version_override=None, recency_bias=RecencyBiasSetting(persona["recency_bias"]), - prompts=cast(list[PromptDBModel] | None, prompts), - document_sets=doc_sets, + prompt_ids=prompt_ids, + document_set_ids=doc_set_ids, default_persona=True, is_public=True, db_session=db_session, diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index d8b3f5552..4726cf426 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -13,7 +13,6 @@ from sqlalchemy.orm import Session from danswer.auth.schemas import UserRole from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX -from danswer.db.document_set import get_document_sets_by_ids from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import DocumentSet from danswer.db.models import Persona @@ -62,19 +61,6 @@ def create_update_persona( ) -> PersonaSnapshot: """Higher level function than upsert_persona, although either is valid to use.""" # Permission to actually use these is checked later - document_sets = list( - get_document_sets_by_ids( - document_set_ids=create_persona_request.document_set_ids, - db_session=db_session, - ) - ) - prompts = list( - get_prompts_by_ids( - prompt_ids=create_persona_request.prompt_ids, - db_session=db_session, - ) - ) - try: persona = upsert_persona( persona_id=persona_id, @@ -85,9 +71,9 @@ def create_update_persona( llm_relevance_filter=create_persona_request.llm_relevance_filter, llm_filter_extraction=create_persona_request.llm_filter_extraction, recency_bias=create_persona_request.recency_bias, - prompts=prompts, + prompt_ids=create_persona_request.prompt_ids, tool_ids=create_persona_request.tool_ids, - document_sets=document_sets, + document_set_ids=create_persona_request.document_set_ids, llm_model_provider_override=create_persona_request.llm_model_provider_override, llm_model_version_override=create_persona_request.llm_model_version_override, starter_messages=create_persona_request.starter_messages, @@ -330,13 +316,13 @@ def upsert_persona( llm_relevance_filter: bool, llm_filter_extraction: bool, recency_bias: RecencyBiasSetting, - prompts: list[Prompt] | None, - document_sets: list[DocumentSet] | None, llm_model_provider_override: str | None, llm_model_version_override: str | None, starter_messages: list[StarterMessage] | None, is_public: bool, db_session: Session, + prompt_ids: list[int] | None = None, + document_set_ids: list[int] | None = None, tool_ids: list[int] | None = None, persona_id: int | None = None, default_persona: bool = False, @@ -356,6 +342,24 @@ def upsert_persona( if not tools and tool_ids: raise ValueError("Tools not found") + # Fetch and attach document_sets by IDs + document_sets = None + if document_set_ids is not None: + document_sets = ( + db_session.query(DocumentSet) + .filter(DocumentSet.id.in_(document_set_ids)) + .all() + ) + if not document_sets and document_set_ids: + raise ValueError("document_sets not found") + + # Fetch and attach prompts by IDs + prompts = None + if prompt_ids is not None: + prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all() + if not prompts and prompt_ids: + raise ValueError("prompts not found") + if persona: if not default_persona and persona.default_persona: raise ValueError("Cannot update default persona with non-default.") @@ -383,10 +387,10 @@ def upsert_persona( if prompts is not None: persona.prompts.clear() - persona.prompts = prompts + persona.prompts = prompts or [] if tools is not None: - persona.tools = tools + persona.tools = tools or [] else: persona = Persona( diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 8683b3222..c0a98bc93 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -5,7 +5,6 @@ from sqlalchemy.orm import Session from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX -from danswer.db.document_set import get_document_sets_by_ids from danswer.db.models import ChannelConfig from danswer.db.models import Persona from danswer.db.models import Persona__DocumentSet @@ -42,12 +41,6 @@ def create_slack_bot_persona( num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, ) -> Persona: """NOTE: does not commit changes""" - document_sets = list( - get_document_sets_by_ids( - document_set_ids=document_set_ids, - db_session=db_session, - ) - ) # create/update persona associated with the slack bot persona_name = _build_persona_name(channel_names) @@ -61,8 +54,8 @@ def create_slack_bot_persona( llm_relevance_filter=True, llm_filter_extraction=True, recency_bias=RecencyBiasSetting.AUTO, - prompts=[default_prompt], - document_sets=document_sets, + prompt_ids=[default_prompt.id], + document_set_ids=document_set_ids, llm_model_provider_override=None, llm_model_version_override=None, starter_messages=None, diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index 3c79b0600..20c57facb 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -7,6 +7,10 @@ from danswer.db.engine import get_session_context_manager from danswer.db.llm import fetch_existing_llm_providers from danswer.db.llm import update_default_provider from danswer.db.llm import upsert_llm_provider +from danswer.db.persona import get_personas +from danswer.db.persona import upsert_persona +from danswer.search.enums import RecencyBiasSetting +from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.settings.models import Settings from danswer.server.settings.store import store_settings as store_base_settings @@ -28,7 +32,7 @@ class SeedConfiguration(BaseModel): admin_user_emails: list[str] | None = None seeded_name: str | None = None seeded_logo_path: str | None = None - # personas: list[Persona] | None = None + personas: list[CreatePersonaRequest] | None = None settings: Settings | None = None @@ -36,7 +40,6 @@ def _parse_env() -> SeedConfiguration | None: seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME) if not seed_config_str: return None - seed_config = SeedConfiguration.parse_raw(seed_config_str) return seed_config @@ -57,38 +60,37 @@ def _seed_llms( update_default_provider(db_session, seeded_providers[0].id) -# def _seed_personas(db_session: Session, personas: list[Persona]) -> None: -# # don't seed personas if we've already done this -# existing_personas = get_personas( -# user_id=None, # Admin view -# db_session=db_session, -# include_default=True, -# include_slack_bot_personas=True, -# include_deleted=False, -# ) -# if existing_personas: -# return +def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None: + # don't seed personas if we've already done this + existing_personas = get_personas( + user_id=None, # Admin view + db_session=db_session, + include_default=True, + include_slack_bot_personas=True, + include_deleted=False, + ) + if existing_personas: + return -# logger.info("Seeding Personas") -# for persona in personas: -# upsert_persona( -# user=None, # Seeding is done as admin -# name=persona.name, -# description=persona.description, -# num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0, -# llm_relevance_filter=persona.llm_relevance_filter, -# llm_filter_extraction=persona.llm_filter_extraction, -# recency_bias=persona.recency_bias, -# prompts=persona.prompts, -# document_sets=persona.document_sets, -# llm_model_provider_override=persona.llm_model_provider_override, -# llm_model_version_override=persona.llm_model_version_override, -# starter_messages=persona.starter_messages, -# is_public=persona.is_public, -# db_session=db_session, -# tool_ids=[tool.id for tool in persona.tools], -# default_persona=persona.default_persona, -# ) + logger.info("Seeding Personas") + for persona in personas: + upsert_persona( + user=None, # Seeding is done as admin + name=persona.name, + description=persona.description, + num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0, + llm_relevance_filter=persona.llm_relevance_filter, + llm_filter_extraction=persona.llm_filter_extraction, + recency_bias=RecencyBiasSetting.AUTO, + prompt_ids=persona.prompt_ids, + document_set_ids=persona.document_set_ids, + llm_model_provider_override=persona.llm_model_provider_override, + llm_model_version_override=persona.llm_model_version_override, + starter_messages=persona.starter_messages, + is_public=persona.is_public, + db_session=db_session, + tool_ids=persona.tool_ids, + ) def _seed_settings(settings: Settings) -> None: @@ -115,8 +117,8 @@ def seed_db() -> None: with get_session_context_manager() as db_session: if seed_config.llms is not None: _seed_llms(db_session, seed_config.llms) - # if seed_config.personas is not None: - # _seed_personas(db_session, seed_config.personas) + if seed_config.personas is not None: + _seed_personas(db_session, seed_config.personas) if seed_config.settings is not None: _seed_settings(seed_config.settings)