Add assistant categories (#3064)

* add assistant categories v1

* functionality finalized

* finalize

* update assistant category display

* nit

* add tests

* post rebase update

* minor update to tests

* update typing

* finalize

* typing

* nit

* alembic

* alembic (once again)
This commit is contained in:
pablodanswer 2024-11-18 12:33:48 -08:00 committed by GitHub
parent 33ee899408
commit a7d95661b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 919 additions and 31 deletions

View File

@ -0,0 +1,45 @@
"""add persona categories
Revision ID: 47e5bef3a1d7
Revises: dfbe9e93d3c7
Create Date: 2024-11-05 18:55:02.221064
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "47e5bef3a1d7"
down_revision = "dfbe9e93d3c7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the persona_category table
op.create_table(
"persona_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add category_id to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_persona_category",
"persona",
"persona_category",
["category_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@ -1363,6 +1363,9 @@ class Persona(Base):
recency_bias: Mapped[RecencyBiasSetting] = mapped_column( recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
Enum(RecencyBiasSetting, native_enum=False) Enum(RecencyBiasSetting, native_enum=False)
) )
category_id: Mapped[int | None] = mapped_column(
ForeignKey("persona_category.id"), nullable=True
)
# Allows the Persona to specify a different LLM version than is controlled # Allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced # globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like # NOTE: only is applied on the actual response generation - is not used for things like
@ -1434,6 +1437,9 @@ class Persona(Base):
secondary="persona__user_group", secondary="persona__user_group",
viewonly=True, viewonly=True,
) )
category: Mapped["PersonaCategory"] = relationship(
"PersonaCategory", back_populates="personas"
)
# Default personas loaded via yaml cannot have the same name # Default personas loaded via yaml cannot have the same name
__table_args__ = ( __table_args__ = (
@ -1446,6 +1452,17 @@ class Persona(Base):
) )
class PersonaCategory(Base):
__tablename__ = "persona_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
description: Mapped[str | None] = mapped_column(String, nullable=True)
personas: Mapped[list["Persona"]] = relationship(
"Persona", back_populates="category"
)
AllowedAnswerFilters = ( AllowedAnswerFilters = (
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"] Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
) )

View File

@ -26,6 +26,7 @@ from danswer.db.models import DocumentSet
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.db.models import Persona__User from danswer.db.models import Persona__User
from danswer.db.models import Persona__UserGroup from danswer.db.models import Persona__UserGroup
from danswer.db.models import PersonaCategory
from danswer.db.models import Prompt from danswer.db.models import Prompt
from danswer.db.models import StarterMessage from danswer.db.models import StarterMessage
from danswer.db.models import Tool from danswer.db.models import Tool
@ -417,6 +418,7 @@ def upsert_persona(
search_start_date: datetime | None = None, search_start_date: datetime | None = None,
builtin_persona: bool = False, builtin_persona: bool = False,
is_default_persona: bool = False, is_default_persona: bool = False,
category_id: int | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE, chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW, chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona: ) -> Persona:
@ -487,7 +489,7 @@ def upsert_persona(
persona.is_visible = is_visible persona.is_visible = is_visible
persona.search_start_date = search_start_date persona.search_start_date = search_start_date
persona.is_default_persona = is_default_persona persona.is_default_persona = is_default_persona
persona.category_id = category_id
# Do not delete any associations manually added unless # Do not delete any associations manually added unless
# a new updated list is provided # a new updated list is provided
if document_sets is not None: if document_sets is not None:
@ -528,6 +530,7 @@ def upsert_persona(
is_visible=is_visible, is_visible=is_visible,
search_start_date=search_start_date, search_start_date=search_start_date,
is_default_persona=is_default_persona, is_default_persona=is_default_persona,
category_id=category_id,
) )
db_session.add(persona) db_session.add(persona)
@ -744,3 +747,39 @@ def delete_persona_by_name(
db_session.execute(stmt) db_session.execute(stmt)
db_session.commit() db_session.commit()
def get_assistant_categories(db_session: Session) -> list[PersonaCategory]:
return db_session.query(PersonaCategory).all()
def create_assistant_category(
db_session: Session, name: str, description: str
) -> PersonaCategory:
category = PersonaCategory(name=name, description=description)
db_session.add(category)
db_session.commit()
return category
def update_persona_category(
category_id: int,
category_description: str,
category_name: str,
db_session: Session,
) -> None:
persona_category = (
db_session.query(PersonaCategory)
.filter(PersonaCategory.id == category_id)
.one_or_none()
)
if persona_category is None:
raise ValueError(f"Persona category with ID {category_id} does not exist")
persona_category.description = category_description
persona_category.name = category_name
db_session.commit()
def delete_persona_category(category_id: int, db_session: Session) -> None:
db_session.query(PersonaCategory).filter(PersonaCategory.id == category_id).delete()
db_session.commit()

View File

@ -18,12 +18,16 @@ from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.models import User from danswer.db.models import User
from danswer.db.notification import create_notification from danswer.db.notification import create_notification
from danswer.db.persona import create_assistant_category
from danswer.db.persona import create_update_persona from danswer.db.persona import create_update_persona
from danswer.db.persona import delete_persona_category
from danswer.db.persona import get_assistant_categories
from danswer.db.persona import get_persona_by_id from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas from danswer.db.persona import get_personas
from danswer.db.persona import mark_persona_as_deleted from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import mark_persona_as_not_deleted from danswer.db.persona import mark_persona_as_not_deleted
from danswer.db.persona import update_all_personas_display_priority from danswer.db.persona import update_all_personas_display_priority
from danswer.db.persona import update_persona_category
from danswer.db.persona import update_persona_public_status from danswer.db.persona import update_persona_public_status
from danswer.db.persona import update_persona_shared_users from danswer.db.persona import update_persona_shared_users
from danswer.db.persona import update_persona_visibility from danswer.db.persona import update_persona_visibility
@ -32,6 +36,8 @@ from danswer.file_store.models import ChatFileType
from danswer.llm.answering.prompts.utils import build_dummy_prompt from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import ImageGenerationToolStatus from danswer.server.features.persona.models import ImageGenerationToolStatus
from danswer.server.features.persona.models import PersonaCategoryCreate
from danswer.server.features.persona.models import PersonaCategoryResponse
from danswer.server.features.persona.models import PersonaSharedNotificationData from danswer.server.features.persona.models import PersonaSharedNotificationData
from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse from danswer.server.features.persona.models import PromptTemplateResponse
@ -39,6 +45,7 @@ from danswer.server.models import DisplayPriorityRequest
from danswer.tools.utils import is_image_generation_available from danswer.tools.utils import is_image_generation_available
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@ -184,6 +191,59 @@ def update_persona(
) )
class PersonaCategoryPatchRequest(BaseModel):
category_description: str
category_name: str
@basic_router.get("/categories")
def get_categories(
db: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> list[PersonaCategoryResponse]:
return [
PersonaCategoryResponse.from_model(category)
for category in get_assistant_categories(db_session=db)
]
@admin_router.post("/categories")
def create_category(
category: PersonaCategoryCreate,
db: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> PersonaCategoryResponse:
"""Create a new assistant category"""
category_model = create_assistant_category(
name=category.name, description=category.description, db_session=db
)
return PersonaCategoryResponse.from_model(category_model)
@admin_router.patch("/category/{category_id}")
def patch_persona_category(
category_id: int,
persona_category_patch_request: PersonaCategoryPatchRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_category(
category_id=category_id,
category_description=persona_category_patch_request.category_description,
category_name=persona_category_patch_request.category_name,
db_session=db_session,
)
@admin_router.delete("/category/{category_id}")
def delete_category(
category_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
delete_persona_category(category_id=category_id, db_session=db_session)
class PersonaShareRequest(BaseModel): class PersonaShareRequest(BaseModel):
user_ids: list[UUID] user_ids: list[UUID]

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
from pydantic import Field from pydantic import Field
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.db.models import PersonaCategory
from danswer.db.models import StarterMessage from danswer.db.models import StarterMessage
from danswer.search.enums import RecencyBiasSetting from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.document_set.models import DocumentSet
@ -41,6 +42,7 @@ class CreatePersonaRequest(BaseModel):
is_default_persona: bool = False is_default_persona: bool = False
display_priority: int | None = None display_priority: int | None = None
search_start_date: datetime | None = None search_start_date: datetime | None = None
category_id: int | None = None
class PersonaSnapshot(BaseModel): class PersonaSnapshot(BaseModel):
@ -68,6 +70,7 @@ class PersonaSnapshot(BaseModel):
uploaded_image_id: str | None = None uploaded_image_id: str | None = None
is_default_persona: bool is_default_persona: bool
search_start_date: datetime | None = None search_start_date: datetime | None = None
category_id: int | None = None
@classmethod @classmethod
def from_model( def from_model(
@ -115,6 +118,7 @@ class PersonaSnapshot(BaseModel):
icon_shape=persona.icon_shape, icon_shape=persona.icon_shape,
uploaded_image_id=persona.uploaded_image_id, uploaded_image_id=persona.uploaded_image_id,
search_start_date=persona.search_start_date, search_start_date=persona.search_start_date,
category_id=persona.category_id,
) )
@ -128,3 +132,22 @@ class PersonaSharedNotificationData(BaseModel):
class ImageGenerationToolStatus(BaseModel): class ImageGenerationToolStatus(BaseModel):
is_available: bool is_available: bool
class PersonaCategoryCreate(BaseModel):
name: str
description: str
class PersonaCategoryResponse(BaseModel):
id: int
name: str
description: str | None
@classmethod
def from_model(cls, category: PersonaCategory) -> "PersonaCategoryResponse":
return PersonaCategoryResponse(
id=category.id,
name=category.name,
description=category.description,
)

View File

@ -7,6 +7,7 @@ from danswer.server.features.persona.models import PersonaSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestPersona from tests.integration.common_utils.test_models import DATestPersona
from tests.integration.common_utils.test_models import DATestPersonaCategory
from tests.integration.common_utils.test_models import DATestUser from tests.integration.common_utils.test_models import DATestUser
@ -27,6 +28,7 @@ class PersonaManager:
llm_model_version_override: str | None = None, llm_model_version_override: str | None = None,
users: list[str] | None = None, users: list[str] | None = None,
groups: list[int] | None = None, groups: list[int] | None = None,
category_id: int | None = None,
user_performing_action: DATestUser | None = None, user_performing_action: DATestUser | None = None,
) -> DATestPersona: ) -> DATestPersona:
name = name or f"test-persona-{uuid4()}" name = name or f"test-persona-{uuid4()}"
@ -212,3 +214,83 @@ class PersonaManager:
else GENERAL_HEADERS, else GENERAL_HEADERS,
) )
return response.ok return response.ok
class PersonaCategoryManager:
@staticmethod
def create(
category: DATestPersonaCategory,
user_performing_action: DATestUser | None = None,
) -> DATestPersonaCategory:
response = requests.post(
f"{API_SERVER_URL}/admin/persona/categories",
json={
"name": category.name,
"description": category.description,
},
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
response_data = response.json()
category.id = response_data["id"]
return category
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[DATestPersonaCategory]:
response = requests.get(
f"{API_SERVER_URL}/persona/categories",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [DATestPersonaCategory(**category) for category in response.json()]
@staticmethod
def update(
category: DATestPersonaCategory,
user_performing_action: DATestUser | None = None,
) -> DATestPersonaCategory:
response = requests.patch(
f"{API_SERVER_URL}/admin/persona/category/{category.id}",
json={
"category_name": category.name,
"category_description": category.description,
},
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return category
@staticmethod
def delete(
category: DATestPersonaCategory,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/admin/persona/category/{category.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
return response.ok
@staticmethod
def verify(
category: DATestPersonaCategory,
user_performing_action: DATestUser | None = None,
) -> bool:
all_categories = PersonaCategoryManager.get_all(user_performing_action)
for fetched_category in all_categories:
if fetched_category.id == category.id:
return (
fetched_category.name == category.name
and fetched_category.description == category.description
)
return False

View File

@ -38,6 +38,12 @@ class DATestUser(BaseModel):
headers: dict headers: dict
class DATestPersonaCategory(BaseModel):
id: int | None = None
name: str
description: str | None
class DATestCredential(BaseModel): class DATestCredential(BaseModel):
id: int id: int
name: str name: str
@ -119,6 +125,7 @@ class DATestPersona(BaseModel):
llm_model_version_override: str | None llm_model_version_override: str | None
users: list[str] users: list[str]
groups: list[int] groups: list[int]
category_id: int | None = None
# #

View File

@ -0,0 +1,92 @@
from uuid import uuid4
import pytest
from requests.exceptions import HTTPError
from tests.integration.common_utils.managers.persona import (
PersonaCategoryManager,
)
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestPersonaCategory
from tests.integration.common_utils.test_models import DATestUser
def test_persona_category_management(reset: None) -> None:
admin_user: DATestUser = UserManager.create(name="admin_user")
persona_category = DATestPersonaCategory(
id=None,
name=f"Test Category {uuid4()}",
description="A description for test category",
)
persona_category = PersonaCategoryManager.create(
category=persona_category,
user_performing_action=admin_user,
)
print(
f"Created persona category {persona_category.name} with id {persona_category.id}"
)
assert PersonaCategoryManager.verify(
category=persona_category,
user_performing_action=admin_user,
), "Persona category was not found after creation"
regular_user: DATestUser = UserManager.create(name="regular_user")
updated_persona_category = DATestPersonaCategory(
id=persona_category.id,
name=f"Updated {persona_category.name}",
description="An updated description",
)
with pytest.raises(HTTPError) as exc_info:
PersonaCategoryManager.update(
category=updated_persona_category,
user_performing_action=regular_user,
)
assert exc_info.value.response.status_code == 403
assert PersonaCategoryManager.verify(
category=persona_category,
user_performing_action=admin_user,
), "Persona category should not have been updated by non-admin user"
result = PersonaCategoryManager.delete(
category=persona_category,
user_performing_action=regular_user,
)
assert (
result is False
), "Regular user should not be able to delete the persona category"
assert PersonaCategoryManager.verify(
category=persona_category,
user_performing_action=admin_user,
), "Persona category should not have been deleted by non-admin user"
updated_persona_category.name = f"Updated {persona_category.name}"
updated_persona_category.description = "An updated description"
updated_persona_category = PersonaCategoryManager.update(
category=updated_persona_category,
user_performing_action=admin_user,
)
print(f"Updated persona category to {updated_persona_category.name}")
assert PersonaCategoryManager.verify(
category=updated_persona_category,
user_performing_action=admin_user,
), "Persona category was not updated by admin"
success = PersonaCategoryManager.delete(
category=persona_category,
user_performing_action=admin_user,
)
assert success, "Admin user should be able to delete the persona category"
print(
f"Deleted persona category {persona_category.name} with id {persona_category.id}"
)
assert not PersonaCategoryManager.verify(
category=persona_category,
user_performing_action=admin_user,
), "Persona category should not exist after deletion by admin"

View File

@ -23,8 +23,16 @@ import {
SelectorFormField, SelectorFormField,
TextFormField, TextFormField,
} from "@/components/admin/connectors/Field"; } from "@/components/admin/connectors/Field";
import {
Card,
CardHeader,
CardTitle,
CardContent,
CardFooter,
} from "@/components/ui/card";
import { usePopup } from "@/components/admin/connectors/Popup"; import { usePopup } from "@/components/admin/connectors/Popup";
import { getDisplayNameForModel } from "@/lib/hooks"; import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown"; import { Option } from "@/components/Dropdown";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences"; import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
@ -46,8 +54,14 @@ import * as Yup from "yup";
import { FullLLMProvider } from "../configuration/llm/interfaces"; import { FullLLMProvider } from "../configuration/llm/interfaces";
import CollapsibleSection from "./CollapsibleSection"; import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums"; import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { Persona, StarterMessage } from "./interfaces"; import { Persona, PersonaCategory, StarterMessage } from "./interfaces";
import { createPersona, updatePersona } from "./lib"; import {
createPersonaCategory,
createPersona,
deletePersonaCategory,
updatePersonaCategory,
updatePersona,
} from "./lib";
import { Popover } from "@/components/popover/Popover"; import { Popover } from "@/components/popover/Popover";
import { import {
CameraIcon, CameraIcon,
@ -59,6 +73,8 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { buildImgUrl } from "@/app/chat/files/images/utils"; import { buildImgUrl } from "@/app/chat/files/images/utils";
import { LlmList } from "@/components/llm/LLMList"; import { LlmList } from "@/components/llm/LLMList";
import { useAssistants } from "@/components/context/AssistantsContext"; import { useAssistants } from "@/components/context/AssistantsContext";
import { Input } from "@/components/ui/input";
import { CategoryCard } from "./CategoryCard";
function findSearchTool(tools: ToolSnapshot[]) { function findSearchTool(tools: ToolSnapshot[]) {
return tools.find((tool) => tool.in_code_tool_id === "SearchTool"); return tools.find((tool) => tool.in_code_tool_id === "SearchTool");
@ -107,6 +123,7 @@ export function AssistantEditor({
const router = useRouter(); const router = useRouter();
const { popup, setPopup } = usePopup(); const { popup, setPopup } = usePopup();
const { data: categories, refreshCategories } = useCategories();
const colorOptions = [ const colorOptions = [
"#FF6FBF", "#FF6FBF",
@ -119,6 +136,7 @@ export function AssistantEditor({
]; ];
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const [showPersonaCategory, setShowPersonaCategory] = useState(!admin);
// state to persist across formik reformatting // state to persist across formik reformatting
const [defautIconColor, _setDeafultIconColor] = useState( const [defautIconColor, _setDeafultIconColor] = useState(
@ -211,6 +229,7 @@ export function AssistantEditor({
icon_color: existingPersona?.icon_color ?? defautIconColor, icon_color: existingPersona?.icon_color ?? defautIconColor,
icon_shape: existingPersona?.icon_shape ?? defaultIconShape, icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
uploaded_image: null, uploaded_image: null,
category_id: existingPersona?.category_id ?? null,
// EE Only // EE Only
groups: existingPersona?.groups ?? [], groups: existingPersona?.groups ?? [],
@ -255,6 +274,7 @@ export function AssistantEditor({
icon_color: Yup.string(), icon_color: Yup.string(),
icon_shape: Yup.number(), icon_shape: Yup.number(),
uploaded_image: Yup.mixed().nullable(), uploaded_image: Yup.mixed().nullable(),
category_id: Yup.number().nullable(),
// EE Only // EE Only
groups: Yup.array().of(Yup.number()), groups: Yup.array().of(Yup.number()),
}) })
@ -968,6 +988,189 @@ export function AssistantEditor({
)} )}
</div> </div>
</div> </div>
{admin && (
<AdvancedOptionsToggle
title="Categories"
showAdvancedOptions={showPersonaCategory}
setShowAdvancedOptions={setShowPersonaCategory}
/>
)}
{showPersonaCategory && (
<>
{categories && categories.length > 0 && (
<div className="my-2">
<div className="flex gap-x-2 items-center">
<div className="block font-medium text-base">
Category
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
Group similar assistants together by category
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<SelectorFormField
includeReset
name="category_id"
options={categories.map((category) => ({
name: category.name,
value: category.id,
}))}
/>
</div>
)}
{admin && (
<>
<div className="my-2">
<div className="flex gap-x-2 items-center mb-2">
<div className="block font-medium text-base">
Create New Category
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
Create a new category to group similar
assistants together
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<div className="grid grid-cols-[1fr,3fr,auto] gap-4">
<TextFormField
fontSize="sm"
name="newCategoryName"
label="Category Name"
placeholder="e.g. Development"
/>
<TextFormField
fontSize="sm"
name="newCategoryDescription"
label="Category Description"
placeholder="e.g. Assistants for software development"
/>
<div className="flex items-end">
<Button
type="button"
onClick={async () => {
const name = values.newCategoryName;
const description =
values.newCategoryDescription;
if (!name || !description) return;
try {
const response = await createPersonaCategory(
name,
description
);
if (response.ok) {
setPopup({
message: `Category "${name}" created successfully`,
type: "success",
});
} else {
throw new Error(await response.text());
}
} catch (error) {
setPopup({
message: `Failed to create category - ${error}`,
type: "error",
});
}
await refreshCategories();
setFieldValue("newCategoryName", "");
setFieldValue("newCategoryDescription", "");
}}
>
Create
</Button>
</div>
</div>
</div>
{categories && categories.length > 0 && (
<div className="my-2 w-full">
<div className="flex gap-x-2 items-center mb-2">
<div className="block font-medium text-base">
Manage categories
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
Manage existing categories or create new ones
to group similar assistants
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<div className="gap-4 w-full flex-wrap flex">
{categories &&
categories.map((category: PersonaCategory) => (
<CategoryCard
setPopup={setPopup}
key={category.id}
category={category}
onUpdate={async (id, name, description) => {
const response =
await updatePersonaCategory(
id,
name,
description
);
if (response?.ok) {
setPopup({
message: `Category "${name}" updated successfully`,
type: "success",
});
} else {
setPopup({
message: `Failed to update category - ${await response.text()}`,
type: "error",
});
}
}}
onDelete={async (id) => {
const response =
await deletePersonaCategory(id);
if (response?.ok) {
setPopup({
message: `Category deleted successfully`,
type: "success",
});
} else {
setPopup({
message: `Failed to delete category - ${await response.text()}`,
type: "error",
});
}
}}
refreshCategories={refreshCategories}
/>
))}
</div>
</div>
)}
</>
)}
</>
)}
<Separator /> <Separator />
<AdvancedOptionsToggle <AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions} showAdvancedOptions={showAdvancedOptions}

View File

@ -0,0 +1,105 @@
import { useState } from "react";
import {
Card,
CardHeader,
CardTitle,
CardContent,
CardFooter,
} from "@/components/ui/card";
import { Input } from "@/components/ui/input";
import { Textarea } from "@/components/ui/textarea";
import { Button } from "@/components/ui/button";
import { PersonaCategory } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
interface CategoryCardProps {
category: PersonaCategory;
onUpdate: (id: number, name: string, description: string) => void;
onDelete: (id: number) => void;
refreshCategories: () => Promise<void>;
setPopup: (popup: PopupSpec) => void;
}
export function CategoryCard({
category,
onUpdate,
onDelete,
refreshCategories,
}: CategoryCardProps) {
const [isEditing, setIsEditing] = useState(false);
const [name, setName] = useState(category.name);
const [description, setDescription] = useState(category.description);
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
await onUpdate(category.id, name, description);
await refreshCategories();
setIsEditing(false);
};
const handleEdit = (e: React.MouseEvent) => {
e.preventDefault();
setIsEditing(true);
};
return (
<Card key={category.id} className="w-full max-w-sm">
<CardHeader className="w-full">
<CardTitle className="text-2xl font-bold">
{isEditing ? (
<Input
value={name}
onChange={(e) => setName(e.target.value)}
className="text-lg font-semibold"
/>
) : (
<span>{category.name}</span>
)}
</CardTitle>
</CardHeader>
<CardContent className="w-full">
{isEditing ? (
<Textarea
value={description}
onChange={(e) => setDescription(e.target.value)}
className="resize-none w-full"
/>
) : (
<p className="text-sm text-gray-600">{category.description}</p>
)}
</CardContent>
<CardFooter className="flex justify-end space-x-2">
{isEditing ? (
<>
<Button type="button" variant="outline" onClick={handleSubmit}>
Save
</Button>
<Button
type="button"
onClick={() => setIsEditing(false)}
variant="default"
>
Cancel
</Button>
</>
) : (
<>
<Button type="button" onClick={handleEdit} variant="outline">
Edit
</Button>
<Button
type="button"
variant="destructive"
onClick={async (e) => {
e.preventDefault();
await onDelete(category.id);
await refreshCategories();
}}
>
Delete
</Button>
</>
)}
</CardFooter>
</Card>
);
}

View File

@ -42,4 +42,11 @@ export interface Persona {
icon_shape?: number; icon_shape?: number;
icon_color?: string; icon_color?: string;
uploaded_image_id?: string; uploaded_image_id?: string;
category_id?: number | null;
}
export interface PersonaCategory {
id: number;
name: string;
description: string;
} }

View File

@ -23,6 +23,7 @@ interface PersonaCreationRequest {
uploaded_image: File | null; uploaded_image: File | null;
search_start_date: Date | null; search_start_date: Date | null;
is_default_persona: boolean; is_default_persona: boolean;
category_id: number | null;
} }
interface PersonaUpdateRequest { interface PersonaUpdateRequest {
@ -48,6 +49,7 @@ interface PersonaUpdateRequest {
remove_image: boolean; remove_image: boolean;
uploaded_image: File | null; uploaded_image: File | null;
search_start_date: Date | null; search_start_date: Date | null;
category_id: number | null;
} }
function promptNameFromPersonaName(personaName: string) { function promptNameFromPersonaName(personaName: string) {
@ -108,6 +110,42 @@ function updatePrompt({
}); });
} }
export const createPersonaCategory = (name: string, description: string) => {
return fetch("/api/admin/persona/categories", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ name, description }),
});
};
export const deletePersonaCategory = (categoryId: number) => {
return fetch(`/api/admin/persona/category/${categoryId}`, {
method: "DELETE",
headers: {
"Content-Type": "application/json",
},
});
};
export const updatePersonaCategory = (
id: number,
name: string,
description: string
) => {
return fetch(`/api/admin/persona/category/${id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
category_name: name,
category_description: description,
}),
});
};
function buildPersonaAPIBody( function buildPersonaAPIBody(
creationRequest: PersonaCreationRequest | PersonaUpdateRequest, creationRequest: PersonaCreationRequest | PersonaUpdateRequest,
promptId: number, promptId: number,
@ -127,6 +165,7 @@ function buildPersonaAPIBody(
icon_shape, icon_shape,
remove_image, remove_image,
search_start_date, search_start_date,
category_id,
} = creationRequest; } = creationRequest;
const is_default_persona = const is_default_persona =
@ -156,6 +195,7 @@ function buildPersonaAPIBody(
remove_image, remove_image,
search_start_date, search_start_date,
is_default_persona, is_default_persona,
category_id,
}; };
} }

View File

@ -108,7 +108,7 @@ function ToolForm({
placeholder="Enter your OpenAPI schema here" placeholder="Enter your OpenAPI schema here"
isTextArea={true} isTextArea={true}
defaultHeight="h-96" defaultHeight="h-96"
fontSize="text-sm" fontSize="sm"
isCode isCode
hideError hideError
/> />

View File

@ -0,0 +1,28 @@
import { Badge } from "@/components/ui/badge";
import { PersonaCategory as PersonaCategoryType } from "../admin/assistants/interfaces";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
export default function PersonaCategory({
personaCategory,
}: {
personaCategory: PersonaCategoryType;
}) {
return (
<TooltipProvider>
<Tooltip>
<TooltipTrigger className="cursor-help">
<Badge variant="purple">{personaCategory.name}</Badge>
</TooltipTrigger>
<TooltipContent>
<p>{personaCategory.description}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
);
}

View File

@ -1,6 +1,9 @@
"use client"; "use client";
import { Persona } from "@/app/admin/assistants/interfaces"; import {
Persona,
PersonaCategory as PersonaCategoryType,
} from "@/app/admin/assistants/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon"; import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { User } from "@/lib/types"; import { User } from "@/lib/types";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
@ -17,6 +20,18 @@ import { AssistantTools } from "../ToolsDisplay";
import { classifyAssistants } from "@/lib/assistants/utils"; import { classifyAssistants } from "@/lib/assistants/utils";
import { useAssistants } from "@/components/context/AssistantsContext"; import { useAssistants } from "@/components/context/AssistantsContext";
import { useUser } from "@/components/user/UserProvider"; import { useUser } from "@/components/user/UserProvider";
import PersonaCategory from "../PersonaCategory";
import { useCategories } from "@/lib/hooks";
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectLabel,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
export function AssistantGalleryCard({ export function AssistantGalleryCard({
assistant, assistant,
user, user,
@ -28,8 +43,10 @@ export function AssistantGalleryCard({
setPopup: (popup: PopupSpec) => void; setPopup: (popup: PopupSpec) => void;
selectedAssistant: boolean; selectedAssistant: boolean;
}) { }) {
const { data: categories } = useCategories();
const { refreshUser } = useUser(); const { refreshUser } = useUser();
const router = useRouter();
return ( return (
<div <div
key={assistant.id} key={assistant.id}
@ -129,7 +146,6 @@ export function AssistantGalleryCard({
</div> </div>
)} )}
</div> </div>
<p className="text-sm mt-2">{assistant.description}</p> <p className="text-sm mt-2">{assistant.description}</p>
<p className="text-subtle text-sm my-2"> <p className="text-subtle text-sm my-2">
Author: {assistant.owner?.email || "Danswer"} Author: {assistant.owner?.email || "Danswer"}
@ -137,6 +153,16 @@ export function AssistantGalleryCard({
{assistant.tools.length > 0 && ( {assistant.tools.length > 0 && (
<AssistantTools list assistant={assistant} /> <AssistantTools list assistant={assistant} />
)} )}
{assistant.category_id && categories && (
<PersonaCategory
personaCategory={
categories?.find(
(category: PersonaCategoryType) =>
category.id === assistant.category_id
)!
}
/>
)}
</div> </div>
); );
} }
@ -144,10 +170,12 @@ export function AssistantsGallery() {
const { assistants } = useAssistants(); const { assistants } = useAssistants();
const { user } = useUser(); const { user } = useUser();
const { data: categories } = useCategories();
const router = useRouter(); const router = useRouter();
const [searchQuery, setSearchQuery] = useState(""); const [searchQuery, setSearchQuery] = useState("");
const { popup, setPopup } = usePopup(); const { popup, setPopup } = usePopup();
const [selectedCategory, setSelectedCategory] = useState<number | null>(null);
const { visibleAssistants, hiddenAssistants: _ } = classifyAssistants( const { visibleAssistants, hiddenAssistants: _ } = classifyAssistants(
user, user,
@ -158,16 +186,24 @@ export function AssistantsGallery() {
.filter((assistant) => assistant.is_default_persona) .filter((assistant) => assistant.is_default_persona)
.filter( .filter(
(assistant) => (assistant) =>
assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) || (assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
assistant.description.toLowerCase().includes(searchQuery.toLowerCase()) assistant.description
.toLowerCase()
.includes(searchQuery.toLowerCase())) &&
(selectedCategory === null ||
selectedCategory === assistant.category_id)
); );
const nonDefaultAssistants = assistants const nonDefaultAssistants = assistants
.filter((assistant) => !assistant.is_default_persona) .filter((assistant) => !assistant.is_default_persona)
.filter( .filter(
(assistant) => (assistant) =>
assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) || (assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
assistant.description.toLowerCase().includes(searchQuery.toLowerCase()) assistant.description
.toLowerCase()
.includes(searchQuery.toLowerCase())) &&
(selectedCategory === null ||
selectedCategory === assistant.category_id)
); );
return ( return (
@ -196,7 +232,7 @@ export function AssistantsGallery() {
</Button> </Button>
</div> </div>
<div className="mt-4 mb-12"> <div className="mt-4 mb-6">
<div className="relative"> <div className="relative">
<input <input
type="text" type="text"
@ -238,6 +274,58 @@ export function AssistantsGallery() {
</div> </div>
</div> </div>
{categories && categories?.length > 0 && (
<div className="mb-8">
<Select
value={selectedCategory?.toString() || "all"}
onValueChange={(value) =>
setSelectedCategory(value === "all" ? null : parseInt(value))
}
>
<SelectTrigger
className="
w-[240px]
bg-background
border-2
border-background-strong
text-text-500
rounded-lg
shadow-sm
hover:bg-background-emphasis
hover:border-primary-500/50
hover:text-primary-500
transition-all
duration-200
"
>
<SelectValue placeholder="Filter by category..." />
</SelectTrigger>
<SelectContent className="bg-background border-background-strong">
<SelectGroup>
<SelectLabel className="text-sm font-medium text-text-400">
Categories
</SelectLabel>
<SelectItem
value="all"
className="cursor-pointer hover:bg-background-emphasis"
>
All Categories
</SelectItem>
{categories.map((category) => (
<SelectItem
key={category.id}
value={category.id.toString()}
className="cursor-pointer hover:bg-background-emphasis"
>
{category.name}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</div>
)}
{defaultAssistants.length == 0 && {defaultAssistants.length == 0 &&
nonDefaultAssistants.length == 0 && nonDefaultAssistants.length == 0 &&
assistants.length != 0 && ( assistants.length != 0 && (

View File

@ -5,24 +5,24 @@ import { FiChevronDown, FiChevronRight } from "react-icons/fi";
interface AdvancedOptionsToggleProps { interface AdvancedOptionsToggleProps {
showAdvancedOptions: boolean; showAdvancedOptions: boolean;
setShowAdvancedOptions: (show: boolean) => void; setShowAdvancedOptions: (show: boolean) => void;
title?: string;
} }
export function AdvancedOptionsToggle({ export function AdvancedOptionsToggle({
showAdvancedOptions, showAdvancedOptions,
setShowAdvancedOptions, setShowAdvancedOptions,
title,
}: AdvancedOptionsToggleProps) { }: AdvancedOptionsToggleProps) {
return ( return (
<div> <Button
<Button type="button"
type="button" variant="link"
variant="link" size="sm"
size="sm" icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight} onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)} className="text-xs !p-0 text-text-950 hover:text-text-500"
className="text-xs text-text-950 hover:text-text-500" >
> {title || "Advanced Options"}
Advanced Options </Button>
</Button>
</div>
); );
} }

View File

@ -161,7 +161,7 @@ export function TextFormField({
error?: string; error?: string;
defaultHeight?: string; defaultHeight?: string;
isCode?: boolean; isCode?: boolean;
fontSize?: "text-sm" | "text-base" | "text-lg"; fontSize?: "sm" | "md" | "lg";
hideError?: boolean; hideError?: boolean;
tooltip?: string; tooltip?: string;
explanationText?: string; explanationText?: string;
@ -187,12 +187,36 @@ export function TextFormField({
onChange(e as React.ChangeEvent<HTMLInputElement>); onChange(e as React.ChangeEvent<HTMLInputElement>);
} }
}; };
const textSizeClasses = {
sm: {
label: "text-sm",
input: "text-sm",
placeholder: "text-sm",
},
md: {
label: "text-base",
input: "text-base",
placeholder: "text-base",
},
lg: {
label: "text-lg",
input: "text-lg",
placeholder: "text-lg",
},
};
const sizeClass = textSizeClasses[fontSize || "md"];
return ( return (
<div className={`w-full ${width}`}> <div className={`w-full ${width}`}>
<div className="flex gap-x-2 items-center"> <div className="flex gap-x-2 items-center">
{!removeLabel && ( {!removeLabel && (
<Label className="text-text-950" small={small}> <Label
className={`${
small ? "text-text-950" : "text-text-700 font-normal"
} ${sizeClass.label}`}
small={small}
>
{label} {label}
</Label> </Label>
)} )}
@ -221,7 +245,7 @@ export function TextFormField({
name={name} name={name}
id={name} id={name}
className={` className={`
${small && "text-sm"} ${small && sizeClass.input}
border border
border-border border-border
rounded-md rounded-md
@ -230,10 +254,10 @@ export function TextFormField({
px-3 px-3
mt-1 mt-1
placeholder:font-description placeholder:font-description
placeholder:text-base placeholder:${sizeClass.placeholder}
placeholder:text-text-400 placeholder:text-text-400
${heightString} ${heightString}
${fontSize} ${sizeClass.input}
${disabled ? " bg-background-strong" : " bg-white"} ${disabled ? " bg-background-strong" : " bg-white"}
${isCode ? " font-mono" : ""} ${isCode ? " font-mono" : ""}
`} `}
@ -585,6 +609,7 @@ interface SelectorFormFieldProps {
onSelect?: (selected: string | number | null) => void; onSelect?: (selected: string | number | null) => void;
defaultValue?: string; defaultValue?: string;
tooltip?: string; tooltip?: string;
includeReset?: boolean;
} }
export function SelectorFormField({ export function SelectorFormField({
@ -597,6 +622,7 @@ export function SelectorFormField({
onSelect, onSelect,
defaultValue, defaultValue,
tooltip, tooltip,
includeReset = false,
}: SelectorFormFieldProps) { }: SelectorFormFieldProps) {
const [field] = useField<string>(name); const [field] = useField<string>(name);
const { setFieldValue } = useFormikContext(); const { setFieldValue } = useFormikContext();
@ -619,7 +645,11 @@ export function SelectorFormField({
<Select <Select
value={field.value || defaultValue} value={field.value || defaultValue}
onValueChange={ onValueChange={
onSelect || ((selected) => setFieldValue(name, selected)) onSelect ||
((selected) =>
selected == "__none__"
? setFieldValue(name, null)
: setFieldValue(name, selected))
} }
defaultValue={defaultValue} defaultValue={defaultValue}
> >
@ -649,6 +679,14 @@ export function SelectorFormField({
</SelectItem> </SelectItem>
)) ))
)} )}
{includeReset && (
<SelectItem
value={"__none__"}
onSelect={() => setFieldValue(name, null)}
>
None
</SelectItem>
)}
</SelectContent> </SelectContent>
)} )}
</Select> </Select>

View File

@ -15,6 +15,7 @@ import { ChatSession } from "@/app/chat/interfaces";
import { UsersResponse } from "./users/interfaces"; import { UsersResponse } from "./users/interfaces";
import { Credential } from "./connectors/credentials"; import { Credential } from "./connectors/credentials";
import { SettingsContext } from "@/components/settings/SettingsProvider"; import { SettingsContext } from "@/components/settings/SettingsProvider";
import { PersonaCategory } from "@/app/admin/assistants/interfaces";
const CREDENTIAL_URL = "/api/manage/admin/credential"; const CREDENTIAL_URL = "/api/manage/admin/credential";
@ -83,6 +84,19 @@ export const useConnectorCredentialIndexingStatus = (
}; };
}; };
export const useCategories = () => {
const { mutate } = useSWRConfig();
const swrResponse = useSWR<PersonaCategory[]>(
"/api/persona/categories",
errorHandlingFetcher
);
return {
...swrResponse,
refreshCategories: () => mutate("/api/persona/categories"),
};
};
export const useTimeRange = (initialValue?: DateRangePickerValue) => { export const useTimeRange = (initialValue?: DateRangePickerValue) => {
return useState<DateRangePickerValue | null>(null); return useState<DateRangePickerValue | null>(null);
}; };