mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Chat without tools should use a less complex prompt (#524)
This commit is contained in:
32
backend/alembic/versions/904451035c9b_store_tool_details.py
Normal file
32
backend/alembic/versions/904451035c9b_store_tool_details.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Store Tool Details
|
||||
|
||||
Revision ID: 904451035c9b
|
||||
Revises: e0a68a81d434
|
||||
Create Date: 2023-10-05 12:29:26.620000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "904451035c9b"
|
||||
down_revision = "e0a68a81d434"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("tools", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
)
|
||||
op.drop_column("persona", "tools_text")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("tools_text", sa.TEXT(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.drop_column("persona", "tools")
|
@@ -9,9 +9,14 @@ from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chat.chat_prompts import build_combined_query
|
||||
from danswer.chat.chat_prompts import DANSWER_TOOL_NAME
|
||||
from danswer.chat.chat_prompts import form_require_search_text
|
||||
from danswer.chat.chat_prompts import form_tool_followup_text
|
||||
from danswer.chat.chat_prompts import form_tool_less_followup_text
|
||||
from danswer.chat.chat_prompts import form_tool_section_text
|
||||
from danswer.chat.chat_prompts import form_user_prompt_text
|
||||
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
|
||||
from danswer.chat.chat_prompts import YES_SEARCH
|
||||
from danswer.chat.tools import call_tool
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
@@ -175,8 +180,8 @@ def _drop_messages_history_overflow(
|
||||
|
||||
def llm_contextless_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
tokenizer: Callable | None = None,
|
||||
system_text: str | None = None,
|
||||
tokenizer: Callable | None = None,
|
||||
) -> Iterator[str]:
|
||||
try:
|
||||
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||
@@ -213,11 +218,92 @@ def llm_contextual_chat_answer(
|
||||
persona: Persona,
|
||||
user_id: UUID | None,
|
||||
tokenizer: Callable,
|
||||
run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG,
|
||||
) -> Iterator[str]:
|
||||
last_message = messages[-1]
|
||||
final_query_text = last_message.message
|
||||
previous_messages = messages[:-1]
|
||||
previous_msgs_as_basemessage = [
|
||||
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
|
||||
]
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
|
||||
if not final_query_text:
|
||||
raise ValueError("User chat message is empty.")
|
||||
|
||||
# Determine if a search is necessary to answer the user query
|
||||
user_req_search_text = form_require_search_text(last_message)
|
||||
last_user_msg = HumanMessage(content=user_req_search_text)
|
||||
|
||||
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
|
||||
danswer_system_tokens = len(tokenizer(run_search_system_text))
|
||||
last_user_msg_tokens = len(tokenizer(user_req_search_text))
|
||||
|
||||
need_search_prompt = _drop_messages_history_overflow(
|
||||
system_msg=SystemMessage(content=run_search_system_text),
|
||||
system_token_count=danswer_system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage,
|
||||
history_token_counts=previous_msg_token_counts,
|
||||
final_msg=last_user_msg,
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
model_out = llm.invoke(need_search_prompt)
|
||||
|
||||
# Model will output "Yes Search" if search is useful
|
||||
# Be a little forgiving though, if we match yes, it's good enough
|
||||
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
|
||||
tool_result_str = danswer_chat_retrieval(
|
||||
query_message=last_message,
|
||||
history=previous_messages,
|
||||
llm=llm,
|
||||
user_id=user_id,
|
||||
)
|
||||
last_user_msg_text = form_tool_less_followup_text(
|
||||
tool_output=tool_result_str,
|
||||
query=last_message.message,
|
||||
hint_text=persona.hint_text,
|
||||
)
|
||||
last_user_msg_tokens = len(tokenizer(last_user_msg_text))
|
||||
last_user_msg = HumanMessage(content=last_user_msg_text)
|
||||
|
||||
else:
|
||||
last_user_msg_tokens = len(tokenizer(final_query_text))
|
||||
last_user_msg = HumanMessage(content=final_query_text)
|
||||
|
||||
system_text = persona.system_text
|
||||
system_msg = SystemMessage(content=system_text) if system_text else None
|
||||
system_tokens = len(tokenizer(system_text)) if system_text else 0
|
||||
|
||||
prompt = _drop_messages_history_overflow(
|
||||
system_msg=system_msg,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage,
|
||||
history_token_counts=previous_msg_token_counts,
|
||||
final_msg=last_user_msg,
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
return llm.stream(prompt)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
|
||||
|
||||
|
||||
def llm_tools_enabled_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona,
|
||||
user_id: UUID | None,
|
||||
tokenizer: Callable,
|
||||
) -> Iterator[str]:
|
||||
retrieval_enabled = persona.retrieval_enabled
|
||||
system_text = persona.system_text
|
||||
tool_text = persona.tools_text
|
||||
hint_text = persona.hint_text
|
||||
tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled)
|
||||
|
||||
last_message = messages[-1]
|
||||
previous_messages = messages[:-1]
|
||||
@@ -351,11 +437,17 @@ def llm_chat_answer(
|
||||
if persona is None:
|
||||
return llm_contextless_chat_answer(messages)
|
||||
|
||||
elif persona.retrieval_enabled is False and persona.tools_text is None:
|
||||
elif persona.retrieval_enabled is False and not persona.tools:
|
||||
return llm_contextless_chat_answer(
|
||||
messages, tokenizer, system_text=persona.system_text
|
||||
messages, system_text=persona.system_text, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
return llm_contextual_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||
)
|
||||
elif persona.retrieval_enabled and not persona.tools:
|
||||
return llm_contextual_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
else:
|
||||
return llm_tools_enabled_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||
)
|
||||
|
@@ -4,7 +4,9 @@ from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ToolInfo
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
|
||||
DANSWER_TOOL_NAME = "Current Search"
|
||||
@@ -20,6 +22,22 @@ DANSWER_SYSTEM_MSG = (
|
||||
"It is used for a natural language search."
|
||||
)
|
||||
|
||||
|
||||
YES_SEARCH = "Yes Search"
|
||||
NO_SEARCH = "No Search"
|
||||
REQUIRE_DANSWER_SYSTEM_MSG = (
|
||||
"You are a large language model whose only job is to determine if the system should call an external search tool "
|
||||
"to be able to answer the user's last message.\n"
|
||||
f'\nRespond with "{NO_SEARCH}" if:\n'
|
||||
f"- there is sufficient information in chat history to fully answer the user query\n"
|
||||
f"- there is enough knowledge in the LLM to fully answer the user query\n"
|
||||
f"- the user query does not rely on any specific knowledge\n"
|
||||
f'\nRespond with "{YES_SEARCH}" if:\n'
|
||||
"- additional knowledge about entities, processes, problems, or anything else could lead to a better answer.\n"
|
||||
"- there is some uncertainty what the user is referring to\n\n"
|
||||
f'Respond with EXACTLY and ONLY "{YES_SEARCH}" or "{NO_SEARCH}"'
|
||||
)
|
||||
|
||||
TOOL_TEMPLATE = """
|
||||
TOOLS
|
||||
------
|
||||
@@ -37,7 +55,7 @@ Use this if you want to use a tool. Markdown code snippet formatted in the follo
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": string, \\ The action to take. Must be one of {tool_names}
|
||||
"action": string, \\ The action to take. {tool_names}
|
||||
"action_input": string \\ The input to the action
|
||||
}}
|
||||
```
|
||||
@@ -88,6 +106,21 @@ IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a s
|
||||
"""
|
||||
|
||||
|
||||
TOOL_LESS_FOLLOWUP = """
|
||||
Refer to the following documents when responding to my final query. Ignore any documents that are not relevant.
|
||||
|
||||
CONTEXT DOCUMENTS:
|
||||
---------------------
|
||||
{context_str}
|
||||
|
||||
FINAL QUERY:
|
||||
--------------------
|
||||
{user_query}
|
||||
|
||||
{hint_text}
|
||||
"""
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
@@ -108,23 +141,30 @@ def form_user_prompt_text(
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[dict[str, str]], retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled:
|
||||
if retrieval_enabled and tools:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
if tools:
|
||||
num_tools = len(tools)
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = ", ".join([tool["name"] for tool in tools])
|
||||
prefix = "Must be one of " if num_tools > 1 else "Must be "
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
@@ -184,10 +224,48 @@ def build_combined_query(
|
||||
content=(
|
||||
"Help me rewrite this final message into a standalone query that takes into consideration the "
|
||||
f"past messages of the conversation if relevant. This query is used with a semantic search engine to "
|
||||
f"retrieve documents. You must ONLY return the rewritten query and nothing else."
|
||||
f"retrieve documents. You must ONLY return the rewritten query and nothing else. "
|
||||
f"Remember, the search engine does not have access to the conversation history!"
|
||||
f"\n\nQuery:\n{query_message.message}"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return combined_query_msgs
|
||||
|
||||
|
||||
def form_require_search_single_msg_text(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
) -> str:
|
||||
prompt = "MESSAGE_HISTORY\n---------------\n" if history else ""
|
||||
|
||||
for msg in history:
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
prefix = "AI"
|
||||
else:
|
||||
prefix = "User"
|
||||
prompt += f"{prefix}:\n```\n{msg.message}\n```\n\n"
|
||||
|
||||
prompt += f"\nFINAL QUERY:\n---------------\n{query_message.message}"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def form_require_search_text(query_message: ChatMessage) -> str:
|
||||
return (
|
||||
query_message.message
|
||||
+ f"\n\nHint: respond with EXACTLY {YES_SEARCH} or {NO_SEARCH}"
|
||||
)
|
||||
|
||||
|
||||
def form_tool_less_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
|
||||
) -> str:
|
||||
hint = f"Hint: {hint_text}" if hint_text else ""
|
||||
return tool_followup_prompt.format(
|
||||
context_str=tool_output, user_query=query, hint_text=hint
|
||||
).strip()
|
||||
|
@@ -1,10 +1,26 @@
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_prompts import form_tool_section_text
|
||||
from danswer.configs.app_configs import PERSONAS_YAML
|
||||
from danswer.db.chat import create_persona
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import ToolInfo
|
||||
|
||||
|
||||
def validate_tool_info(item: Any) -> ToolInfo:
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and "name" in item
|
||||
and isinstance(item["name"], str)
|
||||
and "description" in item
|
||||
and isinstance(item["description"], str)
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid Persona configuration yaml Found, not all tools have name/description"
|
||||
)
|
||||
return ToolInfo(name=item["name"], description=item["description"])
|
||||
|
||||
|
||||
def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
|
||||
@@ -14,15 +30,14 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
for persona in all_personas:
|
||||
tools = form_tool_section_text(
|
||||
persona["tools"], persona["retrieval_enabled"]
|
||||
)
|
||||
create_persona(
|
||||
tools = [validate_tool_info(tool) for tool in persona["tools"]]
|
||||
|
||||
upsert_persona(
|
||||
persona_id=persona["id"],
|
||||
name=persona["name"],
|
||||
retrieval_enabled=persona["retrieval_enabled"],
|
||||
system_text=persona["system"],
|
||||
tools_text=tools,
|
||||
tools=tools,
|
||||
hint_text=persona["hint"],
|
||||
default_persona=True,
|
||||
db_session=db_session,
|
||||
|
@@ -7,7 +7,12 @@ personas:
|
||||
Your responses are as INFORMATIVE and DETAILED as possible.
|
||||
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
|
||||
retrieval_enabled: true
|
||||
# Each added tool needs to have a "name" and "description"
|
||||
# Example of adding tools, it must follow this structure:
|
||||
# tools:
|
||||
# - name: "Calculator"
|
||||
# description: "Use this tool to accurately process math equations, counting, etc."
|
||||
# - name: "Current Time"
|
||||
# description: "Call this to get the current date and time."
|
||||
tools: []
|
||||
# Short tip to pass near the end of the prompt to emphasize some requirement
|
||||
hint: "Try to be as informative as possible!"
|
||||
|
@@ -13,6 +13,7 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import ToolInfo
|
||||
|
||||
|
||||
def fetch_chat_sessions_by_user(
|
||||
@@ -261,12 +262,12 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
||||
return persona
|
||||
|
||||
|
||||
def create_persona(
|
||||
def upsert_persona(
|
||||
persona_id: int | None,
|
||||
name: str,
|
||||
retrieval_enabled: bool,
|
||||
system_text: str | None,
|
||||
tools_text: str | None,
|
||||
tools: list[ToolInfo] | None,
|
||||
hint_text: str | None,
|
||||
default_persona: bool,
|
||||
db_session: Session,
|
||||
@@ -278,7 +279,7 @@ def create_persona(
|
||||
persona.name = name
|
||||
persona.retrieval_enabled = retrieval_enabled
|
||||
persona.system_text = system_text
|
||||
persona.tools_text = tools_text
|
||||
persona.tools = tools
|
||||
persona.hint_text = hint_text
|
||||
persona.default_persona = default_persona
|
||||
else:
|
||||
@@ -287,7 +288,7 @@ def create_persona(
|
||||
name=name,
|
||||
retrieval_enabled=retrieval_enabled,
|
||||
system_text=system_text,
|
||||
tools_text=tools_text,
|
||||
tools=tools,
|
||||
hint_text=hint_text,
|
||||
default_persona=default_persona,
|
||||
)
|
||||
|
@@ -439,6 +439,11 @@ class ChatSession(Base):
|
||||
)
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class Persona(Base):
|
||||
# TODO introduce user and group ownership for personas
|
||||
__tablename__ = "persona"
|
||||
@@ -447,7 +452,9 @@ class Persona(Base):
|
||||
# Danswer retrieval, treated as a special tool
|
||||
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tools_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tools: Mapped[list[ToolInfo] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# Default personas are configured via backend during deployment
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
|
@@ -3,7 +3,7 @@ from collections.abc import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.chat import create_persona
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
@@ -35,12 +35,12 @@ def _create_slack_bot_persona(
|
||||
"""NOTE: does not commit changes"""
|
||||
# create/update persona associated with the slack bot
|
||||
persona_name = _build_persona_name(channel_names)
|
||||
persona = create_persona(
|
||||
persona = upsert_persona(
|
||||
persona_id=existing_persona_id,
|
||||
name=persona_name,
|
||||
retrieval_enabled=True,
|
||||
system_text=None,
|
||||
tools_text=None,
|
||||
tools=None,
|
||||
hint_text=None,
|
||||
default_persona=False,
|
||||
db_session=db_session,
|
||||
|
Reference in New Issue
Block a user