diff --git a/backend/danswer/danswerbot/slack/blocks.py b/backend/danswer/danswerbot/slack/blocks.py index 434d4dfee575..93356aed5336 100644 --- a/backend/danswer/danswerbot/slack/blocks.py +++ b/backend/danswer/danswerbot/slack/blocks.py @@ -185,6 +185,7 @@ def build_qa_response_blocks( source_filters: list[DocumentSource] | None, time_cutoff: datetime | None, favor_recent: bool, + skip_quotes: bool = False, ) -> list[Block]: quotes_blocks: list[Block] = [] @@ -232,8 +233,9 @@ def build_qa_response_blocks( if filter_block is not None: response_blocks.append(filter_block) - response_blocks.extend( - [answer_block, feedback_block] + quotes_blocks + [DividerBlock()] - ) + response_blocks.extend([answer_block, feedback_block]) + if not skip_quotes: + response_blocks.extend(quotes_blocks) + response_blocks.append(DividerBlock()) return response_blocks diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index b611b78b4a5f..10157a1ab11c 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -104,6 +104,8 @@ def handle_message( document_set.name for document_set in persona.document_sets ] + should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False + # List of user id to send message to, if None, send to everyone in channel send_to: list[str] | None = None respond_tag_only = False @@ -257,7 +259,7 @@ def handle_message( logger.debug(answer.answer) return True - if not answer.top_documents: + if not answer.top_documents and not should_respond_even_with_no_docs: logger.error(f"Unable to answer question: '{msg}' - no documents found") # Optionally, respond in thread with the error message, Used primarily # for debugging purposes @@ -288,6 +290,7 @@ def handle_message( source_filters=answer.source_type, time_cutoff=answer.time_cutoff, favor_recent=answer.favor_recent, + skip_quotes=persona is not None, # currently Personas don't support quotes ) # Get the chunks fed to the LLM only, then fill with other docs @@ -298,9 +301,13 @@ def handle_message( doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds ] priority_ordered_docs = llm_docs + remaining_docs - document_blocks = build_documents_blocks( - documents=priority_ordered_docs, - query_event_id=answer.query_event_id, + document_blocks = ( + build_documents_blocks( + documents=priority_ordered_docs, + query_event_id=answer.query_event_id, + ) + if priority_ordered_docs + else [] ) try: diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index d3e1f057d100..270f805f1ee4 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -54,6 +54,13 @@ def _get_qa_model(persona: Persona | None) -> QAModel: return get_default_qa_model() +def _dummy_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]: + """Mimics the interface of `full_chunk_search_generator` but returns empty lists + without actually running retrieval / re-ranking.""" + yield cast(list[InferenceChunk], []) + yield cast(list[bool], []) + + @log_function_time() def answer_qa_query( new_message_request: NewMessageRequest, @@ -91,6 +98,7 @@ def answer_qa_query( not persona.apply_llm_relevance_filter if persona else None ) persona_num_chunks = persona.num_chunks if persona else None + persona_retrieval_disabled = persona.num_chunks == 0 if persona else False if persona: logger.info(f"Using persona: {persona.name}") logger.info( @@ -113,14 +121,19 @@ def answer_qa_query( if disable_generative_answer: predicted_flow = QueryFlow.SEARCH - top_chunks, llm_chunk_selection = full_chunk_search( - query=retrieval_request, - document_index=get_default_document_index(), - retrieval_metrics_callback=retrieval_metrics_callback, - rerank_metrics_callback=rerank_metrics_callback, - ) + if not persona_retrieval_disabled: + top_chunks, llm_chunk_selection = full_chunk_search( + query=retrieval_request, + document_index=get_default_document_index(), + retrieval_metrics_callback=retrieval_metrics_callback, + rerank_metrics_callback=rerank_metrics_callback, + ) - top_docs = chunks_to_search_docs(top_chunks) + top_docs = chunks_to_search_docs(top_chunks) + else: + top_chunks = [] + llm_chunk_selection = [] + top_docs = [] partial_response = partial( QAResponse, @@ -133,7 +146,7 @@ def answer_qa_query( favor_recent=retrieval_request.favor_recent, ) - if disable_generative_answer or not top_docs: + if disable_generative_answer or (not top_docs and not persona_retrieval_disabled): return partial_response( answer=None, quotes=None, @@ -237,6 +250,7 @@ def answer_qa_query_stream( not persona.apply_llm_relevance_filter if persona else None ) persona_num_chunks = persona.num_chunks if persona else None + persona_retrieval_disabled = persona.num_chunks == 0 if persona else False if persona: logger.info(f"Using persona: {persona.name}") logger.info( @@ -245,6 +259,10 @@ def answer_qa_query_stream( f"num_chunks: {persona_num_chunks}" ) + # NOTE: it's not ideal that we're still doing `retrieval_preprocessing` even + # if `persona_retrieval_disabled == True`, but it's a bit tricky to separate this + # out. Since this flow is being re-worked shortly with the move to chat, leaving it + # like this for now. retrieval_request, predicted_search_type, predicted_flow = retrieval_preprocessing( new_message_request=new_message_request, user=user, @@ -257,10 +275,13 @@ def answer_qa_query_stream( if persona: predicted_flow = QueryFlow.QUESTION_ANSWER - search_generator = full_chunk_search_generator( - query=retrieval_request, - document_index=get_default_document_index(), - ) + if not persona_retrieval_disabled: + search_generator = full_chunk_search_generator( + query=retrieval_request, + document_index=get_default_document_index(), + ) + else: + search_generator = _dummy_search_generator() # first fetch and return to the UI the top chunks so the user can # immediately see some results @@ -280,7 +301,9 @@ def answer_qa_query_stream( ).dict() yield get_json_line(initial_response) - if not top_chunks: + # some personas intentionally don't retrieve any documents, so we should + # not return early here + if not top_chunks and not persona_retrieval_disabled: logger.debug("No Documents Found") return diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index dac919178595..439f88a266d3 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -26,6 +26,7 @@ from danswer.prompts.direct_qa_prompts import COT_PROMPT from danswer.prompts.direct_qa_prompts import JSON_PROMPT from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT from danswer.utils.logger import setup_logger from danswer.utils.text_processing import clean_up_code_blocks @@ -210,19 +211,31 @@ class PersonaBasedQAHandler(QAHandler): ) -> list[BaseMessage]: context_docs_str = build_context_str(context_chunks) - single_message = PARAMATERIZED_PROMPT.format( - context_docs_str=context_docs_str, - user_query=query, - system_prompt=self.system_prompt, - task_prompt=self.task_prompt, - ).strip() + if not context_chunks: + single_message = PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( + user_query=query, + system_prompt=self.system_prompt, + task_prompt=self.task_prompt, + ).strip() + else: + single_message = PARAMATERIZED_PROMPT.format( + context_docs_str=context_docs_str, + user_query=query, + system_prompt=self.system_prompt, + task_prompt=self.task_prompt, + ).strip() prompt: list[BaseMessage] = [HumanMessage(content=single_message)] return prompt - def build_dummy_prompt( - self, - ) -> str: + def build_dummy_prompt(self, retrieval_disabled: bool) -> str: + if retrieval_disabled: + return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( + user_query="", + system_prompt=self.system_prompt, + task_prompt=self.task_prompt, + ).strip() + return PARAMATERIZED_PROMPT.format( context_docs_str="", user_query="", diff --git a/backend/danswer/prompts/direct_qa_prompts.py b/backend/danswer/prompts/direct_qa_prompts.py index a6ab5908e47d..24c9226638a1 100644 --- a/backend/danswer/prompts/direct_qa_prompts.py +++ b/backend/danswer/prompts/direct_qa_prompts.py @@ -137,6 +137,15 @@ CONTEXT: RESPONSE: """.strip() +PARAMATERIZED_PROMPT_WITHOUT_CONTEXT = f""" +{{system_prompt}} + +{{task_prompt}} + +{QUESTION_PAT.upper()} {{user_query}} +RESPONSE: +""".strip() + # User the following for easy viewing of prompts if __name__ == "__main__": diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index 919773120ea2..8656054ee35e 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -131,12 +131,13 @@ def get_persona( def build_final_template_prompt( system_prompt: str, task_prompt: str, + retrieval_disabled: bool = False, _: User | None = Depends(current_user), ) -> PromptTemplateResponse: return PromptTemplateResponse( final_prompt_template=PersonaBasedQAHandler( system_prompt=system_prompt, task_prompt=task_prompt - ).build_dummy_prompt() + ).build_dummy_prompt(retrieval_disabled=retrieval_disabled) ) diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index fed8503a8f38..5865e201e71b 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -21,6 +21,7 @@ class PersonaSnapshot(BaseModel): description: str system_prompt: str task_prompt: str + num_chunks: int | None document_sets: list[DocumentSet] llm_model_version_override: str | None @@ -32,6 +33,7 @@ class PersonaSnapshot(BaseModel): description=persona.description or "", system_prompt=persona.system_text or "", task_prompt=persona.hint_text or "", + num_chunks=persona.num_chunks, document_sets=[ DocumentSet.from_model(document_set_model) for document_set_model in persona.document_sets diff --git a/web/src/app/admin/personas/PersonaEditor.tsx b/web/src/app/admin/personas/PersonaEditor.tsx index b34cee5ef679..a357fc4c1d70 100644 --- a/web/src/app/admin/personas/PersonaEditor.tsx +++ b/web/src/app/admin/personas/PersonaEditor.tsx @@ -2,7 +2,7 @@ import { DocumentSet } from "@/lib/types"; import { Button, Divider, Text } from "@tremor/react"; -import { ArrayHelpers, FieldArray, Form, Formik } from "formik"; +import { ArrayHelpers, ErrorMessage, FieldArray, Form, Formik } from "formik"; import * as Yup from "yup"; import { buildFinalPrompt, createPersona, updatePersona } from "./lib"; @@ -13,6 +13,7 @@ import Link from "next/link"; import { useEffect, useState } from "react"; import { BooleanFormField, + ManualErrorMessage, SelectorFormField, TextFormField, } from "@/components/admin/connectors/Field"; @@ -46,12 +47,18 @@ export function PersonaEditor({ const { popup, setPopup } = usePopup(); const [finalPrompt, setFinalPrompt] = useState(""); + const [finalPromptError, setFinalPromptError] = useState(""); const triggerFinalPromptUpdate = async ( systemPrompt: string, - taskPrompt: string + taskPrompt: string, + retrievalDisabled: boolean ) => { - const response = await buildFinalPrompt(systemPrompt, taskPrompt); + const response = await buildFinalPrompt( + systemPrompt, + taskPrompt, + retrievalDisabled + ); if (response.ok) { setFinalPrompt((await response.json()).final_prompt_template); } @@ -63,7 +70,8 @@ export function PersonaEditor({ if (isUpdate) { triggerFinalPromptUpdate( existingPersona.system_prompt, - existingPersona.task_prompt + existingPersona.task_prompt, + existingPersona.num_chunks === 0 ); } }, []); @@ -78,6 +86,7 @@ export function PersonaEditor({ description: existingPersona?.description ?? "", system_prompt: existingPersona?.system_prompt ?? "", task_prompt: existingPersona?.task_prompt ?? "", + disable_retrieval: (existingPersona?.num_chunks ?? 5) === 0, document_set_ids: existingPersona?.document_sets?.map( (documentSet) => documentSet.id @@ -88,36 +97,68 @@ export function PersonaEditor({ llm_model_version_override: existingPersona?.llm_model_version_override ?? null, }} - validationSchema={Yup.object().shape({ - name: Yup.string().required("Must give the Persona a name!"), - description: Yup.string().required( - "Must give the Persona a description!" - ), - system_prompt: Yup.string().required( - "Must give the Persona a system prompt!" - ), - task_prompt: Yup.string().required( - "Must give the Persona a task prompt!" - ), - document_set_ids: Yup.array().of(Yup.number()), - num_chunks: Yup.number().max(20).nullable(), - apply_llm_relevance_filter: Yup.boolean().required(), - llm_model_version_override: Yup.string().nullable(), - })} + validationSchema={Yup.object() + .shape({ + name: Yup.string().required("Must give the Persona a name!"), + description: Yup.string().required( + "Must give the Persona a description!" + ), + system_prompt: Yup.string(), + task_prompt: Yup.string(), + disable_retrieval: Yup.boolean().required(), + document_set_ids: Yup.array().of(Yup.number()), + num_chunks: Yup.number().max(20).nullable(), + apply_llm_relevance_filter: Yup.boolean().required(), + llm_model_version_override: Yup.string().nullable(), + }) + .test( + "system-prompt-or-task-prompt", + "Must provide at least one of System Prompt or Task Prompt", + (values) => { + const systemPromptSpecified = values.system_prompt + ? values.system_prompt.length > 0 + : false; + const taskPromptSpecified = values.task_prompt + ? values.task_prompt.length > 0 + : false; + if (systemPromptSpecified || taskPromptSpecified) { + setFinalPromptError(""); + return true; + } // Return true if at least one field has a value + + setFinalPromptError( + "Must provide at least one of System Prompt or Task Prompt" + ); + } + )} onSubmit={async (values, formikHelpers) => { + if (finalPromptError) { + setPopup({ + type: "error", + message: "Cannot submit while there are errors in the form!", + }); + return; + } + formikHelpers.setSubmitting(true); + // if disable_retrieval is set, set num_chunks to 0 + // to tell the backend to not fetch any documents + const numChunks = values.disable_retrieval + ? 0 + : values.num_chunks || 5; + let response; if (isUpdate) { response = await updatePersona({ id: existingPersona.id, ...values, - num_chunks: values.num_chunks || null, + num_chunks: numChunks, }); } else { response = await createPersona({ ...values, - num_chunks: values.num_chunks || null, + num_chunks: numChunks, }); } if (response.ok) { @@ -163,8 +204,13 @@ export function PersonaEditor({ } onChange={(e) => { setFieldValue("system_prompt", e.target.value); - triggerFinalPromptUpdate(e.target.value, values.task_prompt); + triggerFinalPromptUpdate( + e.target.value, + values.task_prompt, + values.disable_retrieval + ); }} + error={finalPromptError} /> + + { + setFieldValue("disable_retrieval", e.target.checked); + triggerFinalPromptUpdate( + values.system_prompt, + values.task_prompt, + e.target.checked ); }} /> @@ -195,41 +260,45 @@ export function PersonaEditor({ - What data should I have access to? + {!values.disable_retrieval && ( + <> + + What data should I have access to? + - ( -
-
- - <> - Select which{" "} - - Document Sets - {" "} - that this Persona should search through. If none are - specified, the Persona will search through all - available documents in order to try and response to - queries. - - -
-
- {documentSets.map((documentSet) => { - const ind = values.document_set_ids.indexOf( - documentSet.id - ); - let isSelected = ind !== -1; - return ( -
( +
+
+ + <> + Select which{" "} + + Document Sets + {" "} + that this Persona should search through. If none + are specified, the Persona will search through all + available documents in order to try and response + to queries. + + +
+
+ {documentSets.map((documentSet) => { + const ind = values.document_set_ids.indexOf( + documentSet.id + ); + let isSelected = ind !== -1; + return ( +
{ - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push(documentSet.id); - } - }} - > -
{documentSet.name}
-
- ); - })} -
-
- )} - /> + (isSelected + ? " bg-gray-600" + : " bg-gray-900 hover:bg-gray-700") + } + onClick={() => { + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(documentSet.id); + } + }} + > +
+ {documentSet.name} +
+
+ ); + })} +
+
+ )} + /> - + + + )} {llmOverrideOptions.length > 0 && defaultLLM && ( <> @@ -296,41 +369,48 @@ export function PersonaEditor({ - [Advanced] Retrieval Customization + {!values.disable_retrieval && ( + <> + + [Advanced] Retrieval Customization + - - How many chunks should we feed into the LLM when generating - the final response? Each chunk is ~400 words long. If you - are using gpt-3.5-turbo or other similar models, setting - this to a value greater than 5 will result in errors at - query time due to the model's input length limit. -
-
- If unspecified, will use 5 chunks. - - } - onChange={(e) => { - const value = e.target.value; - // Allow only integer values - if (value === "" || /^[0-9]+$/.test(value)) { - setFieldValue("num_chunks", value); - } - }} - /> + + How many chunks should we feed into the LLM when + generating the final response? Each chunk is ~400 words + long. If you are using gpt-3.5-turbo or other similar + models, setting this to a value greater than 5 will + result in errors at query time due to the model's + input length limit. +
+
+ If unspecified, will use 5 chunks. + + } + onChange={(e) => { + const value = e.target.value; + // Allow only integer values + if (value === "" || /^[0-9]+$/.test(value)) { + setFieldValue("num_chunks", value); + } + }} + /> - + - + + + )}