improve model seeding (#2155)

This commit is contained in:
Chris Weaver 2024-08-16 18:30:13 -07:00 committed by GitHub
parent f8e0e6f015
commit efae24acd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,10 +4,8 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import update_default_provider from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider from danswer.db.llm import upsert_llm_provider
from danswer.db.persona import get_personas
from danswer.db.persona import upsert_persona from danswer.db.persona import upsert_persona
from danswer.search.enums import RecencyBiasSetting from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import CreatePersonaRequest
@ -50,50 +48,38 @@ def _parse_env() -> SeedConfiguration | None:
def _seed_llms( def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest] db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None: ) -> None:
# don't seed LLMs if we've already done this if llm_upsert_requests:
existing_llms = fetch_existing_llm_providers(db_session) logger.info("Seeding LLMs")
if existing_llms: seeded_providers = [
return upsert_llm_provider(db_session, llm_upsert_request)
for llm_upsert_request in llm_upsert_requests
logger.info("Seeding LLMs") ]
seeded_providers = [ update_default_provider(db_session, seeded_providers[0].id)
upsert_llm_provider(db_session, llm_upsert_request)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(db_session, seeded_providers[0].id)
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None: def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
# don't seed personas if we've already done this if personas:
existing_personas = get_personas( logger.info("Seeding Personas")
user_id=None, # Admin view for persona in personas:
db_session=db_session, upsert_persona(
include_default=True, user=None, # Seeding is done as admin
include_slack_bot_personas=True, name=persona.name,
include_deleted=False, description=persona.description,
) num_chunks=persona.num_chunks
if existing_personas: if persona.num_chunks is not None
return else 0.0,
llm_relevance_filter=persona.llm_relevance_filter,
logger.info("Seeding Personas") llm_filter_extraction=persona.llm_filter_extraction,
for persona in personas: recency_bias=RecencyBiasSetting.AUTO,
upsert_persona( prompt_ids=persona.prompt_ids,
user=None, # Seeding is done as admin document_set_ids=persona.document_set_ids,
name=persona.name, llm_model_provider_override=persona.llm_model_provider_override,
description=persona.description, llm_model_version_override=persona.llm_model_version_override,
num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0, starter_messages=persona.starter_messages,
llm_relevance_filter=persona.llm_relevance_filter, is_public=persona.is_public,
llm_filter_extraction=persona.llm_filter_extraction, db_session=db_session,
recency_bias=RecencyBiasSetting.AUTO, tool_ids=persona.tool_ids,
prompt_ids=persona.prompt_ids, )
document_set_ids=persona.document_set_ids,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
is_public=persona.is_public,
db_session=db_session,
tool_ids=persona.tool_ids,
)
def _seed_settings(settings: Settings) -> None: def _seed_settings(settings: Settings) -> None: