Persona / prompt hardening (#3375)

* Persona / prompt hardening

* fix it
This commit is contained in:
Chris Weaver
2024-12-08 19:39:59 -08:00
committed by GitHub
parent 4a7bd5578e
commit 970320bd49
4 changed files with 74 additions and 46 deletions

View File

@ -453,9 +453,9 @@ def upsert_persona(
""" """
if persona_id is not None: if persona_id is not None:
persona = db_session.query(Persona).filter_by(id=persona_id).first() existing_persona = db_session.query(Persona).filter_by(id=persona_id).first()
else: else:
persona = _get_persona_by_name( existing_persona = _get_persona_by_name(
persona_name=name, user=user, db_session=db_session persona_name=name, user=user, db_session=db_session
) )
@ -481,62 +481,78 @@ def upsert_persona(
prompts = None prompts = None
if prompt_ids is not None: if prompt_ids is not None:
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all() prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
if not prompts and prompt_ids:
raise ValueError("prompts not found") if prompts is not None and len(prompts) == 0:
raise ValueError(
f"Invalid Persona config, no valid prompts "
f"specified. Specified IDs were: '{prompt_ids}'"
)
# ensure all specified tools are valid # ensure all specified tools are valid
if tools: if tools:
validate_persona_tools(tools) validate_persona_tools(tools)
if persona: if existing_persona:
# Built-in personas can only be updated through YAML configuration. # Built-in personas can only be updated through YAML configuration.
# This ensures that core system personas are not modified unintentionally. # This ensures that core system personas are not modified unintentionally.
if persona.builtin_persona and not builtin_persona: if existing_persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.") raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona # this checks if the user has permission to edit the persona
persona = fetch_persona_by_id( # will raise an Exception if the user does not have permission
db_session=db_session, persona_id=persona.id, user=user, get_editable=True existing_persona = fetch_persona_by_id(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
) )
# The following update excludes `default`, `built-in`, and display priority. # The following update excludes `default`, `built-in`, and display priority.
# Display priority is handled separately in the `display-priority` endpoint. # Display priority is handled separately in the `display-priority` endpoint.
# `default` and `built-in` properties can only be set when creating a persona. # `default` and `built-in` properties can only be set when creating a persona.
persona.name = name existing_persona.name = name
persona.description = description existing_persona.description = description
persona.num_chunks = num_chunks existing_persona.num_chunks = num_chunks
persona.chunks_above = chunks_above existing_persona.chunks_above = chunks_above
persona.chunks_below = chunks_below existing_persona.chunks_below = chunks_below
persona.llm_relevance_filter = llm_relevance_filter existing_persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction existing_persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias existing_persona.recency_bias = recency_bias
persona.llm_model_provider_override = llm_model_provider_override existing_persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override existing_persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages existing_persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted existing_persona.deleted = False # Un-delete if previously deleted
persona.is_public = is_public existing_persona.is_public = is_public
persona.icon_color = icon_color existing_persona.icon_color = icon_color
persona.icon_shape = icon_shape existing_persona.icon_shape = icon_shape
if remove_image or uploaded_image_id: if remove_image or uploaded_image_id:
persona.uploaded_image_id = uploaded_image_id existing_persona.uploaded_image_id = uploaded_image_id
persona.is_visible = is_visible existing_persona.is_visible = is_visible
persona.search_start_date = search_start_date existing_persona.search_start_date = search_start_date
persona.category_id = category_id existing_persona.category_id = category_id
# Do not delete any associations manually added unless # Do not delete any associations manually added unless
# a new updated list is provided # a new updated list is provided
if document_sets is not None: if document_sets is not None:
persona.document_sets.clear() existing_persona.document_sets.clear()
persona.document_sets = document_sets or [] existing_persona.document_sets = document_sets or []
if prompts is not None: if prompts is not None:
persona.prompts.clear() existing_persona.prompts.clear()
persona.prompts = prompts or [] existing_persona.prompts = prompts
if tools is not None: if tools is not None:
persona.tools = tools or [] existing_persona.tools = tools or []
persona = existing_persona
else: else:
persona = Persona( if not prompts:
raise ValueError(
"Invalid Persona config. "
"Must specify at least one prompt for a new persona."
)
new_persona = Persona(
id=persona_id, id=persona_id,
user_id=user.id if user else None, user_id=user.id if user else None,
is_public=is_public, is_public=is_public,
@ -549,7 +565,7 @@ def upsert_persona(
llm_filter_extraction=llm_filter_extraction, llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias, recency_bias=recency_bias,
builtin_persona=builtin_persona, builtin_persona=builtin_persona,
prompts=prompts or [], prompts=prompts,
document_sets=document_sets or [], document_sets=document_sets or [],
llm_model_provider_override=llm_model_provider_override, llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override, llm_model_version_override=llm_model_version_override,
@ -564,8 +580,8 @@ def upsert_persona(
is_default_persona=is_default_persona, is_default_persona=is_default_persona,
category_id=category_id, category_id=category_id,
) )
db_session.add(persona) db_session.add(new_persona)
persona = new_persona
if commit: if commit:
db_session.commit() db_session.commit()
else: else:

View File

@ -79,6 +79,9 @@ def load_personas_from_yaml(
if prompts: if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None] prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
if not prompt_ids:
raise ValueError("Invalid Persona config, no prompts exist")
p_id = persona.get("id") p_id = persona.get("id")
tool_ids = [] tool_ids = []
@ -123,12 +126,16 @@ def load_personas_from_yaml(
tool_ids=tool_ids, tool_ids=tool_ids,
builtin_persona=True, builtin_persona=True,
is_public=True, is_public=True,
display_priority=existing_persona.display_priority display_priority=(
existing_persona.display_priority
if existing_persona is not None if existing_persona is not None
else persona.get("display_priority"), else persona.get("display_priority")
is_visible=existing_persona.is_visible ),
is_visible=(
existing_persona.is_visible
if existing_persona is not None if existing_persona is not None
else persona.get("is_visible"), else persona.get("is_visible")
),
db_session=db_session, db_session=db_session,
) )

View File

@ -132,13 +132,18 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) ->
if personas: if personas:
logger.notice("Seeding Personas") logger.notice("Seeding Personas")
for persona in personas: for persona in personas:
if not persona.prompt_ids:
raise ValueError(
f"Invalid Persona with name {persona.name}; no prompts exist"
)
upsert_persona( upsert_persona(
user=None, # Seeding is done as admin user=None, # Seeding is done as admin
name=persona.name, name=persona.name,
description=persona.description, description=persona.description,
num_chunks=persona.num_chunks num_chunks=(
if persona.num_chunks is not None persona.num_chunks if persona.num_chunks is not None else 0.0
else 0.0, ),
llm_relevance_filter=persona.llm_relevance_filter, llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction, llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO, recency_bias=RecencyBiasSetting.AUTO,

View File

@ -42,7 +42,7 @@ class PersonaManager:
"is_public": is_public, "is_public": is_public,
"llm_filter_extraction": llm_filter_extraction, "llm_filter_extraction": llm_filter_extraction,
"recency_bias": recency_bias, "recency_bias": recency_bias,
"prompt_ids": prompt_ids or [], "prompt_ids": prompt_ids or [0],
"document_set_ids": document_set_ids or [], "document_set_ids": document_set_ids or [],
"tool_ids": tool_ids or [], "tool_ids": tool_ids or [],
"llm_model_provider_override": llm_model_provider_override, "llm_model_provider_override": llm_model_provider_override,