From efae24acd0b4559366ba6cd5f4ca643bb894ca18 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 16 Aug 2024 18:30:13 -0700 Subject: [PATCH] improve model seeding (#2155) --- backend/ee/danswer/server/seeding.py | 72 +++++++++++----------------- 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index 933a7ba82..9dc29e2ca 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -4,10 +4,8 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.db.engine import get_session_context_manager -from danswer.db.llm import fetch_existing_llm_providers from danswer.db.llm import update_default_provider from danswer.db.llm import upsert_llm_provider -from danswer.db.persona import get_personas from danswer.db.persona import upsert_persona from danswer.search.enums import RecencyBiasSetting from danswer.server.features.persona.models import CreatePersonaRequest @@ -50,50 +48,38 @@ def _parse_env() -> SeedConfiguration | None: def _seed_llms( db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest] ) -> None: - # don't seed LLMs if we've already done this - existing_llms = fetch_existing_llm_providers(db_session) - if existing_llms: - return - - logger.info("Seeding LLMs") - seeded_providers = [ - upsert_llm_provider(db_session, llm_upsert_request) - for llm_upsert_request in llm_upsert_requests - ] - update_default_provider(db_session, seeded_providers[0].id) + if llm_upsert_requests: + logger.info("Seeding LLMs") + seeded_providers = [ + upsert_llm_provider(db_session, llm_upsert_request) + for llm_upsert_request in llm_upsert_requests + ] + update_default_provider(db_session, seeded_providers[0].id) def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None: - # don't seed personas if we've already done this - existing_personas = get_personas( - user_id=None, # Admin view - db_session=db_session, - include_default=True, - include_slack_bot_personas=True, - include_deleted=False, - ) - if existing_personas: - return - - logger.info("Seeding Personas") - for persona in personas: - upsert_persona( - user=None, # Seeding is done as admin - name=persona.name, - description=persona.description, - num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0, - llm_relevance_filter=persona.llm_relevance_filter, - llm_filter_extraction=persona.llm_filter_extraction, - recency_bias=RecencyBiasSetting.AUTO, - prompt_ids=persona.prompt_ids, - document_set_ids=persona.document_set_ids, - llm_model_provider_override=persona.llm_model_provider_override, - llm_model_version_override=persona.llm_model_version_override, - starter_messages=persona.starter_messages, - is_public=persona.is_public, - db_session=db_session, - tool_ids=persona.tool_ids, - ) + if personas: + logger.info("Seeding Personas") + for persona in personas: + upsert_persona( + user=None, # Seeding is done as admin + name=persona.name, + description=persona.description, + num_chunks=persona.num_chunks + if persona.num_chunks is not None + else 0.0, + llm_relevance_filter=persona.llm_relevance_filter, + llm_filter_extraction=persona.llm_filter_extraction, + recency_bias=RecencyBiasSetting.AUTO, + prompt_ids=persona.prompt_ids, + document_set_ids=persona.document_set_ids, + llm_model_provider_override=persona.llm_model_provider_override, + llm_model_version_override=persona.llm_model_version_override, + starter_messages=persona.starter_messages, + is_public=persona.is_public, + db_session=db_session, + tool_ids=persona.tool_ids, + ) def _seed_settings(settings: Settings) -> None: