diff --git a/backend/alembic/versions/47e5bef3a1d7_add_persona_categories.py b/backend/alembic/versions/47e5bef3a1d7_add_persona_categories.py new file mode 100644 index 000000000..432e0ab42 --- /dev/null +++ b/backend/alembic/versions/47e5bef3a1d7_add_persona_categories.py @@ -0,0 +1,45 @@ +"""add persona categories + +Revision ID: 47e5bef3a1d7 +Revises: dfbe9e93d3c7 +Create Date: 2024-11-05 18:55:02.221064 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "47e5bef3a1d7" +down_revision = "dfbe9e93d3c7" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Create the persona_category table + op.create_table( + "persona_category", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), + ) + + # Add category_id to persona table + op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True)) + op.create_foreign_key( + "fk_persona_category", + "persona", + "persona_category", + ["category_id"], + ["id"], + ondelete="SET NULL", + ) + + +def downgrade() -> None: + op.drop_constraint("fk_persona_category", "persona", type_="foreignkey") + op.drop_column("persona", "category_id") + op.drop_table("persona_category") diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 831b8eaf6..ec9e4f82e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1363,6 +1363,9 @@ class Persona(Base): recency_bias: Mapped[RecencyBiasSetting] = mapped_column( Enum(RecencyBiasSetting, native_enum=False) ) + category_id: Mapped[int | None] = mapped_column( + ForeignKey("persona_category.id"), nullable=True + ) # Allows the Persona to specify a different LLM version than is controlled # globablly via env variables. For flexibility, validity is not currently enforced # NOTE: only is applied on the actual response generation - is not used for things like @@ -1434,6 +1437,9 @@ class Persona(Base): secondary="persona__user_group", viewonly=True, ) + category: Mapped["PersonaCategory"] = relationship( + "PersonaCategory", back_populates="personas" + ) # Default personas loaded via yaml cannot have the same name __table_args__ = ( @@ -1446,6 +1452,17 @@ class Persona(Base): ) +class PersonaCategory(Base): + __tablename__ = "persona_category" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True) + description: Mapped[str | None] = mapped_column(String, nullable=True) + personas: Mapped[list["Persona"]] = relationship( + "Persona", back_populates="category" + ) + + AllowedAnswerFilters = ( Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"] ) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 364c2ccc1..6446e7382 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -26,6 +26,7 @@ from danswer.db.models import DocumentSet from danswer.db.models import Persona from danswer.db.models import Persona__User from danswer.db.models import Persona__UserGroup +from danswer.db.models import PersonaCategory from danswer.db.models import Prompt from danswer.db.models import StarterMessage from danswer.db.models import Tool @@ -417,6 +418,7 @@ def upsert_persona( search_start_date: datetime | None = None, builtin_persona: bool = False, is_default_persona: bool = False, + category_id: int | None = None, chunks_above: int = CONTEXT_CHUNKS_ABOVE, chunks_below: int = CONTEXT_CHUNKS_BELOW, ) -> Persona: @@ -487,7 +489,7 @@ def upsert_persona( persona.is_visible = is_visible persona.search_start_date = search_start_date persona.is_default_persona = is_default_persona - + persona.category_id = category_id # Do not delete any associations manually added unless # a new updated list is provided if document_sets is not None: @@ -528,6 +530,7 @@ def upsert_persona( is_visible=is_visible, search_start_date=search_start_date, is_default_persona=is_default_persona, + category_id=category_id, ) db_session.add(persona) @@ -744,3 +747,39 @@ def delete_persona_by_name( db_session.execute(stmt) db_session.commit() + + +def get_assistant_categories(db_session: Session) -> list[PersonaCategory]: + return db_session.query(PersonaCategory).all() + + +def create_assistant_category( + db_session: Session, name: str, description: str +) -> PersonaCategory: + category = PersonaCategory(name=name, description=description) + db_session.add(category) + db_session.commit() + return category + + +def update_persona_category( + category_id: int, + category_description: str, + category_name: str, + db_session: Session, +) -> None: + persona_category = ( + db_session.query(PersonaCategory) + .filter(PersonaCategory.id == category_id) + .one_or_none() + ) + if persona_category is None: + raise ValueError(f"Persona category with ID {category_id} does not exist") + persona_category.description = category_description + persona_category.name = category_name + db_session.commit() + + +def delete_persona_category(category_id: int, db_session: Session) -> None: + db_session.query(PersonaCategory).filter(PersonaCategory.id == category_id).delete() + db_session.commit() diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index c7d3d79b1..c5cfa07ad 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -18,12 +18,16 @@ from danswer.configs.constants import NotificationType from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.notification import create_notification +from danswer.db.persona import create_assistant_category from danswer.db.persona import create_update_persona +from danswer.db.persona import delete_persona_category +from danswer.db.persona import get_assistant_categories from danswer.db.persona import get_persona_by_id from danswer.db.persona import get_personas from danswer.db.persona import mark_persona_as_deleted from danswer.db.persona import mark_persona_as_not_deleted from danswer.db.persona import update_all_personas_display_priority +from danswer.db.persona import update_persona_category from danswer.db.persona import update_persona_public_status from danswer.db.persona import update_persona_shared_users from danswer.db.persona import update_persona_visibility @@ -32,6 +36,8 @@ from danswer.file_store.models import ChatFileType from danswer.llm.answering.prompts.utils import build_dummy_prompt from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import ImageGenerationToolStatus +from danswer.server.features.persona.models import PersonaCategoryCreate +from danswer.server.features.persona.models import PersonaCategoryResponse from danswer.server.features.persona.models import PersonaSharedNotificationData from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PromptTemplateResponse @@ -39,6 +45,7 @@ from danswer.server.models import DisplayPriorityRequest from danswer.tools.utils import is_image_generation_available from danswer.utils.logger import setup_logger + logger = setup_logger() @@ -184,6 +191,59 @@ def update_persona( ) +class PersonaCategoryPatchRequest(BaseModel): + category_description: str + category_name: str + + +@basic_router.get("/categories") +def get_categories( + db: Session = Depends(get_session), + _: User | None = Depends(current_user), +) -> list[PersonaCategoryResponse]: + return [ + PersonaCategoryResponse.from_model(category) + for category in get_assistant_categories(db_session=db) + ] + + +@admin_router.post("/categories") +def create_category( + category: PersonaCategoryCreate, + db: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> PersonaCategoryResponse: + """Create a new assistant category""" + category_model = create_assistant_category( + name=category.name, description=category.description, db_session=db + ) + return PersonaCategoryResponse.from_model(category_model) + + +@admin_router.patch("/category/{category_id}") +def patch_persona_category( + category_id: int, + persona_category_patch_request: PersonaCategoryPatchRequest, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + update_persona_category( + category_id=category_id, + category_description=persona_category_patch_request.category_description, + category_name=persona_category_patch_request.category_name, + db_session=db_session, + ) + + +@admin_router.delete("/category/{category_id}") +def delete_category( + category_id: int, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + delete_persona_category(category_id=category_id, db_session=db_session) + + class PersonaShareRequest(BaseModel): user_ids: list[UUID] diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 5fa99952b..4331d0640 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from pydantic import Field from danswer.db.models import Persona +from danswer.db.models import PersonaCategory from danswer.db.models import StarterMessage from danswer.search.enums import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet @@ -41,6 +42,7 @@ class CreatePersonaRequest(BaseModel): is_default_persona: bool = False display_priority: int | None = None search_start_date: datetime | None = None + category_id: int | None = None class PersonaSnapshot(BaseModel): @@ -68,6 +70,7 @@ class PersonaSnapshot(BaseModel): uploaded_image_id: str | None = None is_default_persona: bool search_start_date: datetime | None = None + category_id: int | None = None @classmethod def from_model( @@ -115,6 +118,7 @@ class PersonaSnapshot(BaseModel): icon_shape=persona.icon_shape, uploaded_image_id=persona.uploaded_image_id, search_start_date=persona.search_start_date, + category_id=persona.category_id, ) @@ -128,3 +132,22 @@ class PersonaSharedNotificationData(BaseModel): class ImageGenerationToolStatus(BaseModel): is_available: bool + + +class PersonaCategoryCreate(BaseModel): + name: str + description: str + + +class PersonaCategoryResponse(BaseModel): + id: int + name: str + description: str | None + + @classmethod + def from_model(cls, category: PersonaCategory) -> "PersonaCategoryResponse": + return PersonaCategoryResponse( + id=category.id, + name=category.name, + description=category.description, + ) diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py index 4e8e58224..147e6bf0b 100644 --- a/backend/tests/integration/common_utils/managers/persona.py +++ b/backend/tests/integration/common_utils/managers/persona.py @@ -7,6 +7,7 @@ from danswer.server.features.persona.models import PersonaSnapshot from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.test_models import DATestPersona +from tests.integration.common_utils.test_models import DATestPersonaCategory from tests.integration.common_utils.test_models import DATestUser @@ -27,6 +28,7 @@ class PersonaManager: llm_model_version_override: str | None = None, users: list[str] | None = None, groups: list[int] | None = None, + category_id: int | None = None, user_performing_action: DATestUser | None = None, ) -> DATestPersona: name = name or f"test-persona-{uuid4()}" @@ -212,3 +214,83 @@ class PersonaManager: else GENERAL_HEADERS, ) return response.ok + + +class PersonaCategoryManager: + @staticmethod + def create( + category: DATestPersonaCategory, + user_performing_action: DATestUser | None = None, + ) -> DATestPersonaCategory: + response = requests.post( + f"{API_SERVER_URL}/admin/persona/categories", + json={ + "name": category.name, + "description": category.description, + }, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + response_data = response.json() + category.id = response_data["id"] + return category + + @staticmethod + def get_all( + user_performing_action: DATestUser | None = None, + ) -> list[DATestPersonaCategory]: + response = requests.get( + f"{API_SERVER_URL}/persona/categories", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [DATestPersonaCategory(**category) for category in response.json()] + + @staticmethod + def update( + category: DATestPersonaCategory, + user_performing_action: DATestUser | None = None, + ) -> DATestPersonaCategory: + response = requests.patch( + f"{API_SERVER_URL}/admin/persona/category/{category.id}", + json={ + "category_name": category.name, + "category_description": category.description, + }, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return category + + @staticmethod + def delete( + category: DATestPersonaCategory, + user_performing_action: DATestUser | None = None, + ) -> bool: + response = requests.delete( + f"{API_SERVER_URL}/admin/persona/category/{category.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + return response.ok + + @staticmethod + def verify( + category: DATestPersonaCategory, + user_performing_action: DATestUser | None = None, + ) -> bool: + all_categories = PersonaCategoryManager.get_all(user_performing_action) + for fetched_category in all_categories: + if fetched_category.id == category.id: + return ( + fetched_category.name == category.name + and fetched_category.description == category.description + ) + return False diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py index 2beddeac1..16156d8aa 100644 --- a/backend/tests/integration/common_utils/test_models.py +++ b/backend/tests/integration/common_utils/test_models.py @@ -38,6 +38,12 @@ class DATestUser(BaseModel): headers: dict +class DATestPersonaCategory(BaseModel): + id: int | None = None + name: str + description: str | None + + class DATestCredential(BaseModel): id: int name: str @@ -119,6 +125,7 @@ class DATestPersona(BaseModel): llm_model_version_override: str | None users: list[str] groups: list[int] + category_id: int | None = None # diff --git a/backend/tests/integration/tests/personas/test_persona_categories.py b/backend/tests/integration/tests/personas/test_persona_categories.py new file mode 100644 index 000000000..fdd0e6458 --- /dev/null +++ b/backend/tests/integration/tests/personas/test_persona_categories.py @@ -0,0 +1,92 @@ +from uuid import uuid4 + +import pytest +from requests.exceptions import HTTPError + +from tests.integration.common_utils.managers.persona import ( + PersonaCategoryManager, +) +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import DATestPersonaCategory +from tests.integration.common_utils.test_models import DATestUser + + +def test_persona_category_management(reset: None) -> None: + admin_user: DATestUser = UserManager.create(name="admin_user") + + persona_category = DATestPersonaCategory( + id=None, + name=f"Test Category {uuid4()}", + description="A description for test category", + ) + persona_category = PersonaCategoryManager.create( + category=persona_category, + user_performing_action=admin_user, + ) + print( + f"Created persona category {persona_category.name} with id {persona_category.id}" + ) + + assert PersonaCategoryManager.verify( + category=persona_category, + user_performing_action=admin_user, + ), "Persona category was not found after creation" + + regular_user: DATestUser = UserManager.create(name="regular_user") + + updated_persona_category = DATestPersonaCategory( + id=persona_category.id, + name=f"Updated {persona_category.name}", + description="An updated description", + ) + with pytest.raises(HTTPError) as exc_info: + PersonaCategoryManager.update( + category=updated_persona_category, + user_performing_action=regular_user, + ) + assert exc_info.value.response.status_code == 403 + + assert PersonaCategoryManager.verify( + category=persona_category, + user_performing_action=admin_user, + ), "Persona category should not have been updated by non-admin user" + + result = PersonaCategoryManager.delete( + category=persona_category, + user_performing_action=regular_user, + ) + assert ( + result is False + ), "Regular user should not be able to delete the persona category" + + assert PersonaCategoryManager.verify( + category=persona_category, + user_performing_action=admin_user, + ), "Persona category should not have been deleted by non-admin user" + + updated_persona_category.name = f"Updated {persona_category.name}" + updated_persona_category.description = "An updated description" + updated_persona_category = PersonaCategoryManager.update( + category=updated_persona_category, + user_performing_action=admin_user, + ) + print(f"Updated persona category to {updated_persona_category.name}") + + assert PersonaCategoryManager.verify( + category=updated_persona_category, + user_performing_action=admin_user, + ), "Persona category was not updated by admin" + + success = PersonaCategoryManager.delete( + category=persona_category, + user_performing_action=admin_user, + ) + assert success, "Admin user should be able to delete the persona category" + print( + f"Deleted persona category {persona_category.name} with id {persona_category.id}" + ) + + assert not PersonaCategoryManager.verify( + category=persona_category, + user_performing_action=admin_user, + ), "Persona category should not exist after deletion by admin" diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index a4ed2fea2..6fe9301a7 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -23,8 +23,16 @@ import { SelectorFormField, TextFormField, } from "@/components/admin/connectors/Field"; + +import { + Card, + CardHeader, + CardTitle, + CardContent, + CardFooter, +} from "@/components/ui/card"; import { usePopup } from "@/components/admin/connectors/Popup"; -import { getDisplayNameForModel } from "@/lib/hooks"; +import { getDisplayNameForModel, useCategories } from "@/lib/hooks"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import { Option } from "@/components/Dropdown"; import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences"; @@ -46,8 +54,14 @@ import * as Yup from "yup"; import { FullLLMProvider } from "../configuration/llm/interfaces"; import CollapsibleSection from "./CollapsibleSection"; import { SuccessfulPersonaUpdateRedirectType } from "./enums"; -import { Persona, StarterMessage } from "./interfaces"; -import { createPersona, updatePersona } from "./lib"; +import { Persona, PersonaCategory, StarterMessage } from "./interfaces"; +import { + createPersonaCategory, + createPersona, + deletePersonaCategory, + updatePersonaCategory, + updatePersona, +} from "./lib"; import { Popover } from "@/components/popover/Popover"; import { CameraIcon, @@ -59,6 +73,8 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle"; import { buildImgUrl } from "@/app/chat/files/images/utils"; import { LlmList } from "@/components/llm/LLMList"; import { useAssistants } from "@/components/context/AssistantsContext"; +import { Input } from "@/components/ui/input"; +import { CategoryCard } from "./CategoryCard"; function findSearchTool(tools: ToolSnapshot[]) { return tools.find((tool) => tool.in_code_tool_id === "SearchTool"); @@ -107,6 +123,7 @@ export function AssistantEditor({ const router = useRouter(); const { popup, setPopup } = usePopup(); + const { data: categories, refreshCategories } = useCategories(); const colorOptions = [ "#FF6FBF", @@ -119,6 +136,7 @@ export function AssistantEditor({ ]; const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); + const [showPersonaCategory, setShowPersonaCategory] = useState(!admin); // state to persist across formik reformatting const [defautIconColor, _setDeafultIconColor] = useState( @@ -211,6 +229,7 @@ export function AssistantEditor({ icon_color: existingPersona?.icon_color ?? defautIconColor, icon_shape: existingPersona?.icon_shape ?? defaultIconShape, uploaded_image: null, + category_id: existingPersona?.category_id ?? null, // EE Only groups: existingPersona?.groups ?? [], @@ -255,6 +274,7 @@ export function AssistantEditor({ icon_color: Yup.string(), icon_shape: Yup.number(), uploaded_image: Yup.mixed().nullable(), + category_id: Yup.number().nullable(), // EE Only groups: Yup.array().of(Yup.number()), }) @@ -968,6 +988,189 @@ export function AssistantEditor({ )} + + {admin && ( + + )} + + {showPersonaCategory && ( + <> + {categories && categories.length > 0 && ( +
+
+
+ Category +
+ + + + + + + Group similar assistants together by category + + + +
+ ({ + name: category.name, + value: category.id, + }))} + /> +
+ )} + + {admin && ( + <> +
+
+
+ Create New Category +
+ + + + + + + Create a new category to group similar + assistants together + + + +
+ +
+ + +
+ +
+
+
+ + {categories && categories.length > 0 && ( +
+
+
+ Manage categories +
+ + + + + + + Manage existing categories or create new ones + to group similar assistants + + + +
+
+ {categories && + categories.map((category: PersonaCategory) => ( + { + const response = + await updatePersonaCategory( + id, + name, + description + ); + if (response?.ok) { + setPopup({ + message: `Category "${name}" updated successfully`, + type: "success", + }); + } else { + setPopup({ + message: `Failed to update category - ${await response.text()}`, + type: "error", + }); + } + }} + onDelete={async (id) => { + const response = + await deletePersonaCategory(id); + if (response?.ok) { + setPopup({ + message: `Category deleted successfully`, + type: "success", + }); + } else { + setPopup({ + message: `Failed to delete category - ${await response.text()}`, + type: "error", + }); + } + }} + refreshCategories={refreshCategories} + /> + ))} +
+
+ )} + + )} + + )} + void; + onDelete: (id: number) => void; + refreshCategories: () => Promise; + setPopup: (popup: PopupSpec) => void; +} + +export function CategoryCard({ + category, + onUpdate, + onDelete, + refreshCategories, +}: CategoryCardProps) { + const [isEditing, setIsEditing] = useState(false); + const [name, setName] = useState(category.name); + const [description, setDescription] = useState(category.description); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + await onUpdate(category.id, name, description); + await refreshCategories(); + setIsEditing(false); + }; + const handleEdit = (e: React.MouseEvent) => { + e.preventDefault(); + setIsEditing(true); + }; + + return ( + + + + {isEditing ? ( + setName(e.target.value)} + className="text-lg font-semibold" + /> + ) : ( + {category.name} + )} + + + + {isEditing ? ( +