mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-18 13:51:46 +01:00
782 lines
27 KiB
Python
782 lines
27 KiB
Python
from collections.abc import Sequence
|
|
from datetime import datetime
|
|
from uuid import UUID
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import delete
|
|
from sqlalchemy import exists
|
|
from sqlalchemy import func
|
|
from sqlalchemy import not_
|
|
from sqlalchemy import Select
|
|
from sqlalchemy import select
|
|
from sqlalchemy import update
|
|
from sqlalchemy.orm import aliased
|
|
from sqlalchemy.orm import joinedload
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.schemas import UserRole
|
|
from onyx.configs.app_configs import DISABLE_AUTH
|
|
from onyx.configs.chat_configs import BING_API_KEY
|
|
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
|
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
|
from onyx.configs.constants import NotificationType
|
|
from onyx.context.search.enums import RecencyBiasSetting
|
|
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
|
from onyx.db.models import DocumentSet
|
|
from onyx.db.models import Persona
|
|
from onyx.db.models import Persona__User
|
|
from onyx.db.models import Persona__UserGroup
|
|
from onyx.db.models import PersonaLabel
|
|
from onyx.db.models import Prompt
|
|
from onyx.db.models import StarterMessage
|
|
from onyx.db.models import Tool
|
|
from onyx.db.models import User
|
|
from onyx.db.models import User__UserGroup
|
|
from onyx.db.models import UserGroup
|
|
from onyx.db.notification import create_notification
|
|
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
|
from onyx.server.features.persona.models import PersonaSnapshot
|
|
from onyx.server.features.persona.models import PersonaUpsertRequest
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def _add_user_filters(
|
|
stmt: Select, user: User | None, get_editable: bool = True
|
|
) -> Select:
|
|
# If user is None and auth is disabled, assume the user is an admin
|
|
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
|
return stmt
|
|
|
|
stmt = stmt.distinct()
|
|
Persona__UG = aliased(Persona__UserGroup)
|
|
User__UG = aliased(User__UserGroup)
|
|
"""
|
|
Here we select cc_pairs by relation:
|
|
User -> User__UserGroup -> Persona__UserGroup -> Persona
|
|
"""
|
|
stmt = (
|
|
stmt.outerjoin(Persona__UG)
|
|
.outerjoin(
|
|
User__UserGroup,
|
|
User__UserGroup.user_group_id == Persona__UG.user_group_id,
|
|
)
|
|
.outerjoin(
|
|
Persona__User,
|
|
Persona__User.persona_id == Persona.id,
|
|
)
|
|
)
|
|
"""
|
|
Filter Personas by:
|
|
- if the user is in the user_group that owns the Persona
|
|
- if the user is not a global_curator, they must also have a curator relationship
|
|
to the user_group
|
|
- if editing is being done, we also filter out Personas that are owned by groups
|
|
that the user isn't a curator for
|
|
- if we are not editing, we show all Personas in the groups the user is a curator
|
|
for (as well as public Personas)
|
|
- if we are not editing, we return all Personas directly connected to the user
|
|
"""
|
|
|
|
# If user is None, this is an anonymous user and we should only show public Personas
|
|
if user is None:
|
|
where_clause = Persona.is_public == True # noqa: E712
|
|
return stmt.where(where_clause)
|
|
|
|
where_clause = User__UserGroup.user_id == user.id
|
|
if user.role == UserRole.CURATOR and get_editable:
|
|
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
|
if get_editable:
|
|
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
|
|
if user.role == UserRole.CURATOR:
|
|
user_groups = user_groups.where(User__UG.is_curator == True) # noqa: E712
|
|
where_clause &= (
|
|
~exists()
|
|
.where(Persona__UG.persona_id == Persona.id)
|
|
.where(~Persona__UG.user_group_id.in_(user_groups))
|
|
.correlate(Persona)
|
|
)
|
|
else:
|
|
# Group the public persona conditions
|
|
public_condition = (Persona.is_public == True) & ( # noqa: E712
|
|
Persona.is_visible == True # noqa: E712
|
|
)
|
|
|
|
where_clause |= public_condition
|
|
where_clause |= Persona__User.user_id == user.id
|
|
|
|
where_clause |= Persona.user_id == user.id
|
|
|
|
return stmt.where(where_clause)
|
|
|
|
|
|
def fetch_persona_by_id_for_user(
|
|
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
|
|
) -> Persona:
|
|
stmt = select(Persona).where(Persona.id == persona_id).distinct()
|
|
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
|
|
persona = db_session.scalars(stmt).one_or_none()
|
|
if not persona:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Persona with ID {persona_id} does not exist or user is not authorized to access it",
|
|
)
|
|
return persona
|
|
|
|
|
|
def get_best_persona_id_for_user(
|
|
db_session: Session, user: User | None, persona_id: int | None = None
|
|
) -> int | None:
|
|
if persona_id is not None:
|
|
stmt = select(Persona).where(Persona.id == persona_id).distinct()
|
|
stmt = _add_user_filters(
|
|
stmt=stmt,
|
|
user=user,
|
|
# We don't want to filter by editable here, we just want to see if the
|
|
# persona is usable by the user
|
|
get_editable=False,
|
|
)
|
|
persona = db_session.scalars(stmt).one_or_none()
|
|
if persona:
|
|
return persona.id
|
|
|
|
# If the persona is not found, or the slack bot is using doc sets instead of personas,
|
|
# we need to find the best persona for the user
|
|
# This is the persona with the highest display priority that the user has access to
|
|
stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct()
|
|
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True)
|
|
persona = db_session.scalars(stmt).one_or_none()
|
|
return persona.id if persona else None
|
|
|
|
|
|
def _get_persona_by_name(
|
|
persona_name: str, user: User | None, db_session: Session
|
|
) -> Persona | None:
|
|
"""Admins can see all, regular users can only fetch their own.
|
|
If user is None, assume the user is an admin or auth is disabled."""
|
|
stmt = select(Persona).where(Persona.name == persona_name)
|
|
if user and user.role != UserRole.ADMIN:
|
|
stmt = stmt.where(Persona.user_id == user.id)
|
|
result = db_session.execute(stmt).scalar_one_or_none()
|
|
return result
|
|
|
|
|
|
def make_persona_private(
|
|
persona_id: int,
|
|
user_ids: list[UUID] | None,
|
|
group_ids: list[int] | None,
|
|
db_session: Session,
|
|
) -> None:
|
|
if user_ids is not None:
|
|
db_session.query(Persona__User).filter(
|
|
Persona__User.persona_id == persona_id
|
|
).delete(synchronize_session="fetch")
|
|
|
|
for user_uuid in user_ids:
|
|
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
|
|
|
|
create_notification(
|
|
user_id=user_uuid,
|
|
notif_type=NotificationType.PERSONA_SHARED,
|
|
db_session=db_session,
|
|
additional_data=PersonaSharedNotificationData(
|
|
persona_id=persona_id,
|
|
).model_dump(),
|
|
)
|
|
|
|
db_session.commit()
|
|
|
|
# May cause error if someone switches down to MIT from EE
|
|
if group_ids:
|
|
raise NotImplementedError("Onyx MIT does not support private Personas")
|
|
|
|
|
|
def create_update_persona(
|
|
persona_id: int | None,
|
|
create_persona_request: PersonaUpsertRequest,
|
|
user: User | None,
|
|
db_session: Session,
|
|
) -> PersonaSnapshot:
|
|
"""Higher level function than upsert_persona, although either is valid to use."""
|
|
# Permission to actually use these is checked later
|
|
|
|
try:
|
|
all_prompt_ids = create_persona_request.prompt_ids
|
|
|
|
if not all_prompt_ids:
|
|
raise ValueError("No prompt IDs provided")
|
|
|
|
is_default_persona: bool | None = create_persona_request.is_default_persona
|
|
# Default persona validation
|
|
if create_persona_request.is_default_persona:
|
|
if not create_persona_request.is_public:
|
|
raise ValueError("Cannot make a default persona non public")
|
|
|
|
if user:
|
|
# Curators can edit default personas, but not make them
|
|
if (
|
|
user.role == UserRole.CURATOR
|
|
or user.role == UserRole.GLOBAL_CURATOR
|
|
):
|
|
is_default_persona = None
|
|
elif user.role != UserRole.ADMIN:
|
|
raise ValueError("Only admins can make a default persona")
|
|
|
|
persona = upsert_persona(
|
|
persona_id=persona_id,
|
|
user=user,
|
|
db_session=db_session,
|
|
description=create_persona_request.description,
|
|
name=create_persona_request.name,
|
|
prompt_ids=all_prompt_ids,
|
|
document_set_ids=create_persona_request.document_set_ids,
|
|
tool_ids=create_persona_request.tool_ids,
|
|
is_public=create_persona_request.is_public,
|
|
recency_bias=create_persona_request.recency_bias,
|
|
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
|
llm_model_version_override=create_persona_request.llm_model_version_override,
|
|
starter_messages=create_persona_request.starter_messages,
|
|
icon_color=create_persona_request.icon_color,
|
|
icon_shape=create_persona_request.icon_shape,
|
|
uploaded_image_id=create_persona_request.uploaded_image_id,
|
|
display_priority=create_persona_request.display_priority,
|
|
remove_image=create_persona_request.remove_image,
|
|
search_start_date=create_persona_request.search_start_date,
|
|
label_ids=create_persona_request.label_ids,
|
|
num_chunks=create_persona_request.num_chunks,
|
|
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
|
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
|
is_default_persona=is_default_persona,
|
|
)
|
|
|
|
versioned_make_persona_private = fetch_versioned_implementation(
|
|
"onyx.db.persona", "make_persona_private"
|
|
)
|
|
|
|
# Privatize Persona
|
|
versioned_make_persona_private(
|
|
persona_id=persona.id,
|
|
user_ids=create_persona_request.users,
|
|
group_ids=create_persona_request.groups,
|
|
db_session=db_session,
|
|
)
|
|
|
|
except ValueError as e:
|
|
logger.exception("Failed to create persona")
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
return PersonaSnapshot.from_model(persona)
|
|
|
|
|
|
def update_persona_shared_users(
|
|
persona_id: int,
|
|
user_ids: list[UUID],
|
|
user: User | None,
|
|
db_session: Session,
|
|
) -> None:
|
|
"""Simplified version of `create_update_persona` which only touches the
|
|
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
|
etc.)."""
|
|
persona = fetch_persona_by_id_for_user(
|
|
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
|
)
|
|
|
|
if persona.is_public:
|
|
raise HTTPException(status_code=400, detail="Cannot share public persona")
|
|
|
|
versioned_make_persona_private = fetch_versioned_implementation(
|
|
"onyx.db.persona", "make_persona_private"
|
|
)
|
|
|
|
# Privatize Persona
|
|
versioned_make_persona_private(
|
|
persona_id=persona_id,
|
|
user_ids=user_ids,
|
|
group_ids=None,
|
|
db_session=db_session,
|
|
)
|
|
|
|
|
|
def update_persona_public_status(
|
|
persona_id: int,
|
|
is_public: bool,
|
|
db_session: Session,
|
|
user: User | None,
|
|
) -> None:
|
|
persona = fetch_persona_by_id_for_user(
|
|
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
|
)
|
|
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
|
raise ValueError("You don't have permission to modify this persona")
|
|
|
|
persona.is_public = is_public
|
|
db_session.commit()
|
|
|
|
|
|
def get_personas_for_user(
|
|
# if user is `None` assume the user is an admin or auth is disabled
|
|
user: User | None,
|
|
db_session: Session,
|
|
get_editable: bool = True,
|
|
include_default: bool = True,
|
|
include_slack_bot_personas: bool = False,
|
|
include_deleted: bool = False,
|
|
joinedload_all: bool = False,
|
|
) -> Sequence[Persona]:
|
|
stmt = select(Persona)
|
|
stmt = _add_user_filters(stmt, user, get_editable)
|
|
|
|
if not include_default:
|
|
stmt = stmt.where(Persona.builtin_persona.is_(False))
|
|
if not include_slack_bot_personas:
|
|
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
|
|
if not include_deleted:
|
|
stmt = stmt.where(Persona.deleted.is_(False))
|
|
|
|
if joinedload_all:
|
|
stmt = stmt.options(
|
|
selectinload(Persona.prompts),
|
|
selectinload(Persona.tools),
|
|
selectinload(Persona.document_sets),
|
|
selectinload(Persona.groups),
|
|
selectinload(Persona.users),
|
|
selectinload(Persona.labels),
|
|
)
|
|
|
|
results = db_session.execute(stmt).scalars().all()
|
|
return results
|
|
|
|
|
|
def get_personas(db_session: Session) -> Sequence[Persona]:
|
|
stmt = select(Persona).distinct()
|
|
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
|
|
stmt = stmt.where(Persona.deleted.is_(False))
|
|
return db_session.execute(stmt).unique().scalars().all()
|
|
|
|
|
|
def mark_persona_as_deleted(
|
|
persona_id: int,
|
|
user: User | None,
|
|
db_session: Session,
|
|
) -> None:
|
|
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
|
|
persona.deleted = True
|
|
db_session.commit()
|
|
|
|
|
|
def mark_persona_as_not_deleted(
|
|
persona_id: int,
|
|
user: User | None,
|
|
db_session: Session,
|
|
) -> None:
|
|
persona = get_persona_by_id(
|
|
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
|
|
)
|
|
if persona.deleted:
|
|
persona.deleted = False
|
|
db_session.commit()
|
|
else:
|
|
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
|
|
|
|
|
|
def mark_delete_persona_by_name(
|
|
persona_name: str, db_session: Session, is_default: bool = True
|
|
) -> None:
|
|
stmt = (
|
|
update(Persona)
|
|
.where(Persona.name == persona_name, Persona.builtin_persona == is_default)
|
|
.values(deleted=True)
|
|
)
|
|
|
|
db_session.execute(stmt)
|
|
db_session.commit()
|
|
|
|
|
|
def update_all_personas_display_priority(
|
|
display_priority_map: dict[int, int],
|
|
db_session: Session,
|
|
) -> None:
|
|
"""Updates the display priority of all lives Personas"""
|
|
personas = get_personas(db_session=db_session)
|
|
available_persona_ids = {persona.id for persona in personas}
|
|
if available_persona_ids != set(display_priority_map.keys()):
|
|
raise ValueError("Invalid persona IDs provided")
|
|
|
|
for persona in personas:
|
|
persona.display_priority = display_priority_map[persona.id]
|
|
db_session.commit()
|
|
|
|
|
|
def upsert_persona(
|
|
user: User | None,
|
|
name: str,
|
|
description: str,
|
|
num_chunks: float,
|
|
llm_relevance_filter: bool,
|
|
llm_filter_extraction: bool,
|
|
recency_bias: RecencyBiasSetting,
|
|
llm_model_provider_override: str | None,
|
|
llm_model_version_override: str | None,
|
|
starter_messages: list[StarterMessage] | None,
|
|
is_public: bool,
|
|
db_session: Session,
|
|
prompt_ids: list[int] | None = None,
|
|
document_set_ids: list[int] | None = None,
|
|
tool_ids: list[int] | None = None,
|
|
persona_id: int | None = None,
|
|
commit: bool = True,
|
|
icon_color: str | None = None,
|
|
icon_shape: int | None = None,
|
|
uploaded_image_id: str | None = None,
|
|
display_priority: int | None = None,
|
|
is_visible: bool = True,
|
|
remove_image: bool | None = None,
|
|
search_start_date: datetime | None = None,
|
|
builtin_persona: bool = False,
|
|
is_default_persona: bool | None = None,
|
|
label_ids: list[int] | None = None,
|
|
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
|
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
|
) -> Persona:
|
|
"""
|
|
NOTE: This operation cannot update persona configuration options that
|
|
are core to the persona, such as its display priority and
|
|
whether or not the assistant is a built-in / default assistant
|
|
"""
|
|
|
|
if persona_id is not None:
|
|
existing_persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
|
else:
|
|
existing_persona = _get_persona_by_name(
|
|
persona_name=name, user=user, db_session=db_session
|
|
)
|
|
|
|
if existing_persona:
|
|
# this checks if the user has permission to edit the persona
|
|
# will raise an Exception if the user does not have permission
|
|
existing_persona = fetch_persona_by_id_for_user(
|
|
db_session=db_session,
|
|
persona_id=existing_persona.id,
|
|
user=user,
|
|
get_editable=True,
|
|
)
|
|
# Fetch and attach tools by IDs
|
|
tools = None
|
|
if tool_ids is not None:
|
|
tools = db_session.query(Tool).filter(Tool.id.in_(tool_ids)).all()
|
|
if not tools and tool_ids:
|
|
raise ValueError("Tools not found")
|
|
|
|
# Fetch and attach document_sets by IDs
|
|
document_sets = None
|
|
if document_set_ids is not None:
|
|
document_sets = (
|
|
db_session.query(DocumentSet)
|
|
.filter(DocumentSet.id.in_(document_set_ids))
|
|
.all()
|
|
)
|
|
if not document_sets and document_set_ids:
|
|
raise ValueError("document_sets not found")
|
|
|
|
# Fetch and attach prompts by IDs
|
|
prompts = None
|
|
if prompt_ids is not None:
|
|
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
|
|
|
|
if prompts is not None and len(prompts) == 0:
|
|
raise ValueError(
|
|
f"Invalid Persona config, no valid prompts "
|
|
f"specified. Specified IDs were: '{prompt_ids}'"
|
|
)
|
|
|
|
labels = None
|
|
if label_ids is not None:
|
|
labels = (
|
|
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
|
)
|
|
|
|
# ensure all specified tools are valid
|
|
if tools:
|
|
validate_persona_tools(tools)
|
|
|
|
if existing_persona:
|
|
# Built-in personas can only be updated through YAML configuration.
|
|
# This ensures that core system personas are not modified unintentionally.
|
|
if existing_persona.builtin_persona and not builtin_persona:
|
|
raise ValueError("Cannot update builtin persona with non-builtin.")
|
|
|
|
# The following update excludes `default`, `built-in`, and display priority.
|
|
# Display priority is handled separately in the `display-priority` endpoint.
|
|
# `default` and `built-in` properties can only be set when creating a persona.
|
|
existing_persona.name = name
|
|
existing_persona.description = description
|
|
existing_persona.num_chunks = num_chunks
|
|
existing_persona.chunks_above = chunks_above
|
|
existing_persona.chunks_below = chunks_below
|
|
existing_persona.llm_relevance_filter = llm_relevance_filter
|
|
existing_persona.llm_filter_extraction = llm_filter_extraction
|
|
existing_persona.recency_bias = recency_bias
|
|
existing_persona.llm_model_provider_override = llm_model_provider_override
|
|
existing_persona.llm_model_version_override = llm_model_version_override
|
|
existing_persona.starter_messages = starter_messages
|
|
existing_persona.deleted = False # Un-delete if previously deleted
|
|
existing_persona.is_public = is_public
|
|
existing_persona.icon_color = icon_color
|
|
existing_persona.icon_shape = icon_shape
|
|
if remove_image or uploaded_image_id:
|
|
existing_persona.uploaded_image_id = uploaded_image_id
|
|
existing_persona.is_visible = is_visible
|
|
existing_persona.search_start_date = search_start_date
|
|
existing_persona.labels = labels or []
|
|
existing_persona.is_default_persona = (
|
|
is_default_persona
|
|
if is_default_persona is not None
|
|
else existing_persona.is_default_persona
|
|
)
|
|
# Do not delete any associations manually added unless
|
|
# a new updated list is provided
|
|
if document_sets is not None:
|
|
existing_persona.document_sets.clear()
|
|
existing_persona.document_sets = document_sets or []
|
|
|
|
if prompts is not None:
|
|
existing_persona.prompts.clear()
|
|
existing_persona.prompts = prompts
|
|
|
|
if tools is not None:
|
|
existing_persona.tools = tools or []
|
|
|
|
# We should only update display priority if it is not already set
|
|
if existing_persona.display_priority is None:
|
|
existing_persona.display_priority = display_priority
|
|
|
|
persona = existing_persona
|
|
|
|
else:
|
|
if not prompts:
|
|
raise ValueError(
|
|
"Invalid Persona config. "
|
|
"Must specify at least one prompt for a new persona."
|
|
)
|
|
|
|
new_persona = Persona(
|
|
id=persona_id,
|
|
user_id=user.id if user else None,
|
|
is_public=is_public,
|
|
name=name,
|
|
description=description,
|
|
num_chunks=num_chunks,
|
|
chunks_above=chunks_above,
|
|
chunks_below=chunks_below,
|
|
llm_relevance_filter=llm_relevance_filter,
|
|
llm_filter_extraction=llm_filter_extraction,
|
|
recency_bias=recency_bias,
|
|
builtin_persona=builtin_persona,
|
|
prompts=prompts,
|
|
document_sets=document_sets or [],
|
|
llm_model_provider_override=llm_model_provider_override,
|
|
llm_model_version_override=llm_model_version_override,
|
|
starter_messages=starter_messages,
|
|
tools=tools or [],
|
|
icon_shape=icon_shape,
|
|
icon_color=icon_color,
|
|
uploaded_image_id=uploaded_image_id,
|
|
display_priority=display_priority,
|
|
is_visible=is_visible,
|
|
search_start_date=search_start_date,
|
|
is_default_persona=is_default_persona
|
|
if is_default_persona is not None
|
|
else False,
|
|
labels=labels or [],
|
|
)
|
|
db_session.add(new_persona)
|
|
persona = new_persona
|
|
if commit:
|
|
db_session.commit()
|
|
else:
|
|
# flush the session so that the persona has an ID
|
|
db_session.flush()
|
|
|
|
return persona
|
|
|
|
|
|
def delete_old_default_personas(
|
|
db_session: Session,
|
|
) -> None:
|
|
"""Note, this locks out the Summarize and Paraphrase personas for now
|
|
Need a more graceful fix later or those need to never have IDs"""
|
|
stmt = (
|
|
update(Persona)
|
|
.where(Persona.builtin_persona, Persona.id > 0)
|
|
.values(deleted=True, name=func.concat(Persona.name, "_old"))
|
|
)
|
|
|
|
db_session.execute(stmt)
|
|
db_session.commit()
|
|
|
|
|
|
def update_persona_is_default(
|
|
persona_id: int,
|
|
is_default: bool,
|
|
db_session: Session,
|
|
user: User | None = None,
|
|
) -> None:
|
|
persona = fetch_persona_by_id_for_user(
|
|
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
|
)
|
|
|
|
if not persona.is_public:
|
|
persona.is_public = True
|
|
|
|
persona.is_default_persona = is_default
|
|
db_session.commit()
|
|
|
|
|
|
def update_persona_visibility(
|
|
persona_id: int,
|
|
is_visible: bool,
|
|
db_session: Session,
|
|
user: User | None = None,
|
|
) -> None:
|
|
persona = fetch_persona_by_id_for_user(
|
|
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
|
)
|
|
|
|
persona.is_visible = is_visible
|
|
db_session.commit()
|
|
|
|
|
|
def validate_persona_tools(tools: list[Tool]) -> None:
|
|
for tool in tools:
|
|
if tool.name == "InternetSearchTool" and not BING_API_KEY:
|
|
raise ValueError(
|
|
"Bing API key not found, please contact your Onyx admin to get it added!"
|
|
)
|
|
|
|
|
|
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
|
|
# a direct mapping indicating whether a user has access to a specific persona?
|
|
def get_persona_by_id(
|
|
persona_id: int,
|
|
# if user is `None` assume the user is an admin or auth is disabled
|
|
user: User | None,
|
|
db_session: Session,
|
|
include_deleted: bool = False,
|
|
is_for_edit: bool = True, # NOTE: assume true for safety
|
|
) -> Persona:
|
|
persona_stmt = (
|
|
select(Persona)
|
|
.distinct()
|
|
.outerjoin(Persona.groups)
|
|
.outerjoin(Persona.users)
|
|
.outerjoin(UserGroup.user_group_relationships)
|
|
.where(Persona.id == persona_id)
|
|
)
|
|
|
|
if not include_deleted:
|
|
persona_stmt = persona_stmt.where(Persona.deleted.is_(False))
|
|
|
|
if not user or user.role == UserRole.ADMIN:
|
|
result = db_session.execute(persona_stmt)
|
|
persona = result.scalar_one_or_none()
|
|
if persona is None:
|
|
raise ValueError(f"Persona with ID {persona_id} does not exist")
|
|
return persona
|
|
|
|
# or check if user owns persona
|
|
or_conditions = Persona.user_id == user.id
|
|
# allow access if persona user id is None
|
|
or_conditions |= Persona.user_id == None # noqa: E711
|
|
if not is_for_edit:
|
|
# if the user is in a group related to the persona
|
|
or_conditions |= User__UserGroup.user_id == user.id
|
|
# if the user is in the .users of the persona
|
|
or_conditions |= User.id == user.id
|
|
or_conditions |= Persona.is_public == True # noqa: E712
|
|
elif user.role == UserRole.GLOBAL_CURATOR:
|
|
# global curators can edit personas for the groups they are in
|
|
or_conditions |= User__UserGroup.user_id == user.id
|
|
elif user.role == UserRole.CURATOR:
|
|
# curators can edit personas for the groups they are curators of
|
|
or_conditions |= (User__UserGroup.user_id == user.id) & (
|
|
User__UserGroup.is_curator == True # noqa: E712
|
|
)
|
|
|
|
persona_stmt = persona_stmt.where(or_conditions)
|
|
result = db_session.execute(persona_stmt)
|
|
persona = result.scalar_one_or_none()
|
|
if persona is None:
|
|
raise ValueError(
|
|
f"Persona with ID {persona_id} does not exist or does not belong to user"
|
|
)
|
|
return persona
|
|
|
|
|
|
def get_personas_by_ids(
|
|
persona_ids: list[int], db_session: Session
|
|
) -> Sequence[Persona]:
|
|
"""Unsafe, can fetch personas from all users"""
|
|
if not persona_ids:
|
|
return []
|
|
personas = db_session.scalars(
|
|
select(Persona).where(Persona.id.in_(persona_ids))
|
|
).all()
|
|
|
|
return personas
|
|
|
|
|
|
def delete_persona_by_name(
|
|
persona_name: str, db_session: Session, is_default: bool = True
|
|
) -> None:
|
|
stmt = delete(Persona).where(
|
|
Persona.name == persona_name, Persona.builtin_persona == is_default
|
|
)
|
|
|
|
db_session.execute(stmt)
|
|
db_session.commit()
|
|
|
|
|
|
def get_assistant_labels(db_session: Session) -> list[PersonaLabel]:
|
|
return db_session.query(PersonaLabel).all()
|
|
|
|
|
|
def create_assistant_label(db_session: Session, name: str) -> PersonaLabel:
|
|
label = PersonaLabel(name=name)
|
|
db_session.add(label)
|
|
db_session.commit()
|
|
return label
|
|
|
|
|
|
def update_persona_label(
|
|
label_id: int,
|
|
label_name: str,
|
|
db_session: Session,
|
|
) -> None:
|
|
persona_label = (
|
|
db_session.query(PersonaLabel).filter(PersonaLabel.id == label_id).one_or_none()
|
|
)
|
|
if persona_label is None:
|
|
raise ValueError(f"Persona label with ID {label_id} does not exist")
|
|
persona_label.name = label_name
|
|
db_session.commit()
|
|
|
|
|
|
def delete_persona_label(label_id: int, db_session: Session) -> None:
|
|
db_session.query(PersonaLabel).filter(PersonaLabel.id == label_id).delete()
|
|
db_session.commit()
|
|
|
|
|
|
def persona_has_search_tool(persona_id: int, db_session: Session) -> bool:
|
|
persona = (
|
|
db_session.query(Persona)
|
|
.options(joinedload(Persona.tools))
|
|
.filter(Persona.id == persona_id)
|
|
.one_or_none()
|
|
)
|
|
if persona is None:
|
|
raise ValueError(f"Persona with ID {persona_id} does not exist")
|
|
return any(tool.in_code_tool_id == "run_search" for tool in persona.tools)
|