Special Danswer flow for Chat (#459)

This commit is contained in:
Yuhong Sun 2023-09-18 21:10:20 -07:00 committed by GitHub
parent 3641102672
commit 32eee88628
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 208 additions and 67 deletions

View File

@ -0,0 +1,26 @@
"""Danswer Custom Tool Flow
Revision ID: dba7f71618f5
Revises: d5645c915d0e
Create Date: 2023-09-18 15:18:37.370972
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "dba7f71618f5"
down_revision = "d5645c915d0e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona", sa.Column("retrieval_enabled", sa.Boolean(), nullable=False)
)
def downgrade() -> None:
op.drop_column("persona", "retrieval_enabled")

View File

@ -6,18 +6,30 @@ from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chat.chat_prompts import build_combined_query
from danswer.chat.chat_prompts import DANSWER_TOOL_NAME
from danswer.chat.chat_prompts import form_tool_followup_text
from danswer.chat.chat_prompts import form_user_prompt_text
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.chat.tools import call_tool
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerChatModelOut
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.llm.build import get_default_llm
from danswer.llm.llm import LLM
from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.text_processing import has_unescaped_quote
logger = setup_logger()
def _parse_embedded_json_streamed_response(
tokens: Iterator[str],
@ -55,6 +67,8 @@ def _parse_embedded_json_streamed_response(
yield DanswerAnswerPiece(answer_piece=hold)
hold = ""
logger.debug(model_output)
model_final = extract_embedded_json(model_output)
if "action" not in model_final or "action_input" not in model_final:
raise ValueError("Model did not provide all required action values")
@ -67,6 +81,44 @@ def _parse_embedded_json_streamed_response(
return
def danswer_chat_retrieval(
query_message: ChatMessage,
history: list[ChatMessage],
llm: LLM,
user_id: UUID | None,
) -> str:
if history:
query_combination_msgs = build_combined_query(query_message, history)
reworded_query = llm.invoke(query_combination_msgs)
else:
reworded_query = query_message.message
# Good Debug/Breakpoint
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
reworded_query,
user_id=user_id,
filters=None,
datastore=get_default_document_index(),
)
if not ranked_chunks:
return "No results found"
if unranked_chunks:
ranked_chunks.extend(unranked_chunks)
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
)
return format_danswer_chunks_for_chat(usable_chunks)
def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]:
prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages]
@ -78,11 +130,16 @@ def llm_contextual_chat_answer(
persona: Persona,
user_id: UUID | None,
) -> Iterator[str]:
retrieval_enabled = persona.retrieval_enabled
system_text = persona.system_text
tool_text = persona.tools_text
hint_text = persona.hint_text
last_message = messages[-1]
if not last_message.message:
raise ValueError("User chat message is empty.")
previous_messages = messages[:-1]
user_text = form_user_prompt_text(
@ -102,7 +159,10 @@ def llm_contextual_chat_answer(
prompt.append(HumanMessage(content=user_text))
tokens = get_default_llm().stream(prompt)
llm = get_default_llm()
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
final_result: DanswerChatModelOut | None = None
final_answer_streamed = False
@ -121,16 +181,29 @@ def llm_contextual_chat_answer(
if final_result is None:
raise RuntimeError("Model output finished without final output parsing.")
tool_result_str = call_tool(final_result, user_id=user_id)
if retrieval_enabled and final_result.action.lower() == DANSWER_TOOL_NAME.lower():
tool_result_str = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user_id=user_id,
)
else:
tool_result_str = call_tool(final_result, user_id=user_id)
prompt.append(AIMessage(content=final_result.model_raw))
prompt.append(
HumanMessage(
content=form_tool_followup_text(tool_result_str, hint_text=hint_text)
content=form_tool_followup_text(
tool_output=tool_result_str,
query=last_message.message,
hint_text=hint_text,
)
)
)
tokens = get_default_llm().stream(prompt)
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:

View File

@ -1,31 +1,48 @@
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.db.models import ChatMessage
from danswer.llm.utils import translate_danswer_msg_to_langchain
TOOL_TEMPLATE = """TOOLS
DANSWER_TOOL_NAME = "Current Search"
DANSWER_TOOL_DESCRIPTION = (
"A search tool that can find information on any topic "
"including up to date and proprietary knowledge."
)
DANSWER_SYSTEM_MSG = (
"Given a conversation (between Human and Assistant) and a final message from Human, "
"rewrite the last message to be a standalone question that captures required/relevant context from the previous "
"conversation messages."
)
TOOL_TEMPLATE = """
TOOLS
------
Assistant can ask the user to use tools to look up information that may be helpful in answering the users \
original question. The tools the human can use are:
You can use tools to look up information that may be helpful in answering the user's \
original question. The available tools are:
{}
{tool_overviews}
RESPONSE FORMAT INSTRUCTIONS
----------------------------
When responding to me, please output a response in one of two formats:
**Option 1:**
Use this if you want the human to use a tool.
Markdown code snippet formatted in the following schema:
Use this if you want to use a tool. Markdown code snippet formatted in the following schema:
```json
{{
"action": string, \\ The action to take. Must be one of {}
"action": string, \\ The action to take. Must be one of {tool_names}
"action_input": string \\ The input to the action
}}
```
**Option #2:**
Use this if you want to respond directly to the human. Markdown code snippet formatted in the following schema:
Use this if you want to respond directly to the user. Markdown code snippet formatted in the following schema:
```json
{{
@ -52,19 +69,19 @@ USER'S INPUT
Here is the user's input \
(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
{}
{user_input}
"""
TOOL_FOLLOWUP = """
TOOL RESPONSE:
---------------------
{}
{tool_output}
USER'S INPUT
--------------------
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
{}
{optional_reminder}{hint}
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
"""
@ -78,22 +95,27 @@ def form_user_prompt_text(
) -> str:
user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(query)
user_prompt += user_input_prompt.format(user_input=query)
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "Hint: " + hint_text
return user_prompt
return user_prompt.strip()
def form_tool_section_text(
tools: list[dict[str, str]], template: str = TOOL_TEMPLATE
tools: list[dict[str, str]], retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None:
if not tools:
if not tools and not retrieval_enabled:
return None
if retrieval_enabled:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = []
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
@ -102,7 +124,9 @@ def form_tool_section_text(
tools_intro_text = "\n".join(tools_intro)
tool_names_text = ", ".join([tool["name"] for tool in tools])
return template.format(tools_intro_text, tool_names_text)
return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
).strip()
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
@ -114,12 +138,53 @@ def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
def form_tool_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False,
) -> str:
if not ignore_hint and hint_text:
hint_text_spaced = f"\n{hint_text}\n"
return tool_followup_prompt.format(tool_output, hint_text_spaced)
# If multi-line query, it likely confuses the model more than helps
if "\n" not in query:
optional_reminder = f"As a reminder, my query was: {query}\n"
else:
optional_reminder = ""
return tool_followup_prompt.format(tool_output, "")
if not ignore_hint and hint_text:
hint_text_spaced = f"{hint_text}\n"
else:
hint_text_spaced = ""
return tool_followup_prompt.format(
tool_output=tool_output,
optional_reminder=optional_reminder,
hint=hint_text_spaced,
).strip()
def build_combined_query(
query_message: ChatMessage,
history: list[ChatMessage],
) -> list[BaseMessage]:
user_query = query_message.message
combined_query_msgs: list[BaseMessage] = []
if not user_query:
raise ValueError("Can't rephrase/search an empty query")
combined_query_msgs.append(SystemMessage(content=DANSWER_SYSTEM_MSG))
combined_query_msgs.extend(
[translate_danswer_msg_to_langchain(msg) for msg in history]
)
combined_query_msgs.append(
HumanMessage(
content=(
"Help me rewrite this final query into a standalone question that takes into consideration the "
f"past messages of the conversation. You must ONLY return the rewritten query and nothing else."
f"\n\nQuery:\n{query_message.message}"
)
)
)
return combined_query_msgs

View File

@ -14,10 +14,13 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
for persona in all_personas:
tools = form_tool_section_text(persona["tools"])
tools = form_tool_section_text(
persona["tools"], persona["retrieval_enabled"]
)
create_persona(
persona_id=persona["id"],
name=persona["name"],
retrieval_enabled=persona["retrieval_enabled"],
system_text=persona["system"],
tools_text=tools,
hint_text=persona["hint"],

View File

@ -4,11 +4,11 @@ personas:
system: |
You are a question answering system that is constantly learning and improving.
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
You have access to a search tool and should always use it anytime it might be useful.
Your responses are as INFORMATIVE and DETAILED as possible.
tools:
- name: "Current Search"
description: |
A search for up to date and proprietary information.
This tool can handle natural language questions so it is better to ask very precise questions over keywords.
hint: "Use the Current Search tool to help find information on any topic!"
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
retrieval_enabled: true
# Each added tool needs to have a "name" and "description"
tools: []
# Short tip to pass near the end of the prompt to emphasize some requirement
# Such as "Remember to be informative!"
hint: ""

View File

@ -1,43 +1,10 @@
from uuid import UUID
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index
from danswer.direct_qa.interfaces import DanswerChatModelOut
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.semantic_search import retrieve_ranked_documents
def call_tool(
model_actions: DanswerChatModelOut,
user_id: UUID | None,
) -> str:
if model_actions.action.lower() == "current search":
query = model_actions.action_input
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query,
user_id=user_id,
filters=None,
datastore=get_default_document_index(),
)
if not ranked_chunks:
return "No results found"
if unranked_chunks:
ranked_chunks.extend(unranked_chunks)
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
)
return format_danswer_chunks_for_chat(usable_chunks)
raise ValueError("Invalid tool choice by LLM")
raise NotImplementedError("There are no additional tool integrations right now")

View File

@ -262,6 +262,7 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
def create_persona(
persona_id: int | None,
name: str,
retrieval_enabled: bool,
system_text: str | None,
tools_text: str | None,
hint_text: str | None,
@ -272,6 +273,7 @@ def create_persona(
if persona:
persona.name = name
persona.retrieval_enabled = retrieval_enabled
persona.system_text = system_text
persona.tools_text = tools_text
persona.hint_text = hint_text
@ -280,6 +282,7 @@ def create_persona(
persona = Persona(
id=persona_id,
name=name,
retrieval_enabled=retrieval_enabled,
system_text=system_text,
tools_text=tools_text,
hint_text=hint_text,

View File

@ -354,9 +354,13 @@ class Persona(Base):
__tablename__ = "persona"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String)
# Danswer retrieval, treated as a special tool
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
tools_text: Mapped[str | None] = mapped_column(Text, nullable=True)
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
# Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# If it's updated and no longer latest (should no longer be shown), it is also considered deleted
deleted: Mapped[bool] = mapped_column(Boolean, default=False)