From 44d8e34b5a8325fca0a045181818a81857c87495 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 7 Aug 2024 10:44:33 -0700 Subject: [PATCH] Improve seeding (includes all enterprise features) (#2065) --- backend/ee/danswer/server/seeding.py | 42 ++++++++++++++++++---------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index 20c57facbe57..069ce093e20d 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -15,7 +15,9 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.settings.models import Settings from danswer.server.settings.store import store_settings as store_base_settings from danswer.utils.logger import setup_logger +from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload from ee.danswer.server.enterprise_settings.models import EnterpriseSettings +from ee.danswer.server.enterprise_settings.store import store_analytics_script from ee.danswer.server.enterprise_settings.store import ( store_settings as store_ee_settings, ) @@ -30,10 +32,11 @@ _SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION" class SeedConfiguration(BaseModel): llms: list[LLMProviderUpsertRequest] | None = None admin_user_emails: list[str] | None = None - seeded_name: str | None = None seeded_logo_path: str | None = None personas: list[CreatePersonaRequest] | None = None settings: Settings | None = None + enterprise_settings: EnterpriseSettings | None = None + analytics_script: AnalyticsScriptUpload | None = None def _parse_env() -> SeedConfiguration | None: @@ -103,6 +106,27 @@ def _seed_settings(settings: Settings) -> None: logger.error(f"Failed to seed Settings: {str(e)}") +def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None: + if seed_config.enterprise_settings is not None: + logger.info("Seeding enterprise settings") + store_ee_settings(seed_config.enterprise_settings) + + +def _seed_logo(db_session: Session, logo_path: str | None) -> None: + if logo_path: + logger.info("Uploading logo") + upload_logo(db_session=db_session, file=logo_path) + + +def _seed_analytics_script(seed_config: SeedConfiguration) -> None: + if seed_config.analytics_script is not None: + logger.info("Seeding analytics script") + try: + store_analytics_script(seed_config.analytics_script) + except ValueError as e: + logger.error(f"Failed to seed analytics script: {str(e)}") + + def get_seed_config() -> SeedConfiguration | None: return _parse_env() @@ -122,16 +146,6 @@ def seed_db() -> None: if seed_config.settings is not None: _seed_settings(seed_config.settings) - is_seeded_logo = ( - upload_logo(db_session=db_session, file=seed_config.seeded_logo_path) - if seed_config.seeded_logo_path - else False - ) - seeded_name = seed_config.seeded_name - - if is_seeded_logo or seeded_name: - logger.info("Seeding enterprise settings") - seeded_settings = EnterpriseSettings( - application_name=seeded_name, use_custom_logo=is_seeded_logo - ) - store_ee_settings(seeded_settings) + _seed_logo(db_session, seed_config.seeded_logo_path) + _seed_enterprise_settings(seed_config) + _seed_analytics_script(seed_config)