diff --git a/backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py b/backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py new file mode 100644 index 000000000..2d7264339 --- /dev/null +++ b/backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py @@ -0,0 +1,31 @@ +"""Add starter prompts + +Revision ID: 0a2b51deb0b8 +Revises: 5f4b8568a221 +Create Date: 2024-03-02 23:23:49.960309 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "0a2b51deb0b8" +down_revision = "5f4b8568a221" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "persona", + sa.Column( + "starter_messages", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("persona", "starter_messages") diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index ce9798f17..d85def58d 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -89,6 +89,7 @@ def load_personas_from_yaml( if persona.get("num_chunks") is not None else default_chunks, llm_relevance_filter=persona.get("llm_relevance_filter"), + starter_messages=persona.get("starter_messages"), llm_filter_extraction=persona.get("llm_filter_extraction"), llm_model_version_override=None, recency_bias=RecencyBiasSetting(persona["recency_bias"]), diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index acb81f534..cc0800319 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -23,6 +23,7 @@ from danswer.db.models import Persona from danswer.db.models import Prompt from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc +from danswer.db.models import StarterMessage from danswer.search.models import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc @@ -465,6 +466,7 @@ def upsert_persona( prompts: list[Prompt] | None, document_sets: list[DBDocumentSet] | None, llm_model_version_override: str | None, + starter_messages: list[StarterMessage] | None, shared: bool, db_session: Session, persona_id: int | None = None, @@ -490,6 +492,7 @@ def upsert_persona( persona.recency_bias = recency_bias persona.default_persona = default_persona persona.llm_model_version_override = llm_model_version_override + persona.starter_messages = starter_messages persona.deleted = False # Un-delete if previously deleted # Do not delete any associations manually added unless @@ -516,6 +519,7 @@ def upsert_persona( prompts=prompts or [], document_sets=document_sets or [], llm_model_version_override=llm_model_version_override, + starter_messages=starter_messages, ) db_session.add(persona) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index b338d3e7a..98430fb23 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -716,6 +716,15 @@ class Prompt(Base): ) +class StarterMessage(TypedDict): + """NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column + in Postgres""" + + name: str + description: str + message: str + + class Persona(Base): __tablename__ = "persona" @@ -744,6 +753,9 @@ class Persona(Base): llm_model_version_override: Mapped[str | None] = mapped_column( String, nullable=True ) + starter_messages: Mapped[list[StarterMessage] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) # Default personas are configured via backend during deployment # Treated specially (cannot be user edited etc.) default_persona: Mapped[bool] = mapped_column(Boolean, default=False) diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index bbf4ff0b6..82ed77e3f 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -59,6 +59,7 @@ def create_slack_bot_persona( prompts=None, document_sets=document_sets, llm_model_version_override=None, + starter_messages=None, shared=True, default_persona=False, db_session=db_session, diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index e439aa582..160665495 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -66,6 +66,7 @@ def create_update_persona( prompts=prompts, document_sets=document_sets, llm_model_version_override=create_persona_request.llm_model_version_override, + starter_messages=create_persona_request.starter_messages, shared=create_persona_request.shared, db_session=db_session, ) diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 1eca57f5a..4a36ad709 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from danswer.db.models import Persona +from danswer.db.models import StarterMessage from danswer.search.models import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot @@ -17,6 +18,7 @@ class CreatePersonaRequest(BaseModel): prompt_ids: list[int] document_set_ids: list[int] llm_model_version_override: str | None = None + starter_messages: list[StarterMessage] | None = None class PersonaSnapshot(BaseModel): @@ -30,6 +32,7 @@ class PersonaSnapshot(BaseModel): llm_relevance_filter: bool llm_filter_extraction: bool llm_model_version_override: str | None + starter_messages: list[StarterMessage] | None default_persona: bool prompts: list[PromptSnapshot] document_sets: list[DocumentSet] @@ -50,6 +53,7 @@ class PersonaSnapshot(BaseModel): llm_relevance_filter=persona.llm_relevance_filter, llm_filter_extraction=persona.llm_filter_extraction, llm_model_version_override=persona.llm_model_version_override, + starter_messages=persona.starter_messages, default_persona=persona.default_persona, prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts], document_sets=[ diff --git a/web/src/app/admin/personas/HidableSection.tsx b/web/src/app/admin/personas/HidableSection.tsx new file mode 100644 index 000000000..714f2344c --- /dev/null +++ b/web/src/app/admin/personas/HidableSection.tsx @@ -0,0 +1,50 @@ +import { useState } from "react"; +import { FiChevronDown, FiChevronRight } from "react-icons/fi"; + +export function SectionHeader({ + children, + includeMargin = true, +}: { + children: string | JSX.Element; + includeMargin?: boolean; +}) { + return ( +