mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-29 05:15:12 +02:00
Combined Persona and Prompt API (#3690)
* Combined Persona and Prompt API * quality * added tests * consolidated models and got rid of redundant fields * tenant appreciation day * reverted default
This commit is contained in:
@@ -24,7 +24,7 @@ from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
@@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel):
|
||||
llms: list[LLMProviderUpsertRequest] | None = None
|
||||
admin_user_emails: list[str] | None = None
|
||||
seeded_logo_path: str | None = None
|
||||
personas: list[CreatePersonaRequest] | None = None
|
||||
personas: list[PersonaUpsertRequest] | None = None
|
||||
settings: Settings | None = None
|
||||
enterprise_settings: EnterpriseSettings | None = None
|
||||
|
||||
@@ -128,7 +128,7 @@ def _seed_llms(
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
if personas:
|
||||
logger.notice("Seeding Personas")
|
||||
for persona in personas:
|
||||
|
@@ -25,7 +25,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_prompts_by_ids
|
||||
from onyx.db.prompts import get_prompts_by_ids
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
|
@@ -1,12 +1,13 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.persona import get_default_prompt__read_only
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
@@ -97,11 +98,12 @@ def compute_max_document_tokens(
|
||||
|
||||
|
||||
def compute_max_document_tokens_for_persona(
|
||||
db_session: Session,
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session)
|
||||
return compute_max_document_tokens(
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
|
||||
|
@@ -7,26 +7,6 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def translate_onyx_msg_to_langchain(
|
||||
|
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -8,7 +7,6 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import not_
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
@@ -23,7 +21,6 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
@@ -35,8 +32,8 @@ from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -107,9 +104,6 @@ def _add_user_filters(
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
# fetch_persona_by_id is used to fetch a persona by its ID. It is used to fetch a persona by its ID.
|
||||
|
||||
|
||||
def fetch_persona_by_id_for_user(
|
||||
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
|
||||
) -> Persona:
|
||||
@@ -184,7 +178,7 @@ def make_persona_private(
|
||||
|
||||
def create_update_persona(
|
||||
persona_id: int | None,
|
||||
create_persona_request: CreatePersonaRequest,
|
||||
create_persona_request: PersonaUpsertRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> PersonaSnapshot:
|
||||
@@ -192,14 +186,36 @@ def create_update_persona(
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
try:
|
||||
persona_data = {
|
||||
"persona_id": persona_id,
|
||||
"user": user,
|
||||
"db_session": db_session,
|
||||
**create_persona_request.model_dump(exclude={"users", "groups"}),
|
||||
}
|
||||
all_prompt_ids = create_persona_request.prompt_ids
|
||||
|
||||
persona = upsert_persona(**persona_data)
|
||||
if not all_prompt_ids:
|
||||
raise ValueError("No prompt IDs provided")
|
||||
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
description=create_persona_request.description,
|
||||
name=create_persona_request.name,
|
||||
prompt_ids=all_prompt_ids,
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
is_public=create_persona_request.is_public,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
icon_color=create_persona_request.icon_color,
|
||||
icon_shape=create_persona_request.icon_shape,
|
||||
uploaded_image_id=create_persona_request.uploaded_image_id,
|
||||
display_priority=create_persona_request.display_priority,
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "make_persona_private"
|
||||
@@ -265,24 +281,6 @@ def update_persona_public_status(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_prompts(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
include_default: bool = True,
|
||||
include_deleted: bool = False,
|
||||
) -> Sequence[Prompt]:
|
||||
stmt = select(Prompt).where(
|
||||
or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
|
||||
)
|
||||
|
||||
if not include_default:
|
||||
stmt = stmt.where(Prompt.default_prompt.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Prompt.deleted.is_(False))
|
||||
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_personas_for_user(
|
||||
# if user is `None` assume the user is an admin or auth is disabled
|
||||
user: User | None,
|
||||
@@ -374,65 +372,6 @@ def update_all_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_prompt(
|
||||
user: User | None,
|
||||
name: str,
|
||||
description: str,
|
||||
system_prompt: str,
|
||||
task_prompt: str,
|
||||
include_citations: bool,
|
||||
datetime_aware: bool,
|
||||
personas: list[Persona] | None,
|
||||
db_session: Session,
|
||||
prompt_id: int | None = None,
|
||||
default_prompt: bool = True,
|
||||
commit: bool = True,
|
||||
) -> Prompt:
|
||||
if prompt_id is not None:
|
||||
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
|
||||
else:
|
||||
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
|
||||
|
||||
if prompt:
|
||||
if not default_prompt and prompt.default_prompt:
|
||||
raise ValueError("Cannot update default prompt with non-default.")
|
||||
|
||||
prompt.name = name
|
||||
prompt.description = description
|
||||
prompt.system_prompt = system_prompt
|
||||
prompt.task_prompt = task_prompt
|
||||
prompt.include_citations = include_citations
|
||||
prompt.datetime_aware = datetime_aware
|
||||
prompt.default_prompt = default_prompt
|
||||
|
||||
if personas is not None:
|
||||
prompt.personas.clear()
|
||||
prompt.personas = personas
|
||||
|
||||
else:
|
||||
prompt = Prompt(
|
||||
id=prompt_id,
|
||||
user_id=user.id if user else None,
|
||||
name=name,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
include_citations=include_citations,
|
||||
datetime_aware=datetime_aware,
|
||||
default_prompt=default_prompt,
|
||||
personas=personas or [],
|
||||
)
|
||||
db_session.add(prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
# Flush the session so that the Prompt has an ID
|
||||
db_session.flush()
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -477,6 +416,15 @@ def upsert_persona(
|
||||
persona_name=name, user=user, db_session=db_session
|
||||
)
|
||||
|
||||
if existing_persona:
|
||||
# this checks if the user has permission to edit the persona
|
||||
# will raise an Exception if the user does not have permission
|
||||
existing_persona = fetch_persona_by_id_for_user(
|
||||
db_session=db_session,
|
||||
persona_id=existing_persona.id,
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
# Fetch and attach tools by IDs
|
||||
tools = None
|
||||
if tool_ids is not None:
|
||||
@@ -522,15 +470,6 @@ def upsert_persona(
|
||||
if existing_persona.builtin_persona and not builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
# this checks if the user has permission to edit the persona
|
||||
# will raise an Exception if the user does not have permission
|
||||
existing_persona = fetch_persona_by_id_for_user(
|
||||
db_session=db_session,
|
||||
persona_id=existing_persona.id,
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
|
||||
# The following update excludes `default`, `built-in`, and display priority.
|
||||
# Display priority is handled separately in the `display-priority` endpoint.
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
@@ -619,16 +558,6 @@ def upsert_persona(
|
||||
return persona
|
||||
|
||||
|
||||
def mark_prompt_as_deleted(
|
||||
prompt_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session)
|
||||
prompt.deleted = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_old_default_personas(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -666,69 +595,6 @@ def validate_persona_tools(tools: list[Tool]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
|
||||
"""Unsafe, can fetch prompts from all users"""
|
||||
if not prompt_ids:
|
||||
return []
|
||||
prompts = db_session.scalars(
|
||||
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
|
||||
).all()
|
||||
|
||||
return list(prompts)
|
||||
|
||||
|
||||
def get_prompt_by_id(
|
||||
prompt_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
) -> Prompt:
|
||||
stmt = select(Prompt).where(Prompt.id == prompt_id)
|
||||
|
||||
# if user is not specified OR they are an admin, they should
|
||||
# have access to all prompts, so this where clause is not needed
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None)))
|
||||
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Prompt.deleted.is_(False))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if prompt is None:
|
||||
raise ValueError(
|
||||
f"Prompt with ID {prompt_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def _get_default_prompt(db_session: Session) -> Prompt:
|
||||
stmt = select(Prompt).where(Prompt.id == 0)
|
||||
result = db_session.execute(stmt)
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if prompt is None:
|
||||
raise RuntimeError("Default Prompt not found")
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_default_prompt(db_session: Session) -> Prompt:
|
||||
return _get_default_prompt(db_session)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_default_prompt__read_only() -> Prompt:
|
||||
"""Due to the way lru_cache / SQLAlchemy works, this can cause issues
|
||||
when trying to attach the returned `Prompt` object to a `Persona`. If you are
|
||||
doing anything other than reading, you should use the `get_default_prompt`
|
||||
method instead."""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
return _get_default_prompt(db_session)
|
||||
|
||||
|
||||
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
|
||||
# a direct mapping indicating whether a user has access to a specific persona?
|
||||
def get_persona_by_id(
|
||||
@@ -800,22 +666,6 @@ def get_personas_by_ids(
|
||||
return personas
|
||||
|
||||
|
||||
def get_prompt_by_name(
|
||||
prompt_name: str, user: User | None, db_session: Session
|
||||
) -> Prompt | None:
|
||||
stmt = select(Prompt).where(Prompt.name == prompt_name)
|
||||
|
||||
# if user is not specified OR they are an admin, they should
|
||||
# have access to all prompts, so this where clause is not needed
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(Prompt.user_id == user.id)
|
||||
|
||||
# Order by ID to ensure consistent result when multiple prompts exist
|
||||
stmt = stmt.order_by(Prompt.id).limit(1)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
|
||||
def delete_persona_by_name(
|
||||
persona_name: str, db_session: Session, is_default: bool = True
|
||||
) -> None:
|
||||
|
119
backend/onyx/db/prompts.py
Normal file
119
backend/onyx/db/prompts.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
# Note: As prompts are fairly innocuous/harmless, there are no protections
|
||||
# to prevent users from messing with prompts of other users.
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_default_prompt(db_session: Session) -> Prompt:
|
||||
stmt = select(Prompt).where(Prompt.id == 0)
|
||||
result = db_session.execute(stmt)
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if prompt is None:
|
||||
raise RuntimeError("Default Prompt not found")
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_default_prompt(db_session: Session) -> Prompt:
|
||||
return _get_default_prompt(db_session)
|
||||
|
||||
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
|
||||
"""Unsafe, can fetch prompts from all users"""
|
||||
if not prompt_ids:
|
||||
return []
|
||||
prompts = db_session.scalars(
|
||||
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
|
||||
).all()
|
||||
|
||||
return list(prompts)
|
||||
|
||||
|
||||
def get_prompt_by_name(
|
||||
prompt_name: str, user: User | None, db_session: Session
|
||||
) -> Prompt | None:
|
||||
stmt = select(Prompt).where(Prompt.name == prompt_name)
|
||||
|
||||
# if user is not specified OR they are an admin, they should
|
||||
# have access to all prompts, so this where clause is not needed
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(Prompt.user_id == user.id)
|
||||
|
||||
# Order by ID to ensure consistent result when multiple prompts exist
|
||||
stmt = stmt.order_by(Prompt.id).limit(1)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
|
||||
def build_prompt_name_from_persona_name(persona_name: str) -> str:
|
||||
return f"default-prompt__{persona_name}"
|
||||
|
||||
|
||||
def upsert_prompt(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
name: str,
|
||||
system_prompt: str,
|
||||
task_prompt: str,
|
||||
datetime_aware: bool,
|
||||
prompt_id: int | None = None,
|
||||
personas: list[Persona] | None = None,
|
||||
include_citations: bool = False,
|
||||
default_prompt: bool = True,
|
||||
# Support backwards compatibility
|
||||
description: str | None = None,
|
||||
) -> Prompt:
|
||||
if description is None:
|
||||
description = f"Default prompt for {name}"
|
||||
|
||||
if prompt_id is not None:
|
||||
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
|
||||
else:
|
||||
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
|
||||
|
||||
if prompt:
|
||||
if not default_prompt and prompt.default_prompt:
|
||||
raise ValueError("Cannot update default prompt with non-default.")
|
||||
|
||||
prompt.name = name
|
||||
prompt.description = description
|
||||
prompt.system_prompt = system_prompt
|
||||
prompt.task_prompt = task_prompt
|
||||
prompt.include_citations = include_citations
|
||||
prompt.datetime_aware = datetime_aware
|
||||
prompt.default_prompt = default_prompt
|
||||
|
||||
if personas is not None:
|
||||
prompt.personas.clear()
|
||||
prompt.personas = personas
|
||||
|
||||
else:
|
||||
prompt = Prompt(
|
||||
id=prompt_id,
|
||||
user_id=user.id if user else None,
|
||||
name=name,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
include_citations=include_citations,
|
||||
datetime_aware=datetime_aware,
|
||||
default_prompt=default_prompt,
|
||||
personas=personas or [],
|
||||
)
|
||||
db_session.add(prompt)
|
||||
|
||||
# Flush the session so that the Prompt has an ID
|
||||
db_session.flush()
|
||||
|
||||
return prompt
|
@@ -12,9 +12,9 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__DocumentSet
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_default_prompt
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.utils.errors import EERequiredError
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
|
@@ -64,7 +64,6 @@ from onyx.server.features.input_prompt.api import (
|
||||
from onyx.server.features.notifications.api import router as notification_router
|
||||
from onyx.server.features.persona.api import admin_router as admin_persona_router
|
||||
from onyx.server.features.persona.api import basic_router as persona_router
|
||||
from onyx.server.features.prompt.api import basic_router as prompt_router
|
||||
from onyx.server.features.tool.api import admin_router as admin_tool_router
|
||||
from onyx.server.features.tool.api import router as tool_router
|
||||
from onyx.server.gpts.api import router as gpts_router
|
||||
@@ -296,7 +295,6 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, persona_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, notification_router)
|
||||
include_router_with_global_prefix_prepended(application, prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, tool_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_tool_router)
|
||||
include_router_with_global_prefix_prepended(application, state_router)
|
||||
|
@@ -118,32 +118,6 @@ You should always get right to the point, and never use extraneous language.
|
||||
"""
|
||||
|
||||
|
||||
# This is only for visualization for the users to specify their own prompts
|
||||
# The actual flow does not work like this
|
||||
PARAMATERIZED_PROMPT = f"""
|
||||
{{system_prompt}}
|
||||
|
||||
CONTEXT:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{context_docs_str}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
{{task_prompt}}
|
||||
|
||||
{QUESTION_PAT.upper()} {{user_query}}
|
||||
RESPONSE:
|
||||
""".strip()
|
||||
|
||||
PARAMATERIZED_PROMPT_WITHOUT_CONTEXT = f"""
|
||||
{{system_prompt}}
|
||||
|
||||
{{task_prompt}}
|
||||
|
||||
{QUESTION_PAT.upper()} {{user_query}}
|
||||
RESPONSE:
|
||||
""".strip()
|
||||
|
||||
|
||||
# CURRENTLY DISABLED, CANNOT USE THIS ONE
|
||||
# Default chain-of-thought style json prompt which uses multiple docs
|
||||
# This one has a section for the LLM to output some non-answer "thoughts"
|
||||
|
@@ -12,9 +12,9 @@ from onyx.db.models import DocumentSet as DocumentSetDBModel
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Prompt as PromptDBModel
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.db.persona import get_prompt_by_name
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.persona import upsert_prompt
|
||||
from onyx.db.prompts import get_prompt_by_name
|
||||
from onyx.db.prompts import upsert_prompt
|
||||
|
||||
|
||||
def load_prompts_from_yaml(
|
||||
@@ -26,6 +26,7 @@ def load_prompts_from_yaml(
|
||||
all_prompts = data.get("prompts", [])
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
db_session=db_session,
|
||||
user=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
@@ -36,9 +37,8 @@ def load_prompts_from_yaml(
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(
|
||||
|
@@ -14,7 +14,6 @@ from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.prompt_builder.utils import build_dummy_prompt
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NotificationType
|
||||
@@ -36,19 +35,21 @@ from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared_users
|
||||
from onyx.db.persona import update_persona_visibility
|
||||
from onyx.db.prompts import build_prompt_name_from_persona_name
|
||||
from onyx.db.prompts import upsert_prompt
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.secondary_llm_flows.starter_message_creation import (
|
||||
generate_starter_messages,
|
||||
)
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import GenerateStarterMessageRequest
|
||||
from onyx.server.features.persona.models import ImageGenerationToolStatus
|
||||
from onyx.server.features.persona.models import PersonaLabelCreate
|
||||
from onyx.server.features.persona.models import PersonaLabelResponse
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PromptTemplateResponse
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.features.persona.models import PromptSnapshot
|
||||
from onyx.server.models import DisplayPriorityRequest
|
||||
from onyx.tools.utils import is_image_generation_available
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -173,18 +174,37 @@ def upload_file(
|
||||
|
||||
@basic_router.post("")
|
||||
def create_persona(
|
||||
create_persona_request: CreatePersonaRequest,
|
||||
persona_upsert_request: PersonaUpsertRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> PersonaSnapshot:
|
||||
prompt_id = (
|
||||
persona_upsert_request.prompt_ids[0]
|
||||
if persona_upsert_request.prompt_ids
|
||||
and len(persona_upsert_request.prompt_ids) > 0
|
||||
else None
|
||||
)
|
||||
prompt = upsert_prompt(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
|
||||
system_prompt=persona_upsert_request.system_prompt,
|
||||
task_prompt=persona_upsert_request.task_prompt,
|
||||
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
|
||||
datetime_aware=False,
|
||||
include_citations=persona_upsert_request.include_citations,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
prompt_snapshot = PromptSnapshot.from_model(prompt)
|
||||
persona_upsert_request.prompt_ids = [prompt.id]
|
||||
persona_snapshot = create_update_persona(
|
||||
persona_id=None,
|
||||
create_persona_request=create_persona_request,
|
||||
create_persona_request=persona_upsert_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
persona_snapshot.prompts = [prompt_snapshot]
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
@@ -202,16 +222,37 @@ def create_persona(
|
||||
@basic_router.patch("/{persona_id}")
|
||||
def update_persona(
|
||||
persona_id: int,
|
||||
update_persona_request: CreatePersonaRequest,
|
||||
persona_upsert_request: PersonaUpsertRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PersonaSnapshot:
|
||||
return create_update_persona(
|
||||
prompt_id = (
|
||||
persona_upsert_request.prompt_ids[0]
|
||||
if persona_upsert_request.prompt_ids
|
||||
and len(persona_upsert_request.prompt_ids) > 0
|
||||
else None
|
||||
)
|
||||
prompt = upsert_prompt(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
|
||||
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
|
||||
datetime_aware=False,
|
||||
system_prompt=persona_upsert_request.system_prompt,
|
||||
task_prompt=persona_upsert_request.task_prompt,
|
||||
include_citations=persona_upsert_request.include_citations,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
prompt_snapshot = PromptSnapshot.from_model(prompt)
|
||||
persona_upsert_request.prompt_ids = [prompt.id]
|
||||
persona_snapshot = create_update_persona(
|
||||
persona_id=persona_id,
|
||||
create_persona_request=update_persona_request,
|
||||
create_persona_request=persona_upsert_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
persona_snapshot.prompts = [prompt_snapshot]
|
||||
return persona_snapshot
|
||||
|
||||
|
||||
class PersonaLabelPatchRequest(BaseModel):
|
||||
@@ -365,22 +406,6 @@ def get_persona(
|
||||
)
|
||||
|
||||
|
||||
@basic_router.get("/utils/prompt-explorer")
|
||||
def build_final_template_prompt(
|
||||
system_prompt: str,
|
||||
task_prompt: str,
|
||||
retrieval_disabled: bool = False,
|
||||
_: User | None = Depends(current_user),
|
||||
) -> PromptTemplateResponse:
|
||||
return PromptTemplateResponse(
|
||||
final_prompt_template=build_dummy_prompt(
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
retrieval_disabled=retrieval_disabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@basic_router.post("/assistant-prompt-refresh")
|
||||
def build_assistant_prompts(
|
||||
generate_persona_prompt_request: GenerateStarterMessageRequest,
|
||||
|
@@ -7,9 +7,9 @@ from pydantic import Field
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import PersonaLabel
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import StarterMessage
|
||||
from onyx.server.features.document_set.models import DocumentSet
|
||||
from onyx.server.features.prompt.models import PromptSnapshot
|
||||
from onyx.server.features.tool.models import ToolSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -18,6 +18,34 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class PromptSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
include_citations: bool
|
||||
datetime_aware: bool
|
||||
default_prompt: bool
|
||||
# Not including persona info, not needed
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, prompt: Prompt) -> "PromptSnapshot":
|
||||
if prompt.deleted:
|
||||
raise ValueError("Prompt has been deleted")
|
||||
|
||||
return PromptSnapshot(
|
||||
id=prompt.id,
|
||||
name=prompt.name,
|
||||
description=prompt.description,
|
||||
system_prompt=prompt.system_prompt,
|
||||
task_prompt=prompt.task_prompt,
|
||||
include_citations=prompt.include_citations,
|
||||
datetime_aware=prompt.datetime_aware,
|
||||
default_prompt=prompt.default_prompt,
|
||||
)
|
||||
|
||||
|
||||
# More minimal request for generating a persona prompt
|
||||
class GenerateStarterMessageRequest(BaseModel):
|
||||
name: str
|
||||
@@ -27,32 +55,35 @@ class GenerateStarterMessageRequest(BaseModel):
|
||||
generation_count: int
|
||||
|
||||
|
||||
class CreatePersonaRequest(BaseModel):
|
||||
class PersonaUpsertRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
document_set_ids: list[int]
|
||||
num_chunks: float
|
||||
llm_relevance_filter: bool
|
||||
include_citations: bool
|
||||
is_public: bool
|
||||
llm_filter_extraction: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
prompt_ids: list[int]
|
||||
document_set_ids: list[int]
|
||||
# e.g. ID of SearchTool or ImageGenerationTool or <USER_DEFINED_TOOL>
|
||||
tool_ids: list[int]
|
||||
llm_filter_extraction: bool
|
||||
llm_relevance_filter: bool
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
# For Private Personas, who should be able to access these
|
||||
users: list[UUID] = Field(default_factory=list)
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
# e.g. ID of SearchTool or ImageGenerationTool or <USER_DEFINED_TOOL>
|
||||
tool_ids: list[int]
|
||||
icon_color: str | None = None
|
||||
icon_shape: int | None = None
|
||||
uploaded_image_id: str | None = None # New field for uploaded image
|
||||
remove_image: bool | None = None
|
||||
is_default_persona: bool = False
|
||||
display_priority: int | None = None
|
||||
uploaded_image_id: str | None = None # New field for uploaded image
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
is_default_persona: bool = False
|
||||
display_priority: int | None = None
|
||||
|
||||
|
||||
class PersonaSnapshot(BaseModel):
|
||||
|
@@ -1,152 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette import status
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.persona import get_prompt_by_id
|
||||
from onyx.db.persona import get_prompts
|
||||
from onyx.db.persona import mark_prompt_as_deleted
|
||||
from onyx.db.persona import upsert_prompt
|
||||
from onyx.server.features.prompt.models import CreatePromptRequest
|
||||
from onyx.server.features.prompt.models import PromptSnapshot
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
# Note: As prompts are fairly innocuous/harmless, there are no protections
|
||||
# to prevent users from messing with prompts of other users.
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_router = APIRouter(prefix="/prompt")
|
||||
|
||||
|
||||
def create_update_prompt(
|
||||
prompt_id: int | None,
|
||||
create_prompt_request: CreatePromptRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> PromptSnapshot:
|
||||
personas = (
|
||||
list(
|
||||
get_personas_by_ids(
|
||||
persona_ids=create_prompt_request.persona_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
if create_prompt_request.persona_ids
|
||||
else []
|
||||
)
|
||||
|
||||
prompt = upsert_prompt(
|
||||
prompt_id=prompt_id,
|
||||
user=user,
|
||||
name=create_prompt_request.name,
|
||||
description=create_prompt_request.description,
|
||||
system_prompt=create_prompt_request.system_prompt,
|
||||
task_prompt=create_prompt_request.task_prompt,
|
||||
include_citations=create_prompt_request.include_citations,
|
||||
datetime_aware=create_prompt_request.datetime_aware,
|
||||
personas=personas,
|
||||
db_session=db_session,
|
||||
)
|
||||
return PromptSnapshot.from_model(prompt)
|
||||
|
||||
|
||||
@basic_router.post("")
|
||||
def create_prompt(
|
||||
create_prompt_request: CreatePromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PromptSnapshot:
|
||||
try:
|
||||
return create_update_prompt(
|
||||
prompt_id=None,
|
||||
create_prompt_request=create_prompt_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError as ve:
|
||||
logger.exception(ve)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to create Persona, invalid info.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
@basic_router.patch("/{prompt_id}")
|
||||
def update_prompt(
|
||||
prompt_id: int,
|
||||
update_prompt_request: CreatePromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PromptSnapshot:
|
||||
try:
|
||||
return create_update_prompt(
|
||||
prompt_id=prompt_id,
|
||||
create_prompt_request=update_prompt_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError as ve:
|
||||
logger.exception(ve)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to create Persona, invalid info.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
@basic_router.delete("/{prompt_id}")
|
||||
def delete_prompt(
|
||||
prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
mark_prompt_as_deleted(
|
||||
prompt_id=prompt_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@basic_router.get("")
|
||||
def list_prompts(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PromptSnapshot]:
|
||||
user_id = user.id if user is not None else None
|
||||
return [
|
||||
PromptSnapshot.from_model(prompt)
|
||||
for prompt in get_prompts(user_id=user_id, db_session=db_session)
|
||||
]
|
||||
|
||||
|
||||
@basic_router.get("/{prompt_id}")
|
||||
def get_prompt(
|
||||
prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PromptSnapshot:
|
||||
return PromptSnapshot.from_model(
|
||||
get_prompt_by_id(
|
||||
prompt_id=prompt_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
@@ -1,41 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.models import Prompt
|
||||
|
||||
|
||||
class CreatePromptRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
include_citations: bool = False
|
||||
datetime_aware: bool = False
|
||||
persona_ids: list[int] | None = None
|
||||
|
||||
|
||||
class PromptSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
include_citations: bool
|
||||
datetime_aware: bool
|
||||
default_prompt: bool
|
||||
# Not including persona info, not needed
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, prompt: Prompt) -> "PromptSnapshot":
|
||||
if prompt.deleted:
|
||||
raise ValueError("Prompt has been deleted")
|
||||
|
||||
return PromptSnapshot(
|
||||
id=prompt.id,
|
||||
name=prompt.name,
|
||||
description=prompt.description,
|
||||
system_prompt=prompt.system_prompt,
|
||||
task_prompt=prompt.task_prompt,
|
||||
include_citations=prompt.include_citations,
|
||||
datetime_aware=prompt.datetime_aware,
|
||||
default_prompt=prompt.default_prompt,
|
||||
)
|
@@ -18,7 +18,7 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import get_personas_for_user
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.persona import upsert_prompt
|
||||
from onyx.db.prompts import upsert_prompt
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
@@ -479,7 +479,10 @@ def get_max_document_tokens(
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
|
||||
return MaxSelectedDocumentTokens(
|
||||
max_tokens=compute_max_document_tokens_for_persona(persona),
|
||||
max_tokens=compute_max_document_tokens_for_persona(
|
||||
db_session=db_session,
|
||||
persona=persona,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,9 +1,11 @@
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestPersona
|
||||
@@ -16,6 +18,9 @@ class PersonaManager:
|
||||
def create(
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
include_citations: bool = False,
|
||||
num_chunks: float = 5,
|
||||
llm_relevance_filter: bool = True,
|
||||
is_public: bool = True,
|
||||
@@ -28,32 +33,38 @@ class PersonaManager:
|
||||
llm_model_version_override: str | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
category_id: int | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
system_prompt = system_prompt or f"System prompt for {name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {name}"
|
||||
|
||||
persona_creation_request = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"num_chunks": num_chunks,
|
||||
"llm_relevance_filter": llm_relevance_filter,
|
||||
"is_public": is_public,
|
||||
"llm_filter_extraction": llm_filter_extraction,
|
||||
"recency_bias": recency_bias,
|
||||
"prompt_ids": prompt_ids or [0],
|
||||
"document_set_ids": document_set_ids or [],
|
||||
"tool_ids": tool_ids or [],
|
||||
"llm_model_provider_override": llm_model_provider_override,
|
||||
"llm_model_version_override": llm_model_version_override,
|
||||
"users": users or [],
|
||||
"groups": groups or [],
|
||||
}
|
||||
persona_creation_request = PersonaUpsertRequest(
|
||||
name=name,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
include_citations=include_citations,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
prompt_ids=prompt_ids or [0],
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
users=[UUID(user) for user in (users or [])],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
json=persona_creation_request,
|
||||
json=persona_creation_request.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
@@ -77,6 +88,7 @@ class PersonaManager:
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -84,6 +96,9 @@ class PersonaManager:
|
||||
persona: DATestPersona,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
include_citations: bool = False,
|
||||
num_chunks: float | None = None,
|
||||
llm_relevance_filter: bool | None = None,
|
||||
is_public: bool | None = None,
|
||||
@@ -96,32 +111,38 @@ class PersonaManager:
|
||||
llm_model_version_override: str | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestPersona:
|
||||
persona_update_request = {
|
||||
"name": name or persona.name,
|
||||
"description": description or persona.description,
|
||||
"num_chunks": num_chunks or persona.num_chunks,
|
||||
"llm_relevance_filter": llm_relevance_filter
|
||||
or persona.llm_relevance_filter,
|
||||
"is_public": is_public or persona.is_public,
|
||||
"llm_filter_extraction": llm_filter_extraction
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
persona_update_request = PersonaUpsertRequest(
|
||||
name=name or persona.name,
|
||||
description=description or persona.description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
include_citations=include_citations,
|
||||
num_chunks=num_chunks or persona.num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
|
||||
is_public=is_public or persona.is_public,
|
||||
llm_filter_extraction=llm_filter_extraction
|
||||
or persona.llm_filter_extraction,
|
||||
"recency_bias": recency_bias or persona.recency_bias,
|
||||
"prompt_ids": prompt_ids or persona.prompt_ids,
|
||||
"document_set_ids": document_set_ids or persona.document_set_ids,
|
||||
"tool_ids": tool_ids or persona.tool_ids,
|
||||
"llm_model_provider_override": llm_model_provider_override
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
prompt_ids=prompt_ids or persona.prompt_ids,
|
||||
document_set_ids=document_set_ids or persona.document_set_ids,
|
||||
tool_ids=tool_ids or persona.tool_ids,
|
||||
llm_model_provider_override=llm_model_provider_override
|
||||
or persona.llm_model_provider_override,
|
||||
"llm_model_version_override": llm_model_version_override
|
||||
llm_model_version_override=llm_model_version_override
|
||||
or persona.llm_model_version_override,
|
||||
"users": users or persona.users,
|
||||
"groups": groups or persona.groups,
|
||||
}
|
||||
users=[UUID(user) for user in (users or persona.users)],
|
||||
groups=groups or persona.groups,
|
||||
label_ids=label_ids or persona.label_ids,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=persona_update_request,
|
||||
json=persona_update_request.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
@@ -137,8 +158,8 @@ class PersonaManager:
|
||||
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
|
||||
is_public=updated_persona_data["is_public"],
|
||||
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
|
||||
recency_bias=updated_persona_data["recency_bias"],
|
||||
prompt_ids=updated_persona_data["prompts"],
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
prompt_ids=[prompt["id"] for prompt in updated_persona_data["prompts"]],
|
||||
document_set_ids=updated_persona_data["document_sets"],
|
||||
tool_ids=updated_persona_data["tools"],
|
||||
llm_model_provider_override=updated_persona_data[
|
||||
@@ -149,6 +170,7 @@ class PersonaManager:
|
||||
],
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
label_ids=updated_persona_data["labels"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -164,12 +186,29 @@ class PersonaManager:
|
||||
response.raise_for_status()
|
||||
return [PersonaSnapshot(**persona) for persona in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def get_one(
|
||||
persona_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[PersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona_id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [PersonaSnapshot(**response.json())]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
all_personas = PersonaManager.get_all(user_performing_action)
|
||||
all_personas = PersonaManager.get_one(
|
||||
persona_id=persona.id,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
for fetched_persona in all_personas:
|
||||
if fetched_persona.id == persona.id:
|
||||
return (
|
||||
@@ -199,6 +238,7 @@ class PersonaManager:
|
||||
and set(user.email for user in fetched_persona.users)
|
||||
== set(persona.users)
|
||||
and set(fetched_persona.groups) == set(persona.groups)
|
||||
and set(fetched_persona.labels) == set(persona.label_ids)
|
||||
)
|
||||
return False
|
||||
|
||||
|
@@ -127,7 +127,7 @@ class DATestPersona(BaseModel):
|
||||
llm_model_version_override: str | None
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
category_id: int | None = None
|
||||
label_ids: list[int]
|
||||
|
||||
|
||||
#
|
||||
|
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
This file tests the permissions for creating and editing personas for different user roles:
|
||||
- Basic users can create personas and edit their own
|
||||
- Curators can edit personas that belong exclusively to groups they curate
|
||||
- Admins can edit all personas
|
||||
"""
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.managers.user import DATestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def test_persona_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Creating a curator user
|
||||
curator: DATestUser = UserManager.create(name="curator")
|
||||
|
||||
# Creating a basic user
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user")
|
||||
|
||||
# Creating user groups
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="curated_user_group",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
# Setting the user as a curator for the user group
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=user_group_1,
|
||||
user_to_set_as_curator=curator,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Creating another user group that the user is not a curator of
|
||||
user_group_2 = UserGroupManager.create(
|
||||
name="uncurated_user_group",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_2], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
"""Test that any user can create a persona"""
|
||||
# Basic user creates a persona
|
||||
basic_user_persona = PersonaManager.create(
|
||||
name="basic_user_persona",
|
||||
description="A persona created by basic user",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
|
||||
|
||||
# Curator creates a persona
|
||||
curator_persona = PersonaManager.create(
|
||||
name="curator_persona",
|
||||
description="A persona created by curator",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
PersonaManager.verify(curator_persona, user_performing_action=curator)
|
||||
|
||||
# Admin creates personas for different groups
|
||||
admin_persona_group_1 = PersonaManager.create(
|
||||
name="admin_persona_group_1",
|
||||
description="A persona for group 1",
|
||||
is_public=False,
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
admin_persona_group_2 = PersonaManager.create(
|
||||
name="admin_persona_group_2",
|
||||
description="A persona for group 2",
|
||||
is_public=False,
|
||||
groups=[user_group_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
admin_persona_both_groups = PersonaManager.create(
|
||||
name="admin_persona_both_groups",
|
||||
description="A persona for both groups",
|
||||
is_public=False,
|
||||
groups=[user_group_1.id, user_group_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
"""Test that users can edit their own personas"""
|
||||
# Basic user can edit their own persona
|
||||
PersonaManager.edit(
|
||||
persona=basic_user_persona,
|
||||
description="Updated description by basic user",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
|
||||
|
||||
# Basic user cannot edit other's personas
|
||||
with pytest.raises(HTTPError):
|
||||
PersonaManager.edit(
|
||||
persona=curator_persona,
|
||||
description="Invalid edit by basic user",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
"""Test curator permissions"""
|
||||
# Curator can edit personas that belong exclusively to groups they curate
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_group_1,
|
||||
description="Updated by curator",
|
||||
user_performing_action=curator,
|
||||
)
|
||||
PersonaManager.verify(admin_persona_group_1, user_performing_action=curator)
|
||||
|
||||
# Curator cannot edit personas in groups they don't curate
|
||||
with pytest.raises(HTTPError):
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_group_2,
|
||||
description="Invalid edit by curator",
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curator cannot edit personas that belong to multiple groups, even if they curate one
|
||||
with pytest.raises(HTTPError):
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_both_groups,
|
||||
description="Invalid edit by curator",
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
"""Test admin permissions"""
|
||||
# Admin can edit any persona
|
||||
PersonaManager.edit(
|
||||
persona=basic_user_persona,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=admin_user)
|
||||
|
||||
PersonaManager.edit(
|
||||
persona=curator_persona,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(curator_persona, user_performing_action=admin_user)
|
||||
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_group_1,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(admin_persona_group_1, user_performing_action=admin_user)
|
||||
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_group_2,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(admin_persona_group_2, user_performing_action=admin_user)
|
||||
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_both_groups,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(admin_persona_both_groups, user_performing_action=admin_user)
|
Reference in New Issue
Block a user