Chat without tools should use a less complex prompt (#524)

This commit is contained in:
Yuhong Sun
2023-10-05 21:44:13 -07:00
committed by GitHub
parent 9c89ae78ba
commit 0632e92144
8 changed files with 262 additions and 32 deletions

View File

@@ -0,0 +1,32 @@
"""Store Tool Details
Revision ID: 904451035c9b
Revises: e0a68a81d434
Create Date: 2023-10-05 12:29:26.620000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "904451035c9b"
down_revision = "e0a68a81d434"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("tools", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
)
op.drop_column("persona", "tools_text")
def downgrade() -> None:
op.add_column(
"persona",
sa.Column("tools_text", sa.TEXT(), autoincrement=False, nullable=True),
)
op.drop_column("persona", "tools")

View File

@@ -9,9 +9,14 @@ from langchain.schema.messages import SystemMessage
from danswer.chat.chat_prompts import build_combined_query
from danswer.chat.chat_prompts import DANSWER_TOOL_NAME
from danswer.chat.chat_prompts import form_require_search_text
from danswer.chat.chat_prompts import form_tool_followup_text
from danswer.chat.chat_prompts import form_tool_less_followup_text
from danswer.chat.chat_prompts import form_tool_section_text
from danswer.chat.chat_prompts import form_user_prompt_text
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
from danswer.chat.chat_prompts import YES_SEARCH
from danswer.chat.tools import call_tool
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.constants import IGNORE_FOR_QA
@@ -175,8 +180,8 @@ def _drop_messages_history_overflow(
def llm_contextless_chat_answer(
messages: list[ChatMessage],
tokenizer: Callable | None = None,
system_text: str | None = None,
tokenizer: Callable | None = None,
) -> Iterator[str]:
try:
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
@@ -213,11 +218,92 @@ def llm_contextual_chat_answer(
persona: Persona,
user_id: UUID | None,
tokenizer: Callable,
run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG,
) -> Iterator[str]:
last_message = messages[-1]
final_query_text = last_message.message
previous_messages = messages[:-1]
previous_msgs_as_basemessage = [
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
]
try:
llm = get_default_llm()
if not final_query_text:
raise ValueError("User chat message is empty.")
# Determine if a search is necessary to answer the user query
user_req_search_text = form_require_search_text(last_message)
last_user_msg = HumanMessage(content=user_req_search_text)
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
danswer_system_tokens = len(tokenizer(run_search_system_text))
last_user_msg_tokens = len(tokenizer(user_req_search_text))
need_search_prompt = _drop_messages_history_overflow(
system_msg=SystemMessage(content=run_search_system_text),
system_token_count=danswer_system_tokens,
history_msgs=previous_msgs_as_basemessage,
history_token_counts=previous_msg_token_counts,
final_msg=last_user_msg,
final_msg_token_count=last_user_msg_tokens,
)
# Good Debug/Breakpoint
model_out = llm.invoke(need_search_prompt)
# Model will output "Yes Search" if search is useful
# Be a little forgiving though, if we match yes, it's good enough
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
tool_result_str = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user_id=user_id,
)
last_user_msg_text = form_tool_less_followup_text(
tool_output=tool_result_str,
query=last_message.message,
hint_text=persona.hint_text,
)
last_user_msg_tokens = len(tokenizer(last_user_msg_text))
last_user_msg = HumanMessage(content=last_user_msg_text)
else:
last_user_msg_tokens = len(tokenizer(final_query_text))
last_user_msg = HumanMessage(content=final_query_text)
system_text = persona.system_text
system_msg = SystemMessage(content=system_text) if system_text else None
system_tokens = len(tokenizer(system_text)) if system_text else 0
prompt = _drop_messages_history_overflow(
system_msg=system_msg,
system_token_count=system_tokens,
history_msgs=previous_msgs_as_basemessage,
history_token_counts=previous_msg_token_counts,
final_msg=last_user_msg,
final_msg_token_count=last_user_msg_tokens,
)
return llm.stream(prompt)
except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}")
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
def llm_tools_enabled_chat_answer(
messages: list[ChatMessage],
persona: Persona,
user_id: UUID | None,
tokenizer: Callable,
) -> Iterator[str]:
retrieval_enabled = persona.retrieval_enabled
system_text = persona.system_text
tool_text = persona.tools_text
hint_text = persona.hint_text
tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled)
last_message = messages[-1]
previous_messages = messages[:-1]
@@ -351,11 +437,17 @@ def llm_chat_answer(
if persona is None:
return llm_contextless_chat_answer(messages)
elif persona.retrieval_enabled is False and persona.tools_text is None:
elif persona.retrieval_enabled is False and not persona.tools:
return llm_contextless_chat_answer(
messages, tokenizer, system_text=persona.system_text
messages, system_text=persona.system_text, tokenizer=tokenizer
)
return llm_contextual_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
)
elif persona.retrieval_enabled and not persona.tools:
return llm_contextual_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
)
else:
return llm_tools_enabled_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
)

View File

@@ -4,7 +4,9 @@ from langchain.schema.messages import SystemMessage
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ToolInfo
from danswer.llm.utils import translate_danswer_msg_to_langchain
DANSWER_TOOL_NAME = "Current Search"
@@ -20,6 +22,22 @@ DANSWER_SYSTEM_MSG = (
"It is used for a natural language search."
)
YES_SEARCH = "Yes Search"
NO_SEARCH = "No Search"
REQUIRE_DANSWER_SYSTEM_MSG = (
"You are a large language model whose only job is to determine if the system should call an external search tool "
"to be able to answer the user's last message.\n"
f'\nRespond with "{NO_SEARCH}" if:\n'
f"- there is sufficient information in chat history to fully answer the user query\n"
f"- there is enough knowledge in the LLM to fully answer the user query\n"
f"- the user query does not rely on any specific knowledge\n"
f'\nRespond with "{YES_SEARCH}" if:\n'
"- additional knowledge about entities, processes, problems, or anything else could lead to a better answer.\n"
"- there is some uncertainty what the user is referring to\n\n"
f'Respond with EXACTLY and ONLY "{YES_SEARCH}" or "{NO_SEARCH}"'
)
TOOL_TEMPLATE = """
TOOLS
------
@@ -37,7 +55,7 @@ Use this if you want to use a tool. Markdown code snippet formatted in the follo
```json
{{
"action": string, \\ The action to take. Must be one of {tool_names}
"action": string, \\ The action to take. {tool_names}
"action_input": string \\ The input to the action
}}
```
@@ -88,6 +106,21 @@ IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a s
"""
TOOL_LESS_FOLLOWUP = """
Refer to the following documents when responding to my final query. Ignore any documents that are not relevant.
CONTEXT DOCUMENTS:
---------------------
{context_str}
FINAL QUERY:
--------------------
{user_query}
{hint_text}
"""
def form_user_prompt_text(
query: str,
tool_text: str | None,
@@ -108,23 +141,30 @@ def form_user_prompt_text(
def form_tool_section_text(
tools: list[dict[str, str]], retrieval_enabled: bool, template: str = TOOL_TEMPLATE
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None:
if not tools and not retrieval_enabled:
return None
if retrieval_enabled:
if retrieval_enabled and tools:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = []
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
if tools:
num_tools = len(tools)
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
tools_intro_text = "\n".join(tools_intro)
tool_names_text = ", ".join([tool["name"] for tool in tools])
prefix = "Must be one of " if num_tools > 1 else "Must be "
tools_intro_text = "\n".join(tools_intro)
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
else:
return None
return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
@@ -184,10 +224,48 @@ def build_combined_query(
content=(
"Help me rewrite this final message into a standalone query that takes into consideration the "
f"past messages of the conversation if relevant. This query is used with a semantic search engine to "
f"retrieve documents. You must ONLY return the rewritten query and nothing else."
f"retrieve documents. You must ONLY return the rewritten query and nothing else. "
f"Remember, the search engine does not have access to the conversation history!"
f"\n\nQuery:\n{query_message.message}"
)
)
)
return combined_query_msgs
def form_require_search_single_msg_text(
query_message: ChatMessage,
history: list[ChatMessage],
) -> str:
prompt = "MESSAGE_HISTORY\n---------------\n" if history else ""
for msg in history:
if msg.message_type == MessageType.ASSISTANT:
prefix = "AI"
else:
prefix = "User"
prompt += f"{prefix}:\n```\n{msg.message}\n```\n\n"
prompt += f"\nFINAL QUERY:\n---------------\n{query_message.message}"
return prompt
def form_require_search_text(query_message: ChatMessage) -> str:
return (
query_message.message
+ f"\n\nHint: respond with EXACTLY {YES_SEARCH} or {NO_SEARCH}"
)
def form_tool_less_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
) -> str:
hint = f"Hint: {hint_text}" if hint_text else ""
return tool_followup_prompt.format(
context_str=tool_output, user_query=query, hint_text=hint
).strip()

View File

@@ -1,10 +1,26 @@
from typing import Any
import yaml
from sqlalchemy.orm import Session
from danswer.chat.chat_prompts import form_tool_section_text
from danswer.configs.app_configs import PERSONAS_YAML
from danswer.db.chat import create_persona
from danswer.db.chat import upsert_persona
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import ToolInfo
def validate_tool_info(item: Any) -> ToolInfo:
if not (
isinstance(item, dict)
and "name" in item
and isinstance(item["name"], str)
and "description" in item
and isinstance(item["description"], str)
):
raise ValueError(
"Invalid Persona configuration yaml Found, not all tools have name/description"
)
return ToolInfo(name=item["name"], description=item["description"])
def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
@@ -14,15 +30,14 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
for persona in all_personas:
tools = form_tool_section_text(
persona["tools"], persona["retrieval_enabled"]
)
create_persona(
tools = [validate_tool_info(tool) for tool in persona["tools"]]
upsert_persona(
persona_id=persona["id"],
name=persona["name"],
retrieval_enabled=persona["retrieval_enabled"],
system_text=persona["system"],
tools_text=tools,
tools=tools,
hint_text=persona["hint"],
default_persona=True,
db_session=db_session,

View File

@@ -7,7 +7,12 @@ personas:
Your responses are as INFORMATIVE and DETAILED as possible.
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
retrieval_enabled: true
# Each added tool needs to have a "name" and "description"
# Example of adding tools, it must follow this structure:
# tools:
# - name: "Calculator"
# description: "Use this tool to accurately process math equations, counting, etc."
# - name: "Current Time"
# description: "Call this to get the current date and time."
tools: []
# Short tip to pass near the end of the prompt to emphasize some requirement
hint: "Try to be as informative as possible!"

View File

@@ -13,6 +13,7 @@ from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import Persona
from danswer.db.models import ToolInfo
def fetch_chat_sessions_by_user(
@@ -261,12 +262,12 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
return persona
def create_persona(
def upsert_persona(
persona_id: int | None,
name: str,
retrieval_enabled: bool,
system_text: str | None,
tools_text: str | None,
tools: list[ToolInfo] | None,
hint_text: str | None,
default_persona: bool,
db_session: Session,
@@ -278,7 +279,7 @@ def create_persona(
persona.name = name
persona.retrieval_enabled = retrieval_enabled
persona.system_text = system_text
persona.tools_text = tools_text
persona.tools = tools
persona.hint_text = hint_text
persona.default_persona = default_persona
else:
@@ -287,7 +288,7 @@ def create_persona(
name=name,
retrieval_enabled=retrieval_enabled,
system_text=system_text,
tools_text=tools_text,
tools=tools,
hint_text=hint_text,
default_persona=default_persona,
)

View File

@@ -439,6 +439,11 @@ class ChatSession(Base):
)
class ToolInfo(TypedDict):
name: str
description: str
class Persona(Base):
# TODO introduce user and group ownership for personas
__tablename__ = "persona"
@@ -447,7 +452,9 @@ class Persona(Base):
# Danswer retrieval, treated as a special tool
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
tools_text: Mapped[str | None] = mapped_column(Text, nullable=True)
tools: Mapped[list[ToolInfo] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
# Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.chat import create_persona
from danswer.db.chat import upsert_persona
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
@@ -35,12 +35,12 @@ def _create_slack_bot_persona(
"""NOTE: does not commit changes"""
# create/update persona associated with the slack bot
persona_name = _build_persona_name(channel_names)
persona = create_persona(
persona = upsert_persona(
persona_id=existing_persona_id,
name=persona_name,
retrieval_enabled=True,
system_text=None,
tools_text=None,
tools=None,
hint_text=None,
default_persona=False,
db_session=db_session,