Persona enhancements

This commit is contained in:
Weves
2023-12-07 10:45:59 -08:00
committed by Chris Weaver
parent ddf3f99da4
commit d5658ce477
12 changed files with 304 additions and 149 deletions

View File

@@ -185,6 +185,7 @@ def build_qa_response_blocks(
source_filters: list[DocumentSource] | None,
time_cutoff: datetime | None,
favor_recent: bool,
skip_quotes: bool = False,
) -> list[Block]:
quotes_blocks: list[Block] = []
@@ -232,8 +233,9 @@ def build_qa_response_blocks(
if filter_block is not None:
response_blocks.append(filter_block)
response_blocks.extend(
[answer_block, feedback_block] + quotes_blocks + [DividerBlock()]
)
response_blocks.extend([answer_block, feedback_block])
if not skip_quotes:
response_blocks.extend(quotes_blocks)
response_blocks.append(DividerBlock())
return response_blocks

View File

@@ -104,6 +104,8 @@ def handle_message(
document_set.name for document_set in persona.document_sets
]
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
# List of user id to send message to, if None, send to everyone in channel
send_to: list[str] | None = None
respond_tag_only = False
@@ -257,7 +259,7 @@ def handle_message(
logger.debug(answer.answer)
return True
if not answer.top_documents:
if not answer.top_documents and not should_respond_even_with_no_docs:
logger.error(f"Unable to answer question: '{msg}' - no documents found")
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
@@ -288,6 +290,7 @@ def handle_message(
source_filters=answer.source_type,
time_cutoff=answer.time_cutoff,
favor_recent=answer.favor_recent,
skip_quotes=persona is not None, # currently Personas don't support quotes
)
# Get the chunks fed to the LLM only, then fill with other docs
@@ -298,9 +301,13 @@ def handle_message(
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
query_event_id=answer.query_event_id,
document_blocks = (
build_documents_blocks(
documents=priority_ordered_docs,
query_event_id=answer.query_event_id,
)
if priority_ordered_docs
else []
)
try:

View File

@@ -54,6 +54,13 @@ def _get_qa_model(persona: Persona | None) -> QAModel:
return get_default_qa_model()
def _dummy_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]:
"""Mimics the interface of `full_chunk_search_generator` but returns empty lists
without actually running retrieval / re-ranking."""
yield cast(list[InferenceChunk], [])
yield cast(list[bool], [])
@log_function_time()
def answer_qa_query(
new_message_request: NewMessageRequest,
@@ -91,6 +98,7 @@ def answer_qa_query(
not persona.apply_llm_relevance_filter if persona else None
)
persona_num_chunks = persona.num_chunks if persona else None
persona_retrieval_disabled = persona.num_chunks == 0 if persona else False
if persona:
logger.info(f"Using persona: {persona.name}")
logger.info(
@@ -113,14 +121,19 @@ def answer_qa_query(
if disable_generative_answer:
predicted_flow = QueryFlow.SEARCH
top_chunks, llm_chunk_selection = full_chunk_search(
query=retrieval_request,
document_index=get_default_document_index(),
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
if not persona_retrieval_disabled:
top_chunks, llm_chunk_selection = full_chunk_search(
query=retrieval_request,
document_index=get_default_document_index(),
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
top_docs = chunks_to_search_docs(top_chunks)
top_docs = chunks_to_search_docs(top_chunks)
else:
top_chunks = []
llm_chunk_selection = []
top_docs = []
partial_response = partial(
QAResponse,
@@ -133,7 +146,7 @@ def answer_qa_query(
favor_recent=retrieval_request.favor_recent,
)
if disable_generative_answer or not top_docs:
if disable_generative_answer or (not top_docs and not persona_retrieval_disabled):
return partial_response(
answer=None,
quotes=None,
@@ -237,6 +250,7 @@ def answer_qa_query_stream(
not persona.apply_llm_relevance_filter if persona else None
)
persona_num_chunks = persona.num_chunks if persona else None
persona_retrieval_disabled = persona.num_chunks == 0 if persona else False
if persona:
logger.info(f"Using persona: {persona.name}")
logger.info(
@@ -245,6 +259,10 @@ def answer_qa_query_stream(
f"num_chunks: {persona_num_chunks}"
)
# NOTE: it's not ideal that we're still doing `retrieval_preprocessing` even
# if `persona_retrieval_disabled == True`, but it's a bit tricky to separate this
# out. Since this flow is being re-worked shortly with the move to chat, leaving it
# like this for now.
retrieval_request, predicted_search_type, predicted_flow = retrieval_preprocessing(
new_message_request=new_message_request,
user=user,
@@ -257,10 +275,13 @@ def answer_qa_query_stream(
if persona:
predicted_flow = QueryFlow.QUESTION_ANSWER
search_generator = full_chunk_search_generator(
query=retrieval_request,
document_index=get_default_document_index(),
)
if not persona_retrieval_disabled:
search_generator = full_chunk_search_generator(
query=retrieval_request,
document_index=get_default_document_index(),
)
else:
search_generator = _dummy_search_generator()
# first fetch and return to the UI the top chunks so the user can
# immediately see some results
@@ -280,7 +301,9 @@ def answer_qa_query_stream(
).dict()
yield get_json_line(initial_response)
if not top_chunks:
# some personas intentionally don't retrieve any documents, so we should
# not return early here
if not top_chunks and not persona_retrieval_disabled:
logger.debug("No Documents Found")
return

View File

@@ -26,6 +26,7 @@ from danswer.prompts.direct_qa_prompts import COT_PROMPT
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_up_code_blocks
@@ -210,19 +211,31 @@ class PersonaBasedQAHandler(QAHandler):
) -> list[BaseMessage]:
context_docs_str = build_context_str(context_chunks)
single_message = PARAMATERIZED_PROMPT.format(
context_docs_str=context_docs_str,
user_query=query,
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
if not context_chunks:
single_message = PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query=query,
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
else:
single_message = PARAMATERIZED_PROMPT.format(
context_docs_str=context_docs_str,
user_query=query,
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
def build_dummy_prompt(
self,
) -> str:
def build_dummy_prompt(self, retrieval_disabled: bool) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",

View File

@@ -137,6 +137,15 @@ CONTEXT:
RESPONSE:
""".strip()
PARAMATERIZED_PROMPT_WITHOUT_CONTEXT = f"""
{{system_prompt}}
{{task_prompt}}
{QUESTION_PAT.upper()} {{user_query}}
RESPONSE:
""".strip()
# User the following for easy viewing of prompts
if __name__ == "__main__":

View File

@@ -131,12 +131,13 @@ def get_persona(
def build_final_template_prompt(
system_prompt: str,
task_prompt: str,
retrieval_disabled: bool = False,
_: User | None = Depends(current_user),
) -> PromptTemplateResponse:
return PromptTemplateResponse(
final_prompt_template=PersonaBasedQAHandler(
system_prompt=system_prompt, task_prompt=task_prompt
).build_dummy_prompt()
).build_dummy_prompt(retrieval_disabled=retrieval_disabled)
)

View File

@@ -21,6 +21,7 @@ class PersonaSnapshot(BaseModel):
description: str
system_prompt: str
task_prompt: str
num_chunks: int | None
document_sets: list[DocumentSet]
llm_model_version_override: str | None
@@ -32,6 +33,7 @@ class PersonaSnapshot(BaseModel):
description=persona.description or "",
system_prompt=persona.system_text or "",
task_prompt=persona.hint_text or "",
num_chunks=persona.num_chunks,
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets