Add assistant categories (#3064)

* add assistant categories v1

* functionality finalized

* finalize

* update assistant category display

* nit

* add tests

* post rebase update

* minor update to tests

* update typing

* finalize

* typing

* nit

* alembic

* alembic (once again)
This commit is contained in:
pablodanswer 2024-11-18 12:33:48 -08:00 committed by GitHub
parent 33ee899408
commit a7d95661b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 919 additions and 31 deletions

View File

@ -0,0 +1,45 @@
"""add persona categories
Revision ID: 47e5bef3a1d7
Revises: dfbe9e93d3c7
Create Date: 2024-11-05 18:55:02.221064
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "47e5bef3a1d7"
down_revision = "dfbe9e93d3c7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the persona_category table
op.create_table(
"persona_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add category_id to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_persona_category",
"persona",
"persona_category",
["category_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@ -1363,6 +1363,9 @@ class Persona(Base):
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
Enum(RecencyBiasSetting, native_enum=False)
)
category_id: Mapped[int | None] = mapped_column(
ForeignKey("persona_category.id"), nullable=True
)
# Allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
@ -1434,6 +1437,9 @@ class Persona(Base):
secondary="persona__user_group",
viewonly=True,
)
category: Mapped["PersonaCategory"] = relationship(
"PersonaCategory", back_populates="personas"
)
# Default personas loaded via yaml cannot have the same name
__table_args__ = (
@ -1446,6 +1452,17 @@ class Persona(Base):
)
class PersonaCategory(Base):
__tablename__ = "persona_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
description: Mapped[str | None] = mapped_column(String, nullable=True)
personas: Mapped[list["Persona"]] = relationship(
"Persona", back_populates="category"
)
AllowedAnswerFilters = (
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
)

View File

@ -26,6 +26,7 @@ from danswer.db.models import DocumentSet
from danswer.db.models import Persona
from danswer.db.models import Persona__User
from danswer.db.models import Persona__UserGroup
from danswer.db.models import PersonaCategory
from danswer.db.models import Prompt
from danswer.db.models import StarterMessage
from danswer.db.models import Tool
@ -417,6 +418,7 @@ def upsert_persona(
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool = False,
category_id: int | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
@ -487,7 +489,7 @@ def upsert_persona(
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.is_default_persona = is_default_persona
persona.category_id = category_id
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@ -528,6 +530,7 @@ def upsert_persona(
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona,
category_id=category_id,
)
db_session.add(persona)
@ -744,3 +747,39 @@ def delete_persona_by_name(
db_session.execute(stmt)
db_session.commit()
def get_assistant_categories(db_session: Session) -> list[PersonaCategory]:
return db_session.query(PersonaCategory).all()
def create_assistant_category(
db_session: Session, name: str, description: str
) -> PersonaCategory:
category = PersonaCategory(name=name, description=description)
db_session.add(category)
db_session.commit()
return category
def update_persona_category(
category_id: int,
category_description: str,
category_name: str,
db_session: Session,
) -> None:
persona_category = (
db_session.query(PersonaCategory)
.filter(PersonaCategory.id == category_id)
.one_or_none()
)
if persona_category is None:
raise ValueError(f"Persona category with ID {category_id} does not exist")
persona_category.description = category_description
persona_category.name = category_name
db_session.commit()
def delete_persona_category(category_id: int, db_session: Session) -> None:
db_session.query(PersonaCategory).filter(PersonaCategory.id == category_id).delete()
db_session.commit()

View File

@ -18,12 +18,16 @@ from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.persona import create_assistant_category
from danswer.db.persona import create_update_persona
from danswer.db.persona import delete_persona_category
from danswer.db.persona import get_assistant_categories
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import mark_persona_as_not_deleted
from danswer.db.persona import update_all_personas_display_priority
from danswer.db.persona import update_persona_category
from danswer.db.persona import update_persona_public_status
from danswer.db.persona import update_persona_shared_users
from danswer.db.persona import update_persona_visibility
@ -32,6 +36,8 @@ from danswer.file_store.models import ChatFileType
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import ImageGenerationToolStatus
from danswer.server.features.persona.models import PersonaCategoryCreate
from danswer.server.features.persona.models import PersonaCategoryResponse
from danswer.server.features.persona.models import PersonaSharedNotificationData
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse
@ -39,6 +45,7 @@ from danswer.server.models import DisplayPriorityRequest
from danswer.tools.utils import is_image_generation_available
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -184,6 +191,59 @@ def update_persona(
)
class PersonaCategoryPatchRequest(BaseModel):
category_description: str
category_name: str
@basic_router.get("/categories")
def get_categories(
db: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> list[PersonaCategoryResponse]:
return [
PersonaCategoryResponse.from_model(category)
for category in get_assistant_categories(db_session=db)
]
@admin_router.post("/categories")
def create_category(
category: PersonaCategoryCreate,
db: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> PersonaCategoryResponse:
"""Create a new assistant category"""
category_model = create_assistant_category(
name=category.name, description=category.description, db_session=db
)
return PersonaCategoryResponse.from_model(category_model)
@admin_router.patch("/category/{category_id}")
def patch_persona_category(
category_id: int,
persona_category_patch_request: PersonaCategoryPatchRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_category(
category_id=category_id,
category_description=persona_category_patch_request.category_description,
category_name=persona_category_patch_request.category_name,
db_session=db_session,
)
@admin_router.delete("/category/{category_id}")
def delete_category(
category_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
delete_persona_category(category_id=category_id, db_session=db_session)
class PersonaShareRequest(BaseModel):
user_ids: list[UUID]

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
from pydantic import Field
from danswer.db.models import Persona
from danswer.db.models import PersonaCategory
from danswer.db.models import StarterMessage
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet
@ -41,6 +42,7 @@ class CreatePersonaRequest(BaseModel):
is_default_persona: bool = False
display_priority: int | None = None
search_start_date: datetime | None = None
category_id: int | None = None
class PersonaSnapshot(BaseModel):
@ -68,6 +70,7 @@ class PersonaSnapshot(BaseModel):
uploaded_image_id: str | None = None
is_default_persona: bool
search_start_date: datetime | None = None
category_id: int | None = None
@classmethod
def from_model(
@ -115,6 +118,7 @@ class PersonaSnapshot(BaseModel):
icon_shape=persona.icon_shape,
uploaded_image_id=persona.uploaded_image_id,
search_start_date=persona.search_start_date,
category_id=persona.category_id,
)
@ -128,3 +132,22 @@ class PersonaSharedNotificationData(BaseModel):
class ImageGenerationToolStatus(BaseModel):
is_available: bool
class PersonaCategoryCreate(BaseModel):
name: str
description: str
class PersonaCategoryResponse(BaseModel):
id: int
name: str
description: str | None
@classmethod
def from_model(cls, category: PersonaCategory) -> "PersonaCategoryResponse":
return PersonaCategoryResponse(
id=category.id,
name=category.name,
description=category.description,
)

View File

@ -7,6 +7,7 @@ from danswer.server.features.persona.models import PersonaSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestPersona
from tests.integration.common_utils.test_models import DATestPersonaCategory
from tests.integration.common_utils.test_models import DATestUser
@ -27,6 +28,7 @@ class PersonaManager:
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
category_id: int | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
@ -212,3 +214,83 @@ class PersonaManager:
else GENERAL_HEADERS,
)
return response.ok
class PersonaCategoryManager:
@staticmethod
def create(
category: DATestPersonaCategory,
user_performing_action: DATestUser | None = None,
) -> DATestPersonaCategory:
response = requests.post(
f"{API_SERVER_URL}/admin/persona/categories",
json={
"name": category.name,
"description": category.description,
},
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
response_data = response.json()
category.id = response_data["id"]
return category
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[DATestPersonaCategory]:
response = requests.get(
f"{API_SERVER_URL}/persona/categories",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [DATestPersonaCategory(**category) for category in response.json()]
@staticmethod
def update(
category: DATestPersonaCategory,
user_performing_action: DATestUser | None = None,
) -> DATestPersonaCategory:
response = requests.patch(
f"{API_SERVER_URL}/admin/persona/category/{category.id}",
json={
"category_name": category.name,
"category_description": category.description,
},
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return category
@staticmethod
def delete(
category: DATestPersonaCategory,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/admin/persona/category/{category.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
return response.ok
@staticmethod
def verify(
category: DATestPersonaCategory,
user_performing_action: DATestUser | None = None,
) -> bool:
all_categories = PersonaCategoryManager.get_all(user_performing_action)
for fetched_category in all_categories:
if fetched_category.id == category.id:
return (
fetched_category.name == category.name
and fetched_category.description == category.description
)
return False

View File

@ -38,6 +38,12 @@ class DATestUser(BaseModel):
headers: dict
class DATestPersonaCategory(BaseModel):
id: int | None = None
name: str
description: str | None
class DATestCredential(BaseModel):
id: int
name: str
@ -119,6 +125,7 @@ class DATestPersona(BaseModel):
llm_model_version_override: str | None
users: list[str]
groups: list[int]
category_id: int | None = None
#

View File

@ -0,0 +1,92 @@
from uuid import uuid4
import pytest
from requests.exceptions import HTTPError
from tests.integration.common_utils.managers.persona import (
PersonaCategoryManager,
)
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestPersonaCategory
from tests.integration.common_utils.test_models import DATestUser
def test_persona_category_management(reset: None) -> None:
admin_user: DATestUser = UserManager.create(name="admin_user")
persona_category = DATestPersonaCategory(
id=None,
name=f"Test Category {uuid4()}",
description="A description for test category",
)
persona_category = PersonaCategoryManager.create(
category=persona_category,
user_performing_action=admin_user,
)
print(
f"Created persona category {persona_category.name} with id {persona_category.id}"
)
assert PersonaCategoryManager.verify(
category=persona_category,
user_performing_action=admin_user,
), "Persona category was not found after creation"
regular_user: DATestUser = UserManager.create(name="regular_user")
updated_persona_category = DATestPersonaCategory(
id=persona_category.id,
name=f"Updated {persona_category.name}",
description="An updated description",
)
with pytest.raises(HTTPError) as exc_info:
PersonaCategoryManager.update(
category=updated_persona_category,
user_performing_action=regular_user,
)
assert exc_info.value.response.status_code == 403
assert PersonaCategoryManager.verify(
category=persona_category,
user_performing_action=admin_user,
), "Persona category should not have been updated by non-admin user"
result = PersonaCategoryManager.delete(
category=persona_category,
user_performing_action=regular_user,
)
assert (
result is False
), "Regular user should not be able to delete the persona category"
assert PersonaCategoryManager.verify(
category=persona_category,
user_performing_action=admin_user,
), "Persona category should not have been deleted by non-admin user"
updated_persona_category.name = f"Updated {persona_category.name}"
updated_persona_category.description = "An updated description"
updated_persona_category = PersonaCategoryManager.update(
category=updated_persona_category,
user_performing_action=admin_user,
)
print(f"Updated persona category to {updated_persona_category.name}")
assert PersonaCategoryManager.verify(
category=updated_persona_category,
user_performing_action=admin_user,
), "Persona category was not updated by admin"
success = PersonaCategoryManager.delete(
category=persona_category,
user_performing_action=admin_user,
)
assert success, "Admin user should be able to delete the persona category"
print(
f"Deleted persona category {persona_category.name} with id {persona_category.id}"
)
assert not PersonaCategoryManager.verify(
category=persona_category,
user_performing_action=admin_user,
), "Persona category should not exist after deletion by admin"

View File

@ -23,8 +23,16 @@ import {
SelectorFormField,
TextFormField,
} from "@/components/admin/connectors/Field";
import {
Card,
CardHeader,
CardTitle,
CardContent,
CardFooter,
} from "@/components/ui/card";
import { usePopup } from "@/components/admin/connectors/Popup";
import { getDisplayNameForModel } from "@/lib/hooks";
import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
@ -46,8 +54,14 @@ import * as Yup from "yup";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { Persona, StarterMessage } from "./interfaces";
import { createPersona, updatePersona } from "./lib";
import { Persona, PersonaCategory, StarterMessage } from "./interfaces";
import {
createPersonaCategory,
createPersona,
deletePersonaCategory,
updatePersonaCategory,
updatePersona,
} from "./lib";
import { Popover } from "@/components/popover/Popover";
import {
CameraIcon,
@ -59,6 +73,8 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { buildImgUrl } from "@/app/chat/files/images/utils";
import { LlmList } from "@/components/llm/LLMList";
import { useAssistants } from "@/components/context/AssistantsContext";
import { Input } from "@/components/ui/input";
import { CategoryCard } from "./CategoryCard";
function findSearchTool(tools: ToolSnapshot[]) {
return tools.find((tool) => tool.in_code_tool_id === "SearchTool");
@ -107,6 +123,7 @@ export function AssistantEditor({
const router = useRouter();
const { popup, setPopup } = usePopup();
const { data: categories, refreshCategories } = useCategories();
const colorOptions = [
"#FF6FBF",
@ -119,6 +136,7 @@ export function AssistantEditor({
];
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const [showPersonaCategory, setShowPersonaCategory] = useState(!admin);
// state to persist across formik reformatting
const [defautIconColor, _setDeafultIconColor] = useState(
@ -211,6 +229,7 @@ export function AssistantEditor({
icon_color: existingPersona?.icon_color ?? defautIconColor,
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
uploaded_image: null,
category_id: existingPersona?.category_id ?? null,
// EE Only
groups: existingPersona?.groups ?? [],
@ -255,6 +274,7 @@ export function AssistantEditor({
icon_color: Yup.string(),
icon_shape: Yup.number(),
uploaded_image: Yup.mixed().nullable(),
category_id: Yup.number().nullable(),
// EE Only
groups: Yup.array().of(Yup.number()),
})
@ -968,6 +988,189 @@ export function AssistantEditor({
)}
</div>
</div>
{admin && (
<AdvancedOptionsToggle
title="Categories"
showAdvancedOptions={showPersonaCategory}
setShowAdvancedOptions={setShowPersonaCategory}
/>
)}
{showPersonaCategory && (
<>
{categories && categories.length > 0 && (
<div className="my-2">
<div className="flex gap-x-2 items-center">
<div className="block font-medium text-base">
Category
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
Group similar assistants together by category
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<SelectorFormField
includeReset
name="category_id"
options={categories.map((category) => ({
name: category.name,
value: category.id,
}))}
/>
</div>
)}
{admin && (
<>
<div className="my-2">
<div className="flex gap-x-2 items-center mb-2">
<div className="block font-medium text-base">
Create New Category
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
Create a new category to group similar
assistants together
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<div className="grid grid-cols-[1fr,3fr,auto] gap-4">
<TextFormField
fontSize="sm"
name="newCategoryName"
label="Category Name"
placeholder="e.g. Development"
/>
<TextFormField
fontSize="sm"
name="newCategoryDescription"
label="Category Description"
placeholder="e.g. Assistants for software development"
/>
<div className="flex items-end">
<Button
type="button"
onClick={async () => {
const name = values.newCategoryName;
const description =
values.newCategoryDescription;
if (!name || !description) return;
try {
const response = await createPersonaCategory(
name,
description
);
if (response.ok) {
setPopup({
message: `Category "${name}" created successfully`,
type: "success",
});
} else {
throw new Error(await response.text());
}
} catch (error) {
setPopup({
message: `Failed to create category - ${error}`,
type: "error",
});
}
await refreshCategories();
setFieldValue("newCategoryName", "");
setFieldValue("newCategoryDescription", "");
}}
>
Create
</Button>
</div>
</div>
</div>
{categories && categories.length > 0 && (
<div className="my-2 w-full">
<div className="flex gap-x-2 items-center mb-2">
<div className="block font-medium text-base">
Manage categories
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
Manage existing categories or create new ones
to group similar assistants
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<div className="gap-4 w-full flex-wrap flex">
{categories &&
categories.map((category: PersonaCategory) => (
<CategoryCard
setPopup={setPopup}
key={category.id}
category={category}
onUpdate={async (id, name, description) => {
const response =
await updatePersonaCategory(
id,
name,
description
);
if (response?.ok) {
setPopup({
message: `Category "${name}" updated successfully`,
type: "success",
});
} else {
setPopup({
message: `Failed to update category - ${await response.text()}`,
type: "error",
});
}
}}
onDelete={async (id) => {
const response =
await deletePersonaCategory(id);
if (response?.ok) {
setPopup({
message: `Category deleted successfully`,
type: "success",
});
} else {
setPopup({
message: `Failed to delete category - ${await response.text()}`,
type: "error",
});
}
}}
refreshCategories={refreshCategories}
/>
))}
</div>
</div>
)}
</>
)}
</>
)}
<Separator />
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}

View File

@ -0,0 +1,105 @@
import { useState } from "react";
import {
Card,
CardHeader,
CardTitle,
CardContent,
CardFooter,
} from "@/components/ui/card";
import { Input } from "@/components/ui/input";
import { Textarea } from "@/components/ui/textarea";
import { Button } from "@/components/ui/button";
import { PersonaCategory } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
interface CategoryCardProps {
category: PersonaCategory;
onUpdate: (id: number, name: string, description: string) => void;
onDelete: (id: number) => void;
refreshCategories: () => Promise<void>;
setPopup: (popup: PopupSpec) => void;
}
export function CategoryCard({
category,
onUpdate,
onDelete,
refreshCategories,
}: CategoryCardProps) {
const [isEditing, setIsEditing] = useState(false);
const [name, setName] = useState(category.name);
const [description, setDescription] = useState(category.description);
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
await onUpdate(category.id, name, description);
await refreshCategories();
setIsEditing(false);
};
const handleEdit = (e: React.MouseEvent) => {
e.preventDefault();
setIsEditing(true);
};
return (
<Card key={category.id} className="w-full max-w-sm">
<CardHeader className="w-full">
<CardTitle className="text-2xl font-bold">
{isEditing ? (
<Input
value={name}
onChange={(e) => setName(e.target.value)}
className="text-lg font-semibold"
/>
) : (
<span>{category.name}</span>
)}
</CardTitle>
</CardHeader>
<CardContent className="w-full">
{isEditing ? (
<Textarea
value={description}
onChange={(e) => setDescription(e.target.value)}
className="resize-none w-full"
/>
) : (
<p className="text-sm text-gray-600">{category.description}</p>
)}
</CardContent>
<CardFooter className="flex justify-end space-x-2">
{isEditing ? (
<>
<Button type="button" variant="outline" onClick={handleSubmit}>
Save
</Button>
<Button
type="button"
onClick={() => setIsEditing(false)}
variant="default"
>
Cancel
</Button>
</>
) : (
<>
<Button type="button" onClick={handleEdit} variant="outline">
Edit
</Button>
<Button
type="button"
variant="destructive"
onClick={async (e) => {
e.preventDefault();
await onDelete(category.id);
await refreshCategories();
}}
>
Delete
</Button>
</>
)}
</CardFooter>
</Card>
);
}

View File

@ -42,4 +42,11 @@ export interface Persona {
icon_shape?: number;
icon_color?: string;
uploaded_image_id?: string;
category_id?: number | null;
}
export interface PersonaCategory {
id: number;
name: string;
description: string;
}

View File

@ -23,6 +23,7 @@ interface PersonaCreationRequest {
uploaded_image: File | null;
search_start_date: Date | null;
is_default_persona: boolean;
category_id: number | null;
}
interface PersonaUpdateRequest {
@ -48,6 +49,7 @@ interface PersonaUpdateRequest {
remove_image: boolean;
uploaded_image: File | null;
search_start_date: Date | null;
category_id: number | null;
}
function promptNameFromPersonaName(personaName: string) {
@ -108,6 +110,42 @@ function updatePrompt({
});
}
export const createPersonaCategory = (name: string, description: string) => {
return fetch("/api/admin/persona/categories", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ name, description }),
});
};
export const deletePersonaCategory = (categoryId: number) => {
return fetch(`/api/admin/persona/category/${categoryId}`, {
method: "DELETE",
headers: {
"Content-Type": "application/json",
},
});
};
export const updatePersonaCategory = (
id: number,
name: string,
description: string
) => {
return fetch(`/api/admin/persona/category/${id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
category_name: name,
category_description: description,
}),
});
};
function buildPersonaAPIBody(
creationRequest: PersonaCreationRequest | PersonaUpdateRequest,
promptId: number,
@ -127,6 +165,7 @@ function buildPersonaAPIBody(
icon_shape,
remove_image,
search_start_date,
category_id,
} = creationRequest;
const is_default_persona =
@ -156,6 +195,7 @@ function buildPersonaAPIBody(
remove_image,
search_start_date,
is_default_persona,
category_id,
};
}

View File

@ -108,7 +108,7 @@ function ToolForm({
placeholder="Enter your OpenAPI schema here"
isTextArea={true}
defaultHeight="h-96"
fontSize="text-sm"
fontSize="sm"
isCode
hideError
/>

View File

@ -0,0 +1,28 @@
import { Badge } from "@/components/ui/badge";
import { PersonaCategory as PersonaCategoryType } from "../admin/assistants/interfaces";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
export default function PersonaCategory({
personaCategory,
}: {
personaCategory: PersonaCategoryType;
}) {
return (
<TooltipProvider>
<Tooltip>
<TooltipTrigger className="cursor-help">
<Badge variant="purple">{personaCategory.name}</Badge>
</TooltipTrigger>
<TooltipContent>
<p>{personaCategory.description}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
);
}

View File

@ -1,6 +1,9 @@
"use client";
import { Persona } from "@/app/admin/assistants/interfaces";
import {
Persona,
PersonaCategory as PersonaCategoryType,
} from "@/app/admin/assistants/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { User } from "@/lib/types";
import { Button } from "@/components/ui/button";
@ -17,6 +20,18 @@ import { AssistantTools } from "../ToolsDisplay";
import { classifyAssistants } from "@/lib/assistants/utils";
import { useAssistants } from "@/components/context/AssistantsContext";
import { useUser } from "@/components/user/UserProvider";
import PersonaCategory from "../PersonaCategory";
import { useCategories } from "@/lib/hooks";
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectLabel,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
export function AssistantGalleryCard({
assistant,
user,
@ -28,8 +43,10 @@ export function AssistantGalleryCard({
setPopup: (popup: PopupSpec) => void;
selectedAssistant: boolean;
}) {
const { data: categories } = useCategories();
const { refreshUser } = useUser();
const router = useRouter();
return (
<div
key={assistant.id}
@ -129,7 +146,6 @@ export function AssistantGalleryCard({
</div>
)}
</div>
<p className="text-sm mt-2">{assistant.description}</p>
<p className="text-subtle text-sm my-2">
Author: {assistant.owner?.email || "Danswer"}
@ -137,6 +153,16 @@ export function AssistantGalleryCard({
{assistant.tools.length > 0 && (
<AssistantTools list assistant={assistant} />
)}
{assistant.category_id && categories && (
<PersonaCategory
personaCategory={
categories?.find(
(category: PersonaCategoryType) =>
category.id === assistant.category_id
)!
}
/>
)}
</div>
);
}
@ -144,10 +170,12 @@ export function AssistantsGallery() {
const { assistants } = useAssistants();
const { user } = useUser();
const { data: categories } = useCategories();
const router = useRouter();
const [searchQuery, setSearchQuery] = useState("");
const { popup, setPopup } = usePopup();
const [selectedCategory, setSelectedCategory] = useState<number | null>(null);
const { visibleAssistants, hiddenAssistants: _ } = classifyAssistants(
user,
@ -158,16 +186,24 @@ export function AssistantsGallery() {
.filter((assistant) => assistant.is_default_persona)
.filter(
(assistant) =>
assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
assistant.description.toLowerCase().includes(searchQuery.toLowerCase())
(assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
assistant.description
.toLowerCase()
.includes(searchQuery.toLowerCase())) &&
(selectedCategory === null ||
selectedCategory === assistant.category_id)
);
const nonDefaultAssistants = assistants
.filter((assistant) => !assistant.is_default_persona)
.filter(
(assistant) =>
assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
assistant.description.toLowerCase().includes(searchQuery.toLowerCase())
(assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
assistant.description
.toLowerCase()
.includes(searchQuery.toLowerCase())) &&
(selectedCategory === null ||
selectedCategory === assistant.category_id)
);
return (
@ -196,7 +232,7 @@ export function AssistantsGallery() {
</Button>
</div>
<div className="mt-4 mb-12">
<div className="mt-4 mb-6">
<div className="relative">
<input
type="text"
@ -238,6 +274,58 @@ export function AssistantsGallery() {
</div>
</div>
{categories && categories?.length > 0 && (
<div className="mb-8">
<Select
value={selectedCategory?.toString() || "all"}
onValueChange={(value) =>
setSelectedCategory(value === "all" ? null : parseInt(value))
}
>
<SelectTrigger
className="
w-[240px]
bg-background
border-2
border-background-strong
text-text-500
rounded-lg
shadow-sm
hover:bg-background-emphasis
hover:border-primary-500/50
hover:text-primary-500
transition-all
duration-200
"
>
<SelectValue placeholder="Filter by category..." />
</SelectTrigger>
<SelectContent className="bg-background border-background-strong">
<SelectGroup>
<SelectLabel className="text-sm font-medium text-text-400">
Categories
</SelectLabel>
<SelectItem
value="all"
className="cursor-pointer hover:bg-background-emphasis"
>
All Categories
</SelectItem>
{categories.map((category) => (
<SelectItem
key={category.id}
value={category.id.toString()}
className="cursor-pointer hover:bg-background-emphasis"
>
{category.name}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</div>
)}
{defaultAssistants.length == 0 &&
nonDefaultAssistants.length == 0 &&
assistants.length != 0 && (

View File

@ -5,24 +5,24 @@ import { FiChevronDown, FiChevronRight } from "react-icons/fi";
interface AdvancedOptionsToggleProps {
showAdvancedOptions: boolean;
setShowAdvancedOptions: (show: boolean) => void;
title?: string;
}
export function AdvancedOptionsToggle({
showAdvancedOptions,
setShowAdvancedOptions,
title,
}: AdvancedOptionsToggleProps) {
return (
<div>
<Button
type="button"
variant="link"
size="sm"
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
className="text-xs text-text-950 hover:text-text-500"
>
Advanced Options
</Button>
</div>
<Button
type="button"
variant="link"
size="sm"
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
className="text-xs !p-0 text-text-950 hover:text-text-500"
>
{title || "Advanced Options"}
</Button>
);
}

View File

@ -161,7 +161,7 @@ export function TextFormField({
error?: string;
defaultHeight?: string;
isCode?: boolean;
fontSize?: "text-sm" | "text-base" | "text-lg";
fontSize?: "sm" | "md" | "lg";
hideError?: boolean;
tooltip?: string;
explanationText?: string;
@ -187,12 +187,36 @@ export function TextFormField({
onChange(e as React.ChangeEvent<HTMLInputElement>);
}
};
const textSizeClasses = {
sm: {
label: "text-sm",
input: "text-sm",
placeholder: "text-sm",
},
md: {
label: "text-base",
input: "text-base",
placeholder: "text-base",
},
lg: {
label: "text-lg",
input: "text-lg",
placeholder: "text-lg",
},
};
const sizeClass = textSizeClasses[fontSize || "md"];
return (
<div className={`w-full ${width}`}>
<div className="flex gap-x-2 items-center">
{!removeLabel && (
<Label className="text-text-950" small={small}>
<Label
className={`${
small ? "text-text-950" : "text-text-700 font-normal"
} ${sizeClass.label}`}
small={small}
>
{label}
</Label>
)}
@ -221,7 +245,7 @@ export function TextFormField({
name={name}
id={name}
className={`
${small && "text-sm"}
${small && sizeClass.input}
border
border-border
rounded-md
@ -230,10 +254,10 @@ export function TextFormField({
px-3
mt-1
placeholder:font-description
placeholder:text-base
placeholder:${sizeClass.placeholder}
placeholder:text-text-400
${heightString}
${fontSize}
${sizeClass.input}
${disabled ? " bg-background-strong" : " bg-white"}
${isCode ? " font-mono" : ""}
`}
@ -585,6 +609,7 @@ interface SelectorFormFieldProps {
onSelect?: (selected: string | number | null) => void;
defaultValue?: string;
tooltip?: string;
includeReset?: boolean;
}
export function SelectorFormField({
@ -597,6 +622,7 @@ export function SelectorFormField({
onSelect,
defaultValue,
tooltip,
includeReset = false,
}: SelectorFormFieldProps) {
const [field] = useField<string>(name);
const { setFieldValue } = useFormikContext();
@ -619,7 +645,11 @@ export function SelectorFormField({
<Select
value={field.value || defaultValue}
onValueChange={
onSelect || ((selected) => setFieldValue(name, selected))
onSelect ||
((selected) =>
selected == "__none__"
? setFieldValue(name, null)
: setFieldValue(name, selected))
}
defaultValue={defaultValue}
>
@ -649,6 +679,14 @@ export function SelectorFormField({
</SelectItem>
))
)}
{includeReset && (
<SelectItem
value={"__none__"}
onSelect={() => setFieldValue(name, null)}
>
None
</SelectItem>
)}
</SelectContent>
)}
</Select>

View File

@ -15,6 +15,7 @@ import { ChatSession } from "@/app/chat/interfaces";
import { UsersResponse } from "./users/interfaces";
import { Credential } from "./connectors/credentials";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { PersonaCategory } from "@/app/admin/assistants/interfaces";
const CREDENTIAL_URL = "/api/manage/admin/credential";
@ -83,6 +84,19 @@ export const useConnectorCredentialIndexingStatus = (
};
};
export const useCategories = () => {
const { mutate } = useSWRConfig();
const swrResponse = useSWR<PersonaCategory[]>(
"/api/persona/categories",
errorHandlingFetcher
);
return {
...swrResponse,
refreshCategories: () => mutate("/api/persona/categories"),
};
};
export const useTimeRange = (initialValue?: DateRangePickerValue) => {
return useState<DateRangePickerValue | null>(null);
};