Add ability to add starter messages

This commit is contained in:
Weves
2024-03-02 23:29:01 -08:00
committed by Chris Weaver
parent 9051ebfed7
commit a8cc3d5a07
14 changed files with 579 additions and 253 deletions

View File

@ -0,0 +1,31 @@
"""Add starter prompts
Revision ID: 0a2b51deb0b8
Revises: 5f4b8568a221
Create Date: 2024-03-02 23:23:49.960309
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "0a2b51deb0b8"
down_revision = "5f4b8568a221"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column(
"starter_messages",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("persona", "starter_messages")

View File

@ -89,6 +89,7 @@ def load_personas_from_yaml(
if persona.get("num_chunks") is not None if persona.get("num_chunks") is not None
else default_chunks, else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"), llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"), llm_filter_extraction=persona.get("llm_filter_extraction"),
llm_model_version_override=None, llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]), recency_bias=RecencyBiasSetting(persona["recency_bias"]),

View File

@ -23,6 +23,7 @@ from danswer.db.models import Persona
from danswer.db.models import Prompt from danswer.db.models import Prompt
from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import StarterMessage
from danswer.search.models import RecencyBiasSetting from danswer.search.models import RecencyBiasSetting
from danswer.search.models import RetrievalDocs from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc from danswer.search.models import SavedSearchDoc
@ -465,6 +466,7 @@ def upsert_persona(
prompts: list[Prompt] | None, prompts: list[Prompt] | None,
document_sets: list[DBDocumentSet] | None, document_sets: list[DBDocumentSet] | None,
llm_model_version_override: str | None, llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
shared: bool, shared: bool,
db_session: Session, db_session: Session,
persona_id: int | None = None, persona_id: int | None = None,
@ -490,6 +492,7 @@ def upsert_persona(
persona.recency_bias = recency_bias persona.recency_bias = recency_bias
persona.default_persona = default_persona persona.default_persona = default_persona
persona.llm_model_version_override = llm_model_version_override persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted persona.deleted = False # Un-delete if previously deleted
# Do not delete any associations manually added unless # Do not delete any associations manually added unless
@ -516,6 +519,7 @@ def upsert_persona(
prompts=prompts or [], prompts=prompts or [],
document_sets=document_sets or [], document_sets=document_sets or [],
llm_model_version_override=llm_model_version_override, llm_model_version_override=llm_model_version_override,
starter_messages=starter_messages,
) )
db_session.add(persona) db_session.add(persona)

View File

@ -716,6 +716,15 @@ class Prompt(Base):
) )
class StarterMessage(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
in Postgres"""
name: str
description: str
message: str
class Persona(Base): class Persona(Base):
__tablename__ = "persona" __tablename__ = "persona"
@ -744,6 +753,9 @@ class Persona(Base):
llm_model_version_override: Mapped[str | None] = mapped_column( llm_model_version_override: Mapped[str | None] = mapped_column(
String, nullable=True String, nullable=True
) )
starter_messages: Mapped[list[StarterMessage] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Default personas are configured via backend during deployment # Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.) # Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False) default_persona: Mapped[bool] = mapped_column(Boolean, default=False)

View File

@ -59,6 +59,7 @@ def create_slack_bot_persona(
prompts=None, prompts=None,
document_sets=document_sets, document_sets=document_sets,
llm_model_version_override=None, llm_model_version_override=None,
starter_messages=None,
shared=True, shared=True,
default_persona=False, default_persona=False,
db_session=db_session, db_session=db_session,

View File

@ -66,6 +66,7 @@ def create_update_persona(
prompts=prompts, prompts=prompts,
document_sets=document_sets, document_sets=document_sets,
llm_model_version_override=create_persona_request.llm_model_version_override, llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
shared=create_persona_request.shared, shared=create_persona_request.shared,
db_session=db_session, db_session=db_session,
) )

View File

@ -1,6 +1,7 @@
from pydantic import BaseModel from pydantic import BaseModel
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.db.models import StarterMessage
from danswer.search.models import RecencyBiasSetting from danswer.search.models import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.prompt.models import PromptSnapshot from danswer.server.features.prompt.models import PromptSnapshot
@ -17,6 +18,7 @@ class CreatePersonaRequest(BaseModel):
prompt_ids: list[int] prompt_ids: list[int]
document_set_ids: list[int] document_set_ids: list[int]
llm_model_version_override: str | None = None llm_model_version_override: str | None = None
starter_messages: list[StarterMessage] | None = None
class PersonaSnapshot(BaseModel): class PersonaSnapshot(BaseModel):
@ -30,6 +32,7 @@ class PersonaSnapshot(BaseModel):
llm_relevance_filter: bool llm_relevance_filter: bool
llm_filter_extraction: bool llm_filter_extraction: bool
llm_model_version_override: str | None llm_model_version_override: str | None
starter_messages: list[StarterMessage] | None
default_persona: bool default_persona: bool
prompts: list[PromptSnapshot] prompts: list[PromptSnapshot]
document_sets: list[DocumentSet] document_sets: list[DocumentSet]
@ -50,6 +53,7 @@ class PersonaSnapshot(BaseModel):
llm_relevance_filter=persona.llm_relevance_filter, llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction, llm_filter_extraction=persona.llm_filter_extraction,
llm_model_version_override=persona.llm_model_version_override, llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
default_persona=persona.default_persona, default_persona=persona.default_persona,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts], prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
document_sets=[ document_sets=[

View File

@ -0,0 +1,50 @@
import { useState } from "react";
import { FiChevronDown, FiChevronRight } from "react-icons/fi";
export function SectionHeader({
children,
includeMargin = true,
}: {
children: string | JSX.Element;
includeMargin?: boolean;
}) {
return (
<div
className={"font-bold text-xl my-auto" + (includeMargin ? " mb-4" : "")}
>
{children}
</div>
);
}
export function HidableSection({
children,
sectionTitle,
defaultHidden = false,
}: {
children: string | JSX.Element;
sectionTitle: string | JSX.Element;
defaultHidden?: boolean;
}) {
const [isHidden, setIsHidden] = useState(defaultHidden);
return (
<div>
<div
className="flex hover:bg-hover-light rounded cursor-pointer p-2"
onClick={() => setIsHidden(!isHidden)}
>
<SectionHeader includeMargin={false}>{sectionTitle}</SectionHeader>
<div className="my-auto ml-auto p-1">
{isHidden ? (
<FiChevronRight size={24} />
) : (
<FiChevronDown size={24} />
)}
</div>
</div>
{!isHidden && <div className="mx-2 mt-2">{children}</div>}
</div>
);
}

View File

@ -2,7 +2,14 @@
import { DocumentSet } from "@/lib/types"; import { DocumentSet } from "@/lib/types";
import { Button, Divider, Text } from "@tremor/react"; import { Button, Divider, Text } from "@tremor/react";
import { ArrayHelpers, FieldArray, Form, Formik } from "formik"; import {
ArrayHelpers,
ErrorMessage,
Field,
FieldArray,
Form,
Formik,
} from "formik";
import * as Yup from "yup"; import * as Yup from "yup";
import { buildFinalPrompt, createPersona, updatePersona } from "./lib"; import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
@ -16,10 +23,8 @@ import {
SelectorFormField, SelectorFormField,
TextFormField, TextFormField,
} from "@/components/admin/connectors/Field"; } from "@/components/admin/connectors/Field";
import { HidableSection } from "./HidableSection";
function SectionHeader({ children }: { children: string | JSX.Element }) { import { FiPlus, FiX } from "react-icons/fi";
return <div className="mb-4 font-bold text-lg">{children}</div>;
}
function Label({ children }: { children: string | JSX.Element }) { function Label({ children }: { children: string | JSX.Element }) {
return ( return (
@ -97,6 +102,7 @@ export function PersonaEditor({
llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false, llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false,
llm_model_version_override: llm_model_version_override:
existingPersona?.llm_model_version_override ?? null, existingPersona?.llm_model_version_override ?? null,
starter_messages: existingPersona?.starter_messages ?? null,
}} }}
validationSchema={Yup.object() validationSchema={Yup.object()
.shape({ .shape({
@ -112,6 +118,13 @@ export function PersonaEditor({
include_citations: Yup.boolean().required(), include_citations: Yup.boolean().required(),
llm_relevance_filter: Yup.boolean().required(), llm_relevance_filter: Yup.boolean().required(),
llm_model_version_override: Yup.string().nullable(), llm_model_version_override: Yup.string().nullable(),
starter_messages: Yup.array().of(
Yup.object().shape({
name: Yup.string().required(),
description: Yup.string().required(),
message: Yup.string().required(),
})
),
}) })
.test( .test(
"system-prompt-or-task-prompt", "system-prompt-or-task-prompt",
@ -188,8 +201,8 @@ export function PersonaEditor({
{({ isSubmitting, values, setFieldValue }) => ( {({ isSubmitting, values, setFieldValue }) => (
<Form> <Form>
<div className="pb-6"> <div className="pb-6">
<SectionHeader>Who am I?</SectionHeader> <HidableSection sectionTitle="Who am I?">
<>
<TextFormField <TextFormField
name="name" name="name"
label="Name" label="Name"
@ -202,11 +215,13 @@ export function PersonaEditor({
label="Description" label="Description"
subtext="Provide a short descriptions which gives users a hint as to what they should use this Persona for." subtext="Provide a short descriptions which gives users a hint as to what they should use this Persona for."
/> />
</>
</HidableSection>
<Divider /> <Divider />
<SectionHeader>Customize my response style</SectionHeader> <HidableSection sectionTitle="Customize my response style">
<>
<TextFormField <TextFormField
name="system_prompt" name="system_prompt"
label="System Prompt" label="System Prompt"
@ -281,15 +296,15 @@ export function PersonaEditor({
) : ( ) : (
"-" "-"
)} )}
</>
</HidableSection>
<Divider /> <Divider />
{!values.disable_retrieval && ( {!values.disable_retrieval && (
<> <>
<SectionHeader> <HidableSection sectionTitle="What data should I have access to?">
What data should I have access to? <>
</SectionHeader>
<FieldArray <FieldArray
name="document_set_ids" name="document_set_ids"
render={(arrayHelpers: ArrayHelpers) => ( render={(arrayHelpers: ArrayHelpers) => (
@ -305,10 +320,10 @@ export function PersonaEditor({
> >
Document Sets Document Sets
</Link>{" "} </Link>{" "}
that this Persona should search through. If none that this Persona should search through. If
are specified, the Persona will search through all none are specified, the Persona will search
available documents in order to try and response through all available documents in order to
to queries. try and response to queries.
</> </>
</SubLabel> </SubLabel>
</div> </div>
@ -353,6 +368,8 @@ export function PersonaEditor({
</div> </div>
)} )}
/> />
</>
</HidableSection>
<Divider /> <Divider />
</> </>
@ -360,11 +377,12 @@ export function PersonaEditor({
{llmOverrideOptions.length > 0 && defaultLLM && ( {llmOverrideOptions.length > 0 && defaultLLM && (
<> <>
<SectionHeader>[Advanced] Model Selection</SectionHeader> <HidableSection sectionTitle="[Advanced] Model Selection">
<>
<Text> <Text>
Pick which LLM to use for this Persona. If left as Default, Pick which LLM to use for this Persona. If left as
will use <b className="italic">{defaultLLM}</b>. Default, will use <b className="italic">{defaultLLM}</b>
.
<br /> <br />
<br /> <br />
For more information on the different LLMs, checkout the{" "} For more information on the different LLMs, checkout the{" "}
@ -391,27 +409,27 @@ export function PersonaEditor({
/> />
</div> </div>
</> </>
)} </HidableSection>
<Divider /> <Divider />
</>
)}
{!values.disable_retrieval && ( {!values.disable_retrieval && (
<> <>
<SectionHeader> <HidableSection sectionTitle="[Advanced] Retrieval Customization">
[Advanced] Retrieval Customization <>
</SectionHeader>
<TextFormField <TextFormField
name="num_chunks" name="num_chunks"
label="Number of Chunks" label="Number of Chunks"
subtext={ subtext={
<div> <div>
How many chunks should we feed into the LLM when How many chunks should we feed into the LLM when
generating the final response? Each chunk is ~400 words generating the final response? Each chunk is ~400
long. If you are using gpt-3.5-turbo or other similar words long. If you are using gpt-3.5-turbo or other
models, setting this to a value greater than 5 will similar models, setting this to a value greater than
result in errors at query time due to the model&apos;s 5 will result in errors at query time due to the
input length limit. model&apos;s input length limit.
<br /> <br />
<br /> <br />
If unspecified, will use 10 chunks. If unspecified, will use 10 chunks.
@ -433,11 +451,156 @@ export function PersonaEditor({
"If enabled, the LLM will filter out chunks that are not relevant to the user query." "If enabled, the LLM will filter out chunks that are not relevant to the user query."
} }
/> />
</>
</HidableSection>
<Divider /> <Divider />
</> </>
)} )}
<HidableSection sectionTitle="[Advanced] Starter Messages">
<>
<div className="mb-4">
<SubLabel>
Starter Messages help guide users to use this Persona.
They are shown to the user as clickable options when they
select this Persona. When selected, the specified message
is sent to the LLM as the initial user message.
</SubLabel>
</div>
<FieldArray
name="starter_messages"
render={(arrayHelpers: ArrayHelpers) => (
<div>
{values.starter_messages &&
values.starter_messages.length > 0 &&
values.starter_messages.map((_, index) => (
<div
key={index}
className={index === 0 ? "mt-2" : "mt-6"}
>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Name</Label>
<SubLabel>
Shows up as the &quot;title&quot; for this
Starter Message. For example, &quot;Write
an email&quot;.
</SubLabel>
<Field
name={`starter_messages.${index}.name`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages.${index}.name`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label>Description</Label>
<SubLabel>
A description which tells the user what
they might want to use this Starter
Message for. For example &quot;to a client
about a new feature&quot;
</SubLabel>
<Field
name={`starter_messages.${index}.description`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages.${index}.description`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label>Message</Label>
<SubLabel>
The actual message to be sent as the
initial user message if a user selects
this starter prompt. For example,
&quot;Write me an email to a client about
a new billing feature we just
released.&quot;
</SubLabel>
<Field
name={`starter_messages.${index}.message`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
as="textarea"
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages.${index}.message`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div>
</div>
))}
<Button
onClick={() => {
arrayHelpers.push("");
}}
className="mt-3"
color="green"
size="xs"
type="button"
icon={FiPlus}
>
Add New
</Button>
</div>
)}
/>
</>
</HidableSection>
<Divider />
<div className="flex"> <div className="flex">
<Button <Button
className="mx-auto" className="mx-auto"

View File

@ -1,5 +1,11 @@
import { DocumentSet } from "@/lib/types"; import { DocumentSet } from "@/lib/types";
export interface StarterMessage {
name: string;
description: string | null;
message: string;
}
export interface Prompt { export interface Prompt {
id: number; id: number;
name: string; name: string;
@ -25,5 +31,6 @@ export interface Persona {
llm_relevance_filter?: boolean; llm_relevance_filter?: boolean;
llm_filter_extraction?: boolean; llm_filter_extraction?: boolean;
llm_model_version_override?: string; llm_model_version_override?: string;
starter_messages: StarterMessage[] | null;
default_persona: boolean; default_persona: boolean;
} }

View File

@ -1,4 +1,4 @@
import { Persona, Prompt } from "./interfaces"; import { Persona, Prompt, StarterMessage } from "./interfaces";
interface PersonaCreationRequest { interface PersonaCreationRequest {
name: string; name: string;
@ -10,6 +10,7 @@ interface PersonaCreationRequest {
include_citations: boolean; include_citations: boolean;
llm_relevance_filter: boolean | null; llm_relevance_filter: boolean | null;
llm_model_version_override: string | null; llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
} }
interface PersonaUpdateRequest { interface PersonaUpdateRequest {
@ -24,6 +25,7 @@ interface PersonaUpdateRequest {
include_citations: boolean; include_citations: boolean;
llm_relevance_filter: boolean | null; llm_relevance_filter: boolean | null;
llm_model_version_override: string | null; llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
} }
function promptNameFromPersonaName(personaName: string) { function promptNameFromPersonaName(personaName: string) {
@ -109,6 +111,7 @@ function buildPersonaAPIBody(
prompt_ids: [promptId], prompt_ids: [promptId],
document_set_ids, document_set_ids,
llm_model_version_override: creationRequest.llm_model_version_override, llm_model_version_override: creationRequest.llm_model_version_override,
starter_messages: creationRequest.starter_messages,
}; };
} }

View File

@ -43,6 +43,7 @@ import { ChatIntro } from "./ChatIntro";
import { HEADER_PADDING } from "@/lib/constants"; import { HEADER_PADDING } from "@/lib/constants";
import { computeAvailableFilters } from "@/lib/filters"; import { computeAvailableFilters } from "@/lib/filters";
import { useDocumentSelection } from "./useDocumentSelection"; import { useDocumentSelection } from "./useDocumentSelection";
import { StarterMessage } from "./StarterMessage";
const MAX_INPUT_HEIGHT = 200; const MAX_INPUT_HEIGHT = 200;
@ -290,10 +291,12 @@ export const Chat = ({
const onSubmit = async ({ const onSubmit = async ({
messageIdToResend, messageIdToResend,
messageOverride,
queryOverride, queryOverride,
forceSearch, forceSearch,
}: { }: {
messageIdToResend?: number; messageIdToResend?: number;
messageOverride?: string;
queryOverride?: string; queryOverride?: string;
forceSearch?: boolean; forceSearch?: boolean;
} = {}) => { } = {}) => {
@ -321,7 +324,10 @@ export const Chat = ({
return; return;
} }
const currMessage = messageToResend ? messageToResend.message : message; let currMessage = messageToResend ? messageToResend.message : message;
if (messageOverride) {
currMessage = messageOverride;
}
const currMessageHistory = const currMessageHistory =
messageToResendIndex !== null messageToResendIndex !== null
? messageHistory.slice(0, messageToResendIndex) ? messageHistory.slice(0, messageToResendIndex)
@ -685,6 +691,28 @@ export const Chat = ({
{/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/} {/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
<div className={`min-h-[30px] w-full`}></div> <div className={`min-h-[30px] w-full`}></div>
{livePersona &&
livePersona.starter_messages &&
livePersona.starter_messages.length > 0 &&
selectedPersona &&
messageHistory.length === 0 &&
!isFetchingChatMessages && (
<div className="mx-auto px-4 w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar grid gap-4 grid-cols-1 grid-rows-1 mt-4 md:grid-cols-2">
{livePersona.starter_messages.map((starterMessage, i) => (
<div key={i} className="w-full">
<StarterMessage
starterMessage={starterMessage}
onClick={() =>
onSubmit({
messageOverride: starterMessage.message,
})
}
/>
</div>
))}
</div>
)}
<div ref={endDivRef} /> <div ref={endDivRef} />
</div> </div>
</div> </div>

View File

@ -20,7 +20,7 @@ function HelperItemDisplay({
description: string; description: string;
}) { }) {
return ( return (
<div className="cursor-default hover:bg-hover-light border border-border rounded py-2 px-4"> <div className="cursor-pointer hover:bg-hover-light border border-border rounded py-2 px-4">
<div className="text-emphasis font-bold text-lg flex">{title}</div> <div className="text-emphasis font-bold text-lg flex">{title}</div>
<div className="text-sm">{description}</div> <div className="text-sm">{description}</div>
</div> </div>

View File

@ -0,0 +1,21 @@
import { StarterMessage } from "../admin/personas/interfaces";
export function StarterMessage({
starterMessage,
onClick,
}: {
starterMessage: StarterMessage;
onClick: () => void;
}) {
return (
<div
className={
"py-2 px-3 rounded border border-border bg-white cursor-pointer hover:bg-hover-light h-full"
}
onClick={onClick}
>
<p className="font-medium text-neutral-700">{starterMessage.name}</p>
<p className="text-neutral-500 text-sm">{starterMessage.description}</p>
</div>
);
}