mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
allowed arbitrary types to handle the sqlalchemy datatype (#1758)
* allowed arbitrary types to handle the sqlalchemy datatype * changed persona_upsert to take in ids instead of objects
This commit is contained in:
parent
7f1bb67e52
commit
a7da07afc0
@ -1,5 +1,3 @@
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -50,7 +48,7 @@ def load_personas_from_yaml(
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] | None = [
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
@ -58,22 +56,24 @@ def load_personas_from_yaml(
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
if not doc_sets:
|
||||
doc_sets = None
|
||||
|
||||
prompt_set_names = persona["prompts"]
|
||||
if not prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] | None = None
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
prompts = [
|
||||
doc_set_ids = None
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if not prompts:
|
||||
prompts = None
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
p_id = persona.get("id")
|
||||
upsert_persona(
|
||||
@ -91,8 +91,8 @@ def load_personas_from_yaml(
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompts=cast(list[PromptDBModel] | None, prompts),
|
||||
document_sets=doc_sets,
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
|
@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import Persona
|
||||
@ -62,19 +61,6 @@ def create_update_persona(
|
||||
) -> PersonaSnapshot:
|
||||
"""Higher level function than upsert_persona, although either is valid to use."""
|
||||
# Permission to actually use these is checked later
|
||||
document_sets = list(
|
||||
get_document_sets_by_ids(
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
prompts = list(
|
||||
get_prompts_by_ids(
|
||||
prompt_ids=create_persona_request.prompt_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
@ -85,9 +71,9 @@ def create_update_persona(
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
prompts=prompts,
|
||||
prompt_ids=create_persona_request.prompt_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
document_sets=document_sets,
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
@ -330,13 +316,13 @@ def upsert_persona(
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
prompts: list[Prompt] | None,
|
||||
document_sets: list[DocumentSet] | None,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
prompt_ids: list[int] | None = None,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
@ -356,6 +342,24 @@ def upsert_persona(
|
||||
if not tools and tool_ids:
|
||||
raise ValueError("Tools not found")
|
||||
|
||||
# Fetch and attach document_sets by IDs
|
||||
document_sets = None
|
||||
if document_set_ids is not None:
|
||||
document_sets = (
|
||||
db_session.query(DocumentSet)
|
||||
.filter(DocumentSet.id.in_(document_set_ids))
|
||||
.all()
|
||||
)
|
||||
if not document_sets and document_set_ids:
|
||||
raise ValueError("document_sets not found")
|
||||
|
||||
# Fetch and attach prompts by IDs
|
||||
prompts = None
|
||||
if prompt_ids is not None:
|
||||
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
|
||||
if not prompts and prompt_ids:
|
||||
raise ValueError("prompts not found")
|
||||
|
||||
if persona:
|
||||
if not default_persona and persona.default_persona:
|
||||
raise ValueError("Cannot update default persona with non-default.")
|
||||
@ -383,10 +387,10 @@ def upsert_persona(
|
||||
|
||||
if prompts is not None:
|
||||
persona.prompts.clear()
|
||||
persona.prompts = prompts
|
||||
persona.prompts = prompts or []
|
||||
|
||||
if tools is not None:
|
||||
persona.tools = tools
|
||||
persona.tools = tools or []
|
||||
|
||||
else:
|
||||
persona = Persona(
|
||||
|
@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
@ -42,12 +41,6 @@ def create_slack_bot_persona(
|
||||
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
document_sets = list(
|
||||
get_document_sets_by_ids(
|
||||
document_set_ids=document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
# create/update persona associated with the slack bot
|
||||
persona_name = _build_persona_name(channel_names)
|
||||
@ -61,8 +54,8 @@ def create_slack_bot_persona(
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompts=[default_prompt],
|
||||
document_sets=document_sets,
|
||||
prompt_ids=[default_prompt.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
@ -7,6 +7,10 @@ from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.persona import get_personas
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.settings.models import Settings
|
||||
from danswer.server.settings.store import store_settings as store_base_settings
|
||||
@ -28,7 +32,7 @@ class SeedConfiguration(BaseModel):
|
||||
admin_user_emails: list[str] | None = None
|
||||
seeded_name: str | None = None
|
||||
seeded_logo_path: str | None = None
|
||||
# personas: list[Persona] | None = None
|
||||
personas: list[CreatePersonaRequest] | None = None
|
||||
settings: Settings | None = None
|
||||
|
||||
|
||||
@ -36,7 +40,6 @@ def _parse_env() -> SeedConfiguration | None:
|
||||
seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME)
|
||||
if not seed_config_str:
|
||||
return None
|
||||
|
||||
seed_config = SeedConfiguration.parse_raw(seed_config_str)
|
||||
return seed_config
|
||||
|
||||
@ -57,38 +60,37 @@ def _seed_llms(
|
||||
update_default_provider(db_session, seeded_providers[0].id)
|
||||
|
||||
|
||||
# def _seed_personas(db_session: Session, personas: list[Persona]) -> None:
|
||||
# # don't seed personas if we've already done this
|
||||
# existing_personas = get_personas(
|
||||
# user_id=None, # Admin view
|
||||
# db_session=db_session,
|
||||
# include_default=True,
|
||||
# include_slack_bot_personas=True,
|
||||
# include_deleted=False,
|
||||
# )
|
||||
# if existing_personas:
|
||||
# return
|
||||
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
|
||||
# don't seed personas if we've already done this
|
||||
existing_personas = get_personas(
|
||||
user_id=None, # Admin view
|
||||
db_session=db_session,
|
||||
include_default=True,
|
||||
include_slack_bot_personas=True,
|
||||
include_deleted=False,
|
||||
)
|
||||
if existing_personas:
|
||||
return
|
||||
|
||||
# logger.info("Seeding Personas")
|
||||
# for persona in personas:
|
||||
# upsert_persona(
|
||||
# user=None, # Seeding is done as admin
|
||||
# name=persona.name,
|
||||
# description=persona.description,
|
||||
# num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0,
|
||||
# llm_relevance_filter=persona.llm_relevance_filter,
|
||||
# llm_filter_extraction=persona.llm_filter_extraction,
|
||||
# recency_bias=persona.recency_bias,
|
||||
# prompts=persona.prompts,
|
||||
# document_sets=persona.document_sets,
|
||||
# llm_model_provider_override=persona.llm_model_provider_override,
|
||||
# llm_model_version_override=persona.llm_model_version_override,
|
||||
# starter_messages=persona.starter_messages,
|
||||
# is_public=persona.is_public,
|
||||
# db_session=db_session,
|
||||
# tool_ids=[tool.id for tool in persona.tools],
|
||||
# default_persona=persona.default_persona,
|
||||
# )
|
||||
logger.info("Seeding Personas")
|
||||
for persona in personas:
|
||||
upsert_persona(
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompt_ids=persona.prompt_ids,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
starter_messages=persona.starter_messages,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
tool_ids=persona.tool_ids,
|
||||
)
|
||||
|
||||
|
||||
def _seed_settings(settings: Settings) -> None:
|
||||
@ -115,8 +117,8 @@ def seed_db() -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
if seed_config.llms is not None:
|
||||
_seed_llms(db_session, seed_config.llms)
|
||||
# if seed_config.personas is not None:
|
||||
# _seed_personas(db_session, seed_config.personas)
|
||||
if seed_config.personas is not None:
|
||||
_seed_personas(db_session, seed_config.personas)
|
||||
if seed_config.settings is not None:
|
||||
_seed_settings(seed_config.settings)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user