mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-05 12:39:33 +02:00
Add assistant categories (#3064)
* add assistant categories v1 * functionality finalized * finalize * update assistant category display * nit * add tests * post rebase update * minor update to tests * update typing * finalize * typing * nit * alembic * alembic (once again)
This commit is contained in:
parent
33ee899408
commit
a7d95661b3
@ -0,0 +1,45 @@
|
|||||||
|
"""add persona categories
|
||||||
|
|
||||||
|
Revision ID: 47e5bef3a1d7
|
||||||
|
Revises: dfbe9e93d3c7
|
||||||
|
Create Date: 2024-11-05 18:55:02.221064
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "47e5bef3a1d7"
|
||||||
|
down_revision = "dfbe9e93d3c7"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Create the persona_category table
|
||||||
|
op.create_table(
|
||||||
|
"persona_category",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("name", sa.String(), nullable=False),
|
||||||
|
sa.Column("description", sa.String(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add category_id to persona table
|
||||||
|
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
|
||||||
|
op.create_foreign_key(
|
||||||
|
"fk_persona_category",
|
||||||
|
"persona",
|
||||||
|
"persona_category",
|
||||||
|
["category_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
|
||||||
|
op.drop_column("persona", "category_id")
|
||||||
|
op.drop_table("persona_category")
|
@ -1363,6 +1363,9 @@ class Persona(Base):
|
|||||||
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
|
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
|
||||||
Enum(RecencyBiasSetting, native_enum=False)
|
Enum(RecencyBiasSetting, native_enum=False)
|
||||||
)
|
)
|
||||||
|
category_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("persona_category.id"), nullable=True
|
||||||
|
)
|
||||||
# Allows the Persona to specify a different LLM version than is controlled
|
# Allows the Persona to specify a different LLM version than is controlled
|
||||||
# globablly via env variables. For flexibility, validity is not currently enforced
|
# globablly via env variables. For flexibility, validity is not currently enforced
|
||||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||||
@ -1434,6 +1437,9 @@ class Persona(Base):
|
|||||||
secondary="persona__user_group",
|
secondary="persona__user_group",
|
||||||
viewonly=True,
|
viewonly=True,
|
||||||
)
|
)
|
||||||
|
category: Mapped["PersonaCategory"] = relationship(
|
||||||
|
"PersonaCategory", back_populates="personas"
|
||||||
|
)
|
||||||
|
|
||||||
# Default personas loaded via yaml cannot have the same name
|
# Default personas loaded via yaml cannot have the same name
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
@ -1446,6 +1452,17 @@ class Persona(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaCategory(Base):
|
||||||
|
__tablename__ = "persona_category"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String, unique=True)
|
||||||
|
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||||
|
personas: Mapped[list["Persona"]] = relationship(
|
||||||
|
"Persona", back_populates="category"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AllowedAnswerFilters = (
|
AllowedAnswerFilters = (
|
||||||
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
|
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
|
||||||
)
|
)
|
||||||
|
@ -26,6 +26,7 @@ from danswer.db.models import DocumentSet
|
|||||||
from danswer.db.models import Persona
|
from danswer.db.models import Persona
|
||||||
from danswer.db.models import Persona__User
|
from danswer.db.models import Persona__User
|
||||||
from danswer.db.models import Persona__UserGroup
|
from danswer.db.models import Persona__UserGroup
|
||||||
|
from danswer.db.models import PersonaCategory
|
||||||
from danswer.db.models import Prompt
|
from danswer.db.models import Prompt
|
||||||
from danswer.db.models import StarterMessage
|
from danswer.db.models import StarterMessage
|
||||||
from danswer.db.models import Tool
|
from danswer.db.models import Tool
|
||||||
@ -417,6 +418,7 @@ def upsert_persona(
|
|||||||
search_start_date: datetime | None = None,
|
search_start_date: datetime | None = None,
|
||||||
builtin_persona: bool = False,
|
builtin_persona: bool = False,
|
||||||
is_default_persona: bool = False,
|
is_default_persona: bool = False,
|
||||||
|
category_id: int | None = None,
|
||||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||||
) -> Persona:
|
) -> Persona:
|
||||||
@ -487,7 +489,7 @@ def upsert_persona(
|
|||||||
persona.is_visible = is_visible
|
persona.is_visible = is_visible
|
||||||
persona.search_start_date = search_start_date
|
persona.search_start_date = search_start_date
|
||||||
persona.is_default_persona = is_default_persona
|
persona.is_default_persona = is_default_persona
|
||||||
|
persona.category_id = category_id
|
||||||
# Do not delete any associations manually added unless
|
# Do not delete any associations manually added unless
|
||||||
# a new updated list is provided
|
# a new updated list is provided
|
||||||
if document_sets is not None:
|
if document_sets is not None:
|
||||||
@ -528,6 +530,7 @@ def upsert_persona(
|
|||||||
is_visible=is_visible,
|
is_visible=is_visible,
|
||||||
search_start_date=search_start_date,
|
search_start_date=search_start_date,
|
||||||
is_default_persona=is_default_persona,
|
is_default_persona=is_default_persona,
|
||||||
|
category_id=category_id,
|
||||||
)
|
)
|
||||||
db_session.add(persona)
|
db_session.add(persona)
|
||||||
|
|
||||||
@ -744,3 +747,39 @@ def delete_persona_by_name(
|
|||||||
|
|
||||||
db_session.execute(stmt)
|
db_session.execute(stmt)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def get_assistant_categories(db_session: Session) -> list[PersonaCategory]:
|
||||||
|
return db_session.query(PersonaCategory).all()
|
||||||
|
|
||||||
|
|
||||||
|
def create_assistant_category(
|
||||||
|
db_session: Session, name: str, description: str
|
||||||
|
) -> PersonaCategory:
|
||||||
|
category = PersonaCategory(name=name, description=description)
|
||||||
|
db_session.add(category)
|
||||||
|
db_session.commit()
|
||||||
|
return category
|
||||||
|
|
||||||
|
|
||||||
|
def update_persona_category(
|
||||||
|
category_id: int,
|
||||||
|
category_description: str,
|
||||||
|
category_name: str,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
persona_category = (
|
||||||
|
db_session.query(PersonaCategory)
|
||||||
|
.filter(PersonaCategory.id == category_id)
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
if persona_category is None:
|
||||||
|
raise ValueError(f"Persona category with ID {category_id} does not exist")
|
||||||
|
persona_category.description = category_description
|
||||||
|
persona_category.name = category_name
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def delete_persona_category(category_id: int, db_session: Session) -> None:
|
||||||
|
db_session.query(PersonaCategory).filter(PersonaCategory.id == category_id).delete()
|
||||||
|
db_session.commit()
|
||||||
|
@ -18,12 +18,16 @@ from danswer.configs.constants import NotificationType
|
|||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.db.notification import create_notification
|
from danswer.db.notification import create_notification
|
||||||
|
from danswer.db.persona import create_assistant_category
|
||||||
from danswer.db.persona import create_update_persona
|
from danswer.db.persona import create_update_persona
|
||||||
|
from danswer.db.persona import delete_persona_category
|
||||||
|
from danswer.db.persona import get_assistant_categories
|
||||||
from danswer.db.persona import get_persona_by_id
|
from danswer.db.persona import get_persona_by_id
|
||||||
from danswer.db.persona import get_personas
|
from danswer.db.persona import get_personas
|
||||||
from danswer.db.persona import mark_persona_as_deleted
|
from danswer.db.persona import mark_persona_as_deleted
|
||||||
from danswer.db.persona import mark_persona_as_not_deleted
|
from danswer.db.persona import mark_persona_as_not_deleted
|
||||||
from danswer.db.persona import update_all_personas_display_priority
|
from danswer.db.persona import update_all_personas_display_priority
|
||||||
|
from danswer.db.persona import update_persona_category
|
||||||
from danswer.db.persona import update_persona_public_status
|
from danswer.db.persona import update_persona_public_status
|
||||||
from danswer.db.persona import update_persona_shared_users
|
from danswer.db.persona import update_persona_shared_users
|
||||||
from danswer.db.persona import update_persona_visibility
|
from danswer.db.persona import update_persona_visibility
|
||||||
@ -32,6 +36,8 @@ from danswer.file_store.models import ChatFileType
|
|||||||
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
||||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||||
from danswer.server.features.persona.models import ImageGenerationToolStatus
|
from danswer.server.features.persona.models import ImageGenerationToolStatus
|
||||||
|
from danswer.server.features.persona.models import PersonaCategoryCreate
|
||||||
|
from danswer.server.features.persona.models import PersonaCategoryResponse
|
||||||
from danswer.server.features.persona.models import PersonaSharedNotificationData
|
from danswer.server.features.persona.models import PersonaSharedNotificationData
|
||||||
from danswer.server.features.persona.models import PersonaSnapshot
|
from danswer.server.features.persona.models import PersonaSnapshot
|
||||||
from danswer.server.features.persona.models import PromptTemplateResponse
|
from danswer.server.features.persona.models import PromptTemplateResponse
|
||||||
@ -39,6 +45,7 @@ from danswer.server.models import DisplayPriorityRequest
|
|||||||
from danswer.tools.utils import is_image_generation_available
|
from danswer.tools.utils import is_image_generation_available
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@ -184,6 +191,59 @@ def update_persona(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaCategoryPatchRequest(BaseModel):
|
||||||
|
category_description: str
|
||||||
|
category_name: str
|
||||||
|
|
||||||
|
|
||||||
|
@basic_router.get("/categories")
|
||||||
|
def get_categories(
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
_: User | None = Depends(current_user),
|
||||||
|
) -> list[PersonaCategoryResponse]:
|
||||||
|
return [
|
||||||
|
PersonaCategoryResponse.from_model(category)
|
||||||
|
for category in get_assistant_categories(db_session=db)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/categories")
|
||||||
|
def create_category(
|
||||||
|
category: PersonaCategoryCreate,
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
) -> PersonaCategoryResponse:
|
||||||
|
"""Create a new assistant category"""
|
||||||
|
category_model = create_assistant_category(
|
||||||
|
name=category.name, description=category.description, db_session=db
|
||||||
|
)
|
||||||
|
return PersonaCategoryResponse.from_model(category_model)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.patch("/category/{category_id}")
|
||||||
|
def patch_persona_category(
|
||||||
|
category_id: int,
|
||||||
|
persona_category_patch_request: PersonaCategoryPatchRequest,
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
update_persona_category(
|
||||||
|
category_id=category_id,
|
||||||
|
category_description=persona_category_patch_request.category_description,
|
||||||
|
category_name=persona_category_patch_request.category_name,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.delete("/category/{category_id}")
|
||||||
|
def delete_category(
|
||||||
|
category_id: int,
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
delete_persona_category(category_id=category_id, db_session=db_session)
|
||||||
|
|
||||||
|
|
||||||
class PersonaShareRequest(BaseModel):
|
class PersonaShareRequest(BaseModel):
|
||||||
user_ids: list[UUID]
|
user_ids: list[UUID]
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from pydantic import BaseModel
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from danswer.db.models import Persona
|
from danswer.db.models import Persona
|
||||||
|
from danswer.db.models import PersonaCategory
|
||||||
from danswer.db.models import StarterMessage
|
from danswer.db.models import StarterMessage
|
||||||
from danswer.search.enums import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
from danswer.server.features.document_set.models import DocumentSet
|
from danswer.server.features.document_set.models import DocumentSet
|
||||||
@ -41,6 +42,7 @@ class CreatePersonaRequest(BaseModel):
|
|||||||
is_default_persona: bool = False
|
is_default_persona: bool = False
|
||||||
display_priority: int | None = None
|
display_priority: int | None = None
|
||||||
search_start_date: datetime | None = None
|
search_start_date: datetime | None = None
|
||||||
|
category_id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class PersonaSnapshot(BaseModel):
|
class PersonaSnapshot(BaseModel):
|
||||||
@ -68,6 +70,7 @@ class PersonaSnapshot(BaseModel):
|
|||||||
uploaded_image_id: str | None = None
|
uploaded_image_id: str | None = None
|
||||||
is_default_persona: bool
|
is_default_persona: bool
|
||||||
search_start_date: datetime | None = None
|
search_start_date: datetime | None = None
|
||||||
|
category_id: int | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model(
|
def from_model(
|
||||||
@ -115,6 +118,7 @@ class PersonaSnapshot(BaseModel):
|
|||||||
icon_shape=persona.icon_shape,
|
icon_shape=persona.icon_shape,
|
||||||
uploaded_image_id=persona.uploaded_image_id,
|
uploaded_image_id=persona.uploaded_image_id,
|
||||||
search_start_date=persona.search_start_date,
|
search_start_date=persona.search_start_date,
|
||||||
|
category_id=persona.category_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -128,3 +132,22 @@ class PersonaSharedNotificationData(BaseModel):
|
|||||||
|
|
||||||
class ImageGenerationToolStatus(BaseModel):
|
class ImageGenerationToolStatus(BaseModel):
|
||||||
is_available: bool
|
is_available: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaCategoryCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaCategoryResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
description: str | None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model(cls, category: PersonaCategory) -> "PersonaCategoryResponse":
|
||||||
|
return PersonaCategoryResponse(
|
||||||
|
id=category.id,
|
||||||
|
name=category.name,
|
||||||
|
description=category.description,
|
||||||
|
)
|
||||||
|
@ -7,6 +7,7 @@ from danswer.server.features.persona.models import PersonaSnapshot
|
|||||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
from tests.integration.common_utils.test_models import DATestPersona
|
from tests.integration.common_utils.test_models import DATestPersona
|
||||||
|
from tests.integration.common_utils.test_models import DATestPersonaCategory
|
||||||
from tests.integration.common_utils.test_models import DATestUser
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
|
||||||
@ -27,6 +28,7 @@ class PersonaManager:
|
|||||||
llm_model_version_override: str | None = None,
|
llm_model_version_override: str | None = None,
|
||||||
users: list[str] | None = None,
|
users: list[str] | None = None,
|
||||||
groups: list[int] | None = None,
|
groups: list[int] | None = None,
|
||||||
|
category_id: int | None = None,
|
||||||
user_performing_action: DATestUser | None = None,
|
user_performing_action: DATestUser | None = None,
|
||||||
) -> DATestPersona:
|
) -> DATestPersona:
|
||||||
name = name or f"test-persona-{uuid4()}"
|
name = name or f"test-persona-{uuid4()}"
|
||||||
@ -212,3 +214,83 @@ class PersonaManager:
|
|||||||
else GENERAL_HEADERS,
|
else GENERAL_HEADERS,
|
||||||
)
|
)
|
||||||
return response.ok
|
return response.ok
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaCategoryManager:
|
||||||
|
@staticmethod
|
||||||
|
def create(
|
||||||
|
category: DATestPersonaCategory,
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> DATestPersonaCategory:
|
||||||
|
response = requests.post(
|
||||||
|
f"{API_SERVER_URL}/admin/persona/categories",
|
||||||
|
json={
|
||||||
|
"name": category.name,
|
||||||
|
"description": category.description,
|
||||||
|
},
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
response_data = response.json()
|
||||||
|
category.id = response_data["id"]
|
||||||
|
return category
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all(
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> list[DATestPersonaCategory]:
|
||||||
|
response = requests.get(
|
||||||
|
f"{API_SERVER_URL}/persona/categories",
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return [DATestPersonaCategory(**category) for category in response.json()]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update(
|
||||||
|
category: DATestPersonaCategory,
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> DATestPersonaCategory:
|
||||||
|
response = requests.patch(
|
||||||
|
f"{API_SERVER_URL}/admin/persona/category/{category.id}",
|
||||||
|
json={
|
||||||
|
"category_name": category.name,
|
||||||
|
"category_description": category.description,
|
||||||
|
},
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return category
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete(
|
||||||
|
category: DATestPersonaCategory,
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> bool:
|
||||||
|
response = requests.delete(
|
||||||
|
f"{API_SERVER_URL}/admin/persona/category/{category.id}",
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
return response.ok
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify(
|
||||||
|
category: DATestPersonaCategory,
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> bool:
|
||||||
|
all_categories = PersonaCategoryManager.get_all(user_performing_action)
|
||||||
|
for fetched_category in all_categories:
|
||||||
|
if fetched_category.id == category.id:
|
||||||
|
return (
|
||||||
|
fetched_category.name == category.name
|
||||||
|
and fetched_category.description == category.description
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
@ -38,6 +38,12 @@ class DATestUser(BaseModel):
|
|||||||
headers: dict
|
headers: dict
|
||||||
|
|
||||||
|
|
||||||
|
class DATestPersonaCategory(BaseModel):
|
||||||
|
id: int | None = None
|
||||||
|
name: str
|
||||||
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
class DATestCredential(BaseModel):
|
class DATestCredential(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
@ -119,6 +125,7 @@ class DATestPersona(BaseModel):
|
|||||||
llm_model_version_override: str | None
|
llm_model_version_override: str | None
|
||||||
users: list[str]
|
users: list[str]
|
||||||
groups: list[int]
|
groups: list[int]
|
||||||
|
category_id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -0,0 +1,92 @@
|
|||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
from tests.integration.common_utils.managers.persona import (
|
||||||
|
PersonaCategoryManager,
|
||||||
|
)
|
||||||
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.test_models import DATestPersonaCategory
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
|
||||||
|
def test_persona_category_management(reset: None) -> None:
|
||||||
|
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||||
|
|
||||||
|
persona_category = DATestPersonaCategory(
|
||||||
|
id=None,
|
||||||
|
name=f"Test Category {uuid4()}",
|
||||||
|
description="A description for test category",
|
||||||
|
)
|
||||||
|
persona_category = PersonaCategoryManager.create(
|
||||||
|
category=persona_category,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Created persona category {persona_category.name} with id {persona_category.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert PersonaCategoryManager.verify(
|
||||||
|
category=persona_category,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
), "Persona category was not found after creation"
|
||||||
|
|
||||||
|
regular_user: DATestUser = UserManager.create(name="regular_user")
|
||||||
|
|
||||||
|
updated_persona_category = DATestPersonaCategory(
|
||||||
|
id=persona_category.id,
|
||||||
|
name=f"Updated {persona_category.name}",
|
||||||
|
description="An updated description",
|
||||||
|
)
|
||||||
|
with pytest.raises(HTTPError) as exc_info:
|
||||||
|
PersonaCategoryManager.update(
|
||||||
|
category=updated_persona_category,
|
||||||
|
user_performing_action=regular_user,
|
||||||
|
)
|
||||||
|
assert exc_info.value.response.status_code == 403
|
||||||
|
|
||||||
|
assert PersonaCategoryManager.verify(
|
||||||
|
category=persona_category,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
), "Persona category should not have been updated by non-admin user"
|
||||||
|
|
||||||
|
result = PersonaCategoryManager.delete(
|
||||||
|
category=persona_category,
|
||||||
|
user_performing_action=regular_user,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result is False
|
||||||
|
), "Regular user should not be able to delete the persona category"
|
||||||
|
|
||||||
|
assert PersonaCategoryManager.verify(
|
||||||
|
category=persona_category,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
), "Persona category should not have been deleted by non-admin user"
|
||||||
|
|
||||||
|
updated_persona_category.name = f"Updated {persona_category.name}"
|
||||||
|
updated_persona_category.description = "An updated description"
|
||||||
|
updated_persona_category = PersonaCategoryManager.update(
|
||||||
|
category=updated_persona_category,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
print(f"Updated persona category to {updated_persona_category.name}")
|
||||||
|
|
||||||
|
assert PersonaCategoryManager.verify(
|
||||||
|
category=updated_persona_category,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
), "Persona category was not updated by admin"
|
||||||
|
|
||||||
|
success = PersonaCategoryManager.delete(
|
||||||
|
category=persona_category,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
assert success, "Admin user should be able to delete the persona category"
|
||||||
|
print(
|
||||||
|
f"Deleted persona category {persona_category.name} with id {persona_category.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not PersonaCategoryManager.verify(
|
||||||
|
category=persona_category,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
), "Persona category should not exist after deletion by admin"
|
@ -23,8 +23,16 @@ import {
|
|||||||
SelectorFormField,
|
SelectorFormField,
|
||||||
TextFormField,
|
TextFormField,
|
||||||
} from "@/components/admin/connectors/Field";
|
} from "@/components/admin/connectors/Field";
|
||||||
|
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
CardHeader,
|
||||||
|
CardTitle,
|
||||||
|
CardContent,
|
||||||
|
CardFooter,
|
||||||
|
} from "@/components/ui/card";
|
||||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
|
||||||
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
||||||
import { Option } from "@/components/Dropdown";
|
import { Option } from "@/components/Dropdown";
|
||||||
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
|
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
|
||||||
@ -46,8 +54,14 @@ import * as Yup from "yup";
|
|||||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||||
import CollapsibleSection from "./CollapsibleSection";
|
import CollapsibleSection from "./CollapsibleSection";
|
||||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||||
import { Persona, StarterMessage } from "./interfaces";
|
import { Persona, PersonaCategory, StarterMessage } from "./interfaces";
|
||||||
import { createPersona, updatePersona } from "./lib";
|
import {
|
||||||
|
createPersonaCategory,
|
||||||
|
createPersona,
|
||||||
|
deletePersonaCategory,
|
||||||
|
updatePersonaCategory,
|
||||||
|
updatePersona,
|
||||||
|
} from "./lib";
|
||||||
import { Popover } from "@/components/popover/Popover";
|
import { Popover } from "@/components/popover/Popover";
|
||||||
import {
|
import {
|
||||||
CameraIcon,
|
CameraIcon,
|
||||||
@ -59,6 +73,8 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
|||||||
import { buildImgUrl } from "@/app/chat/files/images/utils";
|
import { buildImgUrl } from "@/app/chat/files/images/utils";
|
||||||
import { LlmList } from "@/components/llm/LLMList";
|
import { LlmList } from "@/components/llm/LLMList";
|
||||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { CategoryCard } from "./CategoryCard";
|
||||||
|
|
||||||
function findSearchTool(tools: ToolSnapshot[]) {
|
function findSearchTool(tools: ToolSnapshot[]) {
|
||||||
return tools.find((tool) => tool.in_code_tool_id === "SearchTool");
|
return tools.find((tool) => tool.in_code_tool_id === "SearchTool");
|
||||||
@ -107,6 +123,7 @@ export function AssistantEditor({
|
|||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
|
|
||||||
const { popup, setPopup } = usePopup();
|
const { popup, setPopup } = usePopup();
|
||||||
|
const { data: categories, refreshCategories } = useCategories();
|
||||||
|
|
||||||
const colorOptions = [
|
const colorOptions = [
|
||||||
"#FF6FBF",
|
"#FF6FBF",
|
||||||
@ -119,6 +136,7 @@ export function AssistantEditor({
|
|||||||
];
|
];
|
||||||
|
|
||||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||||
|
const [showPersonaCategory, setShowPersonaCategory] = useState(!admin);
|
||||||
|
|
||||||
// state to persist across formik reformatting
|
// state to persist across formik reformatting
|
||||||
const [defautIconColor, _setDeafultIconColor] = useState(
|
const [defautIconColor, _setDeafultIconColor] = useState(
|
||||||
@ -211,6 +229,7 @@ export function AssistantEditor({
|
|||||||
icon_color: existingPersona?.icon_color ?? defautIconColor,
|
icon_color: existingPersona?.icon_color ?? defautIconColor,
|
||||||
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
|
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
|
||||||
uploaded_image: null,
|
uploaded_image: null,
|
||||||
|
category_id: existingPersona?.category_id ?? null,
|
||||||
|
|
||||||
// EE Only
|
// EE Only
|
||||||
groups: existingPersona?.groups ?? [],
|
groups: existingPersona?.groups ?? [],
|
||||||
@ -255,6 +274,7 @@ export function AssistantEditor({
|
|||||||
icon_color: Yup.string(),
|
icon_color: Yup.string(),
|
||||||
icon_shape: Yup.number(),
|
icon_shape: Yup.number(),
|
||||||
uploaded_image: Yup.mixed().nullable(),
|
uploaded_image: Yup.mixed().nullable(),
|
||||||
|
category_id: Yup.number().nullable(),
|
||||||
// EE Only
|
// EE Only
|
||||||
groups: Yup.array().of(Yup.number()),
|
groups: Yup.array().of(Yup.number()),
|
||||||
})
|
})
|
||||||
@ -968,6 +988,189 @@ export function AssistantEditor({
|
|||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{admin && (
|
||||||
|
<AdvancedOptionsToggle
|
||||||
|
title="Categories"
|
||||||
|
showAdvancedOptions={showPersonaCategory}
|
||||||
|
setShowAdvancedOptions={setShowPersonaCategory}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{showPersonaCategory && (
|
||||||
|
<>
|
||||||
|
{categories && categories.length > 0 && (
|
||||||
|
<div className="my-2">
|
||||||
|
<div className="flex gap-x-2 items-center">
|
||||||
|
<div className="block font-medium text-base">
|
||||||
|
Category
|
||||||
|
</div>
|
||||||
|
<TooltipProvider>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger>
|
||||||
|
<FiInfo size={12} />
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent side="top" align="center">
|
||||||
|
Group similar assistants together by category
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</TooltipProvider>
|
||||||
|
</div>
|
||||||
|
<SelectorFormField
|
||||||
|
includeReset
|
||||||
|
name="category_id"
|
||||||
|
options={categories.map((category) => ({
|
||||||
|
name: category.name,
|
||||||
|
value: category.id,
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{admin && (
|
||||||
|
<>
|
||||||
|
<div className="my-2">
|
||||||
|
<div className="flex gap-x-2 items-center mb-2">
|
||||||
|
<div className="block font-medium text-base">
|
||||||
|
Create New Category
|
||||||
|
</div>
|
||||||
|
<TooltipProvider>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger>
|
||||||
|
<FiInfo size={12} />
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent side="top" align="center">
|
||||||
|
Create a new category to group similar
|
||||||
|
assistants together
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</TooltipProvider>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-[1fr,3fr,auto] gap-4">
|
||||||
|
<TextFormField
|
||||||
|
fontSize="sm"
|
||||||
|
name="newCategoryName"
|
||||||
|
label="Category Name"
|
||||||
|
placeholder="e.g. Development"
|
||||||
|
/>
|
||||||
|
<TextFormField
|
||||||
|
fontSize="sm"
|
||||||
|
name="newCategoryDescription"
|
||||||
|
label="Category Description"
|
||||||
|
placeholder="e.g. Assistants for software development"
|
||||||
|
/>
|
||||||
|
<div className="flex items-end">
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
onClick={async () => {
|
||||||
|
const name = values.newCategoryName;
|
||||||
|
const description =
|
||||||
|
values.newCategoryDescription;
|
||||||
|
if (!name || !description) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await createPersonaCategory(
|
||||||
|
name,
|
||||||
|
description
|
||||||
|
);
|
||||||
|
if (response.ok) {
|
||||||
|
setPopup({
|
||||||
|
message: `Category "${name}" created successfully`,
|
||||||
|
type: "success",
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
throw new Error(await response.text());
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
setPopup({
|
||||||
|
message: `Failed to create category - ${error}`,
|
||||||
|
type: "error",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await refreshCategories();
|
||||||
|
|
||||||
|
setFieldValue("newCategoryName", "");
|
||||||
|
setFieldValue("newCategoryDescription", "");
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Create
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{categories && categories.length > 0 && (
|
||||||
|
<div className="my-2 w-full">
|
||||||
|
<div className="flex gap-x-2 items-center mb-2">
|
||||||
|
<div className="block font-medium text-base">
|
||||||
|
Manage categories
|
||||||
|
</div>
|
||||||
|
<TooltipProvider>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger>
|
||||||
|
<FiInfo size={12} />
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent side="top" align="center">
|
||||||
|
Manage existing categories or create new ones
|
||||||
|
to group similar assistants
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</TooltipProvider>
|
||||||
|
</div>
|
||||||
|
<div className="gap-4 w-full flex-wrap flex">
|
||||||
|
{categories &&
|
||||||
|
categories.map((category: PersonaCategory) => (
|
||||||
|
<CategoryCard
|
||||||
|
setPopup={setPopup}
|
||||||
|
key={category.id}
|
||||||
|
category={category}
|
||||||
|
onUpdate={async (id, name, description) => {
|
||||||
|
const response =
|
||||||
|
await updatePersonaCategory(
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
description
|
||||||
|
);
|
||||||
|
if (response?.ok) {
|
||||||
|
setPopup({
|
||||||
|
message: `Category "${name}" updated successfully`,
|
||||||
|
type: "success",
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
setPopup({
|
||||||
|
message: `Failed to update category - ${await response.text()}`,
|
||||||
|
type: "error",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
onDelete={async (id) => {
|
||||||
|
const response =
|
||||||
|
await deletePersonaCategory(id);
|
||||||
|
if (response?.ok) {
|
||||||
|
setPopup({
|
||||||
|
message: `Category deleted successfully`,
|
||||||
|
type: "success",
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
setPopup({
|
||||||
|
message: `Failed to delete category - ${await response.text()}`,
|
||||||
|
type: "error",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
refreshCategories={refreshCategories}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
<Separator />
|
<Separator />
|
||||||
<AdvancedOptionsToggle
|
<AdvancedOptionsToggle
|
||||||
showAdvancedOptions={showAdvancedOptions}
|
showAdvancedOptions={showAdvancedOptions}
|
||||||
|
105
web/src/app/admin/assistants/CategoryCard.tsx
Normal file
105
web/src/app/admin/assistants/CategoryCard.tsx
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import { useState } from "react";
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
CardHeader,
|
||||||
|
CardTitle,
|
||||||
|
CardContent,
|
||||||
|
CardFooter,
|
||||||
|
} from "@/components/ui/card";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { Textarea } from "@/components/ui/textarea";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { PersonaCategory } from "./interfaces";
|
||||||
|
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||||
|
|
||||||
|
interface CategoryCardProps {
|
||||||
|
category: PersonaCategory;
|
||||||
|
onUpdate: (id: number, name: string, description: string) => void;
|
||||||
|
onDelete: (id: number) => void;
|
||||||
|
refreshCategories: () => Promise<void>;
|
||||||
|
setPopup: (popup: PopupSpec) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function CategoryCard({
|
||||||
|
category,
|
||||||
|
onUpdate,
|
||||||
|
onDelete,
|
||||||
|
refreshCategories,
|
||||||
|
}: CategoryCardProps) {
|
||||||
|
const [isEditing, setIsEditing] = useState(false);
|
||||||
|
const [name, setName] = useState(category.name);
|
||||||
|
const [description, setDescription] = useState(category.description);
|
||||||
|
|
||||||
|
const handleSubmit = async (e: React.FormEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
await onUpdate(category.id, name, description);
|
||||||
|
await refreshCategories();
|
||||||
|
setIsEditing(false);
|
||||||
|
};
|
||||||
|
const handleEdit = (e: React.MouseEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
setIsEditing(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Card key={category.id} className="w-full max-w-sm">
|
||||||
|
<CardHeader className="w-full">
|
||||||
|
<CardTitle className="text-2xl font-bold">
|
||||||
|
{isEditing ? (
|
||||||
|
<Input
|
||||||
|
value={name}
|
||||||
|
onChange={(e) => setName(e.target.value)}
|
||||||
|
className="text-lg font-semibold"
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<span>{category.name}</span>
|
||||||
|
)}
|
||||||
|
</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="w-full">
|
||||||
|
{isEditing ? (
|
||||||
|
<Textarea
|
||||||
|
value={description}
|
||||||
|
onChange={(e) => setDescription(e.target.value)}
|
||||||
|
className="resize-none w-full"
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<p className="text-sm text-gray-600">{category.description}</p>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
<CardFooter className="flex justify-end space-x-2">
|
||||||
|
{isEditing ? (
|
||||||
|
<>
|
||||||
|
<Button type="button" variant="outline" onClick={handleSubmit}>
|
||||||
|
Save
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setIsEditing(false)}
|
||||||
|
variant="default"
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<Button type="button" onClick={handleEdit} variant="outline">
|
||||||
|
Edit
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="destructive"
|
||||||
|
onClick={async (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
await onDelete(category.id);
|
||||||
|
await refreshCategories();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</CardFooter>
|
||||||
|
</Card>
|
||||||
|
);
|
||||||
|
}
|
@ -42,4 +42,11 @@ export interface Persona {
|
|||||||
icon_shape?: number;
|
icon_shape?: number;
|
||||||
icon_color?: string;
|
icon_color?: string;
|
||||||
uploaded_image_id?: string;
|
uploaded_image_id?: string;
|
||||||
|
category_id?: number | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PersonaCategory {
|
||||||
|
id: number;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ interface PersonaCreationRequest {
|
|||||||
uploaded_image: File | null;
|
uploaded_image: File | null;
|
||||||
search_start_date: Date | null;
|
search_start_date: Date | null;
|
||||||
is_default_persona: boolean;
|
is_default_persona: boolean;
|
||||||
|
category_id: number | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface PersonaUpdateRequest {
|
interface PersonaUpdateRequest {
|
||||||
@ -48,6 +49,7 @@ interface PersonaUpdateRequest {
|
|||||||
remove_image: boolean;
|
remove_image: boolean;
|
||||||
uploaded_image: File | null;
|
uploaded_image: File | null;
|
||||||
search_start_date: Date | null;
|
search_start_date: Date | null;
|
||||||
|
category_id: number | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
function promptNameFromPersonaName(personaName: string) {
|
function promptNameFromPersonaName(personaName: string) {
|
||||||
@ -108,6 +110,42 @@ function updatePrompt({
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const createPersonaCategory = (name: string, description: string) => {
|
||||||
|
return fetch("/api/admin/persona/categories", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ name, description }),
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const deletePersonaCategory = (categoryId: number) => {
|
||||||
|
return fetch(`/api/admin/persona/category/${categoryId}`, {
|
||||||
|
method: "DELETE",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const updatePersonaCategory = (
|
||||||
|
id: number,
|
||||||
|
name: string,
|
||||||
|
description: string
|
||||||
|
) => {
|
||||||
|
return fetch(`/api/admin/persona/category/${id}`, {
|
||||||
|
method: "PATCH",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
category_name: name,
|
||||||
|
category_description: description,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
function buildPersonaAPIBody(
|
function buildPersonaAPIBody(
|
||||||
creationRequest: PersonaCreationRequest | PersonaUpdateRequest,
|
creationRequest: PersonaCreationRequest | PersonaUpdateRequest,
|
||||||
promptId: number,
|
promptId: number,
|
||||||
@ -127,6 +165,7 @@ function buildPersonaAPIBody(
|
|||||||
icon_shape,
|
icon_shape,
|
||||||
remove_image,
|
remove_image,
|
||||||
search_start_date,
|
search_start_date,
|
||||||
|
category_id,
|
||||||
} = creationRequest;
|
} = creationRequest;
|
||||||
|
|
||||||
const is_default_persona =
|
const is_default_persona =
|
||||||
@ -156,6 +195,7 @@ function buildPersonaAPIBody(
|
|||||||
remove_image,
|
remove_image,
|
||||||
search_start_date,
|
search_start_date,
|
||||||
is_default_persona,
|
is_default_persona,
|
||||||
|
category_id,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ function ToolForm({
|
|||||||
placeholder="Enter your OpenAPI schema here"
|
placeholder="Enter your OpenAPI schema here"
|
||||||
isTextArea={true}
|
isTextArea={true}
|
||||||
defaultHeight="h-96"
|
defaultHeight="h-96"
|
||||||
fontSize="text-sm"
|
fontSize="sm"
|
||||||
isCode
|
isCode
|
||||||
hideError
|
hideError
|
||||||
/>
|
/>
|
||||||
|
28
web/src/app/assistants/PersonaCategory.tsx
Normal file
28
web/src/app/assistants/PersonaCategory.tsx
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import { PersonaCategory as PersonaCategoryType } from "../admin/assistants/interfaces";
|
||||||
|
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipProvider,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/ui/tooltip";
|
||||||
|
|
||||||
|
export default function PersonaCategory({
|
||||||
|
personaCategory,
|
||||||
|
}: {
|
||||||
|
personaCategory: PersonaCategoryType;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<TooltipProvider>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger className="cursor-help">
|
||||||
|
<Badge variant="purple">{personaCategory.name}</Badge>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent>
|
||||||
|
<p>{personaCategory.description}</p>
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</TooltipProvider>
|
||||||
|
);
|
||||||
|
}
|
@ -1,6 +1,9 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
import {
|
||||||
|
Persona,
|
||||||
|
PersonaCategory as PersonaCategoryType,
|
||||||
|
} from "@/app/admin/assistants/interfaces";
|
||||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||||
import { User } from "@/lib/types";
|
import { User } from "@/lib/types";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
@ -17,6 +20,18 @@ import { AssistantTools } from "../ToolsDisplay";
|
|||||||
import { classifyAssistants } from "@/lib/assistants/utils";
|
import { classifyAssistants } from "@/lib/assistants/utils";
|
||||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||||
import { useUser } from "@/components/user/UserProvider";
|
import { useUser } from "@/components/user/UserProvider";
|
||||||
|
import PersonaCategory from "../PersonaCategory";
|
||||||
|
import { useCategories } from "@/lib/hooks";
|
||||||
|
import {
|
||||||
|
Select,
|
||||||
|
SelectContent,
|
||||||
|
SelectGroup,
|
||||||
|
SelectItem,
|
||||||
|
SelectLabel,
|
||||||
|
SelectTrigger,
|
||||||
|
SelectValue,
|
||||||
|
} from "@/components/ui/select";
|
||||||
|
|
||||||
export function AssistantGalleryCard({
|
export function AssistantGalleryCard({
|
||||||
assistant,
|
assistant,
|
||||||
user,
|
user,
|
||||||
@ -28,8 +43,10 @@ export function AssistantGalleryCard({
|
|||||||
setPopup: (popup: PopupSpec) => void;
|
setPopup: (popup: PopupSpec) => void;
|
||||||
selectedAssistant: boolean;
|
selectedAssistant: boolean;
|
||||||
}) {
|
}) {
|
||||||
|
const { data: categories } = useCategories();
|
||||||
|
|
||||||
const { refreshUser } = useUser();
|
const { refreshUser } = useUser();
|
||||||
const router = useRouter();
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
key={assistant.id}
|
key={assistant.id}
|
||||||
@ -129,7 +146,6 @@ export function AssistantGalleryCard({
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<p className="text-sm mt-2">{assistant.description}</p>
|
<p className="text-sm mt-2">{assistant.description}</p>
|
||||||
<p className="text-subtle text-sm my-2">
|
<p className="text-subtle text-sm my-2">
|
||||||
Author: {assistant.owner?.email || "Danswer"}
|
Author: {assistant.owner?.email || "Danswer"}
|
||||||
@ -137,6 +153,16 @@ export function AssistantGalleryCard({
|
|||||||
{assistant.tools.length > 0 && (
|
{assistant.tools.length > 0 && (
|
||||||
<AssistantTools list assistant={assistant} />
|
<AssistantTools list assistant={assistant} />
|
||||||
)}
|
)}
|
||||||
|
{assistant.category_id && categories && (
|
||||||
|
<PersonaCategory
|
||||||
|
personaCategory={
|
||||||
|
categories?.find(
|
||||||
|
(category: PersonaCategoryType) =>
|
||||||
|
category.id === assistant.category_id
|
||||||
|
)!
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -144,10 +170,12 @@ export function AssistantsGallery() {
|
|||||||
const { assistants } = useAssistants();
|
const { assistants } = useAssistants();
|
||||||
const { user } = useUser();
|
const { user } = useUser();
|
||||||
|
|
||||||
|
const { data: categories } = useCategories();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
|
|
||||||
const [searchQuery, setSearchQuery] = useState("");
|
const [searchQuery, setSearchQuery] = useState("");
|
||||||
const { popup, setPopup } = usePopup();
|
const { popup, setPopup } = usePopup();
|
||||||
|
const [selectedCategory, setSelectedCategory] = useState<number | null>(null);
|
||||||
|
|
||||||
const { visibleAssistants, hiddenAssistants: _ } = classifyAssistants(
|
const { visibleAssistants, hiddenAssistants: _ } = classifyAssistants(
|
||||||
user,
|
user,
|
||||||
@ -158,16 +186,24 @@ export function AssistantsGallery() {
|
|||||||
.filter((assistant) => assistant.is_default_persona)
|
.filter((assistant) => assistant.is_default_persona)
|
||||||
.filter(
|
.filter(
|
||||||
(assistant) =>
|
(assistant) =>
|
||||||
assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
|
(assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
|
||||||
assistant.description.toLowerCase().includes(searchQuery.toLowerCase())
|
assistant.description
|
||||||
|
.toLowerCase()
|
||||||
|
.includes(searchQuery.toLowerCase())) &&
|
||||||
|
(selectedCategory === null ||
|
||||||
|
selectedCategory === assistant.category_id)
|
||||||
);
|
);
|
||||||
|
|
||||||
const nonDefaultAssistants = assistants
|
const nonDefaultAssistants = assistants
|
||||||
.filter((assistant) => !assistant.is_default_persona)
|
.filter((assistant) => !assistant.is_default_persona)
|
||||||
.filter(
|
.filter(
|
||||||
(assistant) =>
|
(assistant) =>
|
||||||
assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
|
(assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
|
||||||
assistant.description.toLowerCase().includes(searchQuery.toLowerCase())
|
assistant.description
|
||||||
|
.toLowerCase()
|
||||||
|
.includes(searchQuery.toLowerCase())) &&
|
||||||
|
(selectedCategory === null ||
|
||||||
|
selectedCategory === assistant.category_id)
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -196,7 +232,7 @@ export function AssistantsGallery() {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="mt-4 mb-12">
|
<div className="mt-4 mb-6">
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
@ -238,6 +274,58 @@ export function AssistantsGallery() {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{categories && categories?.length > 0 && (
|
||||||
|
<div className="mb-8">
|
||||||
|
<Select
|
||||||
|
value={selectedCategory?.toString() || "all"}
|
||||||
|
onValueChange={(value) =>
|
||||||
|
setSelectedCategory(value === "all" ? null : parseInt(value))
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<SelectTrigger
|
||||||
|
className="
|
||||||
|
w-[240px]
|
||||||
|
bg-background
|
||||||
|
border-2
|
||||||
|
border-background-strong
|
||||||
|
text-text-500
|
||||||
|
rounded-lg
|
||||||
|
shadow-sm
|
||||||
|
hover:bg-background-emphasis
|
||||||
|
hover:border-primary-500/50
|
||||||
|
hover:text-primary-500
|
||||||
|
transition-all
|
||||||
|
duration-200
|
||||||
|
"
|
||||||
|
>
|
||||||
|
<SelectValue placeholder="Filter by category..." />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent className="bg-background border-background-strong">
|
||||||
|
<SelectGroup>
|
||||||
|
<SelectLabel className="text-sm font-medium text-text-400">
|
||||||
|
Categories
|
||||||
|
</SelectLabel>
|
||||||
|
<SelectItem
|
||||||
|
value="all"
|
||||||
|
className="cursor-pointer hover:bg-background-emphasis"
|
||||||
|
>
|
||||||
|
All Categories
|
||||||
|
</SelectItem>
|
||||||
|
{categories.map((category) => (
|
||||||
|
<SelectItem
|
||||||
|
key={category.id}
|
||||||
|
value={category.id.toString()}
|
||||||
|
className="cursor-pointer hover:bg-background-emphasis"
|
||||||
|
>
|
||||||
|
{category.name}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{defaultAssistants.length == 0 &&
|
{defaultAssistants.length == 0 &&
|
||||||
nonDefaultAssistants.length == 0 &&
|
nonDefaultAssistants.length == 0 &&
|
||||||
assistants.length != 0 && (
|
assistants.length != 0 && (
|
||||||
|
@ -5,24 +5,24 @@ import { FiChevronDown, FiChevronRight } from "react-icons/fi";
|
|||||||
interface AdvancedOptionsToggleProps {
|
interface AdvancedOptionsToggleProps {
|
||||||
showAdvancedOptions: boolean;
|
showAdvancedOptions: boolean;
|
||||||
setShowAdvancedOptions: (show: boolean) => void;
|
setShowAdvancedOptions: (show: boolean) => void;
|
||||||
|
title?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function AdvancedOptionsToggle({
|
export function AdvancedOptionsToggle({
|
||||||
showAdvancedOptions,
|
showAdvancedOptions,
|
||||||
setShowAdvancedOptions,
|
setShowAdvancedOptions,
|
||||||
|
title,
|
||||||
}: AdvancedOptionsToggleProps) {
|
}: AdvancedOptionsToggleProps) {
|
||||||
return (
|
return (
|
||||||
<div>
|
<Button
|
||||||
<Button
|
type="button"
|
||||||
type="button"
|
variant="link"
|
||||||
variant="link"
|
size="sm"
|
||||||
size="sm"
|
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
|
||||||
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
|
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
|
||||||
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
|
className="text-xs !p-0 text-text-950 hover:text-text-500"
|
||||||
className="text-xs text-text-950 hover:text-text-500"
|
>
|
||||||
>
|
{title || "Advanced Options"}
|
||||||
Advanced Options
|
</Button>
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -161,7 +161,7 @@ export function TextFormField({
|
|||||||
error?: string;
|
error?: string;
|
||||||
defaultHeight?: string;
|
defaultHeight?: string;
|
||||||
isCode?: boolean;
|
isCode?: boolean;
|
||||||
fontSize?: "text-sm" | "text-base" | "text-lg";
|
fontSize?: "sm" | "md" | "lg";
|
||||||
hideError?: boolean;
|
hideError?: boolean;
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
explanationText?: string;
|
explanationText?: string;
|
||||||
@ -187,12 +187,36 @@ export function TextFormField({
|
|||||||
onChange(e as React.ChangeEvent<HTMLInputElement>);
|
onChange(e as React.ChangeEvent<HTMLInputElement>);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
const textSizeClasses = {
|
||||||
|
sm: {
|
||||||
|
label: "text-sm",
|
||||||
|
input: "text-sm",
|
||||||
|
placeholder: "text-sm",
|
||||||
|
},
|
||||||
|
md: {
|
||||||
|
label: "text-base",
|
||||||
|
input: "text-base",
|
||||||
|
placeholder: "text-base",
|
||||||
|
},
|
||||||
|
lg: {
|
||||||
|
label: "text-lg",
|
||||||
|
input: "text-lg",
|
||||||
|
placeholder: "text-lg",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const sizeClass = textSizeClasses[fontSize || "md"];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={`w-full ${width}`}>
|
<div className={`w-full ${width}`}>
|
||||||
<div className="flex gap-x-2 items-center">
|
<div className="flex gap-x-2 items-center">
|
||||||
{!removeLabel && (
|
{!removeLabel && (
|
||||||
<Label className="text-text-950" small={small}>
|
<Label
|
||||||
|
className={`${
|
||||||
|
small ? "text-text-950" : "text-text-700 font-normal"
|
||||||
|
} ${sizeClass.label}`}
|
||||||
|
small={small}
|
||||||
|
>
|
||||||
{label}
|
{label}
|
||||||
</Label>
|
</Label>
|
||||||
)}
|
)}
|
||||||
@ -221,7 +245,7 @@ export function TextFormField({
|
|||||||
name={name}
|
name={name}
|
||||||
id={name}
|
id={name}
|
||||||
className={`
|
className={`
|
||||||
${small && "text-sm"}
|
${small && sizeClass.input}
|
||||||
border
|
border
|
||||||
border-border
|
border-border
|
||||||
rounded-md
|
rounded-md
|
||||||
@ -230,10 +254,10 @@ export function TextFormField({
|
|||||||
px-3
|
px-3
|
||||||
mt-1
|
mt-1
|
||||||
placeholder:font-description
|
placeholder:font-description
|
||||||
placeholder:text-base
|
placeholder:${sizeClass.placeholder}
|
||||||
placeholder:text-text-400
|
placeholder:text-text-400
|
||||||
${heightString}
|
${heightString}
|
||||||
${fontSize}
|
${sizeClass.input}
|
||||||
${disabled ? " bg-background-strong" : " bg-white"}
|
${disabled ? " bg-background-strong" : " bg-white"}
|
||||||
${isCode ? " font-mono" : ""}
|
${isCode ? " font-mono" : ""}
|
||||||
`}
|
`}
|
||||||
@ -585,6 +609,7 @@ interface SelectorFormFieldProps {
|
|||||||
onSelect?: (selected: string | number | null) => void;
|
onSelect?: (selected: string | number | null) => void;
|
||||||
defaultValue?: string;
|
defaultValue?: string;
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
|
includeReset?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function SelectorFormField({
|
export function SelectorFormField({
|
||||||
@ -597,6 +622,7 @@ export function SelectorFormField({
|
|||||||
onSelect,
|
onSelect,
|
||||||
defaultValue,
|
defaultValue,
|
||||||
tooltip,
|
tooltip,
|
||||||
|
includeReset = false,
|
||||||
}: SelectorFormFieldProps) {
|
}: SelectorFormFieldProps) {
|
||||||
const [field] = useField<string>(name);
|
const [field] = useField<string>(name);
|
||||||
const { setFieldValue } = useFormikContext();
|
const { setFieldValue } = useFormikContext();
|
||||||
@ -619,7 +645,11 @@ export function SelectorFormField({
|
|||||||
<Select
|
<Select
|
||||||
value={field.value || defaultValue}
|
value={field.value || defaultValue}
|
||||||
onValueChange={
|
onValueChange={
|
||||||
onSelect || ((selected) => setFieldValue(name, selected))
|
onSelect ||
|
||||||
|
((selected) =>
|
||||||
|
selected == "__none__"
|
||||||
|
? setFieldValue(name, null)
|
||||||
|
: setFieldValue(name, selected))
|
||||||
}
|
}
|
||||||
defaultValue={defaultValue}
|
defaultValue={defaultValue}
|
||||||
>
|
>
|
||||||
@ -649,6 +679,14 @@ export function SelectorFormField({
|
|||||||
</SelectItem>
|
</SelectItem>
|
||||||
))
|
))
|
||||||
)}
|
)}
|
||||||
|
{includeReset && (
|
||||||
|
<SelectItem
|
||||||
|
value={"__none__"}
|
||||||
|
onSelect={() => setFieldValue(name, null)}
|
||||||
|
>
|
||||||
|
None
|
||||||
|
</SelectItem>
|
||||||
|
)}
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
)}
|
)}
|
||||||
</Select>
|
</Select>
|
||||||
|
@ -15,6 +15,7 @@ import { ChatSession } from "@/app/chat/interfaces";
|
|||||||
import { UsersResponse } from "./users/interfaces";
|
import { UsersResponse } from "./users/interfaces";
|
||||||
import { Credential } from "./connectors/credentials";
|
import { Credential } from "./connectors/credentials";
|
||||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||||
|
import { PersonaCategory } from "@/app/admin/assistants/interfaces";
|
||||||
|
|
||||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||||
|
|
||||||
@ -83,6 +84,19 @@ export const useConnectorCredentialIndexingStatus = (
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const useCategories = () => {
|
||||||
|
const { mutate } = useSWRConfig();
|
||||||
|
const swrResponse = useSWR<PersonaCategory[]>(
|
||||||
|
"/api/persona/categories",
|
||||||
|
errorHandlingFetcher
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
...swrResponse,
|
||||||
|
refreshCategories: () => mutate("/api/persona/categories"),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
export const useTimeRange = (initialValue?: DateRangePickerValue) => {
|
export const useTimeRange = (initialValue?: DateRangePickerValue) => {
|
||||||
return useState<DateRangePickerValue | null>(null);
|
return useState<DateRangePickerValue | null>(null);
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user