Special Danswer flow for Chat (#459)

This commit is contained in:
Yuhong Sun
2023-09-18 21:10:20 -07:00
committed by GitHub
parent 3641102672
commit 32eee88628
8 changed files with 208 additions and 67 deletions

View File

@@ -0,0 +1,26 @@
"""Danswer Custom Tool Flow
Revision ID: dba7f71618f5
Revises: d5645c915d0e
Create Date: 2023-09-18 15:18:37.370972
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "dba7f71618f5"
down_revision = "d5645c915d0e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona", sa.Column("retrieval_enabled", sa.Boolean(), nullable=False)
)
def downgrade() -> None:
op.drop_column("persona", "retrieval_enabled")

View File

@@ -6,18 +6,30 @@ from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage from langchain.schema.messages import SystemMessage
from danswer.chat.chat_prompts import build_combined_query
from danswer.chat.chat_prompts import DANSWER_TOOL_NAME
from danswer.chat.chat_prompts import form_tool_followup_text from danswer.chat.chat_prompts import form_tool_followup_text
from danswer.chat.chat_prompts import form_user_prompt_text from danswer.chat.chat_prompts import form_user_prompt_text
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.chat.tools import call_tool from danswer.chat.tools import call_tool
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerChatModelOut from danswer.direct_qa.interfaces import DanswerChatModelOut
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.llm.build import get_default_llm from danswer.llm.build import get_default_llm
from danswer.llm.llm import LLM
from danswer.llm.utils import translate_danswer_msg_to_langchain from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import extract_embedded_json from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.text_processing import has_unescaped_quote from danswer.utils.text_processing import has_unescaped_quote
logger = setup_logger()
def _parse_embedded_json_streamed_response( def _parse_embedded_json_streamed_response(
tokens: Iterator[str], tokens: Iterator[str],
@@ -55,6 +67,8 @@ def _parse_embedded_json_streamed_response(
yield DanswerAnswerPiece(answer_piece=hold) yield DanswerAnswerPiece(answer_piece=hold)
hold = "" hold = ""
logger.debug(model_output)
model_final = extract_embedded_json(model_output) model_final = extract_embedded_json(model_output)
if "action" not in model_final or "action_input" not in model_final: if "action" not in model_final or "action_input" not in model_final:
raise ValueError("Model did not provide all required action values") raise ValueError("Model did not provide all required action values")
@@ -67,6 +81,44 @@ def _parse_embedded_json_streamed_response(
return return
def danswer_chat_retrieval(
query_message: ChatMessage,
history: list[ChatMessage],
llm: LLM,
user_id: UUID | None,
) -> str:
if history:
query_combination_msgs = build_combined_query(query_message, history)
reworded_query = llm.invoke(query_combination_msgs)
else:
reworded_query = query_message.message
# Good Debug/Breakpoint
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
reworded_query,
user_id=user_id,
filters=None,
datastore=get_default_document_index(),
)
if not ranked_chunks:
return "No results found"
if unranked_chunks:
ranked_chunks.extend(unranked_chunks)
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
)
return format_danswer_chunks_for_chat(usable_chunks)
def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]: def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]:
prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages] prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages]
@@ -78,11 +130,16 @@ def llm_contextual_chat_answer(
persona: Persona, persona: Persona,
user_id: UUID | None, user_id: UUID | None,
) -> Iterator[str]: ) -> Iterator[str]:
retrieval_enabled = persona.retrieval_enabled
system_text = persona.system_text system_text = persona.system_text
tool_text = persona.tools_text tool_text = persona.tools_text
hint_text = persona.hint_text hint_text = persona.hint_text
last_message = messages[-1] last_message = messages[-1]
if not last_message.message:
raise ValueError("User chat message is empty.")
previous_messages = messages[:-1] previous_messages = messages[:-1]
user_text = form_user_prompt_text( user_text = form_user_prompt_text(
@@ -102,7 +159,10 @@ def llm_contextual_chat_answer(
prompt.append(HumanMessage(content=user_text)) prompt.append(HumanMessage(content=user_text))
tokens = get_default_llm().stream(prompt) llm = get_default_llm()
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
final_result: DanswerChatModelOut | None = None final_result: DanswerChatModelOut | None = None
final_answer_streamed = False final_answer_streamed = False
@@ -121,16 +181,29 @@ def llm_contextual_chat_answer(
if final_result is None: if final_result is None:
raise RuntimeError("Model output finished without final output parsing.") raise RuntimeError("Model output finished without final output parsing.")
tool_result_str = call_tool(final_result, user_id=user_id) if retrieval_enabled and final_result.action.lower() == DANSWER_TOOL_NAME.lower():
tool_result_str = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user_id=user_id,
)
else:
tool_result_str = call_tool(final_result, user_id=user_id)
prompt.append(AIMessage(content=final_result.model_raw)) prompt.append(AIMessage(content=final_result.model_raw))
prompt.append( prompt.append(
HumanMessage( HumanMessage(
content=form_tool_followup_text(tool_result_str, hint_text=hint_text) content=form_tool_followup_text(
tool_output=tool_result_str,
query=last_message.message,
hint_text=hint_text,
)
) )
) )
tokens = get_default_llm().stream(prompt) # Good Debug/Breakpoint
tokens = llm.stream(prompt)
for result in _parse_embedded_json_streamed_response(tokens): for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece: if isinstance(result, DanswerAnswerPiece) and result.answer_piece:

View File

@@ -1,31 +1,48 @@
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import CODE_BLOCK_PAT from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.db.models import ChatMessage
from danswer.llm.utils import translate_danswer_msg_to_langchain
TOOL_TEMPLATE = """TOOLS DANSWER_TOOL_NAME = "Current Search"
DANSWER_TOOL_DESCRIPTION = (
"A search tool that can find information on any topic "
"including up to date and proprietary knowledge."
)
DANSWER_SYSTEM_MSG = (
"Given a conversation (between Human and Assistant) and a final message from Human, "
"rewrite the last message to be a standalone question that captures required/relevant context from the previous "
"conversation messages."
)
TOOL_TEMPLATE = """
TOOLS
------ ------
Assistant can ask the user to use tools to look up information that may be helpful in answering the users \ You can use tools to look up information that may be helpful in answering the user's \
original question. The tools the human can use are: original question. The available tools are:
{} {tool_overviews}
RESPONSE FORMAT INSTRUCTIONS RESPONSE FORMAT INSTRUCTIONS
---------------------------- ----------------------------
When responding to me, please output a response in one of two formats: When responding to me, please output a response in one of two formats:
**Option 1:** **Option 1:**
Use this if you want the human to use a tool. Use this if you want to use a tool. Markdown code snippet formatted in the following schema:
Markdown code snippet formatted in the following schema:
```json ```json
{{ {{
"action": string, \\ The action to take. Must be one of {} "action": string, \\ The action to take. Must be one of {tool_names}
"action_input": string \\ The input to the action "action_input": string \\ The input to the action
}} }}
``` ```
**Option #2:** **Option #2:**
Use this if you want to respond directly to the human. Markdown code snippet formatted in the following schema: Use this if you want to respond directly to the user. Markdown code snippet formatted in the following schema:
```json ```json
{{ {{
@@ -52,19 +69,19 @@ USER'S INPUT
Here is the user's input \ Here is the user's input \
(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else): (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
{} {user_input}
""" """
TOOL_FOLLOWUP = """ TOOL_FOLLOWUP = """
TOOL RESPONSE: TOOL RESPONSE:
--------------------- ---------------------
{} {tool_output}
USER'S INPUT USER'S INPUT
-------------------- --------------------
Okay, so what is the response to my last comment? If using information obtained from the tools you must \ Okay, so what is the response to my last comment? If using information obtained from the tools you must \
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
{} {optional_reminder}{hint}
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else. IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
""" """
@@ -78,22 +95,27 @@ def form_user_prompt_text(
) -> str: ) -> str:
user_prompt = tool_text or tool_less_prompt user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(query) user_prompt += user_input_prompt.format(user_input=query)
if hint_text: if hint_text:
if user_prompt[-1] != "\n": if user_prompt[-1] != "\n":
user_prompt += "\n" user_prompt += "\n"
user_prompt += "Hint: " + hint_text user_prompt += "Hint: " + hint_text
return user_prompt return user_prompt.strip()
def form_tool_section_text( def form_tool_section_text(
tools: list[dict[str, str]], template: str = TOOL_TEMPLATE tools: list[dict[str, str]], retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None: ) -> str | None:
if not tools: if not tools and not retrieval_enabled:
return None return None
if retrieval_enabled:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = [] tools_intro = []
for tool in tools: for tool in tools:
description_formatted = tool["description"].replace("\n", " ") description_formatted = tool["description"].replace("\n", " ")
@@ -102,7 +124,9 @@ def form_tool_section_text(
tools_intro_text = "\n".join(tools_intro) tools_intro_text = "\n".join(tools_intro)
tool_names_text = ", ".join([tool["name"] for tool in tools]) tool_names_text = ", ".join([tool["name"] for tool in tools])
return template.format(tools_intro_text, tool_names_text) return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
).strip()
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str: def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
@@ -114,12 +138,53 @@ def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
def form_tool_followup_text( def form_tool_followup_text(
tool_output: str, tool_output: str,
query: str,
hint_text: str | None, hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP, tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False, ignore_hint: bool = False,
) -> str: ) -> str:
if not ignore_hint and hint_text: # If multi-line query, it likely confuses the model more than helps
hint_text_spaced = f"\n{hint_text}\n" if "\n" not in query:
return tool_followup_prompt.format(tool_output, hint_text_spaced) optional_reminder = f"As a reminder, my query was: {query}\n"
else:
optional_reminder = ""
return tool_followup_prompt.format(tool_output, "") if not ignore_hint and hint_text:
hint_text_spaced = f"{hint_text}\n"
else:
hint_text_spaced = ""
return tool_followup_prompt.format(
tool_output=tool_output,
optional_reminder=optional_reminder,
hint=hint_text_spaced,
).strip()
def build_combined_query(
query_message: ChatMessage,
history: list[ChatMessage],
) -> list[BaseMessage]:
user_query = query_message.message
combined_query_msgs: list[BaseMessage] = []
if not user_query:
raise ValueError("Can't rephrase/search an empty query")
combined_query_msgs.append(SystemMessage(content=DANSWER_SYSTEM_MSG))
combined_query_msgs.extend(
[translate_danswer_msg_to_langchain(msg) for msg in history]
)
combined_query_msgs.append(
HumanMessage(
content=(
"Help me rewrite this final query into a standalone question that takes into consideration the "
f"past messages of the conversation. You must ONLY return the rewritten query and nothing else."
f"\n\nQuery:\n{query_message.message}"
)
)
)
return combined_query_msgs

View File

@@ -14,10 +14,13 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
all_personas = data.get("personas", []) all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session: with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
for persona in all_personas: for persona in all_personas:
tools = form_tool_section_text(persona["tools"]) tools = form_tool_section_text(
persona["tools"], persona["retrieval_enabled"]
)
create_persona( create_persona(
persona_id=persona["id"], persona_id=persona["id"],
name=persona["name"], name=persona["name"],
retrieval_enabled=persona["retrieval_enabled"],
system_text=persona["system"], system_text=persona["system"],
tools_text=tools, tools_text=tools,
hint_text=persona["hint"], hint_text=persona["hint"],

View File

@@ -4,11 +4,11 @@ personas:
system: | system: |
You are a question answering system that is constantly learning and improving. You are a question answering system that is constantly learning and improving.
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries. You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
You have access to a search tool and should always use it anytime it might be useful.
Your responses are as INFORMATIVE and DETAILED as possible. Your responses are as INFORMATIVE and DETAILED as possible.
tools: # Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
- name: "Current Search" retrieval_enabled: true
description: | # Each added tool needs to have a "name" and "description"
A search for up to date and proprietary information. tools: []
This tool can handle natural language questions so it is better to ask very precise questions over keywords. # Short tip to pass near the end of the prompt to emphasize some requirement
hint: "Use the Current Search tool to help find information on any topic!" # Such as "Remember to be informative!"
hint: ""

View File

@@ -1,43 +1,10 @@
from uuid import UUID from uuid import UUID
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index
from danswer.direct_qa.interfaces import DanswerChatModelOut from danswer.direct_qa.interfaces import DanswerChatModelOut
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.semantic_search import retrieve_ranked_documents
def call_tool( def call_tool(
model_actions: DanswerChatModelOut, model_actions: DanswerChatModelOut,
user_id: UUID | None, user_id: UUID | None,
) -> str: ) -> str:
if model_actions.action.lower() == "current search": raise NotImplementedError("There are no additional tool integrations right now")
query = model_actions.action_input
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query,
user_id=user_id,
filters=None,
datastore=get_default_document_index(),
)
if not ranked_chunks:
return "No results found"
if unranked_chunks:
ranked_chunks.extend(unranked_chunks)
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
)
return format_danswer_chunks_for_chat(usable_chunks)
raise ValueError("Invalid tool choice by LLM")

View File

@@ -262,6 +262,7 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
def create_persona( def create_persona(
persona_id: int | None, persona_id: int | None,
name: str, name: str,
retrieval_enabled: bool,
system_text: str | None, system_text: str | None,
tools_text: str | None, tools_text: str | None,
hint_text: str | None, hint_text: str | None,
@@ -272,6 +273,7 @@ def create_persona(
if persona: if persona:
persona.name = name persona.name = name
persona.retrieval_enabled = retrieval_enabled
persona.system_text = system_text persona.system_text = system_text
persona.tools_text = tools_text persona.tools_text = tools_text
persona.hint_text = hint_text persona.hint_text = hint_text
@@ -280,6 +282,7 @@ def create_persona(
persona = Persona( persona = Persona(
id=persona_id, id=persona_id,
name=name, name=name,
retrieval_enabled=retrieval_enabled,
system_text=system_text, system_text=system_text,
tools_text=tools_text, tools_text=tools_text,
hint_text=hint_text, hint_text=hint_text,

View File

@@ -354,9 +354,13 @@ class Persona(Base):
__tablename__ = "persona" __tablename__ = "persona"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String) name: Mapped[str] = mapped_column(String)
# Danswer retrieval, treated as a special tool
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
system_text: Mapped[str | None] = mapped_column(Text, nullable=True) system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
tools_text: Mapped[str | None] = mapped_column(Text, nullable=True) tools_text: Mapped[str | None] = mapped_column(Text, nullable=True)
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True) hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
# Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False) default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# If it's updated and no longer latest (should no longer be shown), it is also considered deleted # If it's updated and no longer latest (should no longer be shown), it is also considered deleted
deleted: Mapped[bool] = mapped_column(Boolean, default=False) deleted: Mapped[bool] = mapped_column(Boolean, default=False)