allowed arbitrary types to handle the sqlalchemy datatype (#1758)

* allowed arbitrary types to handle the sqlalchemy datatype

* changed persona_upsert to take in ids instead of objects
This commit is contained in:
hagen-danswer 2024-07-03 00:10:57 -07:00 committed by GitHub
parent 7f1bb67e52
commit a7da07afc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 78 deletions

View File

@ -1,5 +1,3 @@
from typing import cast
import yaml
from sqlalchemy.orm import Session
@ -50,7 +48,7 @@ def load_personas_from_yaml(
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] | None = [
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
@ -58,22 +56,24 @@ def load_personas_from_yaml(
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
if not doc_sets:
doc_sets = None
prompt_set_names = persona["prompts"]
if not prompt_set_names:
prompts: list[PromptDBModel | None] | None = None
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
prompts = [
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if not prompts:
prompts = None
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
upsert_persona(
@ -91,8 +91,8 @@ def load_personas_from_yaml(
llm_model_provider_override=None,
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompts=cast(list[PromptDBModel] | None, prompts),
document_sets=doc_sets,
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
default_persona=True,
is_public=True,
db_session=db_session,

View File

@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet
from danswer.db.models import Persona
@ -62,19 +61,6 @@ def create_update_persona(
) -> PersonaSnapshot:
"""Higher level function than upsert_persona, although either is valid to use."""
# Permission to actually use these is checked later
document_sets = list(
get_document_sets_by_ids(
document_set_ids=create_persona_request.document_set_ids,
db_session=db_session,
)
)
prompts = list(
get_prompts_by_ids(
prompt_ids=create_persona_request.prompt_ids,
db_session=db_session,
)
)
try:
persona = upsert_persona(
persona_id=persona_id,
@ -85,9 +71,9 @@ def create_update_persona(
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
recency_bias=create_persona_request.recency_bias,
prompts=prompts,
prompt_ids=create_persona_request.prompt_ids,
tool_ids=create_persona_request.tool_ids,
document_sets=document_sets,
document_set_ids=create_persona_request.document_set_ids,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
@ -330,13 +316,13 @@ def upsert_persona(
llm_relevance_filter: bool,
llm_filter_extraction: bool,
recency_bias: RecencyBiasSetting,
prompts: list[Prompt] | None,
document_sets: list[DocumentSet] | None,
llm_model_provider_override: str | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
is_public: bool,
db_session: Session,
prompt_ids: list[int] | None = None,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
persona_id: int | None = None,
default_persona: bool = False,
@ -356,6 +342,24 @@ def upsert_persona(
if not tools and tool_ids:
raise ValueError("Tools not found")
# Fetch and attach document_sets by IDs
document_sets = None
if document_set_ids is not None:
document_sets = (
db_session.query(DocumentSet)
.filter(DocumentSet.id.in_(document_set_ids))
.all()
)
if not document_sets and document_set_ids:
raise ValueError("document_sets not found")
# Fetch and attach prompts by IDs
prompts = None
if prompt_ids is not None:
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
if not prompts and prompt_ids:
raise ValueError("prompts not found")
if persona:
if not default_persona and persona.default_persona:
raise ValueError("Cannot update default persona with non-default.")
@ -383,10 +387,10 @@ def upsert_persona(
if prompts is not None:
persona.prompts.clear()
persona.prompts = prompts
persona.prompts = prompts or []
if tools is not None:
persona.tools = tools
persona.tools = tools or []
else:
persona = Persona(

View File

@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
@ -42,12 +41,6 @@ def create_slack_bot_persona(
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> Persona:
"""NOTE: does not commit changes"""
document_sets = list(
get_document_sets_by_ids(
document_set_ids=document_set_ids,
db_session=db_session,
)
)
# create/update persona associated with the slack bot
persona_name = _build_persona_name(channel_names)
@ -61,8 +54,8 @@ def create_slack_bot_persona(
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
prompts=[default_prompt],
document_sets=document_sets,
prompt_ids=[default_prompt.id],
document_set_ids=document_set_ids,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,

View File

@ -7,6 +7,10 @@ from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.persona import get_personas
from danswer.db.persona import upsert_persona
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.settings.models import Settings
from danswer.server.settings.store import store_settings as store_base_settings
@ -28,7 +32,7 @@ class SeedConfiguration(BaseModel):
admin_user_emails: list[str] | None = None
seeded_name: str | None = None
seeded_logo_path: str | None = None
# personas: list[Persona] | None = None
personas: list[CreatePersonaRequest] | None = None
settings: Settings | None = None
@ -36,7 +40,6 @@ def _parse_env() -> SeedConfiguration | None:
seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME)
if not seed_config_str:
return None
seed_config = SeedConfiguration.parse_raw(seed_config_str)
return seed_config
@ -57,38 +60,37 @@ def _seed_llms(
update_default_provider(db_session, seeded_providers[0].id)
# def _seed_personas(db_session: Session, personas: list[Persona]) -> None:
# # don't seed personas if we've already done this
# existing_personas = get_personas(
# user_id=None, # Admin view
# db_session=db_session,
# include_default=True,
# include_slack_bot_personas=True,
# include_deleted=False,
# )
# if existing_personas:
# return
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
# don't seed personas if we've already done this
existing_personas = get_personas(
user_id=None, # Admin view
db_session=db_session,
include_default=True,
include_slack_bot_personas=True,
include_deleted=False,
)
if existing_personas:
return
# logger.info("Seeding Personas")
# for persona in personas:
# upsert_persona(
# user=None, # Seeding is done as admin
# name=persona.name,
# description=persona.description,
# num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0,
# llm_relevance_filter=persona.llm_relevance_filter,
# llm_filter_extraction=persona.llm_filter_extraction,
# recency_bias=persona.recency_bias,
# prompts=persona.prompts,
# document_sets=persona.document_sets,
# llm_model_provider_override=persona.llm_model_provider_override,
# llm_model_version_override=persona.llm_model_version_override,
# starter_messages=persona.starter_messages,
# is_public=persona.is_public,
# db_session=db_session,
# tool_ids=[tool.id for tool in persona.tools],
# default_persona=persona.default_persona,
# )
logger.info("Seeding Personas")
for persona in personas:
upsert_persona(
user=None, # Seeding is done as admin
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,
prompt_ids=persona.prompt_ids,
document_set_ids=persona.document_set_ids,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
is_public=persona.is_public,
db_session=db_session,
tool_ids=persona.tool_ids,
)
def _seed_settings(settings: Settings) -> None:
@ -115,8 +117,8 @@ def seed_db() -> None:
with get_session_context_manager() as db_session:
if seed_config.llms is not None:
_seed_llms(db_session, seed_config.llms)
# if seed_config.personas is not None:
# _seed_personas(db_session, seed_config.personas)
if seed_config.personas is not None:
_seed_personas(db_session, seed_config.personas)
if seed_config.settings is not None:
_seed_settings(seed_config.settings)