From 32eee886283ad09c50574428925476c1714466f3 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Mon, 18 Sep 2023 21:10:20 -0700 Subject: [PATCH] Special Danswer flow for Chat (#459) --- .../dba7f71618f5_danswer_custom_tool_flow.py | 26 +++++ backend/danswer/chat/chat_llm.py | 81 ++++++++++++- backend/danswer/chat/chat_prompts.py | 107 ++++++++++++++---- backend/danswer/chat/personas.py | 5 +- backend/danswer/chat/personas.yaml | 14 +-- backend/danswer/chat/tools.py | 35 +----- backend/danswer/db/chat.py | 3 + backend/danswer/db/models.py | 4 + 8 files changed, 208 insertions(+), 67 deletions(-) create mode 100644 backend/alembic/versions/dba7f71618f5_danswer_custom_tool_flow.py diff --git a/backend/alembic/versions/dba7f71618f5_danswer_custom_tool_flow.py b/backend/alembic/versions/dba7f71618f5_danswer_custom_tool_flow.py new file mode 100644 index 000000000..100f37148 --- /dev/null +++ b/backend/alembic/versions/dba7f71618f5_danswer_custom_tool_flow.py @@ -0,0 +1,26 @@ +"""Danswer Custom Tool Flow + +Revision ID: dba7f71618f5 +Revises: d5645c915d0e +Create Date: 2023-09-18 15:18:37.370972 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "dba7f71618f5" +down_revision = "d5645c915d0e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "persona", sa.Column("retrieval_enabled", sa.Boolean(), nullable=False) + ) + + +def downgrade() -> None: + op.drop_column("persona", "retrieval_enabled") diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index 246cd3255..2839d50f3 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -6,18 +6,30 @@ from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage +from danswer.chat.chat_prompts import build_combined_query +from danswer.chat.chat_prompts import DANSWER_TOOL_NAME from danswer.chat.chat_prompts import form_tool_followup_text from danswer.chat.chat_prompts import form_user_prompt_text +from danswer.chat.chat_prompts import format_danswer_chunks_for_chat from danswer.chat.tools import call_tool +from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT +from danswer.configs.constants import IGNORE_FOR_QA +from danswer.datastores.document_index import get_default_document_index from danswer.db.models import ChatMessage from danswer.db.models import Persona from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerChatModelOut +from danswer.direct_qa.qa_utils import get_usable_chunks from danswer.llm.build import get_default_llm +from danswer.llm.llm import LLM from danswer.llm.utils import translate_danswer_msg_to_langchain +from danswer.search.semantic_search import retrieve_ranked_documents +from danswer.utils.logger import setup_logger from danswer.utils.text_processing import extract_embedded_json from danswer.utils.text_processing import has_unescaped_quote +logger = setup_logger() + def _parse_embedded_json_streamed_response( tokens: Iterator[str], @@ -55,6 +67,8 @@ def _parse_embedded_json_streamed_response( yield DanswerAnswerPiece(answer_piece=hold) hold = "" + logger.debug(model_output) + model_final = extract_embedded_json(model_output) if "action" not in model_final or "action_input" not in model_final: raise ValueError("Model did not provide all required action values") @@ -67,6 +81,44 @@ def _parse_embedded_json_streamed_response( return +def danswer_chat_retrieval( + query_message: ChatMessage, + history: list[ChatMessage], + llm: LLM, + user_id: UUID | None, +) -> str: + if history: + query_combination_msgs = build_combined_query(query_message, history) + reworded_query = llm.invoke(query_combination_msgs) + else: + reworded_query = query_message.message + + # Good Debug/Breakpoint + ranked_chunks, unranked_chunks = retrieve_ranked_documents( + reworded_query, + user_id=user_id, + filters=None, + datastore=get_default_document_index(), + ) + if not ranked_chunks: + return "No results found" + + if unranked_chunks: + ranked_chunks.extend(unranked_chunks) + + filtered_ranked_chunks = [ + chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA) + ] + + # get all chunks that fit into the token limit + usable_chunks = get_usable_chunks( + chunks=filtered_ranked_chunks, + token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT, + ) + + return format_danswer_chunks_for_chat(usable_chunks) + + def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]: prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages] @@ -78,11 +130,16 @@ def llm_contextual_chat_answer( persona: Persona, user_id: UUID | None, ) -> Iterator[str]: + retrieval_enabled = persona.retrieval_enabled system_text = persona.system_text tool_text = persona.tools_text hint_text = persona.hint_text last_message = messages[-1] + + if not last_message.message: + raise ValueError("User chat message is empty.") + previous_messages = messages[:-1] user_text = form_user_prompt_text( @@ -102,7 +159,10 @@ def llm_contextual_chat_answer( prompt.append(HumanMessage(content=user_text)) - tokens = get_default_llm().stream(prompt) + llm = get_default_llm() + + # Good Debug/Breakpoint + tokens = llm.stream(prompt) final_result: DanswerChatModelOut | None = None final_answer_streamed = False @@ -121,16 +181,29 @@ def llm_contextual_chat_answer( if final_result is None: raise RuntimeError("Model output finished without final output parsing.") - tool_result_str = call_tool(final_result, user_id=user_id) + if retrieval_enabled and final_result.action.lower() == DANSWER_TOOL_NAME.lower(): + tool_result_str = danswer_chat_retrieval( + query_message=last_message, + history=previous_messages, + llm=llm, + user_id=user_id, + ) + else: + tool_result_str = call_tool(final_result, user_id=user_id) prompt.append(AIMessage(content=final_result.model_raw)) prompt.append( HumanMessage( - content=form_tool_followup_text(tool_result_str, hint_text=hint_text) + content=form_tool_followup_text( + tool_output=tool_result_str, + query=last_message.message, + hint_text=hint_text, + ) ) ) - tokens = get_default_llm().stream(prompt) + # Good Debug/Breakpoint + tokens = llm.stream(prompt) for result in _parse_embedded_json_streamed_response(tokens): if isinstance(result, DanswerAnswerPiece) and result.answer_piece: diff --git a/backend/danswer/chat/chat_prompts.py b/backend/danswer/chat/chat_prompts.py index d4478f119..27c9d31d3 100644 --- a/backend/danswer/chat/chat_prompts.py +++ b/backend/danswer/chat/chat_prompts.py @@ -1,31 +1,48 @@ +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage + from danswer.chunking.models import InferenceChunk from danswer.configs.constants import CODE_BLOCK_PAT +from danswer.db.models import ChatMessage +from danswer.llm.utils import translate_danswer_msg_to_langchain -TOOL_TEMPLATE = """TOOLS +DANSWER_TOOL_NAME = "Current Search" +DANSWER_TOOL_DESCRIPTION = ( + "A search tool that can find information on any topic " + "including up to date and proprietary knowledge." +) + +DANSWER_SYSTEM_MSG = ( + "Given a conversation (between Human and Assistant) and a final message from Human, " + "rewrite the last message to be a standalone question that captures required/relevant context from the previous " + "conversation messages." +) + +TOOL_TEMPLATE = """ +TOOLS ------ -Assistant can ask the user to use tools to look up information that may be helpful in answering the users \ -original question. The tools the human can use are: +You can use tools to look up information that may be helpful in answering the user's \ +original question. The available tools are: -{} +{tool_overviews} RESPONSE FORMAT INSTRUCTIONS ---------------------------- - When responding to me, please output a response in one of two formats: **Option 1:** -Use this if you want the human to use a tool. -Markdown code snippet formatted in the following schema: +Use this if you want to use a tool. Markdown code snippet formatted in the following schema: ```json {{ - "action": string, \\ The action to take. Must be one of {} + "action": string, \\ The action to take. Must be one of {tool_names} "action_input": string \\ The input to the action }} ``` **Option #2:** -Use this if you want to respond directly to the human. Markdown code snippet formatted in the following schema: +Use this if you want to respond directly to the user. Markdown code snippet formatted in the following schema: ```json {{ @@ -52,19 +69,19 @@ USER'S INPUT Here is the user's input \ (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else): -{} +{user_input} """ TOOL_FOLLOWUP = """ TOOL RESPONSE: --------------------- -{} +{tool_output} USER'S INPUT -------------------- Okay, so what is the response to my last comment? If using information obtained from the tools you must \ mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! -{} +{optional_reminder}{hint} IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else. """ @@ -78,22 +95,27 @@ def form_user_prompt_text( ) -> str: user_prompt = tool_text or tool_less_prompt - user_prompt += user_input_prompt.format(query) + user_prompt += user_input_prompt.format(user_input=query) if hint_text: if user_prompt[-1] != "\n": user_prompt += "\n" user_prompt += "Hint: " + hint_text - return user_prompt + return user_prompt.strip() def form_tool_section_text( - tools: list[dict[str, str]], template: str = TOOL_TEMPLATE + tools: list[dict[str, str]], retrieval_enabled: bool, template: str = TOOL_TEMPLATE ) -> str | None: - if not tools: + if not tools and not retrieval_enabled: return None + if retrieval_enabled: + tools.append( + {"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION} + ) + tools_intro = [] for tool in tools: description_formatted = tool["description"].replace("\n", " ") @@ -102,7 +124,9 @@ def form_tool_section_text( tools_intro_text = "\n".join(tools_intro) tool_names_text = ", ".join([tool["name"] for tool in tools]) - return template.format(tools_intro_text, tool_names_text) + return template.format( + tool_overviews=tools_intro_text, tool_names=tool_names_text + ).strip() def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str: @@ -114,12 +138,53 @@ def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str: def form_tool_followup_text( tool_output: str, + query: str, hint_text: str | None, tool_followup_prompt: str = TOOL_FOLLOWUP, ignore_hint: bool = False, ) -> str: - if not ignore_hint and hint_text: - hint_text_spaced = f"\n{hint_text}\n" - return tool_followup_prompt.format(tool_output, hint_text_spaced) + # If multi-line query, it likely confuses the model more than helps + if "\n" not in query: + optional_reminder = f"As a reminder, my query was: {query}\n" + else: + optional_reminder = "" - return tool_followup_prompt.format(tool_output, "") + if not ignore_hint and hint_text: + hint_text_spaced = f"{hint_text}\n" + else: + hint_text_spaced = "" + + return tool_followup_prompt.format( + tool_output=tool_output, + optional_reminder=optional_reminder, + hint=hint_text_spaced, + ).strip() + + +def build_combined_query( + query_message: ChatMessage, + history: list[ChatMessage], +) -> list[BaseMessage]: + user_query = query_message.message + combined_query_msgs: list[BaseMessage] = [] + + if not user_query: + raise ValueError("Can't rephrase/search an empty query") + + combined_query_msgs.append(SystemMessage(content=DANSWER_SYSTEM_MSG)) + + combined_query_msgs.extend( + [translate_danswer_msg_to_langchain(msg) for msg in history] + ) + + combined_query_msgs.append( + HumanMessage( + content=( + "Help me rewrite this final query into a standalone question that takes into consideration the " + f"past messages of the conversation. You must ONLY return the rewritten query and nothing else." + f"\n\nQuery:\n{query_message.message}" + ) + ) + ) + + return combined_query_msgs diff --git a/backend/danswer/chat/personas.py b/backend/danswer/chat/personas.py index a823c1917..9639c0c63 100644 --- a/backend/danswer/chat/personas.py +++ b/backend/danswer/chat/personas.py @@ -14,10 +14,13 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None: all_personas = data.get("personas", []) with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session: for persona in all_personas: - tools = form_tool_section_text(persona["tools"]) + tools = form_tool_section_text( + persona["tools"], persona["retrieval_enabled"] + ) create_persona( persona_id=persona["id"], name=persona["name"], + retrieval_enabled=persona["retrieval_enabled"], system_text=persona["system"], tools_text=tools, hint_text=persona["hint"], diff --git a/backend/danswer/chat/personas.yaml b/backend/danswer/chat/personas.yaml index fe12d8a0d..f6c8cb638 100644 --- a/backend/danswer/chat/personas.yaml +++ b/backend/danswer/chat/personas.yaml @@ -4,11 +4,11 @@ personas: system: | You are a question answering system that is constantly learning and improving. You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries. - You have access to a search tool and should always use it anytime it might be useful. Your responses are as INFORMATIVE and DETAILED as possible. - tools: - - name: "Current Search" - description: | - A search for up to date and proprietary information. - This tool can handle natural language questions so it is better to ask very precise questions over keywords. - hint: "Use the Current Search tool to help find information on any topic!" + # Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled. + retrieval_enabled: true + # Each added tool needs to have a "name" and "description" + tools: [] + # Short tip to pass near the end of the prompt to emphasize some requirement + # Such as "Remember to be informative!" + hint: "" diff --git a/backend/danswer/chat/tools.py b/backend/danswer/chat/tools.py index 94829c076..0928b5a50 100644 --- a/backend/danswer/chat/tools.py +++ b/backend/danswer/chat/tools.py @@ -1,43 +1,10 @@ from uuid import UUID -from danswer.chat.chat_prompts import format_danswer_chunks_for_chat -from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT -from danswer.configs.constants import IGNORE_FOR_QA -from danswer.datastores.document_index import get_default_document_index from danswer.direct_qa.interfaces import DanswerChatModelOut -from danswer.direct_qa.qa_utils import get_usable_chunks -from danswer.search.semantic_search import retrieve_ranked_documents def call_tool( model_actions: DanswerChatModelOut, user_id: UUID | None, ) -> str: - if model_actions.action.lower() == "current search": - query = model_actions.action_input - - ranked_chunks, unranked_chunks = retrieve_ranked_documents( - query, - user_id=user_id, - filters=None, - datastore=get_default_document_index(), - ) - if not ranked_chunks: - return "No results found" - - if unranked_chunks: - ranked_chunks.extend(unranked_chunks) - - filtered_ranked_chunks = [ - chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA) - ] - - # get all chunks that fit into the token limit - usable_chunks = get_usable_chunks( - chunks=filtered_ranked_chunks, - token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT, - ) - - return format_danswer_chunks_for_chat(usable_chunks) - - raise ValueError("Invalid tool choice by LLM") + raise NotImplementedError("There are no additional tool integrations right now") diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 73ec16991..ceb8ca4c2 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -262,6 +262,7 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona: def create_persona( persona_id: int | None, name: str, + retrieval_enabled: bool, system_text: str | None, tools_text: str | None, hint_text: str | None, @@ -272,6 +273,7 @@ def create_persona( if persona: persona.name = name + persona.retrieval_enabled = retrieval_enabled persona.system_text = system_text persona.tools_text = tools_text persona.hint_text = hint_text @@ -280,6 +282,7 @@ def create_persona( persona = Persona( id=persona_id, name=name, + retrieval_enabled=retrieval_enabled, system_text=system_text, tools_text=tools_text, hint_text=hint_text, diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 7712c250d..3c7782092 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -354,9 +354,13 @@ class Persona(Base): __tablename__ = "persona" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String) + # Danswer retrieval, treated as a special tool + retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True) system_text: Mapped[str | None] = mapped_column(Text, nullable=True) tools_text: Mapped[str | None] = mapped_column(Text, nullable=True) hint_text: Mapped[str | None] = mapped_column(Text, nullable=True) + # Default personas are configured via backend during deployment + # Treated specially (cannot be user edited etc.) default_persona: Mapped[bool] = mapped_column(Boolean, default=False) # If it's updated and no longer latest (should no longer be shown), it is also considered deleted deleted: Mapped[bool] = mapped_column(Boolean, default=False)