Combined Persona and Prompt API (#3690)

* Combined Persona and Prompt API

* quality

* added tests

* consolidated models and got rid of redundant fields

* tenant appreciation day

* reverted default
This commit is contained in:
hagen-danswer
2025-01-17 12:21:20 -08:00
committed by GitHub
parent 880c42ad41
commit 1ad2128b2a
22 changed files with 626 additions and 761 deletions

View File

@@ -24,7 +24,7 @@ from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
from onyx.db.persona import upsert_persona
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.settings.models import Settings
from onyx.server.settings.store import store_settings as store_base_settings
@@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel):
llms: list[LLMProviderUpsertRequest] | None = None
admin_user_emails: list[str] | None = None
seeded_logo_path: str | None = None
personas: list[CreatePersonaRequest] | None = None
personas: list[PersonaUpsertRequest] | None = None
settings: Settings | None = None
enterprise_settings: EnterpriseSettings | None = None
@@ -128,7 +128,7 @@ def _seed_llms(
)
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
if personas:
logger.notice("Seeding Personas")
for persona in personas:

View File

@@ -25,7 +25,7 @@ from onyx.db.models import Persona
from onyx.db.models import Prompt
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.persona import get_prompts_by_ids
from onyx.db.prompts import get_prompts_by_ids
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest

View File

@@ -1,12 +1,13 @@
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from onyx.context.search.models import InferenceChunk
from onyx.db.models import Persona
from onyx.db.persona import get_default_prompt__read_only
from onyx.db.prompts import get_default_prompt
from onyx.db.search_settings import get_multilingual_expansion
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
@@ -97,11 +98,12 @@ def compute_max_document_tokens(
def compute_max_document_tokens_for_persona(
db_session: Session,
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session)
return compute_max_document_tokens(
prompt_config=PromptConfig.from_model(prompt),
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,

View File

@@ -7,26 +7,6 @@ from onyx.db.models import ChatMessage
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def translate_onyx_msg_to_langchain(

View File

@@ -1,6 +1,5 @@
from collections.abc import Sequence
from datetime import datetime
from functools import lru_cache
from uuid import UUID
from fastapi import HTTPException
@@ -8,7 +7,6 @@ from sqlalchemy import delete
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
@@ -23,7 +21,6 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import DocumentSet
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -35,8 +32,8 @@ from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -107,9 +104,6 @@ def _add_user_filters(
return stmt.where(where_clause)
# fetch_persona_by_id is used to fetch a persona by its ID. It is used to fetch a persona by its ID.
def fetch_persona_by_id_for_user(
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
) -> Persona:
@@ -184,7 +178,7 @@ def make_persona_private(
def create_update_persona(
persona_id: int | None,
create_persona_request: CreatePersonaRequest,
create_persona_request: PersonaUpsertRequest,
user: User | None,
db_session: Session,
) -> PersonaSnapshot:
@@ -192,14 +186,36 @@ def create_update_persona(
# Permission to actually use these is checked later
try:
persona_data = {
"persona_id": persona_id,
"user": user,
"db_session": db_session,
**create_persona_request.model_dump(exclude={"users", "groups"}),
}
all_prompt_ids = create_persona_request.prompt_ids
persona = upsert_persona(**persona_data)
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
persona = upsert_persona(
persona_id=persona_id,
user=user,
db_session=db_session,
description=create_persona_request.description,
name=create_persona_request.name,
prompt_ids=all_prompt_ids,
document_set_ids=create_persona_request.document_set_ids,
tool_ids=create_persona_request.tool_ids,
is_public=create_persona_request.is_public,
recency_bias=create_persona_request.recency_bias,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
icon_color=create_persona_request.icon_color,
icon_shape=create_persona_request.icon_shape,
uploaded_image_id=create_persona_request.uploaded_image_id,
display_priority=create_persona_request.display_priority,
remove_image=create_persona_request.remove_image,
search_start_date=create_persona_request.search_start_date,
label_ids=create_persona_request.label_ids,
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
)
versioned_make_persona_private = fetch_versioned_implementation(
"onyx.db.persona", "make_persona_private"
@@ -265,24 +281,6 @@ def update_persona_public_status(
db_session.commit()
def get_prompts(
user_id: UUID | None,
db_session: Session,
include_default: bool = True,
include_deleted: bool = False,
) -> Sequence[Prompt]:
stmt = select(Prompt).where(
or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
)
if not include_default:
stmt = stmt.where(Prompt.default_prompt.is_(False))
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
return db_session.scalars(stmt).all()
def get_personas_for_user(
# if user is `None` assume the user is an admin or auth is disabled
user: User | None,
@@ -374,65 +372,6 @@ def update_all_personas_display_priority(
db_session.commit()
def upsert_prompt(
user: User | None,
name: str,
description: str,
system_prompt: str,
task_prompt: str,
include_citations: bool,
datetime_aware: bool,
personas: list[Persona] | None,
db_session: Session,
prompt_id: int | None = None,
default_prompt: bool = True,
commit: bool = True,
) -> Prompt:
if prompt_id is not None:
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
else:
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
if prompt:
if not default_prompt and prompt.default_prompt:
raise ValueError("Cannot update default prompt with non-default.")
prompt.name = name
prompt.description = description
prompt.system_prompt = system_prompt
prompt.task_prompt = task_prompt
prompt.include_citations = include_citations
prompt.datetime_aware = datetime_aware
prompt.default_prompt = default_prompt
if personas is not None:
prompt.personas.clear()
prompt.personas = personas
else:
prompt = Prompt(
id=prompt_id,
user_id=user.id if user else None,
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
datetime_aware=datetime_aware,
default_prompt=default_prompt,
personas=personas or [],
)
db_session.add(prompt)
if commit:
db_session.commit()
else:
# Flush the session so that the Prompt has an ID
db_session.flush()
return prompt
def upsert_persona(
user: User | None,
name: str,
@@ -477,6 +416,15 @@ def upsert_persona(
persona_name=name, user=user, db_session=db_session
)
if existing_persona:
# this checks if the user has permission to edit the persona
# will raise an Exception if the user does not have permission
existing_persona = fetch_persona_by_id_for_user(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
)
# Fetch and attach tools by IDs
tools = None
if tool_ids is not None:
@@ -522,15 +470,6 @@ def upsert_persona(
if existing_persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
# will raise an Exception if the user does not have permission
existing_persona = fetch_persona_by_id_for_user(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
)
# The following update excludes `default`, `built-in`, and display priority.
# Display priority is handled separately in the `display-priority` endpoint.
# `default` and `built-in` properties can only be set when creating a persona.
@@ -619,16 +558,6 @@ def upsert_persona(
return persona
def mark_prompt_as_deleted(
prompt_id: int,
user: User | None,
db_session: Session,
) -> None:
prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session)
prompt.deleted = True
db_session.commit()
def delete_old_default_personas(
db_session: Session,
) -> None:
@@ -666,69 +595,6 @@ def validate_persona_tools(tools: list[Tool]) -> None:
)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return list(prompts)
def get_prompt_by_id(
prompt_id: int,
user: User | None,
db_session: Session,
include_deleted: bool = False,
) -> Prompt:
stmt = select(Prompt).where(Prompt.id == prompt_id)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None)))
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise ValueError(
f"Prompt with ID {prompt_id} does not exist or does not belong to user"
)
return prompt
def _get_default_prompt(db_session: Session) -> Prompt:
stmt = select(Prompt).where(Prompt.id == 0)
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise RuntimeError("Default Prompt not found")
return prompt
def get_default_prompt(db_session: Session) -> Prompt:
return _get_default_prompt(db_session)
@lru_cache()
def get_default_prompt__read_only() -> Prompt:
"""Due to the way lru_cache / SQLAlchemy works, this can cause issues
when trying to attach the returned `Prompt` object to a `Persona`. If you are
doing anything other than reading, you should use the `get_default_prompt`
method instead."""
with Session(get_sqlalchemy_engine()) as db_session:
return _get_default_prompt(db_session)
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
# a direct mapping indicating whether a user has access to a specific persona?
def get_persona_by_id(
@@ -800,22 +666,6 @@ def get_personas_by_ids(
return personas
def get_prompt_by_name(
prompt_name: str, user: User | None, db_session: Session
) -> Prompt | None:
stmt = select(Prompt).where(Prompt.name == prompt_name)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def delete_persona_by_name(
persona_name: str, db_session: Session, is_default: bool = True
) -> None:

119
backend/onyx/db/prompts.py Normal file
View File

@@ -0,0 +1,119 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.models import Persona
from onyx.db.models import Prompt
from onyx.db.models import User
from onyx.utils.logger import setup_logger
# Note: As prompts are fairly innocuous/harmless, there are no protections
# to prevent users from messing with prompts of other users.
logger = setup_logger()
def _get_default_prompt(db_session: Session) -> Prompt:
stmt = select(Prompt).where(Prompt.id == 0)
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise RuntimeError("Default Prompt not found")
return prompt
def get_default_prompt(db_session: Session) -> Prompt:
return _get_default_prompt(db_session)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return list(prompts)
def get_prompt_by_name(
prompt_name: str, user: User | None, db_session: Session
) -> Prompt | None:
stmt = select(Prompt).where(Prompt.name == prompt_name)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def build_prompt_name_from_persona_name(persona_name: str) -> str:
return f"default-prompt__{persona_name}"
def upsert_prompt(
db_session: Session,
user: User | None,
name: str,
system_prompt: str,
task_prompt: str,
datetime_aware: bool,
prompt_id: int | None = None,
personas: list[Persona] | None = None,
include_citations: bool = False,
default_prompt: bool = True,
# Support backwards compatibility
description: str | None = None,
) -> Prompt:
if description is None:
description = f"Default prompt for {name}"
if prompt_id is not None:
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
else:
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
if prompt:
if not default_prompt and prompt.default_prompt:
raise ValueError("Cannot update default prompt with non-default.")
prompt.name = name
prompt.description = description
prompt.system_prompt = system_prompt
prompt.task_prompt = task_prompt
prompt.include_citations = include_citations
prompt.datetime_aware = datetime_aware
prompt.default_prompt = default_prompt
if personas is not None:
prompt.personas.clear()
prompt.personas = personas
else:
prompt = Prompt(
id=prompt_id,
user_id=user.id if user else None,
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
datetime_aware=datetime_aware,
default_prompt=default_prompt,
personas=personas or [],
)
db_session.add(prompt)
# Flush the session so that the Prompt has an ID
db_session.flush()
return prompt

View File

@@ -12,9 +12,9 @@ from onyx.db.models import Persona
from onyx.db.models import Persona__DocumentSet
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
from onyx.db.persona import get_default_prompt
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.prompts import get_default_prompt
from onyx.utils.errors import EERequiredError
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,

View File

@@ -64,7 +64,6 @@ from onyx.server.features.input_prompt.api import (
from onyx.server.features.notifications.api import router as notification_router
from onyx.server.features.persona.api import admin_router as admin_persona_router
from onyx.server.features.persona.api import basic_router as persona_router
from onyx.server.features.prompt.api import basic_router as prompt_router
from onyx.server.features.tool.api import admin_router as admin_tool_router
from onyx.server.features.tool.api import router as tool_router
from onyx.server.gpts.api import router as gpts_router
@@ -296,7 +295,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)
include_router_with_global_prefix_prepended(application, state_router)

View File

@@ -118,32 +118,6 @@ You should always get right to the point, and never use extraneous language.
"""
# This is only for visualization for the users to specify their own prompts
# The actual flow does not work like this
PARAMATERIZED_PROMPT = f"""
{{system_prompt}}
CONTEXT:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
{{task_prompt}}
{QUESTION_PAT.upper()} {{user_query}}
RESPONSE:
""".strip()
PARAMATERIZED_PROMPT_WITHOUT_CONTEXT = f"""
{{system_prompt}}
{{task_prompt}}
{QUESTION_PAT.upper()} {{user_query}}
RESPONSE:
""".strip()
# CURRENTLY DISABLED, CANNOT USE THIS ONE
# Default chain-of-thought style json prompt which uses multiple docs
# This one has a section for the LLM to output some non-answer "thoughts"

View File

@@ -12,9 +12,9 @@ from onyx.db.models import DocumentSet as DocumentSetDBModel
from onyx.db.models import Persona
from onyx.db.models import Prompt as PromptDBModel
from onyx.db.models import Tool as ToolDBModel
from onyx.db.persona import get_prompt_by_name
from onyx.db.persona import upsert_persona
from onyx.db.persona import upsert_prompt
from onyx.db.prompts import get_prompt_by_name
from onyx.db.prompts import upsert_prompt
def load_prompts_from_yaml(
@@ -26,6 +26,7 @@ def load_prompts_from_yaml(
all_prompts = data.get("prompts", [])
for prompt in all_prompts:
upsert_prompt(
db_session=db_session,
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
@@ -36,9 +37,8 @@ def load_prompts_from_yaml(
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
db_session.commit()
def load_input_prompts_from_yaml(

View File

@@ -14,7 +14,6 @@ from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.chat.prompt_builder.utils import build_dummy_prompt
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NotificationType
@@ -36,19 +35,21 @@ from onyx.db.persona import update_persona_label
from onyx.db.persona import update_persona_public_status
from onyx.db.persona import update_persona_shared_users
from onyx.db.persona import update_persona_visibility
from onyx.db.prompts import build_prompt_name_from_persona_name
from onyx.db.prompts import upsert_prompt
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.secondary_llm_flows.starter_message_creation import (
generate_starter_messages,
)
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import PersonaLabelCreate
from onyx.server.features.persona.models import PersonaLabelResponse
from onyx.server.features.persona.models import PersonaSharedNotificationData
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PromptTemplateResponse
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.features.persona.models import PromptSnapshot
from onyx.server.models import DisplayPriorityRequest
from onyx.tools.utils import is_image_generation_available
from onyx.utils.logger import setup_logger
@@ -173,18 +174,37 @@ def upload_file(
@basic_router.post("")
def create_persona(
create_persona_request: CreatePersonaRequest,
persona_upsert_request: PersonaUpsertRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> PersonaSnapshot:
prompt_id = (
persona_upsert_request.prompt_ids[0]
if persona_upsert_request.prompt_ids
and len(persona_upsert_request.prompt_ids) > 0
else None
)
prompt = upsert_prompt(
db_session=db_session,
user=user,
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
datetime_aware=False,
include_citations=persona_upsert_request.include_citations,
prompt_id=prompt_id,
)
prompt_snapshot = PromptSnapshot.from_model(prompt)
persona_upsert_request.prompt_ids = [prompt.id]
persona_snapshot = create_update_persona(
persona_id=None,
create_persona_request=create_persona_request,
create_persona_request=persona_upsert_request,
user=user,
db_session=db_session,
)
persona_snapshot.prompts = [prompt_snapshot]
create_milestone_and_report(
user=user,
distinct_id=tenant_id or "N/A",
@@ -202,16 +222,37 @@ def create_persona(
@basic_router.patch("/{persona_id}")
def update_persona(
persona_id: int,
update_persona_request: CreatePersonaRequest,
persona_upsert_request: PersonaUpsertRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
return create_update_persona(
prompt_id = (
persona_upsert_request.prompt_ids[0]
if persona_upsert_request.prompt_ids
and len(persona_upsert_request.prompt_ids) > 0
else None
)
prompt = upsert_prompt(
db_session=db_session,
user=user,
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
datetime_aware=False,
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
include_citations=persona_upsert_request.include_citations,
prompt_id=prompt_id,
)
prompt_snapshot = PromptSnapshot.from_model(prompt)
persona_upsert_request.prompt_ids = [prompt.id]
persona_snapshot = create_update_persona(
persona_id=persona_id,
create_persona_request=update_persona_request,
create_persona_request=persona_upsert_request,
user=user,
db_session=db_session,
)
persona_snapshot.prompts = [prompt_snapshot]
return persona_snapshot
class PersonaLabelPatchRequest(BaseModel):
@@ -365,22 +406,6 @@ def get_persona(
)
@basic_router.get("/utils/prompt-explorer")
def build_final_template_prompt(
system_prompt: str,
task_prompt: str,
retrieval_disabled: bool = False,
_: User | None = Depends(current_user),
) -> PromptTemplateResponse:
return PromptTemplateResponse(
final_prompt_template=build_dummy_prompt(
system_prompt=system_prompt,
task_prompt=task_prompt,
retrieval_disabled=retrieval_disabled,
)
)
@basic_router.post("/assistant-prompt-refresh")
def build_assistant_prompts(
generate_persona_prompt_request: GenerateStarterMessageRequest,

View File

@@ -7,9 +7,9 @@ from pydantic import Field
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.models import Persona
from onyx.db.models import PersonaLabel
from onyx.db.models import Prompt
from onyx.db.models import StarterMessage
from onyx.server.features.document_set.models import DocumentSet
from onyx.server.features.prompt.models import PromptSnapshot
from onyx.server.features.tool.models import ToolSnapshot
from onyx.server.models import MinimalUserSnapshot
from onyx.utils.logger import setup_logger
@@ -18,6 +18,34 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class PromptSnapshot(BaseModel):
id: int
name: str
description: str
system_prompt: str
task_prompt: str
include_citations: bool
datetime_aware: bool
default_prompt: bool
# Not including persona info, not needed
@classmethod
def from_model(cls, prompt: Prompt) -> "PromptSnapshot":
if prompt.deleted:
raise ValueError("Prompt has been deleted")
return PromptSnapshot(
id=prompt.id,
name=prompt.name,
description=prompt.description,
system_prompt=prompt.system_prompt,
task_prompt=prompt.task_prompt,
include_citations=prompt.include_citations,
datetime_aware=prompt.datetime_aware,
default_prompt=prompt.default_prompt,
)
# More minimal request for generating a persona prompt
class GenerateStarterMessageRequest(BaseModel):
name: str
@@ -27,32 +55,35 @@ class GenerateStarterMessageRequest(BaseModel):
generation_count: int
class CreatePersonaRequest(BaseModel):
class PersonaUpsertRequest(BaseModel):
name: str
description: str
system_prompt: str
task_prompt: str
document_set_ids: list[int]
num_chunks: float
llm_relevance_filter: bool
include_citations: bool
is_public: bool
llm_filter_extraction: bool
recency_bias: RecencyBiasSetting
prompt_ids: list[int]
document_set_ids: list[int]
# e.g. ID of SearchTool or ImageGenerationTool or <USER_DEFINED_TOOL>
tool_ids: list[int]
llm_filter_extraction: bool
llm_relevance_filter: bool
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
starter_messages: list[StarterMessage] | None = None
# For Private Personas, who should be able to access these
users: list[UUID] = Field(default_factory=list)
groups: list[int] = Field(default_factory=list)
# e.g. ID of SearchTool or ImageGenerationTool or <USER_DEFINED_TOOL>
tool_ids: list[int]
icon_color: str | None = None
icon_shape: int | None = None
uploaded_image_id: str | None = None # New field for uploaded image
remove_image: bool | None = None
is_default_persona: bool = False
display_priority: int | None = None
uploaded_image_id: str | None = None # New field for uploaded image
search_start_date: datetime | None = None
label_ids: list[int] | None = None
is_default_persona: bool = False
display_priority: int | None = None
class PersonaSnapshot(BaseModel):

View File

@@ -1,152 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from starlette import status
from onyx.auth.users import current_user
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.db.persona import get_personas_by_ids
from onyx.db.persona import get_prompt_by_id
from onyx.db.persona import get_prompts
from onyx.db.persona import mark_prompt_as_deleted
from onyx.db.persona import upsert_prompt
from onyx.server.features.prompt.models import CreatePromptRequest
from onyx.server.features.prompt.models import PromptSnapshot
from onyx.utils.logger import setup_logger
# Note: As prompts are fairly innocuous/harmless, there are no protections
# to prevent users from messing with prompts of other users.
logger = setup_logger()
basic_router = APIRouter(prefix="/prompt")
def create_update_prompt(
prompt_id: int | None,
create_prompt_request: CreatePromptRequest,
user: User | None,
db_session: Session,
) -> PromptSnapshot:
personas = (
list(
get_personas_by_ids(
persona_ids=create_prompt_request.persona_ids,
db_session=db_session,
)
)
if create_prompt_request.persona_ids
else []
)
prompt = upsert_prompt(
prompt_id=prompt_id,
user=user,
name=create_prompt_request.name,
description=create_prompt_request.description,
system_prompt=create_prompt_request.system_prompt,
task_prompt=create_prompt_request.task_prompt,
include_citations=create_prompt_request.include_citations,
datetime_aware=create_prompt_request.datetime_aware,
personas=personas,
db_session=db_session,
)
return PromptSnapshot.from_model(prompt)
@basic_router.post("")
def create_prompt(
create_prompt_request: CreatePromptRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PromptSnapshot:
try:
return create_update_prompt(
prompt_id=None,
create_prompt_request=create_prompt_request,
user=user,
db_session=db_session,
)
except ValueError as ve:
logger.exception(ve)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to create Persona, invalid info.",
)
except Exception as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later.",
)
@basic_router.patch("/{prompt_id}")
def update_prompt(
prompt_id: int,
update_prompt_request: CreatePromptRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PromptSnapshot:
try:
return create_update_prompt(
prompt_id=prompt_id,
create_prompt_request=update_prompt_request,
user=user,
db_session=db_session,
)
except ValueError as ve:
logger.exception(ve)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to create Persona, invalid info.",
)
except Exception as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later.",
)
@basic_router.delete("/{prompt_id}")
def delete_prompt(
prompt_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
mark_prompt_as_deleted(
prompt_id=prompt_id,
user=user,
db_session=db_session,
)
@basic_router.get("")
def list_prompts(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[PromptSnapshot]:
user_id = user.id if user is not None else None
return [
PromptSnapshot.from_model(prompt)
for prompt in get_prompts(user_id=user_id, db_session=db_session)
]
@basic_router.get("/{prompt_id}")
def get_prompt(
prompt_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PromptSnapshot:
return PromptSnapshot.from_model(
get_prompt_by_id(
prompt_id=prompt_id,
user=user,
db_session=db_session,
)
)

View File

@@ -1,41 +0,0 @@
from pydantic import BaseModel
from onyx.db.models import Prompt
class CreatePromptRequest(BaseModel):
name: str
description: str
system_prompt: str
task_prompt: str
include_citations: bool = False
datetime_aware: bool = False
persona_ids: list[int] | None = None
class PromptSnapshot(BaseModel):
id: int
name: str
description: str
system_prompt: str
task_prompt: str
include_citations: bool
datetime_aware: bool
default_prompt: bool
# Not including persona info, not needed
@classmethod
def from_model(cls, prompt: Prompt) -> "PromptSnapshot":
if prompt.deleted:
raise ValueError("Prompt has been deleted")
return PromptSnapshot(
id=prompt.id,
name=prompt.name,
description=prompt.description,
system_prompt=prompt.system_prompt,
task_prompt=prompt.task_prompt,
include_citations=prompt.include_citations,
datetime_aware=prompt.datetime_aware,
default_prompt=prompt.default_prompt,
)

View File

@@ -18,7 +18,7 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.persona import get_personas_for_user
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.persona import upsert_prompt
from onyx.db.prompts import upsert_prompt
from onyx.db.tools import get_tool_by_name
from onyx.utils.logger import setup_logger

View File

@@ -479,7 +479,10 @@ def get_max_document_tokens(
raise HTTPException(status_code=404, detail="Persona not found")
return MaxSelectedDocumentTokens(
max_tokens=compute_max_document_tokens_for_persona(persona),
max_tokens=compute_max_document_tokens_for_persona(
db_session=db_session,
persona=persona,
),
)

View File

@@ -1,9 +1,11 @@
from uuid import UUID
from uuid import uuid4
import requests
from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestPersona
@@ -16,6 +18,9 @@ class PersonaManager:
def create(
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
task_prompt: str | None = None,
include_citations: bool = False,
num_chunks: float = 5,
llm_relevance_filter: bool = True,
is_public: bool = True,
@@ -28,32 +33,38 @@ class PersonaManager:
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
category_id: int | None = None,
label_ids: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
description = description or f"Description for {name}"
system_prompt = system_prompt or f"System prompt for {name}"
task_prompt = task_prompt or f"Task prompt for {name}"
persona_creation_request = {
"name": name,
"description": description,
"num_chunks": num_chunks,
"llm_relevance_filter": llm_relevance_filter,
"is_public": is_public,
"llm_filter_extraction": llm_filter_extraction,
"recency_bias": recency_bias,
"prompt_ids": prompt_ids or [0],
"document_set_ids": document_set_ids or [],
"tool_ids": tool_ids or [],
"llm_model_provider_override": llm_model_provider_override,
"llm_model_version_override": llm_model_version_override,
"users": users or [],
"groups": groups or [],
}
persona_creation_request = PersonaUpsertRequest(
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
is_public=is_public,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
prompt_ids=prompt_ids or [0],
document_set_ids=document_set_ids or [],
tool_ids=tool_ids or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
users=[UUID(user) for user in (users or [])],
groups=groups or [],
label_ids=label_ids or [],
)
response = requests.post(
f"{API_SERVER_URL}/persona",
json=persona_creation_request,
json=persona_creation_request.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
@@ -77,6 +88,7 @@ class PersonaManager:
llm_model_version_override=llm_model_version_override,
users=users or [],
groups=groups or [],
label_ids=label_ids or [],
)
@staticmethod
@@ -84,6 +96,9 @@ class PersonaManager:
persona: DATestPersona,
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
task_prompt: str | None = None,
include_citations: bool = False,
num_chunks: float | None = None,
llm_relevance_filter: bool | None = None,
is_public: bool | None = None,
@@ -96,32 +111,38 @@ class PersonaManager:
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
label_ids: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
persona_update_request = {
"name": name or persona.name,
"description": description or persona.description,
"num_chunks": num_chunks or persona.num_chunks,
"llm_relevance_filter": llm_relevance_filter
or persona.llm_relevance_filter,
"is_public": is_public or persona.is_public,
"llm_filter_extraction": llm_filter_extraction
system_prompt = system_prompt or f"System prompt for {persona.name}"
task_prompt = task_prompt or f"Task prompt for {persona.name}"
persona_update_request = PersonaUpsertRequest(
name=name or persona.name,
description=description or persona.description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
num_chunks=num_chunks or persona.num_chunks,
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
is_public=is_public or persona.is_public,
llm_filter_extraction=llm_filter_extraction
or persona.llm_filter_extraction,
"recency_bias": recency_bias or persona.recency_bias,
"prompt_ids": prompt_ids or persona.prompt_ids,
"document_set_ids": document_set_ids or persona.document_set_ids,
"tool_ids": tool_ids or persona.tool_ids,
"llm_model_provider_override": llm_model_provider_override
recency_bias=recency_bias or persona.recency_bias,
prompt_ids=prompt_ids or persona.prompt_ids,
document_set_ids=document_set_ids or persona.document_set_ids,
tool_ids=tool_ids or persona.tool_ids,
llm_model_provider_override=llm_model_provider_override
or persona.llm_model_provider_override,
"llm_model_version_override": llm_model_version_override
llm_model_version_override=llm_model_version_override
or persona.llm_model_version_override,
"users": users or persona.users,
"groups": groups or persona.groups,
}
users=[UUID(user) for user in (users or persona.users)],
groups=groups or persona.groups,
label_ids=label_ids or persona.label_ids,
)
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=persona_update_request,
json=persona_update_request.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
@@ -137,8 +158,8 @@ class PersonaManager:
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
is_public=updated_persona_data["is_public"],
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
recency_bias=updated_persona_data["recency_bias"],
prompt_ids=updated_persona_data["prompts"],
recency_bias=recency_bias or persona.recency_bias,
prompt_ids=[prompt["id"] for prompt in updated_persona_data["prompts"]],
document_set_ids=updated_persona_data["document_sets"],
tool_ids=updated_persona_data["tools"],
llm_model_provider_override=updated_persona_data[
@@ -149,6 +170,7 @@ class PersonaManager:
],
users=[user["email"] for user in updated_persona_data["users"]],
groups=updated_persona_data["groups"],
label_ids=updated_persona_data["labels"],
)
@staticmethod
@@ -164,12 +186,29 @@ class PersonaManager:
response.raise_for_status()
return [PersonaSnapshot(**persona) for persona in response.json()]
@staticmethod
def get_one(
persona_id: int,
user_performing_action: DATestUser | None = None,
) -> list[PersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/persona/{persona_id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [PersonaSnapshot(**response.json())]
@staticmethod
def verify(
persona: DATestPersona,
user_performing_action: DATestUser | None = None,
) -> bool:
all_personas = PersonaManager.get_all(user_performing_action)
all_personas = PersonaManager.get_one(
persona_id=persona.id,
user_performing_action=user_performing_action,
)
for fetched_persona in all_personas:
if fetched_persona.id == persona.id:
return (
@@ -199,6 +238,7 @@ class PersonaManager:
and set(user.email for user in fetched_persona.users)
== set(persona.users)
and set(fetched_persona.groups) == set(persona.groups)
and set(fetched_persona.labels) == set(persona.label_ids)
)
return False

View File

@@ -127,7 +127,7 @@ class DATestPersona(BaseModel):
llm_model_version_override: str | None
users: list[str]
groups: list[int]
category_id: int | None = None
label_ids: list[int]
#

View File

@@ -0,0 +1,175 @@
"""
This file tests the permissions for creating and editing personas for different user roles:
- Basic users can create personas and edit their own
- Curators can edit personas that belong exclusively to groups they curate
- Admins can edit all personas
"""
import pytest
from requests.exceptions import HTTPError
from tests.integration.common_utils.managers.persona import PersonaManager
from tests.integration.common_utils.managers.user import DATestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def test_persona_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# Creating a curator user
curator: DATestUser = UserManager.create(name="curator")
# Creating a basic user
basic_user: DATestUser = UserManager.create(name="basic_user")
# Creating user groups
user_group_1 = UserGroupManager.create(
name="curated_user_group",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# Setting the user as a curator for the user group
UserGroupManager.set_curator_status(
test_user_group=user_group_1,
user_to_set_as_curator=curator,
user_performing_action=admin_user,
)
# Creating another user group that the user is not a curator of
user_group_2 = UserGroupManager.create(
name="uncurated_user_group",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2], user_performing_action=admin_user
)
"""Test that any user can create a persona"""
# Basic user creates a persona
basic_user_persona = PersonaManager.create(
name="basic_user_persona",
description="A persona created by basic user",
is_public=False,
groups=[],
user_performing_action=basic_user,
)
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
# Curator creates a persona
curator_persona = PersonaManager.create(
name="curator_persona",
description="A persona created by curator",
is_public=False,
groups=[],
user_performing_action=curator,
)
PersonaManager.verify(curator_persona, user_performing_action=curator)
# Admin creates personas for different groups
admin_persona_group_1 = PersonaManager.create(
name="admin_persona_group_1",
description="A persona for group 1",
is_public=False,
groups=[user_group_1.id],
user_performing_action=admin_user,
)
admin_persona_group_2 = PersonaManager.create(
name="admin_persona_group_2",
description="A persona for group 2",
is_public=False,
groups=[user_group_2.id],
user_performing_action=admin_user,
)
admin_persona_both_groups = PersonaManager.create(
name="admin_persona_both_groups",
description="A persona for both groups",
is_public=False,
groups=[user_group_1.id, user_group_2.id],
user_performing_action=admin_user,
)
"""Test that users can edit their own personas"""
# Basic user can edit their own persona
PersonaManager.edit(
persona=basic_user_persona,
description="Updated description by basic user",
user_performing_action=basic_user,
)
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
# Basic user cannot edit other's personas
with pytest.raises(HTTPError):
PersonaManager.edit(
persona=curator_persona,
description="Invalid edit by basic user",
user_performing_action=basic_user,
)
"""Test curator permissions"""
# Curator can edit personas that belong exclusively to groups they curate
PersonaManager.edit(
persona=admin_persona_group_1,
description="Updated by curator",
user_performing_action=curator,
)
PersonaManager.verify(admin_persona_group_1, user_performing_action=curator)
# Curator cannot edit personas in groups they don't curate
with pytest.raises(HTTPError):
PersonaManager.edit(
persona=admin_persona_group_2,
description="Invalid edit by curator",
user_performing_action=curator,
)
# Curator cannot edit personas that belong to multiple groups, even if they curate one
with pytest.raises(HTTPError):
PersonaManager.edit(
persona=admin_persona_both_groups,
description="Invalid edit by curator",
user_performing_action=curator,
)
"""Test admin permissions"""
# Admin can edit any persona
PersonaManager.edit(
persona=basic_user_persona,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(basic_user_persona, user_performing_action=admin_user)
PersonaManager.edit(
persona=curator_persona,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(curator_persona, user_performing_action=admin_user)
PersonaManager.edit(
persona=admin_persona_group_1,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(admin_persona_group_1, user_performing_action=admin_user)
PersonaManager.edit(
persona=admin_persona_group_2,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(admin_persona_group_2, user_performing_action=admin_user)
PersonaManager.edit(
persona=admin_persona_both_groups,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(admin_persona_both_groups, user_performing_action=admin_user)