From 581cb827bb4ca5809c3286bf6f9419816cbbb03c Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 1 Jul 2024 15:22:17 -0700 Subject: [PATCH] added settings and persona seeding options (#1742) * added settings and persona seeding options * updated recency_bias * changed variable type * another fix * Update seeding.py * fixed mypy * push --- .vscode/env_template.txt | 5 --- backend/ee/danswer/server/seeding.py | 62 +++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/.vscode/env_template.txt b/.vscode/env_template.txt index bff9b79e5..f6d8c4345 100644 --- a/.vscode/env_template.txt +++ b/.vscode/env_template.txt @@ -25,11 +25,6 @@ OAUTH_CLIENT_SECRET= REQUIRE_EMAIL_VERIFICATION=False -# Toggles on/off the EE Features -# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE -ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False - - # Set these so if you wipe the DB, you don't end up having to go through the UI every time GEN_AI_API_KEY= # If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index 9245c20ca..7ba5af3f7 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -7,10 +7,17 @@ from danswer.db.engine import get_session_context_manager from danswer.db.llm import fetch_existing_llm_providers from danswer.db.llm import update_default_provider from danswer.db.llm import upsert_llm_provider +from danswer.db.models import Persona +from danswer.db.persona import get_personas +from danswer.db.persona import upsert_persona from danswer.server.manage.llm.models import LLMProviderUpsertRequest +from danswer.server.settings.models import Settings +from danswer.server.settings.store import store_settings as store_base_settings from danswer.utils.logger import setup_logger from ee.danswer.server.enterprise_settings.models import EnterpriseSettings -from ee.danswer.server.enterprise_settings.store import store_settings +from ee.danswer.server.enterprise_settings.store import ( + store_settings as store_ee_settings, +) from ee.danswer.server.enterprise_settings.store import upload_logo @@ -24,12 +31,15 @@ class SeedConfiguration(BaseModel): admin_user_emails: list[str] | None = None seeded_name: str | None = None seeded_logo_path: str | None = None + personas: list[Persona] | None = None + settings: Settings | None = None def _parse_env() -> SeedConfiguration | None: seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME) if not seed_config_str: return None + seed_config = SeedConfiguration.parse_raw(seed_config_str) return seed_config @@ -50,6 +60,50 @@ def _seed_llms( update_default_provider(db_session, seeded_providers[0].id) +def _seed_personas(db_session: Session, personas: list[Persona]) -> None: + # don't seed personas if we've already done this + existing_personas = get_personas( + user_id=None, # Admin view + db_session=db_session, + include_default=True, + include_slack_bot_personas=True, + include_deleted=False, + ) + if existing_personas: + return + + logger.info("Seeding Personas") + for persona in personas: + upsert_persona( + user=None, # Seeding is done as admin + name=persona.name, + description=persona.description, + num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0, + llm_relevance_filter=persona.llm_relevance_filter, + llm_filter_extraction=persona.llm_filter_extraction, + recency_bias=persona.recency_bias, + prompts=persona.prompts, + document_sets=persona.document_sets, + llm_model_provider_override=persona.llm_model_provider_override, + llm_model_version_override=persona.llm_model_version_override, + starter_messages=persona.starter_messages, + is_public=persona.is_public, + db_session=db_session, + tool_ids=[tool.id for tool in persona.tools], + default_persona=persona.default_persona, + ) + + +def _seed_settings(settings: Settings) -> None: + logger.info("Seeding Settings") + try: + settings.check_validity() + store_base_settings(settings) + logger.info("Successfully seeded Settings") + except ValueError as e: + logger.error(f"Failed to seed Settings: {str(e)}") + + def get_seed_config() -> SeedConfiguration | None: return _parse_env() @@ -64,6 +118,10 @@ def seed_db() -> None: with get_session_context_manager() as db_session: if seed_config.llms is not None: _seed_llms(db_session, seed_config.llms) + if seed_config.personas is not None: + _seed_personas(db_session, seed_config.personas) + if seed_config.settings is not None: + _seed_settings(seed_config.settings) is_seeded_logo = ( upload_logo(db_session=db_session, file=seed_config.seeded_logo_path) @@ -77,4 +135,4 @@ def seed_db() -> None: seeded_settings = EnterpriseSettings( application_name=seeded_name, use_custom_logo=is_seeded_logo ) - store_settings(seeded_settings) + store_ee_settings(seeded_settings)