Persona enhancements

This commit is contained in:
Weves
2023-12-07 10:45:59 -08:00
committed by Chris Weaver
parent ddf3f99da4
commit d5658ce477
12 changed files with 304 additions and 149 deletions

View File

@@ -185,6 +185,7 @@ def build_qa_response_blocks(
source_filters: list[DocumentSource] | None,
time_cutoff: datetime | None,
favor_recent: bool,
skip_quotes: bool = False,
) -> list[Block]:
quotes_blocks: list[Block] = []
@@ -232,8 +233,9 @@ def build_qa_response_blocks(
if filter_block is not None:
response_blocks.append(filter_block)
response_blocks.extend(
[answer_block, feedback_block] + quotes_blocks + [DividerBlock()]
)
response_blocks.extend([answer_block, feedback_block])
if not skip_quotes:
response_blocks.extend(quotes_blocks)
response_blocks.append(DividerBlock())
return response_blocks

View File

@@ -104,6 +104,8 @@ def handle_message(
document_set.name for document_set in persona.document_sets
]
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
# List of user id to send message to, if None, send to everyone in channel
send_to: list[str] | None = None
respond_tag_only = False
@@ -257,7 +259,7 @@ def handle_message(
logger.debug(answer.answer)
return True
if not answer.top_documents:
if not answer.top_documents and not should_respond_even_with_no_docs:
logger.error(f"Unable to answer question: '{msg}' - no documents found")
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
@@ -288,6 +290,7 @@ def handle_message(
source_filters=answer.source_type,
time_cutoff=answer.time_cutoff,
favor_recent=answer.favor_recent,
skip_quotes=persona is not None, # currently Personas don't support quotes
)
# Get the chunks fed to the LLM only, then fill with other docs
@@ -298,10 +301,14 @@ def handle_message(
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = build_documents_blocks(
document_blocks = (
build_documents_blocks(
documents=priority_ordered_docs,
query_event_id=answer.query_event_id,
)
if priority_ordered_docs
else []
)
try:
respond_in_thread(

View File

@@ -54,6 +54,13 @@ def _get_qa_model(persona: Persona | None) -> QAModel:
return get_default_qa_model()
def _dummy_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]:
"""Mimics the interface of `full_chunk_search_generator` but returns empty lists
without actually running retrieval / re-ranking."""
yield cast(list[InferenceChunk], [])
yield cast(list[bool], [])
@log_function_time()
def answer_qa_query(
new_message_request: NewMessageRequest,
@@ -91,6 +98,7 @@ def answer_qa_query(
not persona.apply_llm_relevance_filter if persona else None
)
persona_num_chunks = persona.num_chunks if persona else None
persona_retrieval_disabled = persona.num_chunks == 0 if persona else False
if persona:
logger.info(f"Using persona: {persona.name}")
logger.info(
@@ -113,6 +121,7 @@ def answer_qa_query(
if disable_generative_answer:
predicted_flow = QueryFlow.SEARCH
if not persona_retrieval_disabled:
top_chunks, llm_chunk_selection = full_chunk_search(
query=retrieval_request,
document_index=get_default_document_index(),
@@ -121,6 +130,10 @@ def answer_qa_query(
)
top_docs = chunks_to_search_docs(top_chunks)
else:
top_chunks = []
llm_chunk_selection = []
top_docs = []
partial_response = partial(
QAResponse,
@@ -133,7 +146,7 @@ def answer_qa_query(
favor_recent=retrieval_request.favor_recent,
)
if disable_generative_answer or not top_docs:
if disable_generative_answer or (not top_docs and not persona_retrieval_disabled):
return partial_response(
answer=None,
quotes=None,
@@ -237,6 +250,7 @@ def answer_qa_query_stream(
not persona.apply_llm_relevance_filter if persona else None
)
persona_num_chunks = persona.num_chunks if persona else None
persona_retrieval_disabled = persona.num_chunks == 0 if persona else False
if persona:
logger.info(f"Using persona: {persona.name}")
logger.info(
@@ -245,6 +259,10 @@ def answer_qa_query_stream(
f"num_chunks: {persona_num_chunks}"
)
# NOTE: it's not ideal that we're still doing `retrieval_preprocessing` even
# if `persona_retrieval_disabled == True`, but it's a bit tricky to separate this
# out. Since this flow is being re-worked shortly with the move to chat, leaving it
# like this for now.
retrieval_request, predicted_search_type, predicted_flow = retrieval_preprocessing(
new_message_request=new_message_request,
user=user,
@@ -257,10 +275,13 @@ def answer_qa_query_stream(
if persona:
predicted_flow = QueryFlow.QUESTION_ANSWER
if not persona_retrieval_disabled:
search_generator = full_chunk_search_generator(
query=retrieval_request,
document_index=get_default_document_index(),
)
else:
search_generator = _dummy_search_generator()
# first fetch and return to the UI the top chunks so the user can
# immediately see some results
@@ -280,7 +301,9 @@ def answer_qa_query_stream(
).dict()
yield get_json_line(initial_response)
if not top_chunks:
# some personas intentionally don't retrieve any documents, so we should
# not return early here
if not top_chunks and not persona_retrieval_disabled:
logger.debug("No Documents Found")
return

View File

@@ -26,6 +26,7 @@ from danswer.prompts.direct_qa_prompts import COT_PROMPT
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_up_code_blocks
@@ -210,6 +211,13 @@ class PersonaBasedQAHandler(QAHandler):
) -> list[BaseMessage]:
context_docs_str = build_context_str(context_chunks)
if not context_chunks:
single_message = PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query=query,
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
else:
single_message = PARAMATERIZED_PROMPT.format(
context_docs_str=context_docs_str,
user_query=query,
@@ -220,9 +228,14 @@ class PersonaBasedQAHandler(QAHandler):
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
def build_dummy_prompt(
self,
) -> str:
def build_dummy_prompt(self, retrieval_disabled: bool) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",

View File

@@ -137,6 +137,15 @@ CONTEXT:
RESPONSE:
""".strip()
PARAMATERIZED_PROMPT_WITHOUT_CONTEXT = f"""
{{system_prompt}}
{{task_prompt}}
{QUESTION_PAT.upper()} {{user_query}}
RESPONSE:
""".strip()
# User the following for easy viewing of prompts
if __name__ == "__main__":

View File

@@ -131,12 +131,13 @@ def get_persona(
def build_final_template_prompt(
system_prompt: str,
task_prompt: str,
retrieval_disabled: bool = False,
_: User | None = Depends(current_user),
) -> PromptTemplateResponse:
return PromptTemplateResponse(
final_prompt_template=PersonaBasedQAHandler(
system_prompt=system_prompt, task_prompt=task_prompt
).build_dummy_prompt()
).build_dummy_prompt(retrieval_disabled=retrieval_disabled)
)

View File

@@ -21,6 +21,7 @@ class PersonaSnapshot(BaseModel):
description: str
system_prompt: str
task_prompt: str
num_chunks: int | None
document_sets: list[DocumentSet]
llm_model_version_override: str | None
@@ -32,6 +33,7 @@ class PersonaSnapshot(BaseModel):
description=persona.description or "",
system_prompt=persona.system_text or "",
task_prompt=persona.hint_text or "",
num_chunks=persona.num_chunks,
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets

View File

@@ -2,7 +2,7 @@
import { DocumentSet } from "@/lib/types";
import { Button, Divider, Text } from "@tremor/react";
import { ArrayHelpers, FieldArray, Form, Formik } from "formik";
import { ArrayHelpers, ErrorMessage, FieldArray, Form, Formik } from "formik";
import * as Yup from "yup";
import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
@@ -13,6 +13,7 @@ import Link from "next/link";
import { useEffect, useState } from "react";
import {
BooleanFormField,
ManualErrorMessage,
SelectorFormField,
TextFormField,
} from "@/components/admin/connectors/Field";
@@ -46,12 +47,18 @@ export function PersonaEditor({
const { popup, setPopup } = usePopup();
const [finalPrompt, setFinalPrompt] = useState<string | null>("");
const [finalPromptError, setFinalPromptError] = useState<string>("");
const triggerFinalPromptUpdate = async (
systemPrompt: string,
taskPrompt: string
taskPrompt: string,
retrievalDisabled: boolean
) => {
const response = await buildFinalPrompt(systemPrompt, taskPrompt);
const response = await buildFinalPrompt(
systemPrompt,
taskPrompt,
retrievalDisabled
);
if (response.ok) {
setFinalPrompt((await response.json()).final_prompt_template);
}
@@ -63,7 +70,8 @@ export function PersonaEditor({
if (isUpdate) {
triggerFinalPromptUpdate(
existingPersona.system_prompt,
existingPersona.task_prompt
existingPersona.task_prompt,
existingPersona.num_chunks === 0
);
}
}, []);
@@ -78,6 +86,7 @@ export function PersonaEditor({
description: existingPersona?.description ?? "",
system_prompt: existingPersona?.system_prompt ?? "",
task_prompt: existingPersona?.task_prompt ?? "",
disable_retrieval: (existingPersona?.num_chunks ?? 5) === 0,
document_set_ids:
existingPersona?.document_sets?.map(
(documentSet) => documentSet.id
@@ -88,36 +97,68 @@ export function PersonaEditor({
llm_model_version_override:
existingPersona?.llm_model_version_override ?? null,
}}
validationSchema={Yup.object().shape({
validationSchema={Yup.object()
.shape({
name: Yup.string().required("Must give the Persona a name!"),
description: Yup.string().required(
"Must give the Persona a description!"
),
system_prompt: Yup.string().required(
"Must give the Persona a system prompt!"
),
task_prompt: Yup.string().required(
"Must give the Persona a task prompt!"
),
system_prompt: Yup.string(),
task_prompt: Yup.string(),
disable_retrieval: Yup.boolean().required(),
document_set_ids: Yup.array().of(Yup.number()),
num_chunks: Yup.number().max(20).nullable(),
apply_llm_relevance_filter: Yup.boolean().required(),
llm_model_version_override: Yup.string().nullable(),
})}
})
.test(
"system-prompt-or-task-prompt",
"Must provide at least one of System Prompt or Task Prompt",
(values) => {
const systemPromptSpecified = values.system_prompt
? values.system_prompt.length > 0
: false;
const taskPromptSpecified = values.task_prompt
? values.task_prompt.length > 0
: false;
if (systemPromptSpecified || taskPromptSpecified) {
setFinalPromptError("");
return true;
} // Return true if at least one field has a value
setFinalPromptError(
"Must provide at least one of System Prompt or Task Prompt"
);
}
)}
onSubmit={async (values, formikHelpers) => {
if (finalPromptError) {
setPopup({
type: "error",
message: "Cannot submit while there are errors in the form!",
});
return;
}
formikHelpers.setSubmitting(true);
// if disable_retrieval is set, set num_chunks to 0
// to tell the backend to not fetch any documents
const numChunks = values.disable_retrieval
? 0
: values.num_chunks || 5;
let response;
if (isUpdate) {
response = await updatePersona({
id: existingPersona.id,
...values,
num_chunks: values.num_chunks || null,
num_chunks: numChunks,
});
} else {
response = await createPersona({
...values,
num_chunks: values.num_chunks || null,
num_chunks: numChunks,
});
}
if (response.ok) {
@@ -163,8 +204,13 @@ export function PersonaEditor({
}
onChange={(e) => {
setFieldValue("system_prompt", e.target.value);
triggerFinalPromptUpdate(e.target.value, values.task_prompt);
triggerFinalPromptUpdate(
e.target.value,
values.task_prompt,
values.disable_retrieval
);
}}
error={finalPromptError}
/>
<TextFormField
@@ -178,7 +224,26 @@ export function PersonaEditor({
setFieldValue("task_prompt", e.target.value);
triggerFinalPromptUpdate(
values.system_prompt,
e.target.value
e.target.value,
values.disable_retrieval
);
}}
error={finalPromptError}
/>
<BooleanFormField
name="disable_retrieval"
label="Disable Retrieval"
subtext={`
If set, the Persona will not fetch any context documents to aid in the response.
Instead, it will only use the supplied system and task prompts plus the user
query in order to generate a response`}
onChange={(e) => {
setFieldValue("disable_retrieval", e.target.checked);
triggerFinalPromptUpdate(
values.system_prompt,
values.task_prompt,
e.target.checked
);
}}
/>
@@ -195,7 +260,11 @@ export function PersonaEditor({
<Divider />
<SectionHeader>What data should I have access to?</SectionHeader>
{!values.disable_retrieval && (
<>
<SectionHeader>
What data should I have access to?
</SectionHeader>
<FieldArray
name="document_set_ids"
@@ -212,10 +281,10 @@ export function PersonaEditor({
>
Document Sets
</Link>{" "}
that this Persona should search through. If none are
specified, the Persona will search through all
available documents in order to try and response to
queries.
that this Persona should search through. If none
are specified, the Persona will search through all
available documents in order to try and response
to queries.
</>
</SubLabel>
</div>
@@ -250,7 +319,9 @@ export function PersonaEditor({
}
}}
>
<div className="my-auto">{documentSet.name}</div>
<div className="my-auto">
{documentSet.name}
</div>
</div>
);
})}
@@ -260,6 +331,8 @@ export function PersonaEditor({
/>
<Divider />
</>
)}
{llmOverrideOptions.length > 0 && defaultLLM && (
<>
@@ -296,18 +369,23 @@ export function PersonaEditor({
<Divider />
<SectionHeader>[Advanced] Retrieval Customization</SectionHeader>
{!values.disable_retrieval && (
<>
<SectionHeader>
[Advanced] Retrieval Customization
</SectionHeader>
<TextFormField
name="num_chunks"
label="Number of Chunks"
subtext={
<div>
How many chunks should we feed into the LLM when generating
the final response? Each chunk is ~400 words long. If you
are using gpt-3.5-turbo or other similar models, setting
this to a value greater than 5 will result in errors at
query time due to the model&apos;s input length limit.
How many chunks should we feed into the LLM when
generating the final response? Each chunk is ~400 words
long. If you are using gpt-3.5-turbo or other similar
models, setting this to a value greater than 5 will
result in errors at query time due to the model&apos;s
input length limit.
<br />
<br />
If unspecified, will use 5 chunks.
@@ -331,6 +409,8 @@ export function PersonaEditor({
/>
<Divider />
</>
)}
<div className="flex">
<Button

View File

@@ -46,10 +46,15 @@ export function deletePersona(personaId: number) {
});
}
export function buildFinalPrompt(systemPrompt: string, taskPrompt: string) {
export function buildFinalPrompt(
systemPrompt: string,
taskPrompt: string,
retrievalDisabled: boolean
) {
let queryString = Object.entries({
system_prompt: systemPrompt,
task_prompt: taskPrompt,
retrieval_disabled: retrievalDisabled,
})
.map(
([key, value]) =>

View File

@@ -30,6 +30,10 @@ export function SubLabel({ children }: { children: string | JSX.Element }) {
return <div className="text-sm text-gray-300 mb-2">{children}</div>;
}
export function ManualErrorMessage({ children }: { children: string }) {
return <div className="text-red-500 text-sm mt-1">{children}</div>;
}
export function TextFormField({
name,
label,
@@ -40,6 +44,7 @@ export function TextFormField({
isTextArea = false,
disabled = false,
autoCompleteDisabled = true,
error,
}: {
name: string;
label: string;
@@ -50,6 +55,7 @@ export function TextFormField({
isTextArea?: boolean;
disabled?: boolean;
autoCompleteDisabled?: boolean;
error?: string;
}) {
return (
<div className="mb-4">
@@ -78,11 +84,15 @@ export function TextFormField({
autoComplete={autoCompleteDisabled ? "off" : undefined}
{...(onChange ? { onChange } : {})}
/>
{error ? (
<ManualErrorMessage>{error}</ManualErrorMessage>
) : (
<ErrorMessage
name={name}
component="div"
className="text-red-500 text-sm mt-1"
/>
)}
</div>
);
}
@@ -91,12 +101,14 @@ interface BooleanFormFieldProps {
name: string;
label: string;
subtext?: string;
onChange?: (e: React.ChangeEvent<HTMLInputElement>) => void;
}
export const BooleanFormField = ({
name,
label,
subtext,
onChange,
}: BooleanFormFieldProps) => {
return (
<div className="mb-4">
@@ -105,6 +117,7 @@ export const BooleanFormField = ({
name={name}
type="checkbox"
className="mx-3 px-5 w-3.5 h-3.5 my-auto"
{...(onChange ? { onChange } : {})}
/>
<div>
<Label>{label}</Label>

View File

@@ -86,7 +86,8 @@ export const SearchResultsDisplay = ({
if (
answer === null &&
(documents === null || documents.length === 0) &&
quotes === null
quotes === null &&
!isFetching
) {
return (
<div className="mt-4">

View File

@@ -43,7 +43,6 @@ export const questionValidationStreamed = async <T>({
let previousPartialChunk: string | null = null;
while (true) {
const rawChunk = await reader?.read();
console.log(rawChunk);
if (!rawChunk) {
throw new Error("Unable to process chunk");
}