mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Persona enhancements
This commit is contained in:
@@ -185,6 +185,7 @@ def build_qa_response_blocks(
|
||||
source_filters: list[DocumentSource] | None,
|
||||
time_cutoff: datetime | None,
|
||||
favor_recent: bool,
|
||||
skip_quotes: bool = False,
|
||||
) -> list[Block]:
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
@@ -232,8 +233,9 @@ def build_qa_response_blocks(
|
||||
if filter_block is not None:
|
||||
response_blocks.append(filter_block)
|
||||
|
||||
response_blocks.extend(
|
||||
[answer_block, feedback_block] + quotes_blocks + [DividerBlock()]
|
||||
)
|
||||
response_blocks.extend([answer_block, feedback_block])
|
||||
if not skip_quotes:
|
||||
response_blocks.extend(quotes_blocks)
|
||||
response_blocks.append(DividerBlock())
|
||||
|
||||
return response_blocks
|
||||
|
@@ -104,6 +104,8 @@ def handle_message(
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
# List of user id to send message to, if None, send to everyone in channel
|
||||
send_to: list[str] | None = None
|
||||
respond_tag_only = False
|
||||
@@ -257,7 +259,7 @@ def handle_message(
|
||||
logger.debug(answer.answer)
|
||||
return True
|
||||
|
||||
if not answer.top_documents:
|
||||
if not answer.top_documents and not should_respond_even_with_no_docs:
|
||||
logger.error(f"Unable to answer question: '{msg}' - no documents found")
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
@@ -288,6 +290,7 @@ def handle_message(
|
||||
source_filters=answer.source_type,
|
||||
time_cutoff=answer.time_cutoff,
|
||||
favor_recent=answer.favor_recent,
|
||||
skip_quotes=persona is not None, # currently Personas don't support quotes
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
@@ -298,9 +301,13 @@ def handle_message(
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
query_event_id=answer.query_event_id,
|
||||
document_blocks = (
|
||||
build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
query_event_id=answer.query_event_id,
|
||||
)
|
||||
if priority_ordered_docs
|
||||
else []
|
||||
)
|
||||
|
||||
try:
|
||||
|
@@ -54,6 +54,13 @@ def _get_qa_model(persona: Persona | None) -> QAModel:
|
||||
return get_default_qa_model()
|
||||
|
||||
|
||||
def _dummy_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]:
|
||||
"""Mimics the interface of `full_chunk_search_generator` but returns empty lists
|
||||
without actually running retrieval / re-ranking."""
|
||||
yield cast(list[InferenceChunk], [])
|
||||
yield cast(list[bool], [])
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def answer_qa_query(
|
||||
new_message_request: NewMessageRequest,
|
||||
@@ -91,6 +98,7 @@ def answer_qa_query(
|
||||
not persona.apply_llm_relevance_filter if persona else None
|
||||
)
|
||||
persona_num_chunks = persona.num_chunks if persona else None
|
||||
persona_retrieval_disabled = persona.num_chunks == 0 if persona else False
|
||||
if persona:
|
||||
logger.info(f"Using persona: {persona.name}")
|
||||
logger.info(
|
||||
@@ -113,14 +121,19 @@ def answer_qa_query(
|
||||
if disable_generative_answer:
|
||||
predicted_flow = QueryFlow.SEARCH
|
||||
|
||||
top_chunks, llm_chunk_selection = full_chunk_search(
|
||||
query=retrieval_request,
|
||||
document_index=get_default_document_index(),
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
if not persona_retrieval_disabled:
|
||||
top_chunks, llm_chunk_selection = full_chunk_search(
|
||||
query=retrieval_request,
|
||||
document_index=get_default_document_index(),
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
else:
|
||||
top_chunks = []
|
||||
llm_chunk_selection = []
|
||||
top_docs = []
|
||||
|
||||
partial_response = partial(
|
||||
QAResponse,
|
||||
@@ -133,7 +146,7 @@ def answer_qa_query(
|
||||
favor_recent=retrieval_request.favor_recent,
|
||||
)
|
||||
|
||||
if disable_generative_answer or not top_docs:
|
||||
if disable_generative_answer or (not top_docs and not persona_retrieval_disabled):
|
||||
return partial_response(
|
||||
answer=None,
|
||||
quotes=None,
|
||||
@@ -237,6 +250,7 @@ def answer_qa_query_stream(
|
||||
not persona.apply_llm_relevance_filter if persona else None
|
||||
)
|
||||
persona_num_chunks = persona.num_chunks if persona else None
|
||||
persona_retrieval_disabled = persona.num_chunks == 0 if persona else False
|
||||
if persona:
|
||||
logger.info(f"Using persona: {persona.name}")
|
||||
logger.info(
|
||||
@@ -245,6 +259,10 @@ def answer_qa_query_stream(
|
||||
f"num_chunks: {persona_num_chunks}"
|
||||
)
|
||||
|
||||
# NOTE: it's not ideal that we're still doing `retrieval_preprocessing` even
|
||||
# if `persona_retrieval_disabled == True`, but it's a bit tricky to separate this
|
||||
# out. Since this flow is being re-worked shortly with the move to chat, leaving it
|
||||
# like this for now.
|
||||
retrieval_request, predicted_search_type, predicted_flow = retrieval_preprocessing(
|
||||
new_message_request=new_message_request,
|
||||
user=user,
|
||||
@@ -257,10 +275,13 @@ def answer_qa_query_stream(
|
||||
if persona:
|
||||
predicted_flow = QueryFlow.QUESTION_ANSWER
|
||||
|
||||
search_generator = full_chunk_search_generator(
|
||||
query=retrieval_request,
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
if not persona_retrieval_disabled:
|
||||
search_generator = full_chunk_search_generator(
|
||||
query=retrieval_request,
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
else:
|
||||
search_generator = _dummy_search_generator()
|
||||
|
||||
# first fetch and return to the UI the top chunks so the user can
|
||||
# immediately see some results
|
||||
@@ -280,7 +301,9 @@ def answer_qa_query_stream(
|
||||
).dict()
|
||||
yield get_json_line(initial_response)
|
||||
|
||||
if not top_chunks:
|
||||
# some personas intentionally don't retrieve any documents, so we should
|
||||
# not return early here
|
||||
if not top_chunks and not persona_retrieval_disabled:
|
||||
logger.debug("No Documents Found")
|
||||
return
|
||||
|
||||
|
@@ -26,6 +26,7 @@ from danswer.prompts.direct_qa_prompts import COT_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_up_code_blocks
|
||||
@@ -210,19 +211,31 @@ class PersonaBasedQAHandler(QAHandler):
|
||||
) -> list[BaseMessage]:
|
||||
context_docs_str = build_context_str(context_chunks)
|
||||
|
||||
single_message = PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str=context_docs_str,
|
||||
user_query=query,
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
).strip()
|
||||
if not context_chunks:
|
||||
single_message = PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query=query,
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
).strip()
|
||||
else:
|
||||
single_message = PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str=context_docs_str,
|
||||
user_query=query,
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
).strip()
|
||||
|
||||
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
|
||||
return prompt
|
||||
|
||||
def build_dummy_prompt(
|
||||
self,
|
||||
) -> str:
|
||||
def build_dummy_prompt(self, retrieval_disabled: bool) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
|
@@ -137,6 +137,15 @@ CONTEXT:
|
||||
RESPONSE:
|
||||
""".strip()
|
||||
|
||||
PARAMATERIZED_PROMPT_WITHOUT_CONTEXT = f"""
|
||||
{{system_prompt}}
|
||||
|
||||
{{task_prompt}}
|
||||
|
||||
{QUESTION_PAT.upper()} {{user_query}}
|
||||
RESPONSE:
|
||||
""".strip()
|
||||
|
||||
|
||||
# User the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
|
@@ -131,12 +131,13 @@ def get_persona(
|
||||
def build_final_template_prompt(
|
||||
system_prompt: str,
|
||||
task_prompt: str,
|
||||
retrieval_disabled: bool = False,
|
||||
_: User | None = Depends(current_user),
|
||||
) -> PromptTemplateResponse:
|
||||
return PromptTemplateResponse(
|
||||
final_prompt_template=PersonaBasedQAHandler(
|
||||
system_prompt=system_prompt, task_prompt=task_prompt
|
||||
).build_dummy_prompt()
|
||||
).build_dummy_prompt(retrieval_disabled=retrieval_disabled)
|
||||
)
|
||||
|
||||
|
||||
|
@@ -21,6 +21,7 @@ class PersonaSnapshot(BaseModel):
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
num_chunks: int | None
|
||||
document_sets: list[DocumentSet]
|
||||
llm_model_version_override: str | None
|
||||
|
||||
@@ -32,6 +33,7 @@ class PersonaSnapshot(BaseModel):
|
||||
description=persona.description or "",
|
||||
system_prompt=persona.system_text or "",
|
||||
task_prompt=persona.hint_text or "",
|
||||
num_chunks=persona.num_chunks,
|
||||
document_sets=[
|
||||
DocumentSet.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
|
Reference in New Issue
Block a user