From 0632e921448547a6627892cb942a485119885dd8 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 5 Oct 2023 21:44:13 -0700 Subject: [PATCH] Chat without tools should use a less complex prompt (#524) --- .../904451035c9b_store_tool_details.py | 32 ++++++ backend/danswer/chat/chat_llm.py | 106 ++++++++++++++++-- backend/danswer/chat/chat_prompts.py | 96 ++++++++++++++-- backend/danswer/chat/personas.py | 29 +++-- backend/danswer/chat/personas.yaml | 7 +- backend/danswer/db/chat.py | 9 +- backend/danswer/db/models.py | 9 +- backend/danswer/db/slack_bot_config.py | 6 +- 8 files changed, 262 insertions(+), 32 deletions(-) create mode 100644 backend/alembic/versions/904451035c9b_store_tool_details.py diff --git a/backend/alembic/versions/904451035c9b_store_tool_details.py b/backend/alembic/versions/904451035c9b_store_tool_details.py new file mode 100644 index 000000000000..28b21e8c691b --- /dev/null +++ b/backend/alembic/versions/904451035c9b_store_tool_details.py @@ -0,0 +1,32 @@ +"""Store Tool Details + +Revision ID: 904451035c9b +Revises: e0a68a81d434 +Create Date: 2023-10-05 12:29:26.620000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "904451035c9b" +down_revision = "e0a68a81d434" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "persona", + sa.Column("tools", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + ) + op.drop_column("persona", "tools_text") + + +def downgrade() -> None: + op.add_column( + "persona", + sa.Column("tools_text", sa.TEXT(), autoincrement=False, nullable=True), + ) + op.drop_column("persona", "tools") diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index fe302a4c62e6..87d34c8b7630 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -9,9 +9,14 @@ from langchain.schema.messages import SystemMessage from danswer.chat.chat_prompts import build_combined_query from danswer.chat.chat_prompts import DANSWER_TOOL_NAME +from danswer.chat.chat_prompts import form_require_search_text from danswer.chat.chat_prompts import form_tool_followup_text +from danswer.chat.chat_prompts import form_tool_less_followup_text +from danswer.chat.chat_prompts import form_tool_section_text from danswer.chat.chat_prompts import form_user_prompt_text from danswer.chat.chat_prompts import format_danswer_chunks_for_chat +from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG +from danswer.chat.chat_prompts import YES_SEARCH from danswer.chat.tools import call_tool from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT from danswer.configs.constants import IGNORE_FOR_QA @@ -175,8 +180,8 @@ def _drop_messages_history_overflow( def llm_contextless_chat_answer( messages: list[ChatMessage], - tokenizer: Callable | None = None, system_text: str | None = None, + tokenizer: Callable | None = None, ) -> Iterator[str]: try: prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages] @@ -213,11 +218,92 @@ def llm_contextual_chat_answer( persona: Persona, user_id: UUID | None, tokenizer: Callable, + run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG, +) -> Iterator[str]: + last_message = messages[-1] + final_query_text = last_message.message + previous_messages = messages[:-1] + previous_msgs_as_basemessage = [ + translate_danswer_msg_to_langchain(msg) for msg in previous_messages + ] + + try: + llm = get_default_llm() + + if not final_query_text: + raise ValueError("User chat message is empty.") + + # Determine if a search is necessary to answer the user query + user_req_search_text = form_require_search_text(last_message) + last_user_msg = HumanMessage(content=user_req_search_text) + + previous_msg_token_counts = [msg.token_count for msg in previous_messages] + danswer_system_tokens = len(tokenizer(run_search_system_text)) + last_user_msg_tokens = len(tokenizer(user_req_search_text)) + + need_search_prompt = _drop_messages_history_overflow( + system_msg=SystemMessage(content=run_search_system_text), + system_token_count=danswer_system_tokens, + history_msgs=previous_msgs_as_basemessage, + history_token_counts=previous_msg_token_counts, + final_msg=last_user_msg, + final_msg_token_count=last_user_msg_tokens, + ) + + # Good Debug/Breakpoint + model_out = llm.invoke(need_search_prompt) + + # Model will output "Yes Search" if search is useful + # Be a little forgiving though, if we match yes, it's good enough + if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower(): + tool_result_str = danswer_chat_retrieval( + query_message=last_message, + history=previous_messages, + llm=llm, + user_id=user_id, + ) + last_user_msg_text = form_tool_less_followup_text( + tool_output=tool_result_str, + query=last_message.message, + hint_text=persona.hint_text, + ) + last_user_msg_tokens = len(tokenizer(last_user_msg_text)) + last_user_msg = HumanMessage(content=last_user_msg_text) + + else: + last_user_msg_tokens = len(tokenizer(final_query_text)) + last_user_msg = HumanMessage(content=final_query_text) + + system_text = persona.system_text + system_msg = SystemMessage(content=system_text) if system_text else None + system_tokens = len(tokenizer(system_text)) if system_text else 0 + + prompt = _drop_messages_history_overflow( + system_msg=system_msg, + system_token_count=system_tokens, + history_msgs=previous_msgs_as_basemessage, + history_token_counts=previous_msg_token_counts, + final_msg=last_user_msg, + final_msg_token_count=last_user_msg_tokens, + ) + + return llm.stream(prompt) + + except Exception as e: + logger.error(f"LLM failed to produce valid chat message, error: {e}") + return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator + + +def llm_tools_enabled_chat_answer( + messages: list[ChatMessage], + persona: Persona, + user_id: UUID | None, + tokenizer: Callable, ) -> Iterator[str]: retrieval_enabled = persona.retrieval_enabled system_text = persona.system_text - tool_text = persona.tools_text hint_text = persona.hint_text + tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled) last_message = messages[-1] previous_messages = messages[:-1] @@ -351,11 +437,17 @@ def llm_chat_answer( if persona is None: return llm_contextless_chat_answer(messages) - elif persona.retrieval_enabled is False and persona.tools_text is None: + elif persona.retrieval_enabled is False and not persona.tools: return llm_contextless_chat_answer( - messages, tokenizer, system_text=persona.system_text + messages, system_text=persona.system_text, tokenizer=tokenizer ) - return llm_contextual_chat_answer( - messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer - ) + elif persona.retrieval_enabled and not persona.tools: + return llm_contextual_chat_answer( + messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer + ) + + else: + return llm_tools_enabled_chat_answer( + messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer + ) diff --git a/backend/danswer/chat/chat_prompts.py b/backend/danswer/chat/chat_prompts.py index a8af1ce4cca2..05c29354e684 100644 --- a/backend/danswer/chat/chat_prompts.py +++ b/backend/danswer/chat/chat_prompts.py @@ -4,7 +4,9 @@ from langchain.schema.messages import SystemMessage from danswer.chunking.models import InferenceChunk from danswer.configs.constants import CODE_BLOCK_PAT +from danswer.configs.constants import MessageType from danswer.db.models import ChatMessage +from danswer.db.models import ToolInfo from danswer.llm.utils import translate_danswer_msg_to_langchain DANSWER_TOOL_NAME = "Current Search" @@ -20,6 +22,22 @@ DANSWER_SYSTEM_MSG = ( "It is used for a natural language search." ) + +YES_SEARCH = "Yes Search" +NO_SEARCH = "No Search" +REQUIRE_DANSWER_SYSTEM_MSG = ( + "You are a large language model whose only job is to determine if the system should call an external search tool " + "to be able to answer the user's last message.\n" + f'\nRespond with "{NO_SEARCH}" if:\n' + f"- there is sufficient information in chat history to fully answer the user query\n" + f"- there is enough knowledge in the LLM to fully answer the user query\n" + f"- the user query does not rely on any specific knowledge\n" + f'\nRespond with "{YES_SEARCH}" if:\n' + "- additional knowledge about entities, processes, problems, or anything else could lead to a better answer.\n" + "- there is some uncertainty what the user is referring to\n\n" + f'Respond with EXACTLY and ONLY "{YES_SEARCH}" or "{NO_SEARCH}"' +) + TOOL_TEMPLATE = """ TOOLS ------ @@ -37,7 +55,7 @@ Use this if you want to use a tool. Markdown code snippet formatted in the follo ```json {{ - "action": string, \\ The action to take. Must be one of {tool_names} + "action": string, \\ The action to take. {tool_names} "action_input": string \\ The input to the action }} ``` @@ -88,6 +106,21 @@ IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a s """ +TOOL_LESS_FOLLOWUP = """ +Refer to the following documents when responding to my final query. Ignore any documents that are not relevant. + +CONTEXT DOCUMENTS: +--------------------- +{context_str} + +FINAL QUERY: +-------------------- +{user_query} + +{hint_text} +""" + + def form_user_prompt_text( query: str, tool_text: str | None, @@ -108,23 +141,30 @@ def form_user_prompt_text( def form_tool_section_text( - tools: list[dict[str, str]], retrieval_enabled: bool, template: str = TOOL_TEMPLATE + tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE ) -> str | None: if not tools and not retrieval_enabled: return None - if retrieval_enabled: + if retrieval_enabled and tools: tools.append( {"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION} ) tools_intro = [] - for tool in tools: - description_formatted = tool["description"].replace("\n", " ") - tools_intro.append(f"> {tool['name']}: {description_formatted}") + if tools: + num_tools = len(tools) + for tool in tools: + description_formatted = tool["description"].replace("\n", " ") + tools_intro.append(f"> {tool['name']}: {description_formatted}") - tools_intro_text = "\n".join(tools_intro) - tool_names_text = ", ".join([tool["name"] for tool in tools]) + prefix = "Must be one of " if num_tools > 1 else "Must be " + + tools_intro_text = "\n".join(tools_intro) + tool_names_text = prefix + ", ".join([tool["name"] for tool in tools]) + + else: + return None return template.format( tool_overviews=tools_intro_text, tool_names=tool_names_text @@ -184,10 +224,48 @@ def build_combined_query( content=( "Help me rewrite this final message into a standalone query that takes into consideration the " f"past messages of the conversation if relevant. This query is used with a semantic search engine to " - f"retrieve documents. You must ONLY return the rewritten query and nothing else." + f"retrieve documents. You must ONLY return the rewritten query and nothing else. " + f"Remember, the search engine does not have access to the conversation history!" f"\n\nQuery:\n{query_message.message}" ) ) ) return combined_query_msgs + + +def form_require_search_single_msg_text( + query_message: ChatMessage, + history: list[ChatMessage], +) -> str: + prompt = "MESSAGE_HISTORY\n---------------\n" if history else "" + + for msg in history: + if msg.message_type == MessageType.ASSISTANT: + prefix = "AI" + else: + prefix = "User" + prompt += f"{prefix}:\n```\n{msg.message}\n```\n\n" + + prompt += f"\nFINAL QUERY:\n---------------\n{query_message.message}" + + return prompt + + +def form_require_search_text(query_message: ChatMessage) -> str: + return ( + query_message.message + + f"\n\nHint: respond with EXACTLY {YES_SEARCH} or {NO_SEARCH}" + ) + + +def form_tool_less_followup_text( + tool_output: str, + query: str, + hint_text: str | None, + tool_followup_prompt: str = TOOL_LESS_FOLLOWUP, +) -> str: + hint = f"Hint: {hint_text}" if hint_text else "" + return tool_followup_prompt.format( + context_str=tool_output, user_query=query, hint_text=hint + ).strip() diff --git a/backend/danswer/chat/personas.py b/backend/danswer/chat/personas.py index 9639c0c63bf2..227ce6123c90 100644 --- a/backend/danswer/chat/personas.py +++ b/backend/danswer/chat/personas.py @@ -1,10 +1,26 @@ +from typing import Any + import yaml from sqlalchemy.orm import Session -from danswer.chat.chat_prompts import form_tool_section_text from danswer.configs.app_configs import PERSONAS_YAML -from danswer.db.chat import create_persona +from danswer.db.chat import upsert_persona from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import ToolInfo + + +def validate_tool_info(item: Any) -> ToolInfo: + if not ( + isinstance(item, dict) + and "name" in item + and isinstance(item["name"], str) + and "description" in item + and isinstance(item["description"], str) + ): + raise ValueError( + "Invalid Persona configuration yaml Found, not all tools have name/description" + ) + return ToolInfo(name=item["name"], description=item["description"]) def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None: @@ -14,15 +30,14 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None: all_personas = data.get("personas", []) with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session: for persona in all_personas: - tools = form_tool_section_text( - persona["tools"], persona["retrieval_enabled"] - ) - create_persona( + tools = [validate_tool_info(tool) for tool in persona["tools"]] + + upsert_persona( persona_id=persona["id"], name=persona["name"], retrieval_enabled=persona["retrieval_enabled"], system_text=persona["system"], - tools_text=tools, + tools=tools, hint_text=persona["hint"], default_persona=True, db_session=db_session, diff --git a/backend/danswer/chat/personas.yaml b/backend/danswer/chat/personas.yaml index d0bb2f7d6a9d..4c956182bf69 100644 --- a/backend/danswer/chat/personas.yaml +++ b/backend/danswer/chat/personas.yaml @@ -7,7 +7,12 @@ personas: Your responses are as INFORMATIVE and DETAILED as possible. # Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled. retrieval_enabled: true - # Each added tool needs to have a "name" and "description" + # Example of adding tools, it must follow this structure: + # tools: + # - name: "Calculator" + # description: "Use this tool to accurately process math equations, counting, etc." + # - name: "Current Time" + # description: "Call this to get the current date and time." tools: [] # Short tip to pass near the end of the prompt to emphasize some requirement hint: "Try to be as informative as possible!" diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 1e2b57cccae7..616cffef76da 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -13,6 +13,7 @@ from danswer.configs.constants import MessageType from danswer.db.models import ChatMessage from danswer.db.models import ChatSession from danswer.db.models import Persona +from danswer.db.models import ToolInfo def fetch_chat_sessions_by_user( @@ -261,12 +262,12 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona: return persona -def create_persona( +def upsert_persona( persona_id: int | None, name: str, retrieval_enabled: bool, system_text: str | None, - tools_text: str | None, + tools: list[ToolInfo] | None, hint_text: str | None, default_persona: bool, db_session: Session, @@ -278,7 +279,7 @@ def create_persona( persona.name = name persona.retrieval_enabled = retrieval_enabled persona.system_text = system_text - persona.tools_text = tools_text + persona.tools = tools persona.hint_text = hint_text persona.default_persona = default_persona else: @@ -287,7 +288,7 @@ def create_persona( name=name, retrieval_enabled=retrieval_enabled, system_text=system_text, - tools_text=tools_text, + tools=tools, hint_text=hint_text, default_persona=default_persona, ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 5ad9d8b44cf3..07a2fb1dec17 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -439,6 +439,11 @@ class ChatSession(Base): ) +class ToolInfo(TypedDict): + name: str + description: str + + class Persona(Base): # TODO introduce user and group ownership for personas __tablename__ = "persona" @@ -447,7 +452,9 @@ class Persona(Base): # Danswer retrieval, treated as a special tool retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True) system_text: Mapped[str | None] = mapped_column(Text, nullable=True) - tools_text: Mapped[str | None] = mapped_column(Text, nullable=True) + tools: Mapped[list[ToolInfo] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) hint_text: Mapped[str | None] = mapped_column(Text, nullable=True) # Default personas are configured via backend during deployment # Treated specially (cannot be user edited etc.) diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 3e03a1fc21fb..6ba73e87ccce 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from sqlalchemy import select from sqlalchemy.orm import Session -from danswer.db.chat import create_persona +from danswer.db.chat import upsert_persona from danswer.db.models import ChannelConfig from danswer.db.models import Persona from danswer.db.models import Persona__DocumentSet @@ -35,12 +35,12 @@ def _create_slack_bot_persona( """NOTE: does not commit changes""" # create/update persona associated with the slack bot persona_name = _build_persona_name(channel_names) - persona = create_persona( + persona = upsert_persona( persona_id=existing_persona_id, name=persona_name, retrieval_enabled=True, system_text=None, - tools_text=None, + tools=None, hint_text=None, default_persona=False, db_session=db_session,