From e549d2bb4a7e68a36e53207864b01dafc53841ec Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 15 Sep 2023 12:17:05 -0700 Subject: [PATCH] Chat with Context Backend (#441) --- .../8e26726b7683_chat_context_addition.py | 40 ++++ backend/danswer/chat/chat_llm.py | 152 +++++++++++- backend/danswer/chat/chat_prompts.py | 125 ++++++++++ backend/danswer/chat/personas.py | 26 +++ backend/danswer/chat/personas.yaml | 14 ++ backend/danswer/chat/tools.py | 43 ++++ backend/danswer/configs/app_configs.py | 6 +- backend/danswer/configs/constants.py | 14 ++ backend/danswer/db/chat.py | 45 ++++ backend/danswer/db/models.py | 18 ++ backend/danswer/direct_qa/interfaces.py | 7 + backend/danswer/direct_qa/qa_block.py | 13 +- backend/danswer/direct_qa/qa_prompts.py | 217 +----------------- backend/danswer/llm/utils.py | 19 ++ backend/danswer/main.py | 4 + .../secondary_llm_flows/answer_validation.py | 12 +- .../secondary_llm_flows/query_validation.py | 4 +- backend/danswer/server/chat_backend.py | 46 +++- backend/danswer/server/models.py | 9 +- backend/danswer/utils/text_processing.py | 16 ++ backend/requirements/default.txt | 2 +- backend/scripts/simulate_chat_frontend.py | 88 +++++++ 22 files changed, 672 insertions(+), 248 deletions(-) create mode 100644 backend/alembic/versions/8e26726b7683_chat_context_addition.py create mode 100644 backend/danswer/chat/chat_prompts.py create mode 100644 backend/danswer/chat/personas.py create mode 100644 backend/danswer/chat/personas.yaml create mode 100644 backend/danswer/chat/tools.py create mode 100644 backend/scripts/simulate_chat_frontend.py diff --git a/backend/alembic/versions/8e26726b7683_chat_context_addition.py b/backend/alembic/versions/8e26726b7683_chat_context_addition.py new file mode 100644 index 0000000000..b70fd1d4b1 --- /dev/null +++ b/backend/alembic/versions/8e26726b7683_chat_context_addition.py @@ -0,0 +1,40 @@ +"""Chat Context Addition + +Revision ID: 8e26726b7683 +Revises: 5809c0787398 +Create Date: 2023-09-13 18:34:31.327944 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "8e26726b7683" +down_revision = "5809c0787398" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "persona", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("system_text", sa.Text(), nullable=True), + sa.Column("tools_text", sa.Text(), nullable=True), + sa.Column("hint_text", sa.Text(), nullable=True), + sa.Column("default_persona", sa.Boolean(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.add_column("chat_message", sa.Column("persona_id", sa.Integer(), nullable=True)) + op.create_foreign_key( + "fk_chat_message_persona_id", "chat_message", "persona", ["persona_id"], ["id"] + ) + + +def downgrade() -> None: + op.drop_constraint("fk_chat_message_persona_id", "chat_message", type_="foreignkey") + op.drop_column("chat_message", "persona_id") + op.drop_table("persona") diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index e496afe15a..246cd3255b 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -1,27 +1,155 @@ from collections.abc import Iterator +from uuid import UUID from langchain.schema.messages import AIMessage from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage -from danswer.configs.constants import MessageType +from danswer.chat.chat_prompts import form_tool_followup_text +from danswer.chat.chat_prompts import form_user_prompt_text +from danswer.chat.tools import call_tool from danswer.db.models import ChatMessage +from danswer.db.models import Persona +from danswer.direct_qa.interfaces import DanswerAnswerPiece +from danswer.direct_qa.interfaces import DanswerChatModelOut from danswer.llm.build import get_default_llm +from danswer.llm.utils import translate_danswer_msg_to_langchain +from danswer.utils.text_processing import extract_embedded_json +from danswer.utils.text_processing import has_unescaped_quote -def llm_chat_answer(previous_messages: list[ChatMessage]) -> Iterator[str]: - prompt: list[BaseMessage] = [] - for msg in previous_messages: - content = msg.message - if msg.message_type == MessageType.SYSTEM: - prompt.append(SystemMessage(content=content)) - if msg.message_type == MessageType.ASSISTANT: - prompt.append(AIMessage(content=content)) +def _parse_embedded_json_streamed_response( + tokens: Iterator[str], +) -> Iterator[DanswerAnswerPiece | DanswerChatModelOut]: + final_answer = False + just_start_stream = False + model_output = "" + hold = "" + finding_end = 0 + for token in tokens: + model_output += token + hold += token + if ( - msg.message_type == MessageType.USER - or msg.message_type == MessageType.DANSWER # consider using FunctionMessage + final_answer is False + and '"action":"finalanswer",' in model_output.lower().replace(" ", "") ): - prompt.append(HumanMessage(content=content)) + final_answer = True + + if final_answer and '"actioninput":"' in model_output.lower().replace( + " ", "" + ).replace("_", ""): + if not just_start_stream: + just_start_stream = True + hold = "" + + if has_unescaped_quote(hold): + finding_end += 1 + hold = hold[: hold.find('"')] + + if finding_end <= 1: + if finding_end == 1: + finding_end += 1 + + yield DanswerAnswerPiece(answer_piece=hold) + hold = "" + + model_final = extract_embedded_json(model_output) + if "action" not in model_final or "action_input" not in model_final: + raise ValueError("Model did not provide all required action values") + + yield DanswerChatModelOut( + model_raw=model_output, + action=model_final["action"], + action_input=model_final["action_input"], + ) + return + + +def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]: + prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages] return get_default_llm().stream(prompt) + + +def llm_contextual_chat_answer( + messages: list[ChatMessage], + persona: Persona, + user_id: UUID | None, +) -> Iterator[str]: + system_text = persona.system_text + tool_text = persona.tools_text + hint_text = persona.hint_text + + last_message = messages[-1] + previous_messages = messages[:-1] + + user_text = form_user_prompt_text( + query=last_message.message, + tool_text=tool_text, + hint_text=hint_text, + ) + + prompt: list[BaseMessage] = [] + + if system_text: + prompt.append(SystemMessage(content=system_text)) + + prompt.extend( + [translate_danswer_msg_to_langchain(msg) for msg in previous_messages] + ) + + prompt.append(HumanMessage(content=user_text)) + + tokens = get_default_llm().stream(prompt) + + final_result: DanswerChatModelOut | None = None + final_answer_streamed = False + for result in _parse_embedded_json_streamed_response(tokens): + if isinstance(result, DanswerAnswerPiece) and result.answer_piece: + yield result.answer_piece + final_answer_streamed = True + + if isinstance(result, DanswerChatModelOut): + final_result = result + break + + if final_answer_streamed: + return + + if final_result is None: + raise RuntimeError("Model output finished without final output parsing.") + + tool_result_str = call_tool(final_result, user_id=user_id) + + prompt.append(AIMessage(content=final_result.model_raw)) + prompt.append( + HumanMessage( + content=form_tool_followup_text(tool_result_str, hint_text=hint_text) + ) + ) + + tokens = get_default_llm().stream(prompt) + + for result in _parse_embedded_json_streamed_response(tokens): + if isinstance(result, DanswerAnswerPiece) and result.answer_piece: + yield result.answer_piece + final_answer_streamed = True + + if final_answer_streamed is False: + raise RuntimeError("LLM failed to produce a Final Answer") + + +def llm_chat_answer( + messages: list[ChatMessage], persona: Persona | None, user_id: UUID | None +) -> Iterator[str]: + # TODO how to handle model giving jibberish or fail on a particular message + # TODO how to handle model failing to choose the right tool + # TODO how to handle model gives wrong format + if persona is None: + return llm_contextless_chat_answer(messages) + + return llm_contextual_chat_answer( + messages=messages, persona=persona, user_id=user_id + ) diff --git a/backend/danswer/chat/chat_prompts.py b/backend/danswer/chat/chat_prompts.py new file mode 100644 index 0000000000..d4478f1190 --- /dev/null +++ b/backend/danswer/chat/chat_prompts.py @@ -0,0 +1,125 @@ +from danswer.chunking.models import InferenceChunk +from danswer.configs.constants import CODE_BLOCK_PAT + +TOOL_TEMPLATE = """TOOLS +------ +Assistant can ask the user to use tools to look up information that may be helpful in answering the users \ +original question. The tools the human can use are: + +{} + +RESPONSE FORMAT INSTRUCTIONS +---------------------------- + +When responding to me, please output a response in one of two formats: + +**Option 1:** +Use this if you want the human to use a tool. +Markdown code snippet formatted in the following schema: + +```json +{{ + "action": string, \\ The action to take. Must be one of {} + "action_input": string \\ The input to the action +}} +``` + +**Option #2:** +Use this if you want to respond directly to the human. Markdown code snippet formatted in the following schema: + +```json +{{ + "action": "Final Answer", + "action_input": string \\ You should put what you want to return to use here +}} +``` +""" + +TOOL_LESS_PROMPT = """ +Respond with a markdown code snippet in the following schema: + +```json +{{ + "action": "Final Answer", + "action_input": string \\ You should put what you want to return to use here +}} +``` +""" + +USER_INPUT = """ +USER'S INPUT +-------------------- +Here is the user's input \ +(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else): + +{} +""" + +TOOL_FOLLOWUP = """ +TOOL RESPONSE: +--------------------- +{} + +USER'S INPUT +-------------------- +Okay, so what is the response to my last comment? If using information obtained from the tools you must \ +mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! +{} +IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else. +""" + + +def form_user_prompt_text( + query: str, + tool_text: str | None, + hint_text: str | None, + user_input_prompt: str = USER_INPUT, + tool_less_prompt: str = TOOL_LESS_PROMPT, +) -> str: + user_prompt = tool_text or tool_less_prompt + + user_prompt += user_input_prompt.format(query) + + if hint_text: + if user_prompt[-1] != "\n": + user_prompt += "\n" + user_prompt += "Hint: " + hint_text + + return user_prompt + + +def form_tool_section_text( + tools: list[dict[str, str]], template: str = TOOL_TEMPLATE +) -> str | None: + if not tools: + return None + + tools_intro = [] + for tool in tools: + description_formatted = tool["description"].replace("\n", " ") + tools_intro.append(f"> {tool['name']}: {description_formatted}") + + tools_intro_text = "\n".join(tools_intro) + tool_names_text = ", ".join([tool["name"] for tool in tools]) + + return template.format(tools_intro_text, tool_names_text) + + +def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str: + return "\n".join( + f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}" + for ind, chunk in enumerate(chunks, start=1) + ) + + +def form_tool_followup_text( + tool_output: str, + hint_text: str | None, + tool_followup_prompt: str = TOOL_FOLLOWUP, + ignore_hint: bool = False, +) -> str: + if not ignore_hint and hint_text: + hint_text_spaced = f"\n{hint_text}\n" + return tool_followup_prompt.format(tool_output, hint_text_spaced) + + return tool_followup_prompt.format(tool_output, "") diff --git a/backend/danswer/chat/personas.py b/backend/danswer/chat/personas.py new file mode 100644 index 0000000000..a823c1917b --- /dev/null +++ b/backend/danswer/chat/personas.py @@ -0,0 +1,26 @@ +import yaml +from sqlalchemy.orm import Session + +from danswer.chat.chat_prompts import form_tool_section_text +from danswer.configs.app_configs import PERSONAS_YAML +from danswer.db.chat import create_persona +from danswer.db.engine import get_sqlalchemy_engine + + +def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None: + with open(personas_yaml, "r") as file: + data = yaml.safe_load(file) + + all_personas = data.get("personas", []) + with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session: + for persona in all_personas: + tools = form_tool_section_text(persona["tools"]) + create_persona( + persona_id=persona["id"], + name=persona["name"], + system_text=persona["system"], + tools_text=tools, + hint_text=persona["hint"], + default_persona=True, + db_session=db_session, + ) diff --git a/backend/danswer/chat/personas.yaml b/backend/danswer/chat/personas.yaml new file mode 100644 index 0000000000..fe12d8a0d0 --- /dev/null +++ b/backend/danswer/chat/personas.yaml @@ -0,0 +1,14 @@ +personas: + - id: 1 + name: "Danswer" + system: | + You are a question answering system that is constantly learning and improving. + You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries. + You have access to a search tool and should always use it anytime it might be useful. + Your responses are as INFORMATIVE and DETAILED as possible. + tools: + - name: "Current Search" + description: | + A search for up to date and proprietary information. + This tool can handle natural language questions so it is better to ask very precise questions over keywords. + hint: "Use the Current Search tool to help find information on any topic!" diff --git a/backend/danswer/chat/tools.py b/backend/danswer/chat/tools.py new file mode 100644 index 0000000000..94829c076c --- /dev/null +++ b/backend/danswer/chat/tools.py @@ -0,0 +1,43 @@ +from uuid import UUID + +from danswer.chat.chat_prompts import format_danswer_chunks_for_chat +from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT +from danswer.configs.constants import IGNORE_FOR_QA +from danswer.datastores.document_index import get_default_document_index +from danswer.direct_qa.interfaces import DanswerChatModelOut +from danswer.direct_qa.qa_utils import get_usable_chunks +from danswer.search.semantic_search import retrieve_ranked_documents + + +def call_tool( + model_actions: DanswerChatModelOut, + user_id: UUID | None, +) -> str: + if model_actions.action.lower() == "current search": + query = model_actions.action_input + + ranked_chunks, unranked_chunks = retrieve_ranked_documents( + query, + user_id=user_id, + filters=None, + datastore=get_default_document_index(), + ) + if not ranked_chunks: + return "No results found" + + if unranked_chunks: + ranked_chunks.extend(unranked_chunks) + + filtered_ranked_chunks = [ + chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA) + ] + + # get all chunks that fit into the token limit + usable_chunks = get_usable_chunks( + chunks=filtered_ranked_chunks, + token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT, + ) + + return format_danswer_chunks_for_chat(usable_chunks) + + raise ValueError("Invalid tool choice by LLM") diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index e52d29d688..889d823e52 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -144,9 +144,12 @@ NUM_RERANKED_RESULTS = 15 NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int( os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5) ) +NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int( + os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (512 * 3) +) # 1 edit per 2 characters, currently unused due to fuzzy match being too slow QUOTE_ALLOWED_ERROR_PERCENT = 0.05 -QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds +QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds # Include additional document/chunk metadata in prompt to GenerativeAI INCLUDE_METADATA = False HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false" @@ -182,6 +185,7 @@ CROSS_ENCODER_PORT = 9000 ##### # Miscellaneous ##### +PERSONAS_YAML = "./danswer/chat/personas.yaml" DYNAMIC_CONFIG_STORE = os.environ.get( "DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore" ) diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index d99cded4c5..5b641699f6 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -25,6 +25,20 @@ BOOST = "boost" SCORE = "score" DEFAULT_BOOST = 0 +# Prompt building constants: +GENERAL_SEP_PAT = "\n-----\n" +CODE_BLOCK_PAT = "\n```\n{}\n```\n" +DOC_SEP_PAT = "---NEW DOCUMENT---" +DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n" +QUESTION_PAT = "Query:" +THOUGHT_PAT = "Thought:" +ANSWER_PAT = "Answer:" +FINAL_ANSWER_PAT = "Final Answer:" +UNCERTAINTY_PAT = "?" +QUOTE_PAT = "Quote:" +QUOTES_PAT_PLURAL = "Quotes:" +INVALID_PAT = "Invalid:" + class DocumentSource(str, Enum): SLACK = "slack" diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 13230929be..73ec16991e 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -12,6 +12,7 @@ from danswer.configs.app_configs import HARD_DELETE_CHATS from danswer.configs.constants import MessageType from danswer.db.models import ChatMessage from danswer.db.models import ChatSession +from danswer.db.models import Persona def fetch_chat_sessions_by_user( @@ -245,3 +246,47 @@ def set_latest_chat_message( ) db_session.commit() + + +def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona: + stmt = select(Persona).where(Persona.id == persona_id) + result = db_session.execute(stmt) + persona = result.scalar_one_or_none() + + if persona is None: + raise ValueError(f"Persona with ID {persona_id} does not exist") + + return persona + + +def create_persona( + persona_id: int | None, + name: str, + system_text: str | None, + tools_text: str | None, + hint_text: str | None, + default_persona: bool, + db_session: Session, +) -> Persona: + persona = db_session.query(Persona).filter_by(id=persona_id).first() + + if persona: + persona.name = name + persona.system_text = system_text + persona.tools_text = tools_text + persona.hint_text = hint_text + persona.default_persona = default_persona + else: + persona = Persona( + id=persona_id, + name=name, + system_text=system_text, + tools_text=tools_text, + hint_text=hint_text, + default_persona=default_persona, + ) + db_session.add(persona) + + db_session.commit() + + return persona diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 9f5ca1fc8f..7712c250da 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -333,6 +333,7 @@ class ChatSession(Base): user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) description: Mapped[str] = mapped_column(Text) deleted: Mapped[bool] = mapped_column(Boolean, default=False) + # The following texts help build up the model's ability to use the context effectively time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -348,6 +349,19 @@ class ChatSession(Base): ) +class Persona(Base): + # TODO introduce user and group ownership for personas + __tablename__ = "persona" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String) + system_text: Mapped[str | None] = mapped_column(Text, nullable=True) + tools_text: Mapped[str | None] = mapped_column(Text, nullable=True) + hint_text: Mapped[str | None] = mapped_column(Text, nullable=True) + default_persona: Mapped[bool] = mapped_column(Boolean, default=False) + # If it's updated and no longer latest (should no longer be shown), it is also considered deleted + deleted: Mapped[bool] = mapped_column(Boolean, default=False) + + class ChatMessage(Base): __tablename__ = "chat_message" @@ -362,8 +376,12 @@ class ChatMessage(Base): latest: Mapped[bool] = mapped_column(Boolean, default=True) message: Mapped[str] = mapped_column(Text) message_type: Mapped[MessageType] = mapped_column(Enum(MessageType)) + persona_id: Mapped[int | None] = mapped_column( + ForeignKey("persona.id"), nullable=True + ) time_sent: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) chat_session: Mapped[ChatSession] = relationship("ChatSession") + persona: Mapped[Persona | None] = relationship("Persona") diff --git a/backend/danswer/direct_qa/interfaces.py b/backend/danswer/direct_qa/interfaces.py index 1fa5e3ea2d..9bd3cfc449 100644 --- a/backend/danswer/direct_qa/interfaces.py +++ b/backend/danswer/direct_qa/interfaces.py @@ -10,6 +10,13 @@ class DanswerAnswer: answer: str | None +@dataclass +class DanswerChatModelOut: + model_raw: str + action: str + action_input: str + + @dataclass class DanswerAnswerPiece: """A small piece of a complete answer. Used for streaming back answers.""" diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index f6aee0f5ea..3cc2d5d62f 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -10,19 +10,18 @@ from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage from danswer.chunking.models import InferenceChunk +from danswer.configs.constants import CODE_BLOCK_PAT +from danswer.configs.constants import GENERAL_SEP_PAT +from danswer.configs.constants import QUESTION_PAT +from danswer.configs.constants import THOUGHT_PAT +from danswer.configs.constants import UNCERTAINTY_PAT from danswer.direct_qa.interfaces import AnswerQuestionReturn from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerQuotes from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON -from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT from danswer.direct_qa.qa_prompts import JsonChatProcessor -from danswer.direct_qa.qa_prompts import QUESTION_PAT -from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE -from danswer.direct_qa.qa_prompts import THOUGHT_PAT -from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor from danswer.direct_qa.qa_utils import process_answer from danswer.direct_qa.qa_utils import process_model_tokens @@ -193,7 +192,7 @@ class JsonChatQAUnshackledHandler(QAHandler): "should be in JSON format and contain an answer and (optionally) quotes that help support the answer. " "Your responses should be informative, detailed, and consider all possibilities and edge cases. " f"If you don't know the answer, respond with '{complete_answer_not_found_response}'\n" - f"Sample response:\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}" + f"Sample response:\n\n{json.dumps(EMPTY_SAMPLE_JSON)}" ) ) ) diff --git a/backend/danswer/direct_qa/qa_prompts.py b/backend/danswer/direct_qa/qa_prompts.py index 5eb6813fbd..aa7d5d01d5 100644 --- a/backend/danswer/direct_qa/qa_prompts.py +++ b/backend/danswer/direct_qa/qa_prompts.py @@ -2,23 +2,17 @@ import abc import json from danswer.chunking.models import InferenceChunk +from danswer.configs.constants import ANSWER_PAT +from danswer.configs.constants import DOC_CONTENT_START_PAT +from danswer.configs.constants import DOC_SEP_PAT from danswer.configs.constants import DocumentSource +from danswer.configs.constants import GENERAL_SEP_PAT +from danswer.configs.constants import QUESTION_PAT +from danswer.configs.constants import QUOTE_PAT +from danswer.configs.constants import UNCERTAINTY_PAT from danswer.connectors.factory import identify_connector_class -GENERAL_SEP_PAT = "\n-----\n" -CODE_BLOCK_PAT = "\n```\n{}\n```\n" -DOC_SEP_PAT = "---NEW DOCUMENT---" -DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n" -QUESTION_PAT = "Query:" -THOUGHT_PAT = "Thought:" -ANSWER_PAT = "Answer:" -FINAL_ANSWER_PAT = "Final Answer:" -UNCERTAINTY_PAT = "?" -QUOTE_PAT = "Quote:" -QUOTES_PAT_PLURAL = "Quotes:" -INVALID_PAT = "Invalid:" - BASE_PROMPT = ( "Answer the query based on provided documents and quote relevant sections. " "Respond with a json containing a concise answer and up to three most relevant quotes from the documents. " @@ -26,16 +20,6 @@ BASE_PROMPT = ( "The quotes must be EXACT substrings from the documents." ) -SAMPLE_QUESTION = "Where is the Eiffel Tower?" - -SAMPLE_JSON_RESPONSE = { - "answer": "The Eiffel Tower is located in Paris, France.", - "quotes": [ - "The Eiffel Tower is an iconic symbol of Paris", - "located on the Champ de Mars in France.", - ], -} - EMPTY_SAMPLE_JSON = { "answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.", "quotes": [ @@ -44,16 +28,6 @@ EMPTY_SAMPLE_JSON = { ], } -ANSWER_NOT_FOUND_JSON = '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}' - -SAMPLE_RESPONSE_COT = ( - "Let's think step by step. The user is asking for the " - "location of the Eiffel Tower. The first document describes the Eiffel Tower " - "as being an iconic symbol of Paris and that it is located on the Champ de Mars. " - "Since the Champ de Mars is in Paris, we know that the Eiffel Tower is in Paris." - f"\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}" -) - def _append_acknowledge_doc_messages( current_messages: list[dict[str, str]], new_chunk_content: str @@ -152,7 +126,7 @@ class JsonProcessor(NonChatPromptProcessor): question: str, chunks: list[InferenceChunk], include_metadata: bool = False ) -> str: prompt = ( - BASE_PROMPT + f" Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n" + BASE_PROMPT + f" Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}\n\n" f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' ) @@ -203,7 +177,7 @@ class JsonChatProcessor(ChatPromptProcessor): f"{complete_answer_not_found_response}\n" "If the query requires aggregating the number of documents, respond with " '{"answer": "Aggregations not supported", "quotes": []}\n' - f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}" + f"Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}" ) messages = [{"role": "system", "content": intro_msg}] for chunk in chunks: @@ -221,65 +195,6 @@ class JsonChatProcessor(ChatPromptProcessor): return messages -class JsonCoTChatProcessor(ChatPromptProcessor): - """Pros: improves performance slightly over the regular JsonChatProcessor. - Cons: Much slower. - """ - - @property - def specifies_json_output(self) -> bool: - return True - - @staticmethod - def fill_prompt( - question: str, - chunks: list[InferenceChunk], - include_metadata: bool = True, - ) -> list[dict[str, str]]: - metadata_prompt_section = ( - "with metadata and contents " if include_metadata else "" - ) - intro_msg = ( - f"You are a Question Answering assistant that answers queries " - f"based on the provided documents.\n" - f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".' - ) - - complete_answer_not_found_response = ( - '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}' - ) - task_msg = ( - "Now answer the user query based on documents above and quote relevant sections.\n" - "When answering, you should think step by step, and verbalize your thought process.\n" - "Then respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n" - "All quotes MUST be EXACT substrings from provided documents.\n" - "Your responses should be informative, detailed, and consider all possibilities and edge cases.\n" - "You MUST prioritize information from provided documents over internal knowledge.\n" - "If the query cannot be answered based on the documents, respond with " - f"{complete_answer_not_found_response}\n" - "If the query requires aggregating the number of documents, respond with " - '{"answer": "Aggregations not supported", "quotes": []}\n' - f"Sample response:\n\n{SAMPLE_RESPONSE_COT}" - ) - messages = [{"role": "system", "content": intro_msg}] - - for chunk in chunks: - full_context = "" - if include_metadata: - full_context = _add_metadata_section( - full_context, chunk, prepend_tab=False, include_sep=False - ) - full_context += chunk.content - messages = _append_acknowledge_doc_messages(messages, full_context) - messages.append({"role": "user", "content": task_msg}) - - messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n\n"}) - - messages.append({"role": "user", "content": "Let's think step by step."}) - - return messages - - class WeakModelFreeformProcessor(NonChatPromptProcessor): """Avoid using this one if the model is capable of using another prompt Intended for models that can't follow complex instructions or have short context windows @@ -366,117 +281,3 @@ class FreeformProcessor(NonChatPromptProcessor): prompt += f"{QUESTION_PAT}\n{question}\n" prompt += f"{ANSWER_PAT}\n" return prompt - - -class FreeformChatProcessor(ChatPromptProcessor): - @property - def specifies_json_output(self) -> bool: - return False - - @staticmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> list[dict[str, str]]: - sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each." - role_msg = ( - f"You are a Question Answering assistant that answers queries based on provided documents. " - f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and ' - f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. ' - f"Answer the question directly and concisely. " - f"Each quote should be a single continuous segment from a document. " - f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". ' - f"An example of quote sections may look like:\n{sample_quote}" - ) - - messages = [ - {"role": "system", "content": role_msg}, - ] - for chunk in chunks: - messages = _append_acknowledge_doc_messages(messages, chunk.content) - - messages.append( - { - "role": "user", - "content": f"Please now answer the following query based on the previously provided " - f"documents and quote the relevant sections of the documents\n{question}", - }, - ) - - return messages - - -class JsonCOTProcessor(NonChatPromptProcessor): - """Chain of Thought allows model to explain out its reasoning to handle harder tests. - This prompt type works however has higher token cost (more expensive) and is slower. - Consider this one if users ask questions that require logical reasoning.""" - - @property - def specifies_json_output(self) -> bool: - return True - - @staticmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> str: - prompt = ( - f"Answer the query based on provided documents and quote relevant sections. " - f'Respond with a freeform reasoning section followed by "Final Answer:" with a ' - f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n" - f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n" - f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' - ) - - for chunk in chunks: - prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}" - - prompt += "\n\n---\n\n" - prompt += f"{QUESTION_PAT}\n{question}\n" - prompt += "Reasoning:\n" - return prompt - - -class JsonReflexionProcessor(NonChatPromptProcessor): - """Reflexion prompting to attempt to have model evaluate its own answer. - This one seems largely useless when only given a single example - Model seems to take the one example of answering "Yes" and just does that too.""" - - @property - def specifies_json_output(self) -> bool: - return True - - @staticmethod - def fill_prompt( - question: str, chunks: list[InferenceChunk], include_metadata: bool = False - ) -> str: - reflexion_str = "Does this fully answer the user query?" - prompt = ( - BASE_PROMPT - + f'After each generated json, ask "{reflexion_str}" and respond Yes or No. ' - f"If No, generate a better json response to the query.\n" - f"Sample question and response:\n" - f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n" - f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n" - f"{reflexion_str} Yes\n\n" - f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' - ) - - for chunk in chunks: - prompt += f"\n---NEW CONTEXT DOCUMENT---\n{chunk.content}" - - prompt += "\n\n---\n\n" - prompt += f"{QUESTION_PAT}\n{question}\n" - return prompt - - -def get_json_chat_reflexion_msg() -> dict[str, str]: - """With the models tried (curent as of Jul 2023), this has not been very useful. - Have not seen any answers improved based on this. - For models like gpt-3.5-turbo, it will often answer something like: - 'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source - document. No revision is needed.'""" - reflexion_content = ( - "Is the assistant response a valid json that fully answer the user query? " - "If the response needs to be fixed or if an improvement is possible, provide a revised json. " - "Otherwise, respond with the same json." - ) - return {"role": "system", "content": reflexion_content} diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index fba529a5dc..f3ddd384a0 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -13,6 +13,25 @@ from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage from danswer.configs.app_configs import LOG_LEVEL +from danswer.configs.constants import MessageType +from danswer.db.models import ChatMessage + + +def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage: + if ( + msg.message_type == MessageType.SYSTEM + or msg.message_type == MessageType.DANSWER + ): + # TODO save at least the Danswer responses to postgres + raise ValueError( + "System and Danswer messages are not currently part of history" + ) + if msg.message_type == MessageType.ASSISTANT: + return AIMessage(content=msg.message) + if msg.message_type == MessageType.USER: + return HumanMessage(content=msg.message) + + raise ValueError(f"New message type {msg.message_type} not handled") def dict_based_prompt_to_langchain_prompt( diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 845fc95bfb..3430f9f009 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -12,6 +12,7 @@ from danswer.auth.schemas import UserUpdate from danswer.auth.users import auth_backend from danswer.auth.users import fastapi_users from danswer.auth.users import oauth_client +from danswer.chat.personas import load_personas_from_yaml from danswer.configs.app_configs import APP_HOST from danswer.configs.app_configs import APP_PORT from danswer.configs.app_configs import DISABLE_AUTH @@ -191,6 +192,9 @@ def get_application() -> FastAPI: logger.info("Verifying public credential exists.") create_initial_public_credential() + logger.info("Loading default Chat Personas") + load_personas_from_yaml() + logger.info("Verifying Document Index(s) is/are available.") get_default_document_index().ensure_indices_exist() diff --git a/backend/danswer/secondary_llm_flows/answer_validation.py b/backend/danswer/secondary_llm_flows/answer_validation.py index 6fe74aa4a2..345530b11d 100644 --- a/backend/danswer/secondary_llm_flows/answer_validation.py +++ b/backend/danswer/secondary_llm_flows/answer_validation.py @@ -1,10 +1,10 @@ +from danswer.configs.constants import ANSWER_PAT +from danswer.configs.constants import CODE_BLOCK_PAT +from danswer.configs.constants import GENERAL_SEP_PAT +from danswer.configs.constants import INVALID_PAT +from danswer.configs.constants import QUESTION_PAT +from danswer.configs.constants import THOUGHT_PAT from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt -from danswer.direct_qa.qa_prompts import ANSWER_PAT -from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT -from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT -from danswer.direct_qa.qa_prompts import INVALID_PAT -from danswer.direct_qa.qa_prompts import QUESTION_PAT -from danswer.direct_qa.qa_prompts import THOUGHT_PAT from danswer.llm.build import get_default_llm from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 07f908c4b8..8c85cf512e 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -2,10 +2,10 @@ import re from collections.abc import Iterator from dataclasses import asdict +from danswer.configs.constants import CODE_BLOCK_PAT +from danswer.configs.constants import GENERAL_SEP_PAT from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt -from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT -from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT from danswer.llm.build import get_default_llm from danswer.server.models import QueryValidationResponse from danswer.server.utils import get_json_line diff --git a/backend/danswer/server/chat_backend.py b/backend/danswer/server/chat_backend.py index b74d083de1..820854eb6e 100644 --- a/backend/danswer/server/chat_backend.py +++ b/backend/danswer/server/chat_backend.py @@ -16,6 +16,7 @@ from danswer.db.chat import fetch_chat_message from danswer.db.chat import fetch_chat_messages_by_session from danswer.db.chat import fetch_chat_session_by_id from danswer.db.chat import fetch_chat_sessions_by_user +from danswer.db.chat import fetch_persona_by_id from danswer.db.chat import set_latest_chat_message from danswer.db.chat import update_chat_session from danswer.db.chat import verify_parent_exists @@ -29,8 +30,9 @@ from danswer.server.models import ChatMessageIdentifier from danswer.server.models import ChatRenameRequest from danswer.server.models import ChatSessionDetailResponse from danswer.server.models import ChatSessionIdsResponse -from danswer.server.models import CreateChatID -from danswer.server.models import CreateChatRequest +from danswer.server.models import CreateChatMessageRequest +from danswer.server.models import CreateChatSessionID +from danswer.server.models import RegenerateMessageRequest from danswer.server.models import RenameChatSessionResponse from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger @@ -108,14 +110,16 @@ def get_chat_session_messages( def create_new_chat_session( user: User | None = Depends(current_user), db_session: Session = Depends(get_session), -) -> CreateChatID: +) -> CreateChatSessionID: user_id = user.id if user is not None else None new_chat_session = create_chat_session( - "", user_id, db_session # Leave the naming till later to prevent delay + "", + user_id, + db_session, # Leave the naming till later to prevent delay ) - return CreateChatID(chat_session_id=new_chat_session.id) + return CreateChatSessionID(chat_session_id=new_chat_session.id) @router.put("/rename-chat-session") @@ -182,9 +186,19 @@ def _create_chat_chain( return mainline_messages +def _return_one_if_any(str_1: str | None, str_2: str | None) -> str | None: + if str_1 is not None and str_2 is not None: + raise ValueError("Conflicting values, can only set one") + if str_1 is not None: + return str_1 + if str_2 is not None: + return str_2 + return None + + @router.post("/send-message") def handle_new_chat_message( - chat_message: CreateChatRequest, + chat_message: CreateChatMessageRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> StreamingResponse: @@ -198,6 +212,11 @@ def handle_new_chat_message( user_id = user.id if user is not None else None chat_session = fetch_chat_session_by_id(chat_session_id, db_session) + persona = ( + fetch_persona_by_id(chat_message.persona_id, db_session) + if chat_message.persona_id is not None + else None + ) if chat_session.deleted: raise ValueError("Cannot send messages to a deleted chat session") @@ -248,7 +267,9 @@ def handle_new_chat_message( @log_generator_function_time() def stream_chat_tokens() -> Iterator[str]: - tokens = llm_chat_answer(mainline_messages) + tokens = llm_chat_answer( + messages=mainline_messages, persona=persona, user_id=user_id + ) llm_output = "" for token in tokens: llm_output += token @@ -268,7 +289,7 @@ def handle_new_chat_message( @router.post("/regenerate-from-parent") def regenerate_message_given_parent( - parent_message: ChatMessageIdentifier, + parent_message: RegenerateMessageRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> StreamingResponse: @@ -288,6 +309,11 @@ def regenerate_message_given_parent( ) chat_session = chat_message.chat_session + persona = ( + fetch_persona_by_id(parent_message.persona_id, db_session) + if parent_message.persona_id is not None + else None + ) if chat_session.deleted: raise ValueError("Chat session has been deleted") @@ -317,7 +343,9 @@ def regenerate_message_given_parent( @log_generator_function_time() def stream_regenerate_tokens() -> Iterator[str]: - tokens = llm_chat_answer(mainline_messages) + tokens = llm_chat_answer( + messages=mainline_messages, persona=persona, user_id=user_id + ) llm_output = "" for token in tokens: llm_output += token diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 6f14c9a3bd..6cca836ff9 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -134,7 +134,7 @@ class SearchDoc(BaseModel): match_highlights: list[str] -class CreateChatID(BaseModel): +class CreateChatSessionID(BaseModel): chat_session_id: int @@ -159,11 +159,12 @@ class SearchFeedbackRequest(BaseModel): search_feedback: SearchFeedbackType -class CreateChatRequest(BaseModel): +class CreateChatMessageRequest(BaseModel): chat_session_id: int message_number: int parent_edit_number: int | None message: str + persona_id: int | None class ChatMessageIdentifier(BaseModel): @@ -172,6 +173,10 @@ class ChatMessageIdentifier(BaseModel): edit_number: int +class RegenerateMessageRequest(ChatMessageIdentifier): + persona_id: int | None + + class ChatRenameRequest(BaseModel): chat_session_id: int name: str | None diff --git a/backend/danswer/utils/text_processing.py b/backend/danswer/utils/text_processing.py index 696f13391e..c5e28ff2d5 100644 --- a/backend/danswer/utils/text_processing.py +++ b/backend/danswer/utils/text_processing.py @@ -1,13 +1,29 @@ +import json import re import bs4 from bs4 import BeautifulSoup +def has_unescaped_quote(s: str) -> bool: + pattern = r'(? str: return re.sub(r"(? dict: + first_brace_index = s.find("{") + last_brace_index = s.rfind("}") + + if first_brace_index == -1 or last_brace_index == -1: + raise ValueError("No valid json found") + + return json.loads(s[first_brace_index : last_brace_index + 1]) + + def clean_up_code_blocks(model_out_raw: str) -> str: return model_out_raw.strip().strip("```").strip() diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 245153f2a9..5921b4888a 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -46,7 +46,7 @@ safetensors==0.3.1 sentence-transformers==2.2.2 slack-sdk==3.20.2 SQLAlchemy[mypy]==2.0.12 -tensorflow==2.12.0 +tensorflow==2.13.0 tiktoken==0.4.0 transformers==4.30.1 typesense==0.15.1 diff --git a/backend/scripts/simulate_chat_frontend.py b/backend/scripts/simulate_chat_frontend.py new file mode 100644 index 0000000000..f671dff7a8 --- /dev/null +++ b/backend/scripts/simulate_chat_frontend.py @@ -0,0 +1,88 @@ +# This file is purely for development use, not included in any builds +# Use this to test the chat feature with and without context. +# With context refers to being able to call out to Danswer and other tools (currently no other tools) +# Without context refers to only knowing the chat's own history with no additional information +# This script does not allow for branching logic that is supported by the backend APIs +# This script also does not allow for editing/regeneration of user/model messages +# Have Danswer API server running to use this. +import argparse +import json + +import requests + +from danswer.configs.app_configs import APP_PORT + +LOCAL_CHAT_ENDPOINT = f"http://127.0.0.1:{APP_PORT}/chat/" + + +def create_new_session() -> int: + response = requests.post(LOCAL_CHAT_ENDPOINT + "create-chat-session") + response.raise_for_status() + new_session_id = response.json()["chat_session_id"] + return new_session_id + + +def send_chat_message( + message: str, + chat_session_id: int, + message_number: int, + parent_edit_number: int | None, + persona_id: int | None, +) -> None: + data = { + "message": message, + "chat_session_id": chat_session_id, + "message_number": message_number, + "parent_edit_number": parent_edit_number, + "persona_id": persona_id, + } + + with requests.post( + LOCAL_CHAT_ENDPOINT + "send-message", json=data, stream=True + ) as r: + for json_response in r.iter_lines(): + response_text = json.loads(json_response.decode()) + new_token = response_text.get("answer_piece") + print(new_token, end="", flush=True) + print() + + +def run_chat(contextual: bool) -> None: + try: + new_session_id = create_new_session() + except requests.exceptions.ConnectionError: + print( + "Looks like you haven't started the Danswer Backend server, please run the FastAPI server" + ) + exit() + + persona_id = 1 if contextual else None + + message_num = 0 + parent_edit = None + while True: + new_message = input( + "\n\n----------------------------------\n" + "Please provide a new chat message:\n> " + ) + + send_chat_message( + new_message, new_session_id, message_num, parent_edit, persona_id + ) + + message_num += 2 # 1 for User message, 1 for AI response + parent_edit = 0 # Since no edits, the parent message is always the first edit of that message number + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--contextual", + action="store_true", + help="If this flag is set, the chat is able to call tools.", + ) + args = parser.parse_args() + + contextual = args.contextual + run_chat(contextual)