Chat with Context Backend (#441)

This commit is contained in:
Yuhong Sun 2023-09-15 12:17:05 -07:00 committed by GitHub
parent a16ce56f6b
commit e549d2bb4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 672 additions and 248 deletions

View File

@ -0,0 +1,40 @@
"""Chat Context Addition
Revision ID: 8e26726b7683
Revises: 5809c0787398
Create Date: 2023-09-13 18:34:31.327944
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8e26726b7683"
down_revision = "5809c0787398"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"persona",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("system_text", sa.Text(), nullable=True),
sa.Column("tools_text", sa.Text(), nullable=True),
sa.Column("hint_text", sa.Text(), nullable=True),
sa.Column("default_persona", sa.Boolean(), nullable=False),
sa.Column("deleted", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("chat_message", sa.Column("persona_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_chat_message_persona_id", "chat_message", "persona", ["persona_id"], ["id"]
)
def downgrade() -> None:
op.drop_constraint("fk_chat_message_persona_id", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "persona_id")
op.drop_table("persona")

View File

@ -1,27 +1,155 @@
from collections.abc import Iterator
from uuid import UUID
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.configs.constants import MessageType
from danswer.chat.chat_prompts import form_tool_followup_text
from danswer.chat.chat_prompts import form_user_prompt_text
from danswer.chat.tools import call_tool
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerChatModelOut
from danswer.llm.build import get_default_llm
from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.text_processing import has_unescaped_quote
def llm_chat_answer(previous_messages: list[ChatMessage]) -> Iterator[str]:
prompt: list[BaseMessage] = []
for msg in previous_messages:
content = msg.message
if msg.message_type == MessageType.SYSTEM:
prompt.append(SystemMessage(content=content))
if msg.message_type == MessageType.ASSISTANT:
prompt.append(AIMessage(content=content))
def _parse_embedded_json_streamed_response(
tokens: Iterator[str],
) -> Iterator[DanswerAnswerPiece | DanswerChatModelOut]:
final_answer = False
just_start_stream = False
model_output = ""
hold = ""
finding_end = 0
for token in tokens:
model_output += token
hold += token
if (
msg.message_type == MessageType.USER
or msg.message_type == MessageType.DANSWER # consider using FunctionMessage
final_answer is False
and '"action":"finalanswer",' in model_output.lower().replace(" ", "")
):
prompt.append(HumanMessage(content=content))
final_answer = True
if final_answer and '"actioninput":"' in model_output.lower().replace(
" ", ""
).replace("_", ""):
if not just_start_stream:
just_start_stream = True
hold = ""
if has_unescaped_quote(hold):
finding_end += 1
hold = hold[: hold.find('"')]
if finding_end <= 1:
if finding_end == 1:
finding_end += 1
yield DanswerAnswerPiece(answer_piece=hold)
hold = ""
model_final = extract_embedded_json(model_output)
if "action" not in model_final or "action_input" not in model_final:
raise ValueError("Model did not provide all required action values")
yield DanswerChatModelOut(
model_raw=model_output,
action=model_final["action"],
action_input=model_final["action_input"],
)
return
def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]:
prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages]
return get_default_llm().stream(prompt)
def llm_contextual_chat_answer(
messages: list[ChatMessage],
persona: Persona,
user_id: UUID | None,
) -> Iterator[str]:
system_text = persona.system_text
tool_text = persona.tools_text
hint_text = persona.hint_text
last_message = messages[-1]
previous_messages = messages[:-1]
user_text = form_user_prompt_text(
query=last_message.message,
tool_text=tool_text,
hint_text=hint_text,
)
prompt: list[BaseMessage] = []
if system_text:
prompt.append(SystemMessage(content=system_text))
prompt.extend(
[translate_danswer_msg_to_langchain(msg) for msg in previous_messages]
)
prompt.append(HumanMessage(content=user_text))
tokens = get_default_llm().stream(prompt)
final_result: DanswerChatModelOut | None = None
final_answer_streamed = False
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result.answer_piece
final_answer_streamed = True
if isinstance(result, DanswerChatModelOut):
final_result = result
break
if final_answer_streamed:
return
if final_result is None:
raise RuntimeError("Model output finished without final output parsing.")
tool_result_str = call_tool(final_result, user_id=user_id)
prompt.append(AIMessage(content=final_result.model_raw))
prompt.append(
HumanMessage(
content=form_tool_followup_text(tool_result_str, hint_text=hint_text)
)
)
tokens = get_default_llm().stream(prompt)
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result.answer_piece
final_answer_streamed = True
if final_answer_streamed is False:
raise RuntimeError("LLM failed to produce a Final Answer")
def llm_chat_answer(
messages: list[ChatMessage], persona: Persona | None, user_id: UUID | None
) -> Iterator[str]:
# TODO how to handle model giving jibberish or fail on a particular message
# TODO how to handle model failing to choose the right tool
# TODO how to handle model gives wrong format
if persona is None:
return llm_contextless_chat_answer(messages)
return llm_contextual_chat_answer(
messages=messages, persona=persona, user_id=user_id
)

View File

@ -0,0 +1,125 @@
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import CODE_BLOCK_PAT
TOOL_TEMPLATE = """TOOLS
------
Assistant can ask the user to use tools to look up information that may be helpful in answering the users \
original question. The tools the human can use are:
{}
RESPONSE FORMAT INSTRUCTIONS
----------------------------
When responding to me, please output a response in one of two formats:
**Option 1:**
Use this if you want the human to use a tool.
Markdown code snippet formatted in the following schema:
```json
{{
"action": string, \\ The action to take. Must be one of {}
"action_input": string \\ The input to the action
}}
```
**Option #2:**
Use this if you want to respond directly to the human. Markdown code snippet formatted in the following schema:
```json
{{
"action": "Final Answer",
"action_input": string \\ You should put what you want to return to use here
}}
```
"""
TOOL_LESS_PROMPT = """
Respond with a markdown code snippet in the following schema:
```json
{{
"action": "Final Answer",
"action_input": string \\ You should put what you want to return to use here
}}
```
"""
USER_INPUT = """
USER'S INPUT
--------------------
Here is the user's input \
(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
{}
"""
TOOL_FOLLOWUP = """
TOOL RESPONSE:
---------------------
{}
USER'S INPUT
--------------------
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
{}
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
"""
def form_user_prompt_text(
query: str,
tool_text: str | None,
hint_text: str | None,
user_input_prompt: str = USER_INPUT,
tool_less_prompt: str = TOOL_LESS_PROMPT,
) -> str:
user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(query)
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "Hint: " + hint_text
return user_prompt
def form_tool_section_text(
tools: list[dict[str, str]], template: str = TOOL_TEMPLATE
) -> str | None:
if not tools:
return None
tools_intro = []
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
tools_intro_text = "\n".join(tools_intro)
tool_names_text = ", ".join([tool["name"] for tool in tools])
return template.format(tools_intro_text, tool_names_text)
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
return "\n".join(
f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}"
for ind, chunk in enumerate(chunks, start=1)
)
def form_tool_followup_text(
tool_output: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False,
) -> str:
if not ignore_hint and hint_text:
hint_text_spaced = f"\n{hint_text}\n"
return tool_followup_prompt.format(tool_output, hint_text_spaced)
return tool_followup_prompt.format(tool_output, "")

View File

@ -0,0 +1,26 @@
import yaml
from sqlalchemy.orm import Session
from danswer.chat.chat_prompts import form_tool_section_text
from danswer.configs.app_configs import PERSONAS_YAML
from danswer.db.chat import create_persona
from danswer.db.engine import get_sqlalchemy_engine
def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
with open(personas_yaml, "r") as file:
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
for persona in all_personas:
tools = form_tool_section_text(persona["tools"])
create_persona(
persona_id=persona["id"],
name=persona["name"],
system_text=persona["system"],
tools_text=tools,
hint_text=persona["hint"],
default_persona=True,
db_session=db_session,
)

View File

@ -0,0 +1,14 @@
personas:
- id: 1
name: "Danswer"
system: |
You are a question answering system that is constantly learning and improving.
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
You have access to a search tool and should always use it anytime it might be useful.
Your responses are as INFORMATIVE and DETAILED as possible.
tools:
- name: "Current Search"
description: |
A search for up to date and proprietary information.
This tool can handle natural language questions so it is better to ask very precise questions over keywords.
hint: "Use the Current Search tool to help find information on any topic!"

View File

@ -0,0 +1,43 @@
from uuid import UUID
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index
from danswer.direct_qa.interfaces import DanswerChatModelOut
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.semantic_search import retrieve_ranked_documents
def call_tool(
model_actions: DanswerChatModelOut,
user_id: UUID | None,
) -> str:
if model_actions.action.lower() == "current search":
query = model_actions.action_input
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query,
user_id=user_id,
filters=None,
datastore=get_default_document_index(),
)
if not ranked_chunks:
return "No results found"
if unranked_chunks:
ranked_chunks.extend(unranked_chunks)
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
)
return format_danswer_chunks_for_chat(usable_chunks)
raise ValueError("Invalid tool choice by LLM")

View File

@ -144,9 +144,12 @@ NUM_RERANKED_RESULTS = 15
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
)
NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (512 * 3)
)
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
@ -182,6 +185,7 @@ CROSS_ENCODER_PORT = 9000
#####
# Miscellaneous
#####
PERSONAS_YAML = "./danswer/chat/personas.yaml"
DYNAMIC_CONFIG_STORE = os.environ.get(
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
)

View File

@ -25,6 +25,20 @@ BOOST = "boost"
SCORE = "score"
DEFAULT_BOOST = 0
# Prompt building constants:
GENERAL_SEP_PAT = "\n-----\n"
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
DOC_SEP_PAT = "---NEW DOCUMENT---"
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
QUESTION_PAT = "Query:"
THOUGHT_PAT = "Thought:"
ANSWER_PAT = "Answer:"
FINAL_ANSWER_PAT = "Final Answer:"
UNCERTAINTY_PAT = "?"
QUOTE_PAT = "Quote:"
QUOTES_PAT_PLURAL = "Quotes:"
INVALID_PAT = "Invalid:"
class DocumentSource(str, Enum):
SLACK = "slack"

View File

@ -12,6 +12,7 @@ from danswer.configs.app_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import Persona
def fetch_chat_sessions_by_user(
@ -245,3 +246,47 @@ def set_latest_chat_message(
)
db_session.commit()
def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id)
result = db_session.execute(stmt)
persona = result.scalar_one_or_none()
if persona is None:
raise ValueError(f"Persona with ID {persona_id} does not exist")
return persona
def create_persona(
persona_id: int | None,
name: str,
system_text: str | None,
tools_text: str | None,
hint_text: str | None,
default_persona: bool,
db_session: Session,
) -> Persona:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
if persona:
persona.name = name
persona.system_text = system_text
persona.tools_text = tools_text
persona.hint_text = hint_text
persona.default_persona = default_persona
else:
persona = Persona(
id=persona_id,
name=name,
system_text=system_text,
tools_text=tools_text,
hint_text=hint_text,
default_persona=default_persona,
)
db_session.add(persona)
db_session.commit()
return persona

View File

@ -333,6 +333,7 @@ class ChatSession(Base):
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
description: Mapped[str] = mapped_column(Text)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
# The following texts help build up the model's ability to use the context effectively
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@ -348,6 +349,19 @@ class ChatSession(Base):
)
class Persona(Base):
# TODO introduce user and group ownership for personas
__tablename__ = "persona"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String)
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
tools_text: Mapped[str | None] = mapped_column(Text, nullable=True)
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# If it's updated and no longer latest (should no longer be shown), it is also considered deleted
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
class ChatMessage(Base):
__tablename__ = "chat_message"
@ -362,8 +376,12 @@ class ChatMessage(Base):
latest: Mapped[bool] = mapped_column(Boolean, default=True)
message: Mapped[str] = mapped_column(Text)
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
time_sent: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
chat_session: Mapped[ChatSession] = relationship("ChatSession")
persona: Mapped[Persona | None] = relationship("Persona")

View File

@ -10,6 +10,13 @@ class DanswerAnswer:
answer: str | None
@dataclass
class DanswerChatModelOut:
model_raw: str
action: str
action_input: str
@dataclass
class DanswerAnswerPiece:
"""A small piece of a complete answer. Used for streaming back answers."""

View File

@ -10,19 +10,18 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.configs.constants import GENERAL_SEP_PAT
from danswer.configs.constants import QUESTION_PAT
from danswer.configs.constants import THOUGHT_PAT
from danswer.configs.constants import UNCERTAINTY_PAT
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import QUESTION_PAT
from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
@ -193,7 +192,7 @@ class JsonChatQAUnshackledHandler(QAHandler):
"should be in JSON format and contain an answer and (optionally) quotes that help support the answer. "
"Your responses should be informative, detailed, and consider all possibilities and edge cases. "
f"If you don't know the answer, respond with '{complete_answer_not_found_response}'\n"
f"Sample response:\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
f"Sample response:\n\n{json.dumps(EMPTY_SAMPLE_JSON)}"
)
)
)

View File

@ -2,23 +2,17 @@ import abc
import json
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import ANSWER_PAT
from danswer.configs.constants import DOC_CONTENT_START_PAT
from danswer.configs.constants import DOC_SEP_PAT
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import GENERAL_SEP_PAT
from danswer.configs.constants import QUESTION_PAT
from danswer.configs.constants import QUOTE_PAT
from danswer.configs.constants import UNCERTAINTY_PAT
from danswer.connectors.factory import identify_connector_class
GENERAL_SEP_PAT = "\n-----\n"
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
DOC_SEP_PAT = "---NEW DOCUMENT---"
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
QUESTION_PAT = "Query:"
THOUGHT_PAT = "Thought:"
ANSWER_PAT = "Answer:"
FINAL_ANSWER_PAT = "Final Answer:"
UNCERTAINTY_PAT = "?"
QUOTE_PAT = "Quote:"
QUOTES_PAT_PLURAL = "Quotes:"
INVALID_PAT = "Invalid:"
BASE_PROMPT = (
"Answer the query based on provided documents and quote relevant sections. "
"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
@ -26,16 +20,6 @@ BASE_PROMPT = (
"The quotes must be EXACT substrings from the documents."
)
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
SAMPLE_JSON_RESPONSE = {
"answer": "The Eiffel Tower is located in Paris, France.",
"quotes": [
"The Eiffel Tower is an iconic symbol of Paris",
"located on the Champ de Mars in France.",
],
}
EMPTY_SAMPLE_JSON = {
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
"quotes": [
@ -44,16 +28,6 @@ EMPTY_SAMPLE_JSON = {
],
}
ANSWER_NOT_FOUND_JSON = '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
SAMPLE_RESPONSE_COT = (
"Let's think step by step. The user is asking for the "
"location of the Eiffel Tower. The first document describes the Eiffel Tower "
"as being an iconic symbol of Paris and that it is located on the Champ de Mars. "
"Since the Champ de Mars is in Paris, we know that the Eiffel Tower is in Paris."
f"\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
def _append_acknowledge_doc_messages(
current_messages: list[dict[str, str]], new_chunk_content: str
@ -152,7 +126,7 @@ class JsonProcessor(NonChatPromptProcessor):
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str:
prompt = (
BASE_PROMPT + f" Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
BASE_PROMPT + f" Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
@ -203,7 +177,7 @@ class JsonChatProcessor(ChatPromptProcessor):
f"{complete_answer_not_found_response}\n"
"If the query requires aggregating the number of documents, respond with "
'{"answer": "Aggregations not supported", "quotes": []}\n'
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
f"Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}"
)
messages = [{"role": "system", "content": intro_msg}]
for chunk in chunks:
@ -221,65 +195,6 @@ class JsonChatProcessor(ChatPromptProcessor):
return messages
class JsonCoTChatProcessor(ChatPromptProcessor):
"""Pros: improves performance slightly over the regular JsonChatProcessor.
Cons: Much slower.
"""
@property
def specifies_json_output(self) -> bool:
return True
@staticmethod
def fill_prompt(
question: str,
chunks: list[InferenceChunk],
include_metadata: bool = True,
) -> list[dict[str, str]]:
metadata_prompt_section = (
"with metadata and contents " if include_metadata else ""
)
intro_msg = (
f"You are a Question Answering assistant that answers queries "
f"based on the provided documents.\n"
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
)
complete_answer_not_found_response = (
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
)
task_msg = (
"Now answer the user query based on documents above and quote relevant sections.\n"
"When answering, you should think step by step, and verbalize your thought process.\n"
"Then respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
"All quotes MUST be EXACT substrings from provided documents.\n"
"Your responses should be informative, detailed, and consider all possibilities and edge cases.\n"
"You MUST prioritize information from provided documents over internal knowledge.\n"
"If the query cannot be answered based on the documents, respond with "
f"{complete_answer_not_found_response}\n"
"If the query requires aggregating the number of documents, respond with "
'{"answer": "Aggregations not supported", "quotes": []}\n'
f"Sample response:\n\n{SAMPLE_RESPONSE_COT}"
)
messages = [{"role": "system", "content": intro_msg}]
for chunk in chunks:
full_context = ""
if include_metadata:
full_context = _add_metadata_section(
full_context, chunk, prepend_tab=False, include_sep=False
)
full_context += chunk.content
messages = _append_acknowledge_doc_messages(messages, full_context)
messages.append({"role": "user", "content": task_msg})
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n\n"})
messages.append({"role": "user", "content": "Let's think step by step."})
return messages
class WeakModelFreeformProcessor(NonChatPromptProcessor):
"""Avoid using this one if the model is capable of using another prompt
Intended for models that can't follow complex instructions or have short context windows
@ -366,117 +281,3 @@ class FreeformProcessor(NonChatPromptProcessor):
prompt += f"{QUESTION_PAT}\n{question}\n"
prompt += f"{ANSWER_PAT}\n"
return prompt
class FreeformChatProcessor(ChatPromptProcessor):
@property
def specifies_json_output(self) -> bool:
return False
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> list[dict[str, str]]:
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
role_msg = (
f"You are a Question Answering assistant that answers queries based on provided documents. "
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
f"Answer the question directly and concisely. "
f"Each quote should be a single continuous segment from a document. "
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
f"An example of quote sections may look like:\n{sample_quote}"
)
messages = [
{"role": "system", "content": role_msg},
]
for chunk in chunks:
messages = _append_acknowledge_doc_messages(messages, chunk.content)
messages.append(
{
"role": "user",
"content": f"Please now answer the following query based on the previously provided "
f"documents and quote the relevant sections of the documents\n{question}",
},
)
return messages
class JsonCOTProcessor(NonChatPromptProcessor):
"""Chain of Thought allows model to explain out its reasoning to handle harder tests.
This prompt type works however has higher token cost (more expensive) and is slower.
Consider this one if users ask questions that require logical reasoning."""
@property
def specifies_json_output(self) -> bool:
return True
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str:
prompt = (
f"Answer the query based on provided documents and quote relevant sections. "
f'Respond with a freeform reasoning section followed by "Final Answer:" with a '
f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n"
f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
for chunk in chunks:
prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}"
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
prompt += "Reasoning:\n"
return prompt
class JsonReflexionProcessor(NonChatPromptProcessor):
"""Reflexion prompting to attempt to have model evaluate its own answer.
This one seems largely useless when only given a single example
Model seems to take the one example of answering "Yes" and just does that too."""
@property
def specifies_json_output(self) -> bool:
return True
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str:
reflexion_str = "Does this fully answer the user query?"
prompt = (
BASE_PROMPT
+ f'After each generated json, ask "{reflexion_str}" and respond Yes or No. '
f"If No, generate a better json response to the query.\n"
f"Sample question and response:\n"
f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n"
f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n"
f"{reflexion_str} Yes\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
for chunk in chunks:
prompt += f"\n---NEW CONTEXT DOCUMENT---\n{chunk.content}"
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
return prompt
def get_json_chat_reflexion_msg() -> dict[str, str]:
"""With the models tried (curent as of Jul 2023), this has not been very useful.
Have not seen any answers improved based on this.
For models like gpt-3.5-turbo, it will often answer something like:
'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source
document. No revision is needed.'"""
reflexion_content = (
"Is the assistant response a valid json that fully answer the user query? "
"If the response needs to be fixed or if an improvement is possible, provide a revised json. "
"Otherwise, respond with the same json."
)
return {"role": "system", "content": reflexion_content}

View File

@ -13,6 +13,25 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage:
if (
msg.message_type == MessageType.SYSTEM
or msg.message_type == MessageType.DANSWER
):
# TODO save at least the Danswer responses to postgres
raise ValueError(
"System and Danswer messages are not currently part of history"
)
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=msg.message)
if msg.message_type == MessageType.USER:
return HumanMessage(content=msg.message)
raise ValueError(f"New message type {msg.message_type} not handled")
def dict_based_prompt_to_langchain_prompt(

View File

@ -12,6 +12,7 @@ from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.auth.users import oauth_client
from danswer.chat.personas import load_personas_from_yaml
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import DISABLE_AUTH
@ -191,6 +192,9 @@ def get_application() -> FastAPI:
logger.info("Verifying public credential exists.")
create_initial_public_credential()
logger.info("Loading default Chat Personas")
load_personas_from_yaml()
logger.info("Verifying Document Index(s) is/are available.")
get_default_document_index().ensure_indices_exist()

View File

@ -1,10 +1,10 @@
from danswer.configs.constants import ANSWER_PAT
from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.configs.constants import GENERAL_SEP_PAT
from danswer.configs.constants import INVALID_PAT
from danswer.configs.constants import QUESTION_PAT
from danswer.configs.constants import THOUGHT_PAT
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
from danswer.direct_qa.qa_prompts import INVALID_PAT
from danswer.direct_qa.qa_prompts import QUESTION_PAT
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
from danswer.llm.build import get_default_llm
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time

View File

@ -2,10 +2,10 @@ import re
from collections.abc import Iterator
from dataclasses import asdict
from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.configs.constants import GENERAL_SEP_PAT
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
from danswer.llm.build import get_default_llm
from danswer.server.models import QueryValidationResponse
from danswer.server.utils import get_json_line

View File

@ -16,6 +16,7 @@ from danswer.db.chat import fetch_chat_message
from danswer.db.chat import fetch_chat_messages_by_session
from danswer.db.chat import fetch_chat_session_by_id
from danswer.db.chat import fetch_chat_sessions_by_user
from danswer.db.chat import fetch_persona_by_id
from danswer.db.chat import set_latest_chat_message
from danswer.db.chat import update_chat_session
from danswer.db.chat import verify_parent_exists
@ -29,8 +30,9 @@ from danswer.server.models import ChatMessageIdentifier
from danswer.server.models import ChatRenameRequest
from danswer.server.models import ChatSessionDetailResponse
from danswer.server.models import ChatSessionIdsResponse
from danswer.server.models import CreateChatID
from danswer.server.models import CreateChatRequest
from danswer.server.models import CreateChatMessageRequest
from danswer.server.models import CreateChatSessionID
from danswer.server.models import RegenerateMessageRequest
from danswer.server.models import RenameChatSessionResponse
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
@ -108,14 +110,16 @@ def get_chat_session_messages(
def create_new_chat_session(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CreateChatID:
) -> CreateChatSessionID:
user_id = user.id if user is not None else None
new_chat_session = create_chat_session(
"", user_id, db_session # Leave the naming till later to prevent delay
"",
user_id,
db_session, # Leave the naming till later to prevent delay
)
return CreateChatID(chat_session_id=new_chat_session.id)
return CreateChatSessionID(chat_session_id=new_chat_session.id)
@router.put("/rename-chat-session")
@ -182,9 +186,19 @@ def _create_chat_chain(
return mainline_messages
def _return_one_if_any(str_1: str | None, str_2: str | None) -> str | None:
if str_1 is not None and str_2 is not None:
raise ValueError("Conflicting values, can only set one")
if str_1 is not None:
return str_1
if str_2 is not None:
return str_2
return None
@router.post("/send-message")
def handle_new_chat_message(
chat_message: CreateChatRequest,
chat_message: CreateChatMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
@ -198,6 +212,11 @@ def handle_new_chat_message(
user_id = user.id if user is not None else None
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
persona = (
fetch_persona_by_id(chat_message.persona_id, db_session)
if chat_message.persona_id is not None
else None
)
if chat_session.deleted:
raise ValueError("Cannot send messages to a deleted chat session")
@ -248,7 +267,9 @@ def handle_new_chat_message(
@log_generator_function_time()
def stream_chat_tokens() -> Iterator[str]:
tokens = llm_chat_answer(mainline_messages)
tokens = llm_chat_answer(
messages=mainline_messages, persona=persona, user_id=user_id
)
llm_output = ""
for token in tokens:
llm_output += token
@ -268,7 +289,7 @@ def handle_new_chat_message(
@router.post("/regenerate-from-parent")
def regenerate_message_given_parent(
parent_message: ChatMessageIdentifier,
parent_message: RegenerateMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
@ -288,6 +309,11 @@ def regenerate_message_given_parent(
)
chat_session = chat_message.chat_session
persona = (
fetch_persona_by_id(parent_message.persona_id, db_session)
if parent_message.persona_id is not None
else None
)
if chat_session.deleted:
raise ValueError("Chat session has been deleted")
@ -317,7 +343,9 @@ def regenerate_message_given_parent(
@log_generator_function_time()
def stream_regenerate_tokens() -> Iterator[str]:
tokens = llm_chat_answer(mainline_messages)
tokens = llm_chat_answer(
messages=mainline_messages, persona=persona, user_id=user_id
)
llm_output = ""
for token in tokens:
llm_output += token

View File

@ -134,7 +134,7 @@ class SearchDoc(BaseModel):
match_highlights: list[str]
class CreateChatID(BaseModel):
class CreateChatSessionID(BaseModel):
chat_session_id: int
@ -159,11 +159,12 @@ class SearchFeedbackRequest(BaseModel):
search_feedback: SearchFeedbackType
class CreateChatRequest(BaseModel):
class CreateChatMessageRequest(BaseModel):
chat_session_id: int
message_number: int
parent_edit_number: int | None
message: str
persona_id: int | None
class ChatMessageIdentifier(BaseModel):
@ -172,6 +173,10 @@ class ChatMessageIdentifier(BaseModel):
edit_number: int
class RegenerateMessageRequest(ChatMessageIdentifier):
persona_id: int | None
class ChatRenameRequest(BaseModel):
chat_session_id: int
name: str | None

View File

@ -1,13 +1,29 @@
import json
import re
import bs4
from bs4 import BeautifulSoup
def has_unescaped_quote(s: str) -> bool:
pattern = r'(?<!\\)"'
return bool(re.search(pattern, s))
def escape_newlines(s: str) -> str:
return re.sub(r"(?<!\\)\n", "\\\\n", s)
def extract_embedded_json(s: str) -> dict:
first_brace_index = s.find("{")
last_brace_index = s.rfind("}")
if first_brace_index == -1 or last_brace_index == -1:
raise ValueError("No valid json found")
return json.loads(s[first_brace_index : last_brace_index + 1])
def clean_up_code_blocks(model_out_raw: str) -> str:
return model_out_raw.strip().strip("```").strip()

View File

@ -46,7 +46,7 @@ safetensors==0.3.1
sentence-transformers==2.2.2
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.12
tensorflow==2.12.0
tensorflow==2.13.0
tiktoken==0.4.0
transformers==4.30.1
typesense==0.15.1

View File

@ -0,0 +1,88 @@
# This file is purely for development use, not included in any builds
# Use this to test the chat feature with and without context.
# With context refers to being able to call out to Danswer and other tools (currently no other tools)
# Without context refers to only knowing the chat's own history with no additional information
# This script does not allow for branching logic that is supported by the backend APIs
# This script also does not allow for editing/regeneration of user/model messages
# Have Danswer API server running to use this.
import argparse
import json
import requests
from danswer.configs.app_configs import APP_PORT
LOCAL_CHAT_ENDPOINT = f"http://127.0.0.1:{APP_PORT}/chat/"
def create_new_session() -> int:
response = requests.post(LOCAL_CHAT_ENDPOINT + "create-chat-session")
response.raise_for_status()
new_session_id = response.json()["chat_session_id"]
return new_session_id
def send_chat_message(
message: str,
chat_session_id: int,
message_number: int,
parent_edit_number: int | None,
persona_id: int | None,
) -> None:
data = {
"message": message,
"chat_session_id": chat_session_id,
"message_number": message_number,
"parent_edit_number": parent_edit_number,
"persona_id": persona_id,
}
with requests.post(
LOCAL_CHAT_ENDPOINT + "send-message", json=data, stream=True
) as r:
for json_response in r.iter_lines():
response_text = json.loads(json_response.decode())
new_token = response_text.get("answer_piece")
print(new_token, end="", flush=True)
print()
def run_chat(contextual: bool) -> None:
try:
new_session_id = create_new_session()
except requests.exceptions.ConnectionError:
print(
"Looks like you haven't started the Danswer Backend server, please run the FastAPI server"
)
exit()
persona_id = 1 if contextual else None
message_num = 0
parent_edit = None
while True:
new_message = input(
"\n\n----------------------------------\n"
"Please provide a new chat message:\n> "
)
send_chat_message(
new_message, new_session_id, message_num, parent_edit, persona_id
)
message_num += 2 # 1 for User message, 1 for AI response
parent_edit = 0 # Since no edits, the parent message is always the first edit of that message number
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--contextual",
action="store_true",
help="If this flag is set, the chat is able to call tools.",
)
args = parser.parse_args()
contextual = args.contextual
run_chat(contextual)