mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 11:58:34 +02:00
Special Danswer flow for Chat (#459)
This commit is contained in:
parent
3641102672
commit
32eee88628
@ -0,0 +1,26 @@
|
||||
"""Danswer Custom Tool Flow
|
||||
|
||||
Revision ID: dba7f71618f5
|
||||
Revises: d5645c915d0e
|
||||
Create Date: 2023-09-18 15:18:37.370972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dba7f71618f5"
|
||||
down_revision = "d5645c915d0e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona", sa.Column("retrieval_enabled", sa.Boolean(), nullable=False)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "retrieval_enabled")
|
@ -6,18 +6,30 @@ from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chat.chat_prompts import build_combined_query
|
||||
from danswer.chat.chat_prompts import DANSWER_TOOL_NAME
|
||||
from danswer.chat.chat_prompts import form_tool_followup_text
|
||||
from danswer.chat.chat_prompts import form_user_prompt_text
|
||||
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||
from danswer.chat.tools import call_tool
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
from danswer.utils.text_processing import has_unescaped_quote
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _parse_embedded_json_streamed_response(
|
||||
tokens: Iterator[str],
|
||||
@ -55,6 +67,8 @@ def _parse_embedded_json_streamed_response(
|
||||
yield DanswerAnswerPiece(answer_piece=hold)
|
||||
hold = ""
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
model_final = extract_embedded_json(model_output)
|
||||
if "action" not in model_final or "action_input" not in model_final:
|
||||
raise ValueError("Model did not provide all required action values")
|
||||
@ -67,6 +81,44 @@ def _parse_embedded_json_streamed_response(
|
||||
return
|
||||
|
||||
|
||||
def danswer_chat_retrieval(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
llm: LLM,
|
||||
user_id: UUID | None,
|
||||
) -> str:
|
||||
if history:
|
||||
query_combination_msgs = build_combined_query(query_message, history)
|
||||
reworded_query = llm.invoke(query_combination_msgs)
|
||||
else:
|
||||
reworded_query = query_message.message
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
reworded_query,
|
||||
user_id=user_id,
|
||||
filters=None,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return "No results found"
|
||||
|
||||
if unranked_chunks:
|
||||
ranked_chunks.extend(unranked_chunks)
|
||||
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
usable_chunks = get_usable_chunks(
|
||||
chunks=filtered_ranked_chunks,
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
|
||||
)
|
||||
|
||||
return format_danswer_chunks_for_chat(usable_chunks)
|
||||
|
||||
|
||||
def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]:
|
||||
prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||
|
||||
@ -78,11 +130,16 @@ def llm_contextual_chat_answer(
|
||||
persona: Persona,
|
||||
user_id: UUID | None,
|
||||
) -> Iterator[str]:
|
||||
retrieval_enabled = persona.retrieval_enabled
|
||||
system_text = persona.system_text
|
||||
tool_text = persona.tools_text
|
||||
hint_text = persona.hint_text
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if not last_message.message:
|
||||
raise ValueError("User chat message is empty.")
|
||||
|
||||
previous_messages = messages[:-1]
|
||||
|
||||
user_text = form_user_prompt_text(
|
||||
@ -102,7 +159,10 @@ def llm_contextual_chat_answer(
|
||||
|
||||
prompt.append(HumanMessage(content=user_text))
|
||||
|
||||
tokens = get_default_llm().stream(prompt)
|
||||
llm = get_default_llm()
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
|
||||
final_result: DanswerChatModelOut | None = None
|
||||
final_answer_streamed = False
|
||||
@ -121,16 +181,29 @@ def llm_contextual_chat_answer(
|
||||
if final_result is None:
|
||||
raise RuntimeError("Model output finished without final output parsing.")
|
||||
|
||||
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||
if retrieval_enabled and final_result.action.lower() == DANSWER_TOOL_NAME.lower():
|
||||
tool_result_str = danswer_chat_retrieval(
|
||||
query_message=last_message,
|
||||
history=previous_messages,
|
||||
llm=llm,
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||
|
||||
prompt.append(AIMessage(content=final_result.model_raw))
|
||||
prompt.append(
|
||||
HumanMessage(
|
||||
content=form_tool_followup_text(tool_result_str, hint_text=hint_text)
|
||||
content=form_tool_followup_text(
|
||||
tool_output=tool_result_str,
|
||||
query=last_message.message,
|
||||
hint_text=hint_text,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
tokens = get_default_llm().stream(prompt)
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
|
@ -1,31 +1,48 @@
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
|
||||
TOOL_TEMPLATE = """TOOLS
|
||||
DANSWER_TOOL_NAME = "Current Search"
|
||||
DANSWER_TOOL_DESCRIPTION = (
|
||||
"A search tool that can find information on any topic "
|
||||
"including up to date and proprietary knowledge."
|
||||
)
|
||||
|
||||
DANSWER_SYSTEM_MSG = (
|
||||
"Given a conversation (between Human and Assistant) and a final message from Human, "
|
||||
"rewrite the last message to be a standalone question that captures required/relevant context from the previous "
|
||||
"conversation messages."
|
||||
)
|
||||
|
||||
TOOL_TEMPLATE = """
|
||||
TOOLS
|
||||
------
|
||||
Assistant can ask the user to use tools to look up information that may be helpful in answering the users \
|
||||
original question. The tools the human can use are:
|
||||
You can use tools to look up information that may be helpful in answering the user's \
|
||||
original question. The available tools are:
|
||||
|
||||
{}
|
||||
{tool_overviews}
|
||||
|
||||
RESPONSE FORMAT INSTRUCTIONS
|
||||
----------------------------
|
||||
|
||||
When responding to me, please output a response in one of two formats:
|
||||
|
||||
**Option 1:**
|
||||
Use this if you want the human to use a tool.
|
||||
Markdown code snippet formatted in the following schema:
|
||||
Use this if you want to use a tool. Markdown code snippet formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": string, \\ The action to take. Must be one of {}
|
||||
"action": string, \\ The action to take. Must be one of {tool_names}
|
||||
"action_input": string \\ The input to the action
|
||||
}}
|
||||
```
|
||||
|
||||
**Option #2:**
|
||||
Use this if you want to respond directly to the human. Markdown code snippet formatted in the following schema:
|
||||
Use this if you want to respond directly to the user. Markdown code snippet formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
@ -52,19 +69,19 @@ USER'S INPUT
|
||||
Here is the user's input \
|
||||
(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
|
||||
|
||||
{}
|
||||
{user_input}
|
||||
"""
|
||||
|
||||
TOOL_FOLLOWUP = """
|
||||
TOOL RESPONSE:
|
||||
---------------------
|
||||
{}
|
||||
{tool_output}
|
||||
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
|
||||
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
|
||||
{}
|
||||
{optional_reminder}{hint}
|
||||
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
|
||||
"""
|
||||
|
||||
@ -78,22 +95,27 @@ def form_user_prompt_text(
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(query)
|
||||
user_prompt += user_input_prompt.format(user_input=query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "Hint: " + hint_text
|
||||
|
||||
return user_prompt
|
||||
return user_prompt.strip()
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[dict[str, str]], template: str = TOOL_TEMPLATE
|
||||
tools: list[dict[str, str]], retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
@ -102,7 +124,9 @@ def form_tool_section_text(
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
return template.format(tools_intro_text, tool_names_text)
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
).strip()
|
||||
|
||||
|
||||
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
|
||||
@ -114,12 +138,53 @@ def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\n{hint_text}\n"
|
||||
return tool_followup_prompt.format(tool_output, hint_text_spaced)
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"As a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
return tool_followup_prompt.format(tool_output, "")
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"{hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
return tool_followup_prompt.format(
|
||||
tool_output=tool_output,
|
||||
optional_reminder=optional_reminder,
|
||||
hint=hint_text_spaced,
|
||||
).strip()
|
||||
|
||||
|
||||
def build_combined_query(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
) -> list[BaseMessage]:
|
||||
user_query = query_message.message
|
||||
combined_query_msgs: list[BaseMessage] = []
|
||||
|
||||
if not user_query:
|
||||
raise ValueError("Can't rephrase/search an empty query")
|
||||
|
||||
combined_query_msgs.append(SystemMessage(content=DANSWER_SYSTEM_MSG))
|
||||
|
||||
combined_query_msgs.extend(
|
||||
[translate_danswer_msg_to_langchain(msg) for msg in history]
|
||||
)
|
||||
|
||||
combined_query_msgs.append(
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Help me rewrite this final query into a standalone question that takes into consideration the "
|
||||
f"past messages of the conversation. You must ONLY return the rewritten query and nothing else."
|
||||
f"\n\nQuery:\n{query_message.message}"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return combined_query_msgs
|
||||
|
@ -14,10 +14,13 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
for persona in all_personas:
|
||||
tools = form_tool_section_text(persona["tools"])
|
||||
tools = form_tool_section_text(
|
||||
persona["tools"], persona["retrieval_enabled"]
|
||||
)
|
||||
create_persona(
|
||||
persona_id=persona["id"],
|
||||
name=persona["name"],
|
||||
retrieval_enabled=persona["retrieval_enabled"],
|
||||
system_text=persona["system"],
|
||||
tools_text=tools,
|
||||
hint_text=persona["hint"],
|
||||
|
@ -4,11 +4,11 @@ personas:
|
||||
system: |
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
|
||||
You have access to a search tool and should always use it anytime it might be useful.
|
||||
Your responses are as INFORMATIVE and DETAILED as possible.
|
||||
tools:
|
||||
- name: "Current Search"
|
||||
description: |
|
||||
A search for up to date and proprietary information.
|
||||
This tool can handle natural language questions so it is better to ask very precise questions over keywords.
|
||||
hint: "Use the Current Search tool to help find information on any topic!"
|
||||
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
|
||||
retrieval_enabled: true
|
||||
# Each added tool needs to have a "name" and "description"
|
||||
tools: []
|
||||
# Short tip to pass near the end of the prompt to emphasize some requirement
|
||||
# Such as "Remember to be informative!"
|
||||
hint: ""
|
||||
|
@ -1,43 +1,10 @@
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
|
||||
|
||||
def call_tool(
|
||||
model_actions: DanswerChatModelOut,
|
||||
user_id: UUID | None,
|
||||
) -> str:
|
||||
if model_actions.action.lower() == "current search":
|
||||
query = model_actions.action_input
|
||||
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query,
|
||||
user_id=user_id,
|
||||
filters=None,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return "No results found"
|
||||
|
||||
if unranked_chunks:
|
||||
ranked_chunks.extend(unranked_chunks)
|
||||
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
usable_chunks = get_usable_chunks(
|
||||
chunks=filtered_ranked_chunks,
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
|
||||
)
|
||||
|
||||
return format_danswer_chunks_for_chat(usable_chunks)
|
||||
|
||||
raise ValueError("Invalid tool choice by LLM")
|
||||
raise NotImplementedError("There are no additional tool integrations right now")
|
||||
|
@ -262,6 +262,7 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
||||
def create_persona(
|
||||
persona_id: int | None,
|
||||
name: str,
|
||||
retrieval_enabled: bool,
|
||||
system_text: str | None,
|
||||
tools_text: str | None,
|
||||
hint_text: str | None,
|
||||
@ -272,6 +273,7 @@ def create_persona(
|
||||
|
||||
if persona:
|
||||
persona.name = name
|
||||
persona.retrieval_enabled = retrieval_enabled
|
||||
persona.system_text = system_text
|
||||
persona.tools_text = tools_text
|
||||
persona.hint_text = hint_text
|
||||
@ -280,6 +282,7 @@ def create_persona(
|
||||
persona = Persona(
|
||||
id=persona_id,
|
||||
name=name,
|
||||
retrieval_enabled=retrieval_enabled,
|
||||
system_text=system_text,
|
||||
tools_text=tools_text,
|
||||
hint_text=hint_text,
|
||||
|
@ -354,9 +354,13 @@ class Persona(Base):
|
||||
__tablename__ = "persona"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
# Danswer retrieval, treated as a special tool
|
||||
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tools_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# Default personas are configured via backend during deployment
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# If it's updated and no longer latest (should no longer be shown), it is also considered deleted
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user