diff --git a/backend/alembic/versions/47e5bef3a1d7_add_persona_categories.py b/backend/alembic/versions/47e5bef3a1d7_add_persona_categories.py
new file mode 100644
index 000000000..432e0ab42
--- /dev/null
+++ b/backend/alembic/versions/47e5bef3a1d7_add_persona_categories.py
@@ -0,0 +1,45 @@
+"""add persona categories
+
+Revision ID: 47e5bef3a1d7
+Revises: dfbe9e93d3c7
+Create Date: 2024-11-05 18:55:02.221064
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = "47e5bef3a1d7"
+down_revision = "dfbe9e93d3c7"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # Create the persona_category table
+ op.create_table(
+ "persona_category",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(), nullable=False),
+ sa.Column("description", sa.String(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("name"),
+ )
+
+ # Add category_id to persona table
+ op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
+ op.create_foreign_key(
+ "fk_persona_category",
+ "persona",
+ "persona_category",
+ ["category_id"],
+ ["id"],
+ ondelete="SET NULL",
+ )
+
+
+def downgrade() -> None:
+ op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
+ op.drop_column("persona", "category_id")
+ op.drop_table("persona_category")
diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py
index 831b8eaf6..ec9e4f82e 100644
--- a/backend/danswer/db/models.py
+++ b/backend/danswer/db/models.py
@@ -1363,6 +1363,9 @@ class Persona(Base):
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
Enum(RecencyBiasSetting, native_enum=False)
)
+ category_id: Mapped[int | None] = mapped_column(
+ ForeignKey("persona_category.id"), nullable=True
+ )
# Allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
@@ -1434,6 +1437,9 @@ class Persona(Base):
secondary="persona__user_group",
viewonly=True,
)
+ category: Mapped["PersonaCategory"] = relationship(
+ "PersonaCategory", back_populates="personas"
+ )
# Default personas loaded via yaml cannot have the same name
__table_args__ = (
@@ -1446,6 +1452,17 @@ class Persona(Base):
)
+class PersonaCategory(Base):
+ __tablename__ = "persona_category"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str] = mapped_column(String, unique=True)
+ description: Mapped[str | None] = mapped_column(String, nullable=True)
+ personas: Mapped[list["Persona"]] = relationship(
+ "Persona", back_populates="category"
+ )
+
+
AllowedAnswerFilters = (
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
)
diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py
index 364c2ccc1..6446e7382 100644
--- a/backend/danswer/db/persona.py
+++ b/backend/danswer/db/persona.py
@@ -26,6 +26,7 @@ from danswer.db.models import DocumentSet
from danswer.db.models import Persona
from danswer.db.models import Persona__User
from danswer.db.models import Persona__UserGroup
+from danswer.db.models import PersonaCategory
from danswer.db.models import Prompt
from danswer.db.models import StarterMessage
from danswer.db.models import Tool
@@ -417,6 +418,7 @@ def upsert_persona(
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool = False,
+ category_id: int | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
@@ -487,7 +489,7 @@ def upsert_persona(
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.is_default_persona = is_default_persona
-
+ persona.category_id = category_id
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@@ -528,6 +530,7 @@ def upsert_persona(
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona,
+ category_id=category_id,
)
db_session.add(persona)
@@ -744,3 +747,39 @@ def delete_persona_by_name(
db_session.execute(stmt)
db_session.commit()
+
+
+def get_assistant_categories(db_session: Session) -> list[PersonaCategory]:
+ return db_session.query(PersonaCategory).all()
+
+
+def create_assistant_category(
+ db_session: Session, name: str, description: str
+) -> PersonaCategory:
+ category = PersonaCategory(name=name, description=description)
+ db_session.add(category)
+ db_session.commit()
+ return category
+
+
+def update_persona_category(
+ category_id: int,
+ category_description: str,
+ category_name: str,
+ db_session: Session,
+) -> None:
+ persona_category = (
+ db_session.query(PersonaCategory)
+ .filter(PersonaCategory.id == category_id)
+ .one_or_none()
+ )
+ if persona_category is None:
+ raise ValueError(f"Persona category with ID {category_id} does not exist")
+ persona_category.description = category_description
+ persona_category.name = category_name
+ db_session.commit()
+
+
+def delete_persona_category(category_id: int, db_session: Session) -> None:
+ db_session.query(PersonaCategory).filter(PersonaCategory.id == category_id).delete()
+ db_session.commit()
diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py
index c7d3d79b1..c5cfa07ad 100644
--- a/backend/danswer/server/features/persona/api.py
+++ b/backend/danswer/server/features/persona/api.py
@@ -18,12 +18,16 @@ from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
+from danswer.db.persona import create_assistant_category
from danswer.db.persona import create_update_persona
+from danswer.db.persona import delete_persona_category
+from danswer.db.persona import get_assistant_categories
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import mark_persona_as_not_deleted
from danswer.db.persona import update_all_personas_display_priority
+from danswer.db.persona import update_persona_category
from danswer.db.persona import update_persona_public_status
from danswer.db.persona import update_persona_shared_users
from danswer.db.persona import update_persona_visibility
@@ -32,6 +36,8 @@ from danswer.file_store.models import ChatFileType
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import ImageGenerationToolStatus
+from danswer.server.features.persona.models import PersonaCategoryCreate
+from danswer.server.features.persona.models import PersonaCategoryResponse
from danswer.server.features.persona.models import PersonaSharedNotificationData
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse
@@ -39,6 +45,7 @@ from danswer.server.models import DisplayPriorityRequest
from danswer.tools.utils import is_image_generation_available
from danswer.utils.logger import setup_logger
+
logger = setup_logger()
@@ -184,6 +191,59 @@ def update_persona(
)
+class PersonaCategoryPatchRequest(BaseModel):
+ category_description: str
+ category_name: str
+
+
+@basic_router.get("/categories")
+def get_categories(
+ db: Session = Depends(get_session),
+ _: User | None = Depends(current_user),
+) -> list[PersonaCategoryResponse]:
+ return [
+ PersonaCategoryResponse.from_model(category)
+ for category in get_assistant_categories(db_session=db)
+ ]
+
+
+@admin_router.post("/categories")
+def create_category(
+ category: PersonaCategoryCreate,
+ db: Session = Depends(get_session),
+ _: User | None = Depends(current_admin_user),
+) -> PersonaCategoryResponse:
+ """Create a new assistant category"""
+ category_model = create_assistant_category(
+ name=category.name, description=category.description, db_session=db
+ )
+ return PersonaCategoryResponse.from_model(category_model)
+
+
+@admin_router.patch("/category/{category_id}")
+def patch_persona_category(
+ category_id: int,
+ persona_category_patch_request: PersonaCategoryPatchRequest,
+ _: User | None = Depends(current_admin_user),
+ db_session: Session = Depends(get_session),
+) -> None:
+ update_persona_category(
+ category_id=category_id,
+ category_description=persona_category_patch_request.category_description,
+ category_name=persona_category_patch_request.category_name,
+ db_session=db_session,
+ )
+
+
+@admin_router.delete("/category/{category_id}")
+def delete_category(
+ category_id: int,
+ _: User | None = Depends(current_admin_user),
+ db_session: Session = Depends(get_session),
+) -> None:
+ delete_persona_category(category_id=category_id, db_session=db_session)
+
+
class PersonaShareRequest(BaseModel):
user_ids: list[UUID]
diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py
index 5fa99952b..4331d0640 100644
--- a/backend/danswer/server/features/persona/models.py
+++ b/backend/danswer/server/features/persona/models.py
@@ -5,6 +5,7 @@ from pydantic import BaseModel
from pydantic import Field
from danswer.db.models import Persona
+from danswer.db.models import PersonaCategory
from danswer.db.models import StarterMessage
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet
@@ -41,6 +42,7 @@ class CreatePersonaRequest(BaseModel):
is_default_persona: bool = False
display_priority: int | None = None
search_start_date: datetime | None = None
+ category_id: int | None = None
class PersonaSnapshot(BaseModel):
@@ -68,6 +70,7 @@ class PersonaSnapshot(BaseModel):
uploaded_image_id: str | None = None
is_default_persona: bool
search_start_date: datetime | None = None
+ category_id: int | None = None
@classmethod
def from_model(
@@ -115,6 +118,7 @@ class PersonaSnapshot(BaseModel):
icon_shape=persona.icon_shape,
uploaded_image_id=persona.uploaded_image_id,
search_start_date=persona.search_start_date,
+ category_id=persona.category_id,
)
@@ -128,3 +132,22 @@ class PersonaSharedNotificationData(BaseModel):
class ImageGenerationToolStatus(BaseModel):
is_available: bool
+
+
+class PersonaCategoryCreate(BaseModel):
+ name: str
+ description: str
+
+
+class PersonaCategoryResponse(BaseModel):
+ id: int
+ name: str
+ description: str | None
+
+ @classmethod
+ def from_model(cls, category: PersonaCategory) -> "PersonaCategoryResponse":
+ return PersonaCategoryResponse(
+ id=category.id,
+ name=category.name,
+ description=category.description,
+ )
diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py
index 4e8e58224..147e6bf0b 100644
--- a/backend/tests/integration/common_utils/managers/persona.py
+++ b/backend/tests/integration/common_utils/managers/persona.py
@@ -7,6 +7,7 @@ from danswer.server.features.persona.models import PersonaSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestPersona
+from tests.integration.common_utils.test_models import DATestPersonaCategory
from tests.integration.common_utils.test_models import DATestUser
@@ -27,6 +28,7 @@ class PersonaManager:
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
+ category_id: int | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
@@ -212,3 +214,83 @@ class PersonaManager:
else GENERAL_HEADERS,
)
return response.ok
+
+
+class PersonaCategoryManager:
+ @staticmethod
+ def create(
+ category: DATestPersonaCategory,
+ user_performing_action: DATestUser | None = None,
+ ) -> DATestPersonaCategory:
+ response = requests.post(
+ f"{API_SERVER_URL}/admin/persona/categories",
+ json={
+ "name": category.name,
+ "description": category.description,
+ },
+ headers=user_performing_action.headers
+ if user_performing_action
+ else GENERAL_HEADERS,
+ )
+ response.raise_for_status()
+ response_data = response.json()
+ category.id = response_data["id"]
+ return category
+
+ @staticmethod
+ def get_all(
+ user_performing_action: DATestUser | None = None,
+ ) -> list[DATestPersonaCategory]:
+ response = requests.get(
+ f"{API_SERVER_URL}/persona/categories",
+ headers=user_performing_action.headers
+ if user_performing_action
+ else GENERAL_HEADERS,
+ )
+ response.raise_for_status()
+ return [DATestPersonaCategory(**category) for category in response.json()]
+
+ @staticmethod
+ def update(
+ category: DATestPersonaCategory,
+ user_performing_action: DATestUser | None = None,
+ ) -> DATestPersonaCategory:
+ response = requests.patch(
+ f"{API_SERVER_URL}/admin/persona/category/{category.id}",
+ json={
+ "category_name": category.name,
+ "category_description": category.description,
+ },
+ headers=user_performing_action.headers
+ if user_performing_action
+ else GENERAL_HEADERS,
+ )
+ response.raise_for_status()
+ return category
+
+ @staticmethod
+ def delete(
+ category: DATestPersonaCategory,
+ user_performing_action: DATestUser | None = None,
+ ) -> bool:
+ response = requests.delete(
+ f"{API_SERVER_URL}/admin/persona/category/{category.id}",
+ headers=user_performing_action.headers
+ if user_performing_action
+ else GENERAL_HEADERS,
+ )
+ return response.ok
+
+ @staticmethod
+ def verify(
+ category: DATestPersonaCategory,
+ user_performing_action: DATestUser | None = None,
+ ) -> bool:
+ all_categories = PersonaCategoryManager.get_all(user_performing_action)
+ for fetched_category in all_categories:
+ if fetched_category.id == category.id:
+ return (
+ fetched_category.name == category.name
+ and fetched_category.description == category.description
+ )
+ return False
diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py
index 2beddeac1..16156d8aa 100644
--- a/backend/tests/integration/common_utils/test_models.py
+++ b/backend/tests/integration/common_utils/test_models.py
@@ -38,6 +38,12 @@ class DATestUser(BaseModel):
headers: dict
+class DATestPersonaCategory(BaseModel):
+ id: int | None = None
+ name: str
+ description: str | None
+
+
class DATestCredential(BaseModel):
id: int
name: str
@@ -119,6 +125,7 @@ class DATestPersona(BaseModel):
llm_model_version_override: str | None
users: list[str]
groups: list[int]
+ category_id: int | None = None
#
diff --git a/backend/tests/integration/tests/personas/test_persona_categories.py b/backend/tests/integration/tests/personas/test_persona_categories.py
new file mode 100644
index 000000000..fdd0e6458
--- /dev/null
+++ b/backend/tests/integration/tests/personas/test_persona_categories.py
@@ -0,0 +1,92 @@
+from uuid import uuid4
+
+import pytest
+from requests.exceptions import HTTPError
+
+from tests.integration.common_utils.managers.persona import (
+ PersonaCategoryManager,
+)
+from tests.integration.common_utils.managers.user import UserManager
+from tests.integration.common_utils.test_models import DATestPersonaCategory
+from tests.integration.common_utils.test_models import DATestUser
+
+
+def test_persona_category_management(reset: None) -> None:
+ admin_user: DATestUser = UserManager.create(name="admin_user")
+
+ persona_category = DATestPersonaCategory(
+ id=None,
+ name=f"Test Category {uuid4()}",
+ description="A description for test category",
+ )
+ persona_category = PersonaCategoryManager.create(
+ category=persona_category,
+ user_performing_action=admin_user,
+ )
+ print(
+ f"Created persona category {persona_category.name} with id {persona_category.id}"
+ )
+
+ assert PersonaCategoryManager.verify(
+ category=persona_category,
+ user_performing_action=admin_user,
+ ), "Persona category was not found after creation"
+
+ regular_user: DATestUser = UserManager.create(name="regular_user")
+
+ updated_persona_category = DATestPersonaCategory(
+ id=persona_category.id,
+ name=f"Updated {persona_category.name}",
+ description="An updated description",
+ )
+ with pytest.raises(HTTPError) as exc_info:
+ PersonaCategoryManager.update(
+ category=updated_persona_category,
+ user_performing_action=regular_user,
+ )
+ assert exc_info.value.response.status_code == 403
+
+ assert PersonaCategoryManager.verify(
+ category=persona_category,
+ user_performing_action=admin_user,
+ ), "Persona category should not have been updated by non-admin user"
+
+ result = PersonaCategoryManager.delete(
+ category=persona_category,
+ user_performing_action=regular_user,
+ )
+ assert (
+ result is False
+ ), "Regular user should not be able to delete the persona category"
+
+ assert PersonaCategoryManager.verify(
+ category=persona_category,
+ user_performing_action=admin_user,
+ ), "Persona category should not have been deleted by non-admin user"
+
+ updated_persona_category.name = f"Updated {persona_category.name}"
+ updated_persona_category.description = "An updated description"
+ updated_persona_category = PersonaCategoryManager.update(
+ category=updated_persona_category,
+ user_performing_action=admin_user,
+ )
+ print(f"Updated persona category to {updated_persona_category.name}")
+
+ assert PersonaCategoryManager.verify(
+ category=updated_persona_category,
+ user_performing_action=admin_user,
+ ), "Persona category was not updated by admin"
+
+ success = PersonaCategoryManager.delete(
+ category=persona_category,
+ user_performing_action=admin_user,
+ )
+ assert success, "Admin user should be able to delete the persona category"
+ print(
+ f"Deleted persona category {persona_category.name} with id {persona_category.id}"
+ )
+
+ assert not PersonaCategoryManager.verify(
+ category=persona_category,
+ user_performing_action=admin_user,
+ ), "Persona category should not exist after deletion by admin"
diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx
index a4ed2fea2..6fe9301a7 100644
--- a/web/src/app/admin/assistants/AssistantEditor.tsx
+++ b/web/src/app/admin/assistants/AssistantEditor.tsx
@@ -23,8 +23,16 @@ import {
SelectorFormField,
TextFormField,
} from "@/components/admin/connectors/Field";
+
+import {
+ Card,
+ CardHeader,
+ CardTitle,
+ CardContent,
+ CardFooter,
+} from "@/components/ui/card";
import { usePopup } from "@/components/admin/connectors/Popup";
-import { getDisplayNameForModel } from "@/lib/hooks";
+import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
@@ -46,8 +54,14 @@ import * as Yup from "yup";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
-import { Persona, StarterMessage } from "./interfaces";
-import { createPersona, updatePersona } from "./lib";
+import { Persona, PersonaCategory, StarterMessage } from "./interfaces";
+import {
+ createPersonaCategory,
+ createPersona,
+ deletePersonaCategory,
+ updatePersonaCategory,
+ updatePersona,
+} from "./lib";
import { Popover } from "@/components/popover/Popover";
import {
CameraIcon,
@@ -59,6 +73,8 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { buildImgUrl } from "@/app/chat/files/images/utils";
import { LlmList } from "@/components/llm/LLMList";
import { useAssistants } from "@/components/context/AssistantsContext";
+import { Input } from "@/components/ui/input";
+import { CategoryCard } from "./CategoryCard";
function findSearchTool(tools: ToolSnapshot[]) {
return tools.find((tool) => tool.in_code_tool_id === "SearchTool");
@@ -107,6 +123,7 @@ export function AssistantEditor({
const router = useRouter();
const { popup, setPopup } = usePopup();
+ const { data: categories, refreshCategories } = useCategories();
const colorOptions = [
"#FF6FBF",
@@ -119,6 +136,7 @@ export function AssistantEditor({
];
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
+ const [showPersonaCategory, setShowPersonaCategory] = useState(!admin);
// state to persist across formik reformatting
const [defautIconColor, _setDeafultIconColor] = useState(
@@ -211,6 +229,7 @@ export function AssistantEditor({
icon_color: existingPersona?.icon_color ?? defautIconColor,
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
uploaded_image: null,
+ category_id: existingPersona?.category_id ?? null,
// EE Only
groups: existingPersona?.groups ?? [],
@@ -255,6 +274,7 @@ export function AssistantEditor({
icon_color: Yup.string(),
icon_shape: Yup.number(),
uploaded_image: Yup.mixed().nullable(),
+ category_id: Yup.number().nullable(),
// EE Only
groups: Yup.array().of(Yup.number()),
})
@@ -968,6 +988,189 @@ export function AssistantEditor({
)}
+
+ {admin && (
+
{personaCategory.description}
+{assistant.description}
Author: {assistant.owner?.email || "Danswer"}
@@ -137,6 +153,16 @@ export function AssistantGalleryCard({
{assistant.tools.length > 0 && (