Chat without tools should use a less complex prompt (#524)

This commit is contained in:
Yuhong Sun
2023-10-05 21:44:13 -07:00
committed by GitHub
parent 9c89ae78ba
commit 0632e92144
8 changed files with 262 additions and 32 deletions

View File

@@ -0,0 +1,32 @@
"""Store Tool Details
Revision ID: 904451035c9b
Revises: e0a68a81d434
Create Date: 2023-10-05 12:29:26.620000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "904451035c9b"
down_revision = "e0a68a81d434"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("tools", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
)
op.drop_column("persona", "tools_text")
def downgrade() -> None:
op.add_column(
"persona",
sa.Column("tools_text", sa.TEXT(), autoincrement=False, nullable=True),
)
op.drop_column("persona", "tools")

View File

@@ -9,9 +9,14 @@ from langchain.schema.messages import SystemMessage
from danswer.chat.chat_prompts import build_combined_query from danswer.chat.chat_prompts import build_combined_query
from danswer.chat.chat_prompts import DANSWER_TOOL_NAME from danswer.chat.chat_prompts import DANSWER_TOOL_NAME
from danswer.chat.chat_prompts import form_require_search_text
from danswer.chat.chat_prompts import form_tool_followup_text from danswer.chat.chat_prompts import form_tool_followup_text
from danswer.chat.chat_prompts import form_tool_less_followup_text
from danswer.chat.chat_prompts import form_tool_section_text
from danswer.chat.chat_prompts import form_user_prompt_text from danswer.chat.chat_prompts import form_user_prompt_text
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
from danswer.chat.chat_prompts import YES_SEARCH
from danswer.chat.tools import call_tool from danswer.chat.tools import call_tool
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.constants import IGNORE_FOR_QA
@@ -175,8 +180,8 @@ def _drop_messages_history_overflow(
def llm_contextless_chat_answer( def llm_contextless_chat_answer(
messages: list[ChatMessage], messages: list[ChatMessage],
tokenizer: Callable | None = None,
system_text: str | None = None, system_text: str | None = None,
tokenizer: Callable | None = None,
) -> Iterator[str]: ) -> Iterator[str]:
try: try:
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages] prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
@@ -213,11 +218,92 @@ def llm_contextual_chat_answer(
persona: Persona, persona: Persona,
user_id: UUID | None, user_id: UUID | None,
tokenizer: Callable, tokenizer: Callable,
run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG,
) -> Iterator[str]:
last_message = messages[-1]
final_query_text = last_message.message
previous_messages = messages[:-1]
previous_msgs_as_basemessage = [
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
]
try:
llm = get_default_llm()
if not final_query_text:
raise ValueError("User chat message is empty.")
# Determine if a search is necessary to answer the user query
user_req_search_text = form_require_search_text(last_message)
last_user_msg = HumanMessage(content=user_req_search_text)
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
danswer_system_tokens = len(tokenizer(run_search_system_text))
last_user_msg_tokens = len(tokenizer(user_req_search_text))
need_search_prompt = _drop_messages_history_overflow(
system_msg=SystemMessage(content=run_search_system_text),
system_token_count=danswer_system_tokens,
history_msgs=previous_msgs_as_basemessage,
history_token_counts=previous_msg_token_counts,
final_msg=last_user_msg,
final_msg_token_count=last_user_msg_tokens,
)
# Good Debug/Breakpoint
model_out = llm.invoke(need_search_prompt)
# Model will output "Yes Search" if search is useful
# Be a little forgiving though, if we match yes, it's good enough
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
tool_result_str = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user_id=user_id,
)
last_user_msg_text = form_tool_less_followup_text(
tool_output=tool_result_str,
query=last_message.message,
hint_text=persona.hint_text,
)
last_user_msg_tokens = len(tokenizer(last_user_msg_text))
last_user_msg = HumanMessage(content=last_user_msg_text)
else:
last_user_msg_tokens = len(tokenizer(final_query_text))
last_user_msg = HumanMessage(content=final_query_text)
system_text = persona.system_text
system_msg = SystemMessage(content=system_text) if system_text else None
system_tokens = len(tokenizer(system_text)) if system_text else 0
prompt = _drop_messages_history_overflow(
system_msg=system_msg,
system_token_count=system_tokens,
history_msgs=previous_msgs_as_basemessage,
history_token_counts=previous_msg_token_counts,
final_msg=last_user_msg,
final_msg_token_count=last_user_msg_tokens,
)
return llm.stream(prompt)
except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}")
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
def llm_tools_enabled_chat_answer(
messages: list[ChatMessage],
persona: Persona,
user_id: UUID | None,
tokenizer: Callable,
) -> Iterator[str]: ) -> Iterator[str]:
retrieval_enabled = persona.retrieval_enabled retrieval_enabled = persona.retrieval_enabled
system_text = persona.system_text system_text = persona.system_text
tool_text = persona.tools_text
hint_text = persona.hint_text hint_text = persona.hint_text
tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled)
last_message = messages[-1] last_message = messages[-1]
previous_messages = messages[:-1] previous_messages = messages[:-1]
@@ -351,11 +437,17 @@ def llm_chat_answer(
if persona is None: if persona is None:
return llm_contextless_chat_answer(messages) return llm_contextless_chat_answer(messages)
elif persona.retrieval_enabled is False and persona.tools_text is None: elif persona.retrieval_enabled is False and not persona.tools:
return llm_contextless_chat_answer( return llm_contextless_chat_answer(
messages, tokenizer, system_text=persona.system_text messages, system_text=persona.system_text, tokenizer=tokenizer
) )
return llm_contextual_chat_answer( elif persona.retrieval_enabled and not persona.tools:
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer return llm_contextual_chat_answer(
) messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
)
else:
return llm_tools_enabled_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
)

View File

@@ -4,7 +4,9 @@ from langchain.schema.messages import SystemMessage
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import CODE_BLOCK_PAT from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.db.models import ToolInfo
from danswer.llm.utils import translate_danswer_msg_to_langchain from danswer.llm.utils import translate_danswer_msg_to_langchain
DANSWER_TOOL_NAME = "Current Search" DANSWER_TOOL_NAME = "Current Search"
@@ -20,6 +22,22 @@ DANSWER_SYSTEM_MSG = (
"It is used for a natural language search." "It is used for a natural language search."
) )
YES_SEARCH = "Yes Search"
NO_SEARCH = "No Search"
REQUIRE_DANSWER_SYSTEM_MSG = (
"You are a large language model whose only job is to determine if the system should call an external search tool "
"to be able to answer the user's last message.\n"
f'\nRespond with "{NO_SEARCH}" if:\n'
f"- there is sufficient information in chat history to fully answer the user query\n"
f"- there is enough knowledge in the LLM to fully answer the user query\n"
f"- the user query does not rely on any specific knowledge\n"
f'\nRespond with "{YES_SEARCH}" if:\n'
"- additional knowledge about entities, processes, problems, or anything else could lead to a better answer.\n"
"- there is some uncertainty what the user is referring to\n\n"
f'Respond with EXACTLY and ONLY "{YES_SEARCH}" or "{NO_SEARCH}"'
)
TOOL_TEMPLATE = """ TOOL_TEMPLATE = """
TOOLS TOOLS
------ ------
@@ -37,7 +55,7 @@ Use this if you want to use a tool. Markdown code snippet formatted in the follo
```json ```json
{{ {{
"action": string, \\ The action to take. Must be one of {tool_names} "action": string, \\ The action to take. {tool_names}
"action_input": string \\ The input to the action "action_input": string \\ The input to the action
}} }}
``` ```
@@ -88,6 +106,21 @@ IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a s
""" """
TOOL_LESS_FOLLOWUP = """
Refer to the following documents when responding to my final query. Ignore any documents that are not relevant.
CONTEXT DOCUMENTS:
---------------------
{context_str}
FINAL QUERY:
--------------------
{user_query}
{hint_text}
"""
def form_user_prompt_text( def form_user_prompt_text(
query: str, query: str,
tool_text: str | None, tool_text: str | None,
@@ -108,23 +141,30 @@ def form_user_prompt_text(
def form_tool_section_text( def form_tool_section_text(
tools: list[dict[str, str]], retrieval_enabled: bool, template: str = TOOL_TEMPLATE tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None: ) -> str | None:
if not tools and not retrieval_enabled: if not tools and not retrieval_enabled:
return None return None
if retrieval_enabled: if retrieval_enabled and tools:
tools.append( tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION} {"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
) )
tools_intro = [] tools_intro = []
for tool in tools: if tools:
description_formatted = tool["description"].replace("\n", " ") num_tools = len(tools)
tools_intro.append(f"> {tool['name']}: {description_formatted}") for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
tools_intro_text = "\n".join(tools_intro) prefix = "Must be one of " if num_tools > 1 else "Must be "
tool_names_text = ", ".join([tool["name"] for tool in tools])
tools_intro_text = "\n".join(tools_intro)
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
else:
return None
return template.format( return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text tool_overviews=tools_intro_text, tool_names=tool_names_text
@@ -184,10 +224,48 @@ def build_combined_query(
content=( content=(
"Help me rewrite this final message into a standalone query that takes into consideration the " "Help me rewrite this final message into a standalone query that takes into consideration the "
f"past messages of the conversation if relevant. This query is used with a semantic search engine to " f"past messages of the conversation if relevant. This query is used with a semantic search engine to "
f"retrieve documents. You must ONLY return the rewritten query and nothing else." f"retrieve documents. You must ONLY return the rewritten query and nothing else. "
f"Remember, the search engine does not have access to the conversation history!"
f"\n\nQuery:\n{query_message.message}" f"\n\nQuery:\n{query_message.message}"
) )
) )
) )
return combined_query_msgs return combined_query_msgs
def form_require_search_single_msg_text(
query_message: ChatMessage,
history: list[ChatMessage],
) -> str:
prompt = "MESSAGE_HISTORY\n---------------\n" if history else ""
for msg in history:
if msg.message_type == MessageType.ASSISTANT:
prefix = "AI"
else:
prefix = "User"
prompt += f"{prefix}:\n```\n{msg.message}\n```\n\n"
prompt += f"\nFINAL QUERY:\n---------------\n{query_message.message}"
return prompt
def form_require_search_text(query_message: ChatMessage) -> str:
return (
query_message.message
+ f"\n\nHint: respond with EXACTLY {YES_SEARCH} or {NO_SEARCH}"
)
def form_tool_less_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
) -> str:
hint = f"Hint: {hint_text}" if hint_text else ""
return tool_followup_prompt.format(
context_str=tool_output, user_query=query, hint_text=hint
).strip()

View File

@@ -1,10 +1,26 @@
from typing import Any
import yaml import yaml
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.chat.chat_prompts import form_tool_section_text
from danswer.configs.app_configs import PERSONAS_YAML from danswer.configs.app_configs import PERSONAS_YAML
from danswer.db.chat import create_persona from danswer.db.chat import upsert_persona
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import ToolInfo
def validate_tool_info(item: Any) -> ToolInfo:
if not (
isinstance(item, dict)
and "name" in item
and isinstance(item["name"], str)
and "description" in item
and isinstance(item["description"], str)
):
raise ValueError(
"Invalid Persona configuration yaml Found, not all tools have name/description"
)
return ToolInfo(name=item["name"], description=item["description"])
def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None: def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
@@ -14,15 +30,14 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
all_personas = data.get("personas", []) all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session: with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
for persona in all_personas: for persona in all_personas:
tools = form_tool_section_text( tools = [validate_tool_info(tool) for tool in persona["tools"]]
persona["tools"], persona["retrieval_enabled"]
) upsert_persona(
create_persona(
persona_id=persona["id"], persona_id=persona["id"],
name=persona["name"], name=persona["name"],
retrieval_enabled=persona["retrieval_enabled"], retrieval_enabled=persona["retrieval_enabled"],
system_text=persona["system"], system_text=persona["system"],
tools_text=tools, tools=tools,
hint_text=persona["hint"], hint_text=persona["hint"],
default_persona=True, default_persona=True,
db_session=db_session, db_session=db_session,

View File

@@ -7,7 +7,12 @@ personas:
Your responses are as INFORMATIVE and DETAILED as possible. Your responses are as INFORMATIVE and DETAILED as possible.
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled. # Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
retrieval_enabled: true retrieval_enabled: true
# Each added tool needs to have a "name" and "description" # Example of adding tools, it must follow this structure:
# tools:
# - name: "Calculator"
# description: "Use this tool to accurately process math equations, counting, etc."
# - name: "Current Time"
# description: "Call this to get the current date and time."
tools: [] tools: []
# Short tip to pass near the end of the prompt to emphasize some requirement # Short tip to pass near the end of the prompt to emphasize some requirement
hint: "Try to be as informative as possible!" hint: "Try to be as informative as possible!"

View File

@@ -13,6 +13,7 @@ from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession from danswer.db.models import ChatSession
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.db.models import ToolInfo
def fetch_chat_sessions_by_user( def fetch_chat_sessions_by_user(
@@ -261,12 +262,12 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
return persona return persona
def create_persona( def upsert_persona(
persona_id: int | None, persona_id: int | None,
name: str, name: str,
retrieval_enabled: bool, retrieval_enabled: bool,
system_text: str | None, system_text: str | None,
tools_text: str | None, tools: list[ToolInfo] | None,
hint_text: str | None, hint_text: str | None,
default_persona: bool, default_persona: bool,
db_session: Session, db_session: Session,
@@ -278,7 +279,7 @@ def create_persona(
persona.name = name persona.name = name
persona.retrieval_enabled = retrieval_enabled persona.retrieval_enabled = retrieval_enabled
persona.system_text = system_text persona.system_text = system_text
persona.tools_text = tools_text persona.tools = tools
persona.hint_text = hint_text persona.hint_text = hint_text
persona.default_persona = default_persona persona.default_persona = default_persona
else: else:
@@ -287,7 +288,7 @@ def create_persona(
name=name, name=name,
retrieval_enabled=retrieval_enabled, retrieval_enabled=retrieval_enabled,
system_text=system_text, system_text=system_text,
tools_text=tools_text, tools=tools,
hint_text=hint_text, hint_text=hint_text,
default_persona=default_persona, default_persona=default_persona,
) )

View File

@@ -439,6 +439,11 @@ class ChatSession(Base):
) )
class ToolInfo(TypedDict):
name: str
description: str
class Persona(Base): class Persona(Base):
# TODO introduce user and group ownership for personas # TODO introduce user and group ownership for personas
__tablename__ = "persona" __tablename__ = "persona"
@@ -447,7 +452,9 @@ class Persona(Base):
# Danswer retrieval, treated as a special tool # Danswer retrieval, treated as a special tool
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True) retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
system_text: Mapped[str | None] = mapped_column(Text, nullable=True) system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
tools_text: Mapped[str | None] = mapped_column(Text, nullable=True) tools: Mapped[list[ToolInfo] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True) hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
# Default personas are configured via backend during deployment # Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.) # Treated specially (cannot be user edited etc.)

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.db.chat import create_persona from danswer.db.chat import upsert_persona
from danswer.db.models import ChannelConfig from danswer.db.models import ChannelConfig
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet from danswer.db.models import Persona__DocumentSet
@@ -35,12 +35,12 @@ def _create_slack_bot_persona(
"""NOTE: does not commit changes""" """NOTE: does not commit changes"""
# create/update persona associated with the slack bot # create/update persona associated with the slack bot
persona_name = _build_persona_name(channel_names) persona_name = _build_persona_name(channel_names)
persona = create_persona( persona = upsert_persona(
persona_id=existing_persona_id, persona_id=existing_persona_id,
name=persona_name, name=persona_name,
retrieval_enabled=True, retrieval_enabled=True,
system_text=None, system_text=None,
tools_text=None, tools=None,
hint_text=None, hint_text=None,
default_persona=False, default_persona=False,
db_session=db_session, db_session=db_session,