From a8cc3d5a0796ccd21db60902217a3748d78fbde9 Mon Sep 17 00:00:00 2001 From: Weves Date: Sat, 2 Mar 2024 23:29:01 -0800 Subject: [PATCH] Add ability to add starter messages --- .../0a2b51deb0b8_add_starter_prompts.py | 31 + backend/danswer/chat/load_yamls.py | 1 + backend/danswer/db/chat.py | 4 + backend/danswer/db/models.py | 12 + backend/danswer/db/slack_bot_config.py | 1 + .../danswer/server/features/persona/api.py | 1 + .../danswer/server/features/persona/models.py | 4 + web/src/app/admin/personas/HidableSection.tsx | 50 ++ web/src/app/admin/personas/PersonaEditor.tsx | 615 +++++++++++------- web/src/app/admin/personas/interfaces.ts | 7 + web/src/app/admin/personas/lib.ts | 5 +- web/src/app/chat/Chat.tsx | 78 ++- web/src/app/chat/ChatIntro.tsx | 2 +- web/src/app/chat/StarterMessage.tsx | 21 + 14 files changed, 579 insertions(+), 253 deletions(-) create mode 100644 backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py create mode 100644 web/src/app/admin/personas/HidableSection.tsx create mode 100644 web/src/app/chat/StarterMessage.tsx diff --git a/backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py b/backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py new file mode 100644 index 000000000..2d7264339 --- /dev/null +++ b/backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py @@ -0,0 +1,31 @@ +"""Add starter prompts + +Revision ID: 0a2b51deb0b8 +Revises: 5f4b8568a221 +Create Date: 2024-03-02 23:23:49.960309 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "0a2b51deb0b8" +down_revision = "5f4b8568a221" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "persona", + sa.Column( + "starter_messages", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("persona", "starter_messages") diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index ce9798f17..d85def58d 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -89,6 +89,7 @@ def load_personas_from_yaml( if persona.get("num_chunks") is not None else default_chunks, llm_relevance_filter=persona.get("llm_relevance_filter"), + starter_messages=persona.get("starter_messages"), llm_filter_extraction=persona.get("llm_filter_extraction"), llm_model_version_override=None, recency_bias=RecencyBiasSetting(persona["recency_bias"]), diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index acb81f534..cc0800319 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -23,6 +23,7 @@ from danswer.db.models import Persona from danswer.db.models import Prompt from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc +from danswer.db.models import StarterMessage from danswer.search.models import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc @@ -465,6 +466,7 @@ def upsert_persona( prompts: list[Prompt] | None, document_sets: list[DBDocumentSet] | None, llm_model_version_override: str | None, + starter_messages: list[StarterMessage] | None, shared: bool, db_session: Session, persona_id: int | None = None, @@ -490,6 +492,7 @@ def upsert_persona( persona.recency_bias = recency_bias persona.default_persona = default_persona persona.llm_model_version_override = llm_model_version_override + persona.starter_messages = starter_messages persona.deleted = False # Un-delete if previously deleted # Do not delete any associations manually added unless @@ -516,6 +519,7 @@ def upsert_persona( prompts=prompts or [], document_sets=document_sets or [], llm_model_version_override=llm_model_version_override, + starter_messages=starter_messages, ) db_session.add(persona) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index b338d3e7a..98430fb23 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -716,6 +716,15 @@ class Prompt(Base): ) +class StarterMessage(TypedDict): + """NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column + in Postgres""" + + name: str + description: str + message: str + + class Persona(Base): __tablename__ = "persona" @@ -744,6 +753,9 @@ class Persona(Base): llm_model_version_override: Mapped[str | None] = mapped_column( String, nullable=True ) + starter_messages: Mapped[list[StarterMessage] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) # Default personas are configured via backend during deployment # Treated specially (cannot be user edited etc.) default_persona: Mapped[bool] = mapped_column(Boolean, default=False) diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index bbf4ff0b6..82ed77e3f 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -59,6 +59,7 @@ def create_slack_bot_persona( prompts=None, document_sets=document_sets, llm_model_version_override=None, + starter_messages=None, shared=True, default_persona=False, db_session=db_session, diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index e439aa582..160665495 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -66,6 +66,7 @@ def create_update_persona( prompts=prompts, document_sets=document_sets, llm_model_version_override=create_persona_request.llm_model_version_override, + starter_messages=create_persona_request.starter_messages, shared=create_persona_request.shared, db_session=db_session, ) diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 1eca57f5a..4a36ad709 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from danswer.db.models import Persona +from danswer.db.models import StarterMessage from danswer.search.models import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot @@ -17,6 +18,7 @@ class CreatePersonaRequest(BaseModel): prompt_ids: list[int] document_set_ids: list[int] llm_model_version_override: str | None = None + starter_messages: list[StarterMessage] | None = None class PersonaSnapshot(BaseModel): @@ -30,6 +32,7 @@ class PersonaSnapshot(BaseModel): llm_relevance_filter: bool llm_filter_extraction: bool llm_model_version_override: str | None + starter_messages: list[StarterMessage] | None default_persona: bool prompts: list[PromptSnapshot] document_sets: list[DocumentSet] @@ -50,6 +53,7 @@ class PersonaSnapshot(BaseModel): llm_relevance_filter=persona.llm_relevance_filter, llm_filter_extraction=persona.llm_filter_extraction, llm_model_version_override=persona.llm_model_version_override, + starter_messages=persona.starter_messages, default_persona=persona.default_persona, prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts], document_sets=[ diff --git a/web/src/app/admin/personas/HidableSection.tsx b/web/src/app/admin/personas/HidableSection.tsx new file mode 100644 index 000000000..714f2344c --- /dev/null +++ b/web/src/app/admin/personas/HidableSection.tsx @@ -0,0 +1,50 @@ +import { useState } from "react"; +import { FiChevronDown, FiChevronRight } from "react-icons/fi"; + +export function SectionHeader({ + children, + includeMargin = true, +}: { + children: string | JSX.Element; + includeMargin?: boolean; +}) { + return ( +
+ {children} +
+ ); +} + +export function HidableSection({ + children, + sectionTitle, + defaultHidden = false, +}: { + children: string | JSX.Element; + sectionTitle: string | JSX.Element; + defaultHidden?: boolean; +}) { + const [isHidden, setIsHidden] = useState(defaultHidden); + + return ( +
+
setIsHidden(!isHidden)} + > + {sectionTitle} +
+ {isHidden ? ( + + ) : ( + + )} +
+
+ + {!isHidden &&
{children}
} +
+ ); +} diff --git a/web/src/app/admin/personas/PersonaEditor.tsx b/web/src/app/admin/personas/PersonaEditor.tsx index 42106108c..be9344e43 100644 --- a/web/src/app/admin/personas/PersonaEditor.tsx +++ b/web/src/app/admin/personas/PersonaEditor.tsx @@ -2,7 +2,14 @@ import { DocumentSet } from "@/lib/types"; import { Button, Divider, Text } from "@tremor/react"; -import { ArrayHelpers, FieldArray, Form, Formik } from "formik"; +import { + ArrayHelpers, + ErrorMessage, + Field, + FieldArray, + Form, + Formik, +} from "formik"; import * as Yup from "yup"; import { buildFinalPrompt, createPersona, updatePersona } from "./lib"; @@ -16,10 +23,8 @@ import { SelectorFormField, TextFormField, } from "@/components/admin/connectors/Field"; - -function SectionHeader({ children }: { children: string | JSX.Element }) { - return
{children}
; -} +import { HidableSection } from "./HidableSection"; +import { FiPlus, FiX } from "react-icons/fi"; function Label({ children }: { children: string | JSX.Element }) { return ( @@ -97,6 +102,7 @@ export function PersonaEditor({ llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false, llm_model_version_override: existingPersona?.llm_model_version_override ?? null, + starter_messages: existingPersona?.starter_messages ?? null, }} validationSchema={Yup.object() .shape({ @@ -112,6 +118,13 @@ export function PersonaEditor({ include_citations: Yup.boolean().required(), llm_relevance_filter: Yup.boolean().required(), llm_model_version_override: Yup.string().nullable(), + starter_messages: Yup.array().of( + Yup.object().shape({ + name: Yup.string().required(), + description: Yup.string().required(), + message: Yup.string().required(), + }) + ), }) .test( "system-prompt-or-task-prompt", @@ -188,171 +201,175 @@ export function PersonaEditor({ {({ isSubmitting, values, setFieldValue }) => (
- Who am I? + + <> + - - - + + + - Customize my response style + + <> + { + setFieldValue("system_prompt", e.target.value); + triggerFinalPromptUpdate( + e.target.value, + values.task_prompt, + values.disable_retrieval + ); + }} + error={finalPromptError} + /> - { - setFieldValue("system_prompt", e.target.value); - triggerFinalPromptUpdate( - e.target.value, - values.task_prompt, - values.disable_retrieval - ); - }} - error={finalPromptError} - /> + { + setFieldValue("task_prompt", e.target.value); + triggerFinalPromptUpdate( + values.system_prompt, + e.target.value, + values.disable_retrieval + ); + }} + error={finalPromptError} + /> - { - setFieldValue("task_prompt", e.target.value); - triggerFinalPromptUpdate( - values.system_prompt, - e.target.value, - values.disable_retrieval - ); - }} - error={finalPromptError} - /> + {!values.disable_retrieval && ( + + )} - {!values.disable_retrieval && ( - - )} + { + setFieldValue("disable_retrieval", e.target.checked); + triggerFinalPromptUpdate( + values.system_prompt, + values.task_prompt, + e.target.checked + ); + }} + /> - { - setFieldValue("disable_retrieval", e.target.checked); - triggerFinalPromptUpdate( - values.system_prompt, - values.task_prompt, - e.target.checked - ); - }} - /> + - - - {finalPrompt ? ( -
-                  {finalPrompt}
-                
- ) : ( - "-" - )} + {finalPrompt ? ( +
+                      {finalPrompt}
+                    
+ ) : ( + "-" + )} + +
{!values.disable_retrieval && ( <> - - What data should I have access to? - - - ( -
-
- - <> - Select which{" "} - - Document Sets - {" "} - that this Persona should search through. If none - are specified, the Persona will search through all - available documents in order to try and response - to queries. - - -
-
- {documentSets.map((documentSet) => { - const ind = values.document_set_ids.indexOf( - documentSet.id - ); - let isSelected = ind !== -1; - return ( -
{ - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push(documentSet.id); - } - }} - > -
- {documentSet.name} -
-
- ); - })} -
-
- )} - /> + + <> + ( +
+
+ + <> + Select which{" "} + + Document Sets + {" "} + that this Persona should search through. If + none are specified, the Persona will search + through all available documents in order to + try and response to queries. + + +
+
+ {documentSets.map((documentSet) => { + const ind = values.document_set_ids.indexOf( + documentSet.id + ); + let isSelected = ind !== -1; + return ( +
{ + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(documentSet.id); + } + }} + > +
+ {documentSet.name} +
+
+ ); + })} +
+
+ )} + /> + +
@@ -360,84 +377,230 @@ export function PersonaEditor({ {llmOverrideOptions.length > 0 && defaultLLM && ( <> - [Advanced] Model Selection - - - Pick which LLM to use for this Persona. If left as Default, - will use {defaultLLM}. -
-
- For more information on the different LLMs, checkout the{" "} - - OpenAI docs - - . -
- -
- { - return { - name: llmOption, - value: llmOption, - }; - })} - includeDefault={true} - /> -
- - )} - - - - {!values.disable_retrieval && ( - <> - - [Advanced] Retrieval Customization - - - - How many chunks should we feed into the LLM when - generating the final response? Each chunk is ~400 words - long. If you are using gpt-3.5-turbo or other similar - models, setting this to a value greater than 5 will - result in errors at query time due to the model's - input length limit. + + <> + + Pick which LLM to use for this Persona. If left as + Default, will use {defaultLLM} + .

- If unspecified, will use 10 chunks. + For more information on the different LLMs, checkout the{" "} + + OpenAI docs + + . +
+ +
+ { + return { + name: llmOption, + value: llmOption, + }; + })} + includeDefault={true} + />
- } - onChange={(e) => { - const value = e.target.value; - // Allow only integer values - if (value === "" || /^[0-9]+$/.test(value)) { - setFieldValue("num_chunks", value); - } - }} - /> - - + +
)} + {!values.disable_retrieval && ( + <> + + <> + + How many chunks should we feed into the LLM when + generating the final response? Each chunk is ~400 + words long. If you are using gpt-3.5-turbo or other + similar models, setting this to a value greater than + 5 will result in errors at query time due to the + model's input length limit. +
+
+ If unspecified, will use 10 chunks. +
+ } + onChange={(e) => { + const value = e.target.value; + // Allow only integer values + if (value === "" || /^[0-9]+$/.test(value)) { + setFieldValue("num_chunks", value); + } + }} + /> + + + + + + + + )} + + + <> +
+ + Starter Messages help guide users to use this Persona. + They are shown to the user as clickable options when they + select this Persona. When selected, the specified message + is sent to the LLM as the initial user message. + +
+ + ( +
+ {values.starter_messages && + values.starter_messages.length > 0 && + values.starter_messages.map((_, index) => ( +
+
+
+
+ + + Shows up as the "title" for this + Starter Message. For example, "Write + an email". + + + +
+ +
+ + + A description which tells the user what + they might want to use this Starter + Message for. For example "to a client + about a new feature" + + + +
+ +
+ + + The actual message to be sent as the + initial user message if a user selects + this starter prompt. For example, + "Write me an email to a client about + a new billing feature we just + released." + + + +
+
+
+ arrayHelpers.remove(index)} + /> +
+
+
+ ))} + + +
+ )} + /> + +
+ + +
@@ -716,30 +744,30 @@ export const Chat = ({ ref={textareaRef} autoFocus className={` - opacity-100 - w-full - shrink - border - border-border - rounded-lg - outline-none - placeholder-gray-400 - pl-4 - pr-12 - py-4 - overflow-hidden - h-14 - ${ - (textareaRef?.current?.scrollHeight || 0) > - MAX_INPUT_HEIGHT - ? "overflow-y-auto" - : "" - } - whitespace-normal - break-word - overscroll-contain - resize-none - `} + opacity-100 + w-full + shrink + border + border-border + rounded-lg + outline-none + placeholder-gray-400 + pl-4 + pr-12 + py-4 + overflow-hidden + h-14 + ${ + (textareaRef?.current?.scrollHeight || 0) > + MAX_INPUT_HEIGHT + ? "overflow-y-auto" + : "" + } + whitespace-normal + break-word + overscroll-contain + resize-none + `} style={{ scrollbarWidth: "thin" }} role="textarea" aria-multiline diff --git a/web/src/app/chat/ChatIntro.tsx b/web/src/app/chat/ChatIntro.tsx index 84ae38f23..adb17a383 100644 --- a/web/src/app/chat/ChatIntro.tsx +++ b/web/src/app/chat/ChatIntro.tsx @@ -20,7 +20,7 @@ function HelperItemDisplay({ description: string; }) { return ( -
+
{title}
{description}
diff --git a/web/src/app/chat/StarterMessage.tsx b/web/src/app/chat/StarterMessage.tsx new file mode 100644 index 000000000..3c83dd3c6 --- /dev/null +++ b/web/src/app/chat/StarterMessage.tsx @@ -0,0 +1,21 @@ +import { StarterMessage } from "../admin/personas/interfaces"; + +export function StarterMessage({ + starterMessage, + onClick, +}: { + starterMessage: StarterMessage; + onClick: () => void; +}) { + return ( +
+

{starterMessage.name}

+

{starterMessage.description}

+
+ ); +}