mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-18 00:31:36 +02:00
182 lines
6.8 KiB
Python
182 lines
6.8 KiB
Python
import yaml
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.configs.chat_configs import INPUT_PROMPT_YAML
|
|
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
|
from onyx.configs.chat_configs import PERSONAS_YAML
|
|
from onyx.configs.chat_configs import PROMPTS_YAML
|
|
from onyx.context.search.enums import RecencyBiasSetting
|
|
from onyx.db.document_set import get_or_create_document_set_by_name
|
|
from onyx.db.input_prompt import insert_input_prompt_if_not_exists
|
|
from onyx.db.models import DocumentSet as DocumentSetDBModel
|
|
from onyx.db.models import Persona
|
|
from onyx.db.models import Prompt as PromptDBModel
|
|
from onyx.db.models import Tool as ToolDBModel
|
|
from onyx.db.persona import upsert_persona
|
|
from onyx.db.prompts import get_prompt_by_name
|
|
from onyx.db.prompts import upsert_prompt
|
|
|
|
|
|
def load_prompts_from_yaml(
|
|
db_session: Session, prompts_yaml: str = PROMPTS_YAML
|
|
) -> None:
|
|
with open(prompts_yaml, "r") as file:
|
|
data = yaml.safe_load(file)
|
|
|
|
all_prompts = data.get("prompts", [])
|
|
for prompt in all_prompts:
|
|
upsert_prompt(
|
|
db_session=db_session,
|
|
user=None,
|
|
prompt_id=prompt.get("id"),
|
|
name=prompt["name"],
|
|
description=prompt["description"].strip(),
|
|
system_prompt=prompt["system"].strip(),
|
|
task_prompt=prompt["task"].strip(),
|
|
include_citations=prompt["include_citations"],
|
|
datetime_aware=prompt.get("datetime_aware", True),
|
|
default_prompt=True,
|
|
personas=None,
|
|
)
|
|
db_session.commit()
|
|
|
|
|
|
def load_input_prompts_from_yaml(
|
|
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
|
|
) -> None:
|
|
with open(input_prompts_yaml, "r") as file:
|
|
data = yaml.safe_load(file)
|
|
|
|
all_input_prompts = data.get("input_prompts", [])
|
|
for input_prompt in all_input_prompts:
|
|
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
|
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
|
|
|
insert_input_prompt_if_not_exists(
|
|
user=None,
|
|
input_prompt_id=input_prompt.get("id"),
|
|
prompt=input_prompt["prompt"],
|
|
content=input_prompt["content"],
|
|
is_public=input_prompt["is_public"],
|
|
active=input_prompt.get("active", True),
|
|
db_session=db_session,
|
|
commit=True,
|
|
)
|
|
|
|
|
|
def load_personas_from_yaml(
|
|
db_session: Session,
|
|
personas_yaml: str = PERSONAS_YAML,
|
|
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
|
) -> None:
|
|
with open(personas_yaml, "r") as file:
|
|
data = yaml.safe_load(file)
|
|
|
|
all_personas = data.get("personas", [])
|
|
|
|
for persona in all_personas:
|
|
doc_set_names = persona["document_sets"]
|
|
doc_sets: list[DocumentSetDBModel] = [
|
|
get_or_create_document_set_by_name(db_session, name)
|
|
for name in doc_set_names
|
|
]
|
|
|
|
# Assume if user hasn't set any document sets for the persona, the user may want
|
|
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
|
# the document sets for the persona
|
|
doc_set_ids: list[int] | None = None
|
|
if doc_sets:
|
|
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
|
else:
|
|
doc_set_ids = None
|
|
|
|
prompt_ids: list[int] | None = None
|
|
prompt_set_names = persona["prompts"]
|
|
if prompt_set_names:
|
|
prompts: list[PromptDBModel | None] = [
|
|
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
|
for prompt_name in prompt_set_names
|
|
]
|
|
if any([prompt is None for prompt in prompts]):
|
|
raise ValueError("Invalid Persona configs, not all prompts exist")
|
|
|
|
if prompts:
|
|
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
|
|
|
if not prompt_ids:
|
|
raise ValueError("Invalid Persona config, no prompts exist")
|
|
|
|
p_id = persona.get("id")
|
|
tool_ids = []
|
|
|
|
if persona.get("image_generation"):
|
|
image_gen_tool = (
|
|
db_session.query(ToolDBModel)
|
|
.filter(ToolDBModel.name == "ImageGenerationTool")
|
|
.first()
|
|
)
|
|
if image_gen_tool:
|
|
tool_ids.append(image_gen_tool.id)
|
|
|
|
llm_model_provider_override = persona.get("llm_model_provider_override")
|
|
llm_model_version_override = persona.get("llm_model_version_override")
|
|
|
|
# Set specific overrides for image generation persona
|
|
if persona.get("image_generation"):
|
|
llm_model_version_override = "gpt-4o"
|
|
|
|
existing_persona = (
|
|
db_session.query(Persona).filter(Persona.name == persona["name"]).first()
|
|
)
|
|
|
|
upsert_persona(
|
|
user=None,
|
|
persona_id=(-1 * p_id) if p_id is not None else None,
|
|
name=persona["name"],
|
|
description=persona["description"],
|
|
num_chunks=persona.get("num_chunks")
|
|
if persona.get("num_chunks") is not None
|
|
else default_chunks,
|
|
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
|
starter_messages=persona.get("starter_messages", []),
|
|
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
|
icon_shape=persona.get("icon_shape"),
|
|
icon_color=persona.get("icon_color"),
|
|
llm_model_provider_override=llm_model_provider_override,
|
|
llm_model_version_override=llm_model_version_override,
|
|
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
|
prompt_ids=prompt_ids,
|
|
document_set_ids=doc_set_ids,
|
|
tool_ids=tool_ids,
|
|
builtin_persona=True,
|
|
is_public=True,
|
|
display_priority=(
|
|
existing_persona.display_priority
|
|
if existing_persona is not None
|
|
and persona.get("display_priority") is None
|
|
else persona.get("display_priority")
|
|
),
|
|
is_visible=(
|
|
existing_persona.is_visible
|
|
if existing_persona is not None
|
|
else persona.get("is_visible")
|
|
),
|
|
db_session=db_session,
|
|
is_default_persona=(
|
|
existing_persona.is_default_persona
|
|
if existing_persona is not None
|
|
else persona.get("is_default_persona", False)
|
|
),
|
|
)
|
|
|
|
|
|
def load_chat_yamls(
|
|
db_session: Session,
|
|
prompt_yaml: str = PROMPTS_YAML,
|
|
personas_yaml: str = PERSONAS_YAML,
|
|
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
|
) -> None:
|
|
load_prompts_from_yaml(db_session, prompt_yaml)
|
|
load_personas_from_yaml(db_session, personas_yaml)
|
|
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
|