Add ability to add starter messages

This commit is contained in:
Weves 2024-03-02 23:29:01 -08:00 committed by Chris Weaver
parent 9051ebfed7
commit a8cc3d5a07
14 changed files with 579 additions and 253 deletions

View File

@ -0,0 +1,31 @@
"""Add starter prompts
Revision ID: 0a2b51deb0b8
Revises: 5f4b8568a221
Create Date: 2024-03-02 23:23:49.960309
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "0a2b51deb0b8"
down_revision = "5f4b8568a221"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column(
"starter_messages",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("persona", "starter_messages")

View File

@ -89,6 +89,7 @@ def load_personas_from_yaml(
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),

View File

@ -23,6 +23,7 @@ from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import StarterMessage
from danswer.search.models import RecencyBiasSetting
from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
@ -465,6 +466,7 @@ def upsert_persona(
prompts: list[Prompt] | None,
document_sets: list[DBDocumentSet] | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
shared: bool,
db_session: Session,
persona_id: int | None = None,
@ -490,6 +492,7 @@ def upsert_persona(
persona.recency_bias = recency_bias
persona.default_persona = default_persona
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted
# Do not delete any associations manually added unless
@ -516,6 +519,7 @@ def upsert_persona(
prompts=prompts or [],
document_sets=document_sets or [],
llm_model_version_override=llm_model_version_override,
starter_messages=starter_messages,
)
db_session.add(persona)

View File

@ -716,6 +716,15 @@ class Prompt(Base):
)
class StarterMessage(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
in Postgres"""
name: str
description: str
message: str
class Persona(Base):
__tablename__ = "persona"
@ -744,6 +753,9 @@ class Persona(Base):
llm_model_version_override: Mapped[str | None] = mapped_column(
String, nullable=True
)
starter_messages: Mapped[list[StarterMessage] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)

View File

@ -59,6 +59,7 @@ def create_slack_bot_persona(
prompts=None,
document_sets=document_sets,
llm_model_version_override=None,
starter_messages=None,
shared=True,
default_persona=False,
db_session=db_session,

View File

@ -66,6 +66,7 @@ def create_update_persona(
prompts=prompts,
document_sets=document_sets,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
shared=create_persona_request.shared,
db_session=db_session,
)

View File

@ -1,6 +1,7 @@
from pydantic import BaseModel
from danswer.db.models import Persona
from danswer.db.models import StarterMessage
from danswer.search.models import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.prompt.models import PromptSnapshot
@ -17,6 +18,7 @@ class CreatePersonaRequest(BaseModel):
prompt_ids: list[int]
document_set_ids: list[int]
llm_model_version_override: str | None = None
starter_messages: list[StarterMessage] | None = None
class PersonaSnapshot(BaseModel):
@ -30,6 +32,7 @@ class PersonaSnapshot(BaseModel):
llm_relevance_filter: bool
llm_filter_extraction: bool
llm_model_version_override: str | None
starter_messages: list[StarterMessage] | None
default_persona: bool
prompts: list[PromptSnapshot]
document_sets: list[DocumentSet]
@ -50,6 +53,7 @@ class PersonaSnapshot(BaseModel):
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
default_persona=persona.default_persona,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
document_sets=[

View File

@ -0,0 +1,50 @@
import { useState } from "react";
import { FiChevronDown, FiChevronRight } from "react-icons/fi";
export function SectionHeader({
children,
includeMargin = true,
}: {
children: string | JSX.Element;
includeMargin?: boolean;
}) {
return (
<div
className={"font-bold text-xl my-auto" + (includeMargin ? " mb-4" : "")}
>
{children}
</div>
);
}
export function HidableSection({
children,
sectionTitle,
defaultHidden = false,
}: {
children: string | JSX.Element;
sectionTitle: string | JSX.Element;
defaultHidden?: boolean;
}) {
const [isHidden, setIsHidden] = useState(defaultHidden);
return (
<div>
<div
className="flex hover:bg-hover-light rounded cursor-pointer p-2"
onClick={() => setIsHidden(!isHidden)}
>
<SectionHeader includeMargin={false}>{sectionTitle}</SectionHeader>
<div className="my-auto ml-auto p-1">
{isHidden ? (
<FiChevronRight size={24} />
) : (
<FiChevronDown size={24} />
)}
</div>
</div>
{!isHidden && <div className="mx-2 mt-2">{children}</div>}
</div>
);
}

View File

@ -2,7 +2,14 @@
import { DocumentSet } from "@/lib/types";
import { Button, Divider, Text } from "@tremor/react";
import { ArrayHelpers, FieldArray, Form, Formik } from "formik";
import {
ArrayHelpers,
ErrorMessage,
Field,
FieldArray,
Form,
Formik,
} from "formik";
import * as Yup from "yup";
import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
@ -16,10 +23,8 @@ import {
SelectorFormField,
TextFormField,
} from "@/components/admin/connectors/Field";
function SectionHeader({ children }: { children: string | JSX.Element }) {
return <div className="mb-4 font-bold text-lg">{children}</div>;
}
import { HidableSection } from "./HidableSection";
import { FiPlus, FiX } from "react-icons/fi";
function Label({ children }: { children: string | JSX.Element }) {
return (
@ -97,6 +102,7 @@ export function PersonaEditor({
llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false,
llm_model_version_override:
existingPersona?.llm_model_version_override ?? null,
starter_messages: existingPersona?.starter_messages ?? null,
}}
validationSchema={Yup.object()
.shape({
@ -112,6 +118,13 @@ export function PersonaEditor({
include_citations: Yup.boolean().required(),
llm_relevance_filter: Yup.boolean().required(),
llm_model_version_override: Yup.string().nullable(),
starter_messages: Yup.array().of(
Yup.object().shape({
name: Yup.string().required(),
description: Yup.string().required(),
message: Yup.string().required(),
})
),
})
.test(
"system-prompt-or-task-prompt",
@ -188,8 +201,8 @@ export function PersonaEditor({
{({ isSubmitting, values, setFieldValue }) => (
<Form>
<div className="pb-6">
<SectionHeader>Who am I?</SectionHeader>
<HidableSection sectionTitle="Who am I?">
<>
<TextFormField
name="name"
label="Name"
@ -202,11 +215,13 @@ export function PersonaEditor({
label="Description"
subtext="Provide a short descriptions which gives users a hint as to what they should use this Persona for."
/>
</>
</HidableSection>
<Divider />
<SectionHeader>Customize my response style</SectionHeader>
<HidableSection sectionTitle="Customize my response style">
<>
<TextFormField
name="system_prompt"
label="System Prompt"
@ -281,15 +296,15 @@ export function PersonaEditor({
) : (
"-"
)}
</>
</HidableSection>
<Divider />
{!values.disable_retrieval && (
<>
<SectionHeader>
What data should I have access to?
</SectionHeader>
<HidableSection sectionTitle="What data should I have access to?">
<>
<FieldArray
name="document_set_ids"
render={(arrayHelpers: ArrayHelpers) => (
@ -305,10 +320,10 @@ export function PersonaEditor({
>
Document Sets
</Link>{" "}
that this Persona should search through. If none
are specified, the Persona will search through all
available documents in order to try and response
to queries.
that this Persona should search through. If
none are specified, the Persona will search
through all available documents in order to
try and response to queries.
</>
</SubLabel>
</div>
@ -353,6 +368,8 @@ export function PersonaEditor({
</div>
)}
/>
</>
</HidableSection>
<Divider />
</>
@ -360,11 +377,12 @@ export function PersonaEditor({
{llmOverrideOptions.length > 0 && defaultLLM && (
<>
<SectionHeader>[Advanced] Model Selection</SectionHeader>
<HidableSection sectionTitle="[Advanced] Model Selection">
<>
<Text>
Pick which LLM to use for this Persona. If left as Default,
will use <b className="italic">{defaultLLM}</b>.
Pick which LLM to use for this Persona. If left as
Default, will use <b className="italic">{defaultLLM}</b>
.
<br />
<br />
For more information on the different LLMs, checkout the{" "}
@ -391,27 +409,27 @@ export function PersonaEditor({
/>
</div>
</>
)}
</HidableSection>
<Divider />
</>
)}
{!values.disable_retrieval && (
<>
<SectionHeader>
[Advanced] Retrieval Customization
</SectionHeader>
<HidableSection sectionTitle="[Advanced] Retrieval Customization">
<>
<TextFormField
name="num_chunks"
label="Number of Chunks"
subtext={
<div>
How many chunks should we feed into the LLM when
generating the final response? Each chunk is ~400 words
long. If you are using gpt-3.5-turbo or other similar
models, setting this to a value greater than 5 will
result in errors at query time due to the model&apos;s
input length limit.
generating the final response? Each chunk is ~400
words long. If you are using gpt-3.5-turbo or other
similar models, setting this to a value greater than
5 will result in errors at query time due to the
model&apos;s input length limit.
<br />
<br />
If unspecified, will use 10 chunks.
@ -433,11 +451,156 @@ export function PersonaEditor({
"If enabled, the LLM will filter out chunks that are not relevant to the user query."
}
/>
</>
</HidableSection>
<Divider />
</>
)}
<HidableSection sectionTitle="[Advanced] Starter Messages">
<>
<div className="mb-4">
<SubLabel>
Starter Messages help guide users to use this Persona.
They are shown to the user as clickable options when they
select this Persona. When selected, the specified message
is sent to the LLM as the initial user message.
</SubLabel>
</div>
<FieldArray
name="starter_messages"
render={(arrayHelpers: ArrayHelpers) => (
<div>
{values.starter_messages &&
values.starter_messages.length > 0 &&
values.starter_messages.map((_, index) => (
<div
key={index}
className={index === 0 ? "mt-2" : "mt-6"}
>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Name</Label>
<SubLabel>
Shows up as the &quot;title&quot; for this
Starter Message. For example, &quot;Write
an email&quot;.
</SubLabel>
<Field
name={`starter_messages.${index}.name`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages.${index}.name`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label>Description</Label>
<SubLabel>
A description which tells the user what
they might want to use this Starter
Message for. For example &quot;to a client
about a new feature&quot;
</SubLabel>
<Field
name={`starter_messages.${index}.description`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages.${index}.description`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label>Message</Label>
<SubLabel>
The actual message to be sent as the
initial user message if a user selects
this starter prompt. For example,
&quot;Write me an email to a client about
a new billing feature we just
released.&quot;
</SubLabel>
<Field
name={`starter_messages.${index}.message`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
as="textarea"
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages.${index}.message`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div>
</div>
))}
<Button
onClick={() => {
arrayHelpers.push("");
}}
className="mt-3"
color="green"
size="xs"
type="button"
icon={FiPlus}
>
Add New
</Button>
</div>
)}
/>
</>
</HidableSection>
<Divider />
<div className="flex">
<Button
className="mx-auto"

View File

@ -1,5 +1,11 @@
import { DocumentSet } from "@/lib/types";
export interface StarterMessage {
name: string;
description: string | null;
message: string;
}
export interface Prompt {
id: number;
name: string;
@ -25,5 +31,6 @@ export interface Persona {
llm_relevance_filter?: boolean;
llm_filter_extraction?: boolean;
llm_model_version_override?: string;
starter_messages: StarterMessage[] | null;
default_persona: boolean;
}

View File

@ -1,4 +1,4 @@
import { Persona, Prompt } from "./interfaces";
import { Persona, Prompt, StarterMessage } from "./interfaces";
interface PersonaCreationRequest {
name: string;
@ -10,6 +10,7 @@ interface PersonaCreationRequest {
include_citations: boolean;
llm_relevance_filter: boolean | null;
llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
}
interface PersonaUpdateRequest {
@ -24,6 +25,7 @@ interface PersonaUpdateRequest {
include_citations: boolean;
llm_relevance_filter: boolean | null;
llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
}
function promptNameFromPersonaName(personaName: string) {
@ -109,6 +111,7 @@ function buildPersonaAPIBody(
prompt_ids: [promptId],
document_set_ids,
llm_model_version_override: creationRequest.llm_model_version_override,
starter_messages: creationRequest.starter_messages,
};
}

View File

@ -43,6 +43,7 @@ import { ChatIntro } from "./ChatIntro";
import { HEADER_PADDING } from "@/lib/constants";
import { computeAvailableFilters } from "@/lib/filters";
import { useDocumentSelection } from "./useDocumentSelection";
import { StarterMessage } from "./StarterMessage";
const MAX_INPUT_HEIGHT = 200;
@ -290,10 +291,12 @@ export const Chat = ({
const onSubmit = async ({
messageIdToResend,
messageOverride,
queryOverride,
forceSearch,
}: {
messageIdToResend?: number;
messageOverride?: string;
queryOverride?: string;
forceSearch?: boolean;
} = {}) => {
@ -321,7 +324,10 @@ export const Chat = ({
return;
}
const currMessage = messageToResend ? messageToResend.message : message;
let currMessage = messageToResend ? messageToResend.message : message;
if (messageOverride) {
currMessage = messageOverride;
}
const currMessageHistory =
messageToResendIndex !== null
? messageHistory.slice(0, messageToResendIndex)
@ -685,6 +691,28 @@ export const Chat = ({
{/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
<div className={`min-h-[30px] w-full`}></div>
{livePersona &&
livePersona.starter_messages &&
livePersona.starter_messages.length > 0 &&
selectedPersona &&
messageHistory.length === 0 &&
!isFetchingChatMessages && (
<div className="mx-auto px-4 w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar grid gap-4 grid-cols-1 grid-rows-1 mt-4 md:grid-cols-2">
{livePersona.starter_messages.map((starterMessage, i) => (
<div key={i} className="w-full">
<StarterMessage
starterMessage={starterMessage}
onClick={() =>
onSubmit({
messageOverride: starterMessage.message,
})
}
/>
</div>
))}
</div>
)}
<div ref={endDivRef} />
</div>
</div>

View File

@ -20,7 +20,7 @@ function HelperItemDisplay({
description: string;
}) {
return (
<div className="cursor-default hover:bg-hover-light border border-border rounded py-2 px-4">
<div className="cursor-pointer hover:bg-hover-light border border-border rounded py-2 px-4">
<div className="text-emphasis font-bold text-lg flex">{title}</div>
<div className="text-sm">{description}</div>
</div>

View File

@ -0,0 +1,21 @@
import { StarterMessage } from "../admin/personas/interfaces";
export function StarterMessage({
starterMessage,
onClick,
}: {
starterMessage: StarterMessage;
onClick: () => void;
}) {
return (
<div
className={
"py-2 px-3 rounded border border-border bg-white cursor-pointer hover:bg-hover-light h-full"
}
onClick={onClick}
>
<p className="font-medium text-neutral-700">{starterMessage.name}</p>
<p className="text-neutral-500 text-sm">{starterMessage.description}</p>
</div>
);
}