mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-26 16:01:09 +02:00
Chat with Context Backend (#441)
This commit is contained in:
parent
a16ce56f6b
commit
e549d2bb4a
@ -0,0 +1,40 @@
|
|||||||
|
"""Chat Context Addition
|
||||||
|
|
||||||
|
Revision ID: 8e26726b7683
|
||||||
|
Revises: 5809c0787398
|
||||||
|
Create Date: 2023-09-13 18:34:31.327944
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "8e26726b7683"
|
||||||
|
down_revision = "5809c0787398"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"persona",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("name", sa.String(), nullable=False),
|
||||||
|
sa.Column("system_text", sa.Text(), nullable=True),
|
||||||
|
sa.Column("tools_text", sa.Text(), nullable=True),
|
||||||
|
sa.Column("hint_text", sa.Text(), nullable=True),
|
||||||
|
sa.Column("default_persona", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("deleted", sa.Boolean(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.add_column("chat_message", sa.Column("persona_id", sa.Integer(), nullable=True))
|
||||||
|
op.create_foreign_key(
|
||||||
|
"fk_chat_message_persona_id", "chat_message", "persona", ["persona_id"], ["id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_constraint("fk_chat_message_persona_id", "chat_message", type_="foreignkey")
|
||||||
|
op.drop_column("chat_message", "persona_id")
|
||||||
|
op.drop_table("persona")
|
@ -1,27 +1,155 @@
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.schema.messages import AIMessage
|
from langchain.schema.messages import AIMessage
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.schema.messages import HumanMessage
|
from langchain.schema.messages import HumanMessage
|
||||||
from langchain.schema.messages import SystemMessage
|
from langchain.schema.messages import SystemMessage
|
||||||
|
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.chat.chat_prompts import form_tool_followup_text
|
||||||
|
from danswer.chat.chat_prompts import form_user_prompt_text
|
||||||
|
from danswer.chat.tools import call_tool
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
|
from danswer.db.models import Persona
|
||||||
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
|
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||||
from danswer.llm.build import get_default_llm
|
from danswer.llm.build import get_default_llm
|
||||||
|
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||||
|
from danswer.utils.text_processing import extract_embedded_json
|
||||||
|
from danswer.utils.text_processing import has_unescaped_quote
|
||||||
|
|
||||||
|
|
||||||
def llm_chat_answer(previous_messages: list[ChatMessage]) -> Iterator[str]:
|
def _parse_embedded_json_streamed_response(
|
||||||
prompt: list[BaseMessage] = []
|
tokens: Iterator[str],
|
||||||
for msg in previous_messages:
|
) -> Iterator[DanswerAnswerPiece | DanswerChatModelOut]:
|
||||||
content = msg.message
|
final_answer = False
|
||||||
if msg.message_type == MessageType.SYSTEM:
|
just_start_stream = False
|
||||||
prompt.append(SystemMessage(content=content))
|
model_output = ""
|
||||||
if msg.message_type == MessageType.ASSISTANT:
|
hold = ""
|
||||||
prompt.append(AIMessage(content=content))
|
finding_end = 0
|
||||||
|
for token in tokens:
|
||||||
|
model_output += token
|
||||||
|
hold += token
|
||||||
|
|
||||||
if (
|
if (
|
||||||
msg.message_type == MessageType.USER
|
final_answer is False
|
||||||
or msg.message_type == MessageType.DANSWER # consider using FunctionMessage
|
and '"action":"finalanswer",' in model_output.lower().replace(" ", "")
|
||||||
):
|
):
|
||||||
prompt.append(HumanMessage(content=content))
|
final_answer = True
|
||||||
|
|
||||||
|
if final_answer and '"actioninput":"' in model_output.lower().replace(
|
||||||
|
" ", ""
|
||||||
|
).replace("_", ""):
|
||||||
|
if not just_start_stream:
|
||||||
|
just_start_stream = True
|
||||||
|
hold = ""
|
||||||
|
|
||||||
|
if has_unescaped_quote(hold):
|
||||||
|
finding_end += 1
|
||||||
|
hold = hold[: hold.find('"')]
|
||||||
|
|
||||||
|
if finding_end <= 1:
|
||||||
|
if finding_end == 1:
|
||||||
|
finding_end += 1
|
||||||
|
|
||||||
|
yield DanswerAnswerPiece(answer_piece=hold)
|
||||||
|
hold = ""
|
||||||
|
|
||||||
|
model_final = extract_embedded_json(model_output)
|
||||||
|
if "action" not in model_final or "action_input" not in model_final:
|
||||||
|
raise ValueError("Model did not provide all required action values")
|
||||||
|
|
||||||
|
yield DanswerChatModelOut(
|
||||||
|
model_raw=model_output,
|
||||||
|
action=model_final["action"],
|
||||||
|
action_input=model_final["action_input"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]:
|
||||||
|
prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||||
|
|
||||||
return get_default_llm().stream(prompt)
|
return get_default_llm().stream(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def llm_contextual_chat_answer(
|
||||||
|
messages: list[ChatMessage],
|
||||||
|
persona: Persona,
|
||||||
|
user_id: UUID | None,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
system_text = persona.system_text
|
||||||
|
tool_text = persona.tools_text
|
||||||
|
hint_text = persona.hint_text
|
||||||
|
|
||||||
|
last_message = messages[-1]
|
||||||
|
previous_messages = messages[:-1]
|
||||||
|
|
||||||
|
user_text = form_user_prompt_text(
|
||||||
|
query=last_message.message,
|
||||||
|
tool_text=tool_text,
|
||||||
|
hint_text=hint_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt: list[BaseMessage] = []
|
||||||
|
|
||||||
|
if system_text:
|
||||||
|
prompt.append(SystemMessage(content=system_text))
|
||||||
|
|
||||||
|
prompt.extend(
|
||||||
|
[translate_danswer_msg_to_langchain(msg) for msg in previous_messages]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt.append(HumanMessage(content=user_text))
|
||||||
|
|
||||||
|
tokens = get_default_llm().stream(prompt)
|
||||||
|
|
||||||
|
final_result: DanswerChatModelOut | None = None
|
||||||
|
final_answer_streamed = False
|
||||||
|
for result in _parse_embedded_json_streamed_response(tokens):
|
||||||
|
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||||
|
yield result.answer_piece
|
||||||
|
final_answer_streamed = True
|
||||||
|
|
||||||
|
if isinstance(result, DanswerChatModelOut):
|
||||||
|
final_result = result
|
||||||
|
break
|
||||||
|
|
||||||
|
if final_answer_streamed:
|
||||||
|
return
|
||||||
|
|
||||||
|
if final_result is None:
|
||||||
|
raise RuntimeError("Model output finished without final output parsing.")
|
||||||
|
|
||||||
|
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||||
|
|
||||||
|
prompt.append(AIMessage(content=final_result.model_raw))
|
||||||
|
prompt.append(
|
||||||
|
HumanMessage(
|
||||||
|
content=form_tool_followup_text(tool_result_str, hint_text=hint_text)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = get_default_llm().stream(prompt)
|
||||||
|
|
||||||
|
for result in _parse_embedded_json_streamed_response(tokens):
|
||||||
|
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||||
|
yield result.answer_piece
|
||||||
|
final_answer_streamed = True
|
||||||
|
|
||||||
|
if final_answer_streamed is False:
|
||||||
|
raise RuntimeError("LLM failed to produce a Final Answer")
|
||||||
|
|
||||||
|
|
||||||
|
def llm_chat_answer(
|
||||||
|
messages: list[ChatMessage], persona: Persona | None, user_id: UUID | None
|
||||||
|
) -> Iterator[str]:
|
||||||
|
# TODO how to handle model giving jibberish or fail on a particular message
|
||||||
|
# TODO how to handle model failing to choose the right tool
|
||||||
|
# TODO how to handle model gives wrong format
|
||||||
|
if persona is None:
|
||||||
|
return llm_contextless_chat_answer(messages)
|
||||||
|
|
||||||
|
return llm_contextual_chat_answer(
|
||||||
|
messages=messages, persona=persona, user_id=user_id
|
||||||
|
)
|
||||||
|
125
backend/danswer/chat/chat_prompts.py
Normal file
125
backend/danswer/chat/chat_prompts.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||||
|
|
||||||
|
TOOL_TEMPLATE = """TOOLS
|
||||||
|
------
|
||||||
|
Assistant can ask the user to use tools to look up information that may be helpful in answering the users \
|
||||||
|
original question. The tools the human can use are:
|
||||||
|
|
||||||
|
{}
|
||||||
|
|
||||||
|
RESPONSE FORMAT INSTRUCTIONS
|
||||||
|
----------------------------
|
||||||
|
|
||||||
|
When responding to me, please output a response in one of two formats:
|
||||||
|
|
||||||
|
**Option 1:**
|
||||||
|
Use this if you want the human to use a tool.
|
||||||
|
Markdown code snippet formatted in the following schema:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"action": string, \\ The action to take. Must be one of {}
|
||||||
|
"action_input": string \\ The input to the action
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Option #2:**
|
||||||
|
Use this if you want to respond directly to the human. Markdown code snippet formatted in the following schema:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": string \\ You should put what you want to return to use here
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
TOOL_LESS_PROMPT = """
|
||||||
|
Respond with a markdown code snippet in the following schema:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": string \\ You should put what you want to return to use here
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
USER_INPUT = """
|
||||||
|
USER'S INPUT
|
||||||
|
--------------------
|
||||||
|
Here is the user's input \
|
||||||
|
(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
|
||||||
|
|
||||||
|
{}
|
||||||
|
"""
|
||||||
|
|
||||||
|
TOOL_FOLLOWUP = """
|
||||||
|
TOOL RESPONSE:
|
||||||
|
---------------------
|
||||||
|
{}
|
||||||
|
|
||||||
|
USER'S INPUT
|
||||||
|
--------------------
|
||||||
|
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
|
||||||
|
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
|
||||||
|
{}
|
||||||
|
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def form_user_prompt_text(
|
||||||
|
query: str,
|
||||||
|
tool_text: str | None,
|
||||||
|
hint_text: str | None,
|
||||||
|
user_input_prompt: str = USER_INPUT,
|
||||||
|
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
user_prompt = tool_text or tool_less_prompt
|
||||||
|
|
||||||
|
user_prompt += user_input_prompt.format(query)
|
||||||
|
|
||||||
|
if hint_text:
|
||||||
|
if user_prompt[-1] != "\n":
|
||||||
|
user_prompt += "\n"
|
||||||
|
user_prompt += "Hint: " + hint_text
|
||||||
|
|
||||||
|
return user_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def form_tool_section_text(
|
||||||
|
tools: list[dict[str, str]], template: str = TOOL_TEMPLATE
|
||||||
|
) -> str | None:
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tools_intro = []
|
||||||
|
for tool in tools:
|
||||||
|
description_formatted = tool["description"].replace("\n", " ")
|
||||||
|
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||||
|
|
||||||
|
tools_intro_text = "\n".join(tools_intro)
|
||||||
|
tool_names_text = ", ".join([tool["name"] for tool in tools])
|
||||||
|
|
||||||
|
return template.format(tools_intro_text, tool_names_text)
|
||||||
|
|
||||||
|
|
||||||
|
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
|
||||||
|
return "\n".join(
|
||||||
|
f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}"
|
||||||
|
for ind, chunk in enumerate(chunks, start=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def form_tool_followup_text(
|
||||||
|
tool_output: str,
|
||||||
|
hint_text: str | None,
|
||||||
|
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||||
|
ignore_hint: bool = False,
|
||||||
|
) -> str:
|
||||||
|
if not ignore_hint and hint_text:
|
||||||
|
hint_text_spaced = f"\n{hint_text}\n"
|
||||||
|
return tool_followup_prompt.format(tool_output, hint_text_spaced)
|
||||||
|
|
||||||
|
return tool_followup_prompt.format(tool_output, "")
|
26
backend/danswer/chat/personas.py
Normal file
26
backend/danswer/chat/personas.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import yaml
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.chat.chat_prompts import form_tool_section_text
|
||||||
|
from danswer.configs.app_configs import PERSONAS_YAML
|
||||||
|
from danswer.db.chat import create_persona
|
||||||
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
|
||||||
|
|
||||||
|
def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
|
||||||
|
with open(personas_yaml, "r") as file:
|
||||||
|
data = yaml.safe_load(file)
|
||||||
|
|
||||||
|
all_personas = data.get("personas", [])
|
||||||
|
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||||
|
for persona in all_personas:
|
||||||
|
tools = form_tool_section_text(persona["tools"])
|
||||||
|
create_persona(
|
||||||
|
persona_id=persona["id"],
|
||||||
|
name=persona["name"],
|
||||||
|
system_text=persona["system"],
|
||||||
|
tools_text=tools,
|
||||||
|
hint_text=persona["hint"],
|
||||||
|
default_persona=True,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
14
backend/danswer/chat/personas.yaml
Normal file
14
backend/danswer/chat/personas.yaml
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
personas:
|
||||||
|
- id: 1
|
||||||
|
name: "Danswer"
|
||||||
|
system: |
|
||||||
|
You are a question answering system that is constantly learning and improving.
|
||||||
|
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
|
||||||
|
You have access to a search tool and should always use it anytime it might be useful.
|
||||||
|
Your responses are as INFORMATIVE and DETAILED as possible.
|
||||||
|
tools:
|
||||||
|
- name: "Current Search"
|
||||||
|
description: |
|
||||||
|
A search for up to date and proprietary information.
|
||||||
|
This tool can handle natural language questions so it is better to ask very precise questions over keywords.
|
||||||
|
hint: "Use the Current Search tool to help find information on any topic!"
|
43
backend/danswer/chat/tools.py
Normal file
43
backend/danswer/chat/tools.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||||
|
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||||
|
from danswer.configs.constants import IGNORE_FOR_QA
|
||||||
|
from danswer.datastores.document_index import get_default_document_index
|
||||||
|
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||||
|
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||||
|
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||||
|
|
||||||
|
|
||||||
|
def call_tool(
|
||||||
|
model_actions: DanswerChatModelOut,
|
||||||
|
user_id: UUID | None,
|
||||||
|
) -> str:
|
||||||
|
if model_actions.action.lower() == "current search":
|
||||||
|
query = model_actions.action_input
|
||||||
|
|
||||||
|
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||||
|
query,
|
||||||
|
user_id=user_id,
|
||||||
|
filters=None,
|
||||||
|
datastore=get_default_document_index(),
|
||||||
|
)
|
||||||
|
if not ranked_chunks:
|
||||||
|
return "No results found"
|
||||||
|
|
||||||
|
if unranked_chunks:
|
||||||
|
ranked_chunks.extend(unranked_chunks)
|
||||||
|
|
||||||
|
filtered_ranked_chunks = [
|
||||||
|
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||||
|
]
|
||||||
|
|
||||||
|
# get all chunks that fit into the token limit
|
||||||
|
usable_chunks = get_usable_chunks(
|
||||||
|
chunks=filtered_ranked_chunks,
|
||||||
|
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
|
||||||
|
)
|
||||||
|
|
||||||
|
return format_danswer_chunks_for_chat(usable_chunks)
|
||||||
|
|
||||||
|
raise ValueError("Invalid tool choice by LLM")
|
@ -144,9 +144,12 @@ NUM_RERANKED_RESULTS = 15
|
|||||||
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
||||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
|
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
|
||||||
)
|
)
|
||||||
|
NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
|
||||||
|
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (512 * 3)
|
||||||
|
)
|
||||||
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds
|
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||||
INCLUDE_METADATA = False
|
INCLUDE_METADATA = False
|
||||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
||||||
@ -182,6 +185,7 @@ CROSS_ENCODER_PORT = 9000
|
|||||||
#####
|
#####
|
||||||
# Miscellaneous
|
# Miscellaneous
|
||||||
#####
|
#####
|
||||||
|
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||||
DYNAMIC_CONFIG_STORE = os.environ.get(
|
DYNAMIC_CONFIG_STORE = os.environ.get(
|
||||||
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
|
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
|
||||||
)
|
)
|
||||||
|
@ -25,6 +25,20 @@ BOOST = "boost"
|
|||||||
SCORE = "score"
|
SCORE = "score"
|
||||||
DEFAULT_BOOST = 0
|
DEFAULT_BOOST = 0
|
||||||
|
|
||||||
|
# Prompt building constants:
|
||||||
|
GENERAL_SEP_PAT = "\n-----\n"
|
||||||
|
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
|
||||||
|
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
||||||
|
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
||||||
|
QUESTION_PAT = "Query:"
|
||||||
|
THOUGHT_PAT = "Thought:"
|
||||||
|
ANSWER_PAT = "Answer:"
|
||||||
|
FINAL_ANSWER_PAT = "Final Answer:"
|
||||||
|
UNCERTAINTY_PAT = "?"
|
||||||
|
QUOTE_PAT = "Quote:"
|
||||||
|
QUOTES_PAT_PLURAL = "Quotes:"
|
||||||
|
INVALID_PAT = "Invalid:"
|
||||||
|
|
||||||
|
|
||||||
class DocumentSource(str, Enum):
|
class DocumentSource(str, Enum):
|
||||||
SLACK = "slack"
|
SLACK = "slack"
|
||||||
|
@ -12,6 +12,7 @@ from danswer.configs.app_configs import HARD_DELETE_CHATS
|
|||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
from danswer.db.models import ChatSession
|
from danswer.db.models import ChatSession
|
||||||
|
from danswer.db.models import Persona
|
||||||
|
|
||||||
|
|
||||||
def fetch_chat_sessions_by_user(
|
def fetch_chat_sessions_by_user(
|
||||||
@ -245,3 +246,47 @@ def set_latest_chat_message(
|
|||||||
)
|
)
|
||||||
|
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
||||||
|
stmt = select(Persona).where(Persona.id == persona_id)
|
||||||
|
result = db_session.execute(stmt)
|
||||||
|
persona = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if persona is None:
|
||||||
|
raise ValueError(f"Persona with ID {persona_id} does not exist")
|
||||||
|
|
||||||
|
return persona
|
||||||
|
|
||||||
|
|
||||||
|
def create_persona(
|
||||||
|
persona_id: int | None,
|
||||||
|
name: str,
|
||||||
|
system_text: str | None,
|
||||||
|
tools_text: str | None,
|
||||||
|
hint_text: str | None,
|
||||||
|
default_persona: bool,
|
||||||
|
db_session: Session,
|
||||||
|
) -> Persona:
|
||||||
|
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||||
|
|
||||||
|
if persona:
|
||||||
|
persona.name = name
|
||||||
|
persona.system_text = system_text
|
||||||
|
persona.tools_text = tools_text
|
||||||
|
persona.hint_text = hint_text
|
||||||
|
persona.default_persona = default_persona
|
||||||
|
else:
|
||||||
|
persona = Persona(
|
||||||
|
id=persona_id,
|
||||||
|
name=name,
|
||||||
|
system_text=system_text,
|
||||||
|
tools_text=tools_text,
|
||||||
|
hint_text=hint_text,
|
||||||
|
default_persona=default_persona,
|
||||||
|
)
|
||||||
|
db_session.add(persona)
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
return persona
|
||||||
|
@ -333,6 +333,7 @@ class ChatSession(Base):
|
|||||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||||
description: Mapped[str] = mapped_column(Text)
|
description: Mapped[str] = mapped_column(Text)
|
||||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
# The following texts help build up the model's ability to use the context effectively
|
||||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||||
DateTime(timezone=True),
|
DateTime(timezone=True),
|
||||||
server_default=func.now(),
|
server_default=func.now(),
|
||||||
@ -348,6 +349,19 @@ class ChatSession(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Persona(Base):
|
||||||
|
# TODO introduce user and group ownership for personas
|
||||||
|
__tablename__ = "persona"
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String)
|
||||||
|
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
tools_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
# If it's updated and no longer latest (should no longer be shown), it is also considered deleted
|
||||||
|
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(Base):
|
class ChatMessage(Base):
|
||||||
__tablename__ = "chat_message"
|
__tablename__ = "chat_message"
|
||||||
|
|
||||||
@ -362,8 +376,12 @@ class ChatMessage(Base):
|
|||||||
latest: Mapped[bool] = mapped_column(Boolean, default=True)
|
latest: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
message: Mapped[str] = mapped_column(Text)
|
message: Mapped[str] = mapped_column(Text)
|
||||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
||||||
|
persona_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("persona.id"), nullable=True
|
||||||
|
)
|
||||||
time_sent: Mapped[datetime.datetime] = mapped_column(
|
time_sent: Mapped[datetime.datetime] = mapped_column(
|
||||||
DateTime(timezone=True), server_default=func.now()
|
DateTime(timezone=True), server_default=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||||
|
persona: Mapped[Persona | None] = relationship("Persona")
|
||||||
|
@ -10,6 +10,13 @@ class DanswerAnswer:
|
|||||||
answer: str | None
|
answer: str | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DanswerChatModelOut:
|
||||||
|
model_raw: str
|
||||||
|
action: str
|
||||||
|
action_input: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DanswerAnswerPiece:
|
class DanswerAnswerPiece:
|
||||||
"""A small piece of a complete answer. Used for streaming back answers."""
|
"""A small piece of a complete answer. Used for streaming back answers."""
|
||||||
|
@ -10,19 +10,18 @@ from langchain.schema.messages import HumanMessage
|
|||||||
from langchain.schema.messages import SystemMessage
|
from langchain.schema.messages import SystemMessage
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||||
|
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||||
|
from danswer.configs.constants import QUESTION_PAT
|
||||||
|
from danswer.configs.constants import THOUGHT_PAT
|
||||||
|
from danswer.configs.constants import UNCERTAINTY_PAT
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON
|
from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON
|
||||||
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||||
from danswer.direct_qa.qa_prompts import QUESTION_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE
|
|
||||||
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||||
from danswer.direct_qa.qa_utils import process_answer
|
from danswer.direct_qa.qa_utils import process_answer
|
||||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||||
@ -193,7 +192,7 @@ class JsonChatQAUnshackledHandler(QAHandler):
|
|||||||
"should be in JSON format and contain an answer and (optionally) quotes that help support the answer. "
|
"should be in JSON format and contain an answer and (optionally) quotes that help support the answer. "
|
||||||
"Your responses should be informative, detailed, and consider all possibilities and edge cases. "
|
"Your responses should be informative, detailed, and consider all possibilities and edge cases. "
|
||||||
f"If you don't know the answer, respond with '{complete_answer_not_found_response}'\n"
|
f"If you don't know the answer, respond with '{complete_answer_not_found_response}'\n"
|
||||||
f"Sample response:\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
f"Sample response:\n\n{json.dumps(EMPTY_SAMPLE_JSON)}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -2,23 +2,17 @@ import abc
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.constants import ANSWER_PAT
|
||||||
|
from danswer.configs.constants import DOC_CONTENT_START_PAT
|
||||||
|
from danswer.configs.constants import DOC_SEP_PAT
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||||
|
from danswer.configs.constants import QUESTION_PAT
|
||||||
|
from danswer.configs.constants import QUOTE_PAT
|
||||||
|
from danswer.configs.constants import UNCERTAINTY_PAT
|
||||||
from danswer.connectors.factory import identify_connector_class
|
from danswer.connectors.factory import identify_connector_class
|
||||||
|
|
||||||
|
|
||||||
GENERAL_SEP_PAT = "\n-----\n"
|
|
||||||
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
|
|
||||||
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
|
||||||
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
|
||||||
QUESTION_PAT = "Query:"
|
|
||||||
THOUGHT_PAT = "Thought:"
|
|
||||||
ANSWER_PAT = "Answer:"
|
|
||||||
FINAL_ANSWER_PAT = "Final Answer:"
|
|
||||||
UNCERTAINTY_PAT = "?"
|
|
||||||
QUOTE_PAT = "Quote:"
|
|
||||||
QUOTES_PAT_PLURAL = "Quotes:"
|
|
||||||
INVALID_PAT = "Invalid:"
|
|
||||||
|
|
||||||
BASE_PROMPT = (
|
BASE_PROMPT = (
|
||||||
"Answer the query based on provided documents and quote relevant sections. "
|
"Answer the query based on provided documents and quote relevant sections. "
|
||||||
"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
||||||
@ -26,16 +20,6 @@ BASE_PROMPT = (
|
|||||||
"The quotes must be EXACT substrings from the documents."
|
"The quotes must be EXACT substrings from the documents."
|
||||||
)
|
)
|
||||||
|
|
||||||
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
|
|
||||||
|
|
||||||
SAMPLE_JSON_RESPONSE = {
|
|
||||||
"answer": "The Eiffel Tower is located in Paris, France.",
|
|
||||||
"quotes": [
|
|
||||||
"The Eiffel Tower is an iconic symbol of Paris",
|
|
||||||
"located on the Champ de Mars in France.",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
EMPTY_SAMPLE_JSON = {
|
EMPTY_SAMPLE_JSON = {
|
||||||
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
|
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
|
||||||
"quotes": [
|
"quotes": [
|
||||||
@ -44,16 +28,6 @@ EMPTY_SAMPLE_JSON = {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
ANSWER_NOT_FOUND_JSON = '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
|
||||||
|
|
||||||
SAMPLE_RESPONSE_COT = (
|
|
||||||
"Let's think step by step. The user is asking for the "
|
|
||||||
"location of the Eiffel Tower. The first document describes the Eiffel Tower "
|
|
||||||
"as being an iconic symbol of Paris and that it is located on the Champ de Mars. "
|
|
||||||
"Since the Champ de Mars is in Paris, we know that the Eiffel Tower is in Paris."
|
|
||||||
f"\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _append_acknowledge_doc_messages(
|
def _append_acknowledge_doc_messages(
|
||||||
current_messages: list[dict[str, str]], new_chunk_content: str
|
current_messages: list[dict[str, str]], new_chunk_content: str
|
||||||
@ -152,7 +126,7 @@ class JsonProcessor(NonChatPromptProcessor):
|
|||||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
prompt = (
|
prompt = (
|
||||||
BASE_PROMPT + f" Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
BASE_PROMPT + f" Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}\n\n"
|
||||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -203,7 +177,7 @@ class JsonChatProcessor(ChatPromptProcessor):
|
|||||||
f"{complete_answer_not_found_response}\n"
|
f"{complete_answer_not_found_response}\n"
|
||||||
"If the query requires aggregating the number of documents, respond with "
|
"If the query requires aggregating the number of documents, respond with "
|
||||||
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
||||||
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
f"Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}"
|
||||||
)
|
)
|
||||||
messages = [{"role": "system", "content": intro_msg}]
|
messages = [{"role": "system", "content": intro_msg}]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
@ -221,65 +195,6 @@ class JsonChatProcessor(ChatPromptProcessor):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
class JsonCoTChatProcessor(ChatPromptProcessor):
|
|
||||||
"""Pros: improves performance slightly over the regular JsonChatProcessor.
|
|
||||||
Cons: Much slower.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def specifies_json_output(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def fill_prompt(
|
|
||||||
question: str,
|
|
||||||
chunks: list[InferenceChunk],
|
|
||||||
include_metadata: bool = True,
|
|
||||||
) -> list[dict[str, str]]:
|
|
||||||
metadata_prompt_section = (
|
|
||||||
"with metadata and contents " if include_metadata else ""
|
|
||||||
)
|
|
||||||
intro_msg = (
|
|
||||||
f"You are a Question Answering assistant that answers queries "
|
|
||||||
f"based on the provided documents.\n"
|
|
||||||
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
|
|
||||||
)
|
|
||||||
|
|
||||||
complete_answer_not_found_response = (
|
|
||||||
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
|
||||||
)
|
|
||||||
task_msg = (
|
|
||||||
"Now answer the user query based on documents above and quote relevant sections.\n"
|
|
||||||
"When answering, you should think step by step, and verbalize your thought process.\n"
|
|
||||||
"Then respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
|
||||||
"All quotes MUST be EXACT substrings from provided documents.\n"
|
|
||||||
"Your responses should be informative, detailed, and consider all possibilities and edge cases.\n"
|
|
||||||
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
|
||||||
"If the query cannot be answered based on the documents, respond with "
|
|
||||||
f"{complete_answer_not_found_response}\n"
|
|
||||||
"If the query requires aggregating the number of documents, respond with "
|
|
||||||
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
|
||||||
f"Sample response:\n\n{SAMPLE_RESPONSE_COT}"
|
|
||||||
)
|
|
||||||
messages = [{"role": "system", "content": intro_msg}]
|
|
||||||
|
|
||||||
for chunk in chunks:
|
|
||||||
full_context = ""
|
|
||||||
if include_metadata:
|
|
||||||
full_context = _add_metadata_section(
|
|
||||||
full_context, chunk, prepend_tab=False, include_sep=False
|
|
||||||
)
|
|
||||||
full_context += chunk.content
|
|
||||||
messages = _append_acknowledge_doc_messages(messages, full_context)
|
|
||||||
messages.append({"role": "user", "content": task_msg})
|
|
||||||
|
|
||||||
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n\n"})
|
|
||||||
|
|
||||||
messages.append({"role": "user", "content": "Let's think step by step."})
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
class WeakModelFreeformProcessor(NonChatPromptProcessor):
|
class WeakModelFreeformProcessor(NonChatPromptProcessor):
|
||||||
"""Avoid using this one if the model is capable of using another prompt
|
"""Avoid using this one if the model is capable of using another prompt
|
||||||
Intended for models that can't follow complex instructions or have short context windows
|
Intended for models that can't follow complex instructions or have short context windows
|
||||||
@ -366,117 +281,3 @@ class FreeformProcessor(NonChatPromptProcessor):
|
|||||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||||
prompt += f"{ANSWER_PAT}\n"
|
prompt += f"{ANSWER_PAT}\n"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
class FreeformChatProcessor(ChatPromptProcessor):
|
|
||||||
@property
|
|
||||||
def specifies_json_output(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def fill_prompt(
|
|
||||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
|
||||||
) -> list[dict[str, str]]:
|
|
||||||
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
|
|
||||||
role_msg = (
|
|
||||||
f"You are a Question Answering assistant that answers queries based on provided documents. "
|
|
||||||
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
|
|
||||||
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
|
|
||||||
f"Answer the question directly and concisely. "
|
|
||||||
f"Each quote should be a single continuous segment from a document. "
|
|
||||||
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
|
||||||
f"An example of quote sections may look like:\n{sample_quote}"
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": role_msg},
|
|
||||||
]
|
|
||||||
for chunk in chunks:
|
|
||||||
messages = _append_acknowledge_doc_messages(messages, chunk.content)
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"Please now answer the following query based on the previously provided "
|
|
||||||
f"documents and quote the relevant sections of the documents\n{question}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
class JsonCOTProcessor(NonChatPromptProcessor):
|
|
||||||
"""Chain of Thought allows model to explain out its reasoning to handle harder tests.
|
|
||||||
This prompt type works however has higher token cost (more expensive) and is slower.
|
|
||||||
Consider this one if users ask questions that require logical reasoning."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def specifies_json_output(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def fill_prompt(
|
|
||||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
|
||||||
) -> str:
|
|
||||||
prompt = (
|
|
||||||
f"Answer the query based on provided documents and quote relevant sections. "
|
|
||||||
f'Respond with a freeform reasoning section followed by "Final Answer:" with a '
|
|
||||||
f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n"
|
|
||||||
f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
|
||||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
for chunk in chunks:
|
|
||||||
prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}"
|
|
||||||
|
|
||||||
prompt += "\n\n---\n\n"
|
|
||||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
|
||||||
prompt += "Reasoning:\n"
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
class JsonReflexionProcessor(NonChatPromptProcessor):
|
|
||||||
"""Reflexion prompting to attempt to have model evaluate its own answer.
|
|
||||||
This one seems largely useless when only given a single example
|
|
||||||
Model seems to take the one example of answering "Yes" and just does that too."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def specifies_json_output(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def fill_prompt(
|
|
||||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
|
||||||
) -> str:
|
|
||||||
reflexion_str = "Does this fully answer the user query?"
|
|
||||||
prompt = (
|
|
||||||
BASE_PROMPT
|
|
||||||
+ f'After each generated json, ask "{reflexion_str}" and respond Yes or No. '
|
|
||||||
f"If No, generate a better json response to the query.\n"
|
|
||||||
f"Sample question and response:\n"
|
|
||||||
f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n"
|
|
||||||
f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n"
|
|
||||||
f"{reflexion_str} Yes\n\n"
|
|
||||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
for chunk in chunks:
|
|
||||||
prompt += f"\n---NEW CONTEXT DOCUMENT---\n{chunk.content}"
|
|
||||||
|
|
||||||
prompt += "\n\n---\n\n"
|
|
||||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def get_json_chat_reflexion_msg() -> dict[str, str]:
|
|
||||||
"""With the models tried (curent as of Jul 2023), this has not been very useful.
|
|
||||||
Have not seen any answers improved based on this.
|
|
||||||
For models like gpt-3.5-turbo, it will often answer something like:
|
|
||||||
'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source
|
|
||||||
document. No revision is needed.'"""
|
|
||||||
reflexion_content = (
|
|
||||||
"Is the assistant response a valid json that fully answer the user query? "
|
|
||||||
"If the response needs to be fixed or if an improvement is possible, provide a revised json. "
|
|
||||||
"Otherwise, respond with the same json."
|
|
||||||
)
|
|
||||||
return {"role": "system", "content": reflexion_content}
|
|
||||||
|
@ -13,6 +13,25 @@ from langchain.schema.messages import HumanMessage
|
|||||||
from langchain.schema.messages import SystemMessage
|
from langchain.schema.messages import SystemMessage
|
||||||
|
|
||||||
from danswer.configs.app_configs import LOG_LEVEL
|
from danswer.configs.app_configs import LOG_LEVEL
|
||||||
|
from danswer.configs.constants import MessageType
|
||||||
|
from danswer.db.models import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
|
def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage:
|
||||||
|
if (
|
||||||
|
msg.message_type == MessageType.SYSTEM
|
||||||
|
or msg.message_type == MessageType.DANSWER
|
||||||
|
):
|
||||||
|
# TODO save at least the Danswer responses to postgres
|
||||||
|
raise ValueError(
|
||||||
|
"System and Danswer messages are not currently part of history"
|
||||||
|
)
|
||||||
|
if msg.message_type == MessageType.ASSISTANT:
|
||||||
|
return AIMessage(content=msg.message)
|
||||||
|
if msg.message_type == MessageType.USER:
|
||||||
|
return HumanMessage(content=msg.message)
|
||||||
|
|
||||||
|
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||||
|
|
||||||
|
|
||||||
def dict_based_prompt_to_langchain_prompt(
|
def dict_based_prompt_to_langchain_prompt(
|
||||||
|
@ -12,6 +12,7 @@ from danswer.auth.schemas import UserUpdate
|
|||||||
from danswer.auth.users import auth_backend
|
from danswer.auth.users import auth_backend
|
||||||
from danswer.auth.users import fastapi_users
|
from danswer.auth.users import fastapi_users
|
||||||
from danswer.auth.users import oauth_client
|
from danswer.auth.users import oauth_client
|
||||||
|
from danswer.chat.personas import load_personas_from_yaml
|
||||||
from danswer.configs.app_configs import APP_HOST
|
from danswer.configs.app_configs import APP_HOST
|
||||||
from danswer.configs.app_configs import APP_PORT
|
from danswer.configs.app_configs import APP_PORT
|
||||||
from danswer.configs.app_configs import DISABLE_AUTH
|
from danswer.configs.app_configs import DISABLE_AUTH
|
||||||
@ -191,6 +192,9 @@ def get_application() -> FastAPI:
|
|||||||
logger.info("Verifying public credential exists.")
|
logger.info("Verifying public credential exists.")
|
||||||
create_initial_public_credential()
|
create_initial_public_credential()
|
||||||
|
|
||||||
|
logger.info("Loading default Chat Personas")
|
||||||
|
load_personas_from_yaml()
|
||||||
|
|
||||||
logger.info("Verifying Document Index(s) is/are available.")
|
logger.info("Verifying Document Index(s) is/are available.")
|
||||||
get_default_document_index().ensure_indices_exist()
|
get_default_document_index().ensure_indices_exist()
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
|
from danswer.configs.constants import ANSWER_PAT
|
||||||
|
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||||
|
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||||
|
from danswer.configs.constants import INVALID_PAT
|
||||||
|
from danswer.configs.constants import QUESTION_PAT
|
||||||
|
from danswer.configs.constants import THOUGHT_PAT
|
||||||
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import INVALID_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import QUESTION_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
|
|
||||||
from danswer.llm.build import get_default_llm
|
from danswer.llm.build import get_default_llm
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
|
@ -2,10 +2,10 @@ import re
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||||
|
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
|
||||||
from danswer.llm.build import get_default_llm
|
from danswer.llm.build import get_default_llm
|
||||||
from danswer.server.models import QueryValidationResponse
|
from danswer.server.models import QueryValidationResponse
|
||||||
from danswer.server.utils import get_json_line
|
from danswer.server.utils import get_json_line
|
||||||
|
@ -16,6 +16,7 @@ from danswer.db.chat import fetch_chat_message
|
|||||||
from danswer.db.chat import fetch_chat_messages_by_session
|
from danswer.db.chat import fetch_chat_messages_by_session
|
||||||
from danswer.db.chat import fetch_chat_session_by_id
|
from danswer.db.chat import fetch_chat_session_by_id
|
||||||
from danswer.db.chat import fetch_chat_sessions_by_user
|
from danswer.db.chat import fetch_chat_sessions_by_user
|
||||||
|
from danswer.db.chat import fetch_persona_by_id
|
||||||
from danswer.db.chat import set_latest_chat_message
|
from danswer.db.chat import set_latest_chat_message
|
||||||
from danswer.db.chat import update_chat_session
|
from danswer.db.chat import update_chat_session
|
||||||
from danswer.db.chat import verify_parent_exists
|
from danswer.db.chat import verify_parent_exists
|
||||||
@ -29,8 +30,9 @@ from danswer.server.models import ChatMessageIdentifier
|
|||||||
from danswer.server.models import ChatRenameRequest
|
from danswer.server.models import ChatRenameRequest
|
||||||
from danswer.server.models import ChatSessionDetailResponse
|
from danswer.server.models import ChatSessionDetailResponse
|
||||||
from danswer.server.models import ChatSessionIdsResponse
|
from danswer.server.models import ChatSessionIdsResponse
|
||||||
from danswer.server.models import CreateChatID
|
from danswer.server.models import CreateChatMessageRequest
|
||||||
from danswer.server.models import CreateChatRequest
|
from danswer.server.models import CreateChatSessionID
|
||||||
|
from danswer.server.models import RegenerateMessageRequest
|
||||||
from danswer.server.models import RenameChatSessionResponse
|
from danswer.server.models import RenameChatSessionResponse
|
||||||
from danswer.server.utils import get_json_line
|
from danswer.server.utils import get_json_line
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@ -108,14 +110,16 @@ def get_chat_session_messages(
|
|||||||
def create_new_chat_session(
|
def create_new_chat_session(
|
||||||
user: User | None = Depends(current_user),
|
user: User | None = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> CreateChatID:
|
) -> CreateChatSessionID:
|
||||||
user_id = user.id if user is not None else None
|
user_id = user.id if user is not None else None
|
||||||
|
|
||||||
new_chat_session = create_chat_session(
|
new_chat_session = create_chat_session(
|
||||||
"", user_id, db_session # Leave the naming till later to prevent delay
|
"",
|
||||||
|
user_id,
|
||||||
|
db_session, # Leave the naming till later to prevent delay
|
||||||
)
|
)
|
||||||
|
|
||||||
return CreateChatID(chat_session_id=new_chat_session.id)
|
return CreateChatSessionID(chat_session_id=new_chat_session.id)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/rename-chat-session")
|
@router.put("/rename-chat-session")
|
||||||
@ -182,9 +186,19 @@ def _create_chat_chain(
|
|||||||
return mainline_messages
|
return mainline_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _return_one_if_any(str_1: str | None, str_2: str | None) -> str | None:
|
||||||
|
if str_1 is not None and str_2 is not None:
|
||||||
|
raise ValueError("Conflicting values, can only set one")
|
||||||
|
if str_1 is not None:
|
||||||
|
return str_1
|
||||||
|
if str_2 is not None:
|
||||||
|
return str_2
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/send-message")
|
@router.post("/send-message")
|
||||||
def handle_new_chat_message(
|
def handle_new_chat_message(
|
||||||
chat_message: CreateChatRequest,
|
chat_message: CreateChatMessageRequest,
|
||||||
user: User | None = Depends(current_user),
|
user: User | None = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
@ -198,6 +212,11 @@ def handle_new_chat_message(
|
|||||||
user_id = user.id if user is not None else None
|
user_id = user.id if user is not None else None
|
||||||
|
|
||||||
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
||||||
|
persona = (
|
||||||
|
fetch_persona_by_id(chat_message.persona_id, db_session)
|
||||||
|
if chat_message.persona_id is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if chat_session.deleted:
|
if chat_session.deleted:
|
||||||
raise ValueError("Cannot send messages to a deleted chat session")
|
raise ValueError("Cannot send messages to a deleted chat session")
|
||||||
@ -248,7 +267,9 @@ def handle_new_chat_message(
|
|||||||
|
|
||||||
@log_generator_function_time()
|
@log_generator_function_time()
|
||||||
def stream_chat_tokens() -> Iterator[str]:
|
def stream_chat_tokens() -> Iterator[str]:
|
||||||
tokens = llm_chat_answer(mainline_messages)
|
tokens = llm_chat_answer(
|
||||||
|
messages=mainline_messages, persona=persona, user_id=user_id
|
||||||
|
)
|
||||||
llm_output = ""
|
llm_output = ""
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
llm_output += token
|
llm_output += token
|
||||||
@ -268,7 +289,7 @@ def handle_new_chat_message(
|
|||||||
|
|
||||||
@router.post("/regenerate-from-parent")
|
@router.post("/regenerate-from-parent")
|
||||||
def regenerate_message_given_parent(
|
def regenerate_message_given_parent(
|
||||||
parent_message: ChatMessageIdentifier,
|
parent_message: RegenerateMessageRequest,
|
||||||
user: User | None = Depends(current_user),
|
user: User | None = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
@ -288,6 +309,11 @@ def regenerate_message_given_parent(
|
|||||||
)
|
)
|
||||||
|
|
||||||
chat_session = chat_message.chat_session
|
chat_session = chat_message.chat_session
|
||||||
|
persona = (
|
||||||
|
fetch_persona_by_id(parent_message.persona_id, db_session)
|
||||||
|
if parent_message.persona_id is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if chat_session.deleted:
|
if chat_session.deleted:
|
||||||
raise ValueError("Chat session has been deleted")
|
raise ValueError("Chat session has been deleted")
|
||||||
@ -317,7 +343,9 @@ def regenerate_message_given_parent(
|
|||||||
|
|
||||||
@log_generator_function_time()
|
@log_generator_function_time()
|
||||||
def stream_regenerate_tokens() -> Iterator[str]:
|
def stream_regenerate_tokens() -> Iterator[str]:
|
||||||
tokens = llm_chat_answer(mainline_messages)
|
tokens = llm_chat_answer(
|
||||||
|
messages=mainline_messages, persona=persona, user_id=user_id
|
||||||
|
)
|
||||||
llm_output = ""
|
llm_output = ""
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
llm_output += token
|
llm_output += token
|
||||||
|
@ -134,7 +134,7 @@ class SearchDoc(BaseModel):
|
|||||||
match_highlights: list[str]
|
match_highlights: list[str]
|
||||||
|
|
||||||
|
|
||||||
class CreateChatID(BaseModel):
|
class CreateChatSessionID(BaseModel):
|
||||||
chat_session_id: int
|
chat_session_id: int
|
||||||
|
|
||||||
|
|
||||||
@ -159,11 +159,12 @@ class SearchFeedbackRequest(BaseModel):
|
|||||||
search_feedback: SearchFeedbackType
|
search_feedback: SearchFeedbackType
|
||||||
|
|
||||||
|
|
||||||
class CreateChatRequest(BaseModel):
|
class CreateChatMessageRequest(BaseModel):
|
||||||
chat_session_id: int
|
chat_session_id: int
|
||||||
message_number: int
|
message_number: int
|
||||||
parent_edit_number: int | None
|
parent_edit_number: int | None
|
||||||
message: str
|
message: str
|
||||||
|
persona_id: int | None
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageIdentifier(BaseModel):
|
class ChatMessageIdentifier(BaseModel):
|
||||||
@ -172,6 +173,10 @@ class ChatMessageIdentifier(BaseModel):
|
|||||||
edit_number: int
|
edit_number: int
|
||||||
|
|
||||||
|
|
||||||
|
class RegenerateMessageRequest(ChatMessageIdentifier):
|
||||||
|
persona_id: int | None
|
||||||
|
|
||||||
|
|
||||||
class ChatRenameRequest(BaseModel):
|
class ChatRenameRequest(BaseModel):
|
||||||
chat_session_id: int
|
chat_session_id: int
|
||||||
name: str | None
|
name: str | None
|
||||||
|
@ -1,13 +1,29 @@
|
|||||||
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import bs4
|
import bs4
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
||||||
|
def has_unescaped_quote(s: str) -> bool:
|
||||||
|
pattern = r'(?<!\\)"'
|
||||||
|
return bool(re.search(pattern, s))
|
||||||
|
|
||||||
|
|
||||||
def escape_newlines(s: str) -> str:
|
def escape_newlines(s: str) -> str:
|
||||||
return re.sub(r"(?<!\\)\n", "\\\\n", s)
|
return re.sub(r"(?<!\\)\n", "\\\\n", s)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_embedded_json(s: str) -> dict:
|
||||||
|
first_brace_index = s.find("{")
|
||||||
|
last_brace_index = s.rfind("}")
|
||||||
|
|
||||||
|
if first_brace_index == -1 or last_brace_index == -1:
|
||||||
|
raise ValueError("No valid json found")
|
||||||
|
|
||||||
|
return json.loads(s[first_brace_index : last_brace_index + 1])
|
||||||
|
|
||||||
|
|
||||||
def clean_up_code_blocks(model_out_raw: str) -> str:
|
def clean_up_code_blocks(model_out_raw: str) -> str:
|
||||||
return model_out_raw.strip().strip("```").strip()
|
return model_out_raw.strip().strip("```").strip()
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ safetensors==0.3.1
|
|||||||
sentence-transformers==2.2.2
|
sentence-transformers==2.2.2
|
||||||
slack-sdk==3.20.2
|
slack-sdk==3.20.2
|
||||||
SQLAlchemy[mypy]==2.0.12
|
SQLAlchemy[mypy]==2.0.12
|
||||||
tensorflow==2.12.0
|
tensorflow==2.13.0
|
||||||
tiktoken==0.4.0
|
tiktoken==0.4.0
|
||||||
transformers==4.30.1
|
transformers==4.30.1
|
||||||
typesense==0.15.1
|
typesense==0.15.1
|
||||||
|
88
backend/scripts/simulate_chat_frontend.py
Normal file
88
backend/scripts/simulate_chat_frontend.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# This file is purely for development use, not included in any builds
|
||||||
|
# Use this to test the chat feature with and without context.
|
||||||
|
# With context refers to being able to call out to Danswer and other tools (currently no other tools)
|
||||||
|
# Without context refers to only knowing the chat's own history with no additional information
|
||||||
|
# This script does not allow for branching logic that is supported by the backend APIs
|
||||||
|
# This script also does not allow for editing/regeneration of user/model messages
|
||||||
|
# Have Danswer API server running to use this.
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import APP_PORT
|
||||||
|
|
||||||
|
LOCAL_CHAT_ENDPOINT = f"http://127.0.0.1:{APP_PORT}/chat/"
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_session() -> int:
|
||||||
|
response = requests.post(LOCAL_CHAT_ENDPOINT + "create-chat-session")
|
||||||
|
response.raise_for_status()
|
||||||
|
new_session_id = response.json()["chat_session_id"]
|
||||||
|
return new_session_id
|
||||||
|
|
||||||
|
|
||||||
|
def send_chat_message(
|
||||||
|
message: str,
|
||||||
|
chat_session_id: int,
|
||||||
|
message_number: int,
|
||||||
|
parent_edit_number: int | None,
|
||||||
|
persona_id: int | None,
|
||||||
|
) -> None:
|
||||||
|
data = {
|
||||||
|
"message": message,
|
||||||
|
"chat_session_id": chat_session_id,
|
||||||
|
"message_number": message_number,
|
||||||
|
"parent_edit_number": parent_edit_number,
|
||||||
|
"persona_id": persona_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
with requests.post(
|
||||||
|
LOCAL_CHAT_ENDPOINT + "send-message", json=data, stream=True
|
||||||
|
) as r:
|
||||||
|
for json_response in r.iter_lines():
|
||||||
|
response_text = json.loads(json_response.decode())
|
||||||
|
new_token = response_text.get("answer_piece")
|
||||||
|
print(new_token, end="", flush=True)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def run_chat(contextual: bool) -> None:
|
||||||
|
try:
|
||||||
|
new_session_id = create_new_session()
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
print(
|
||||||
|
"Looks like you haven't started the Danswer Backend server, please run the FastAPI server"
|
||||||
|
)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
persona_id = 1 if contextual else None
|
||||||
|
|
||||||
|
message_num = 0
|
||||||
|
parent_edit = None
|
||||||
|
while True:
|
||||||
|
new_message = input(
|
||||||
|
"\n\n----------------------------------\n"
|
||||||
|
"Please provide a new chat message:\n> "
|
||||||
|
)
|
||||||
|
|
||||||
|
send_chat_message(
|
||||||
|
new_message, new_session_id, message_num, parent_edit, persona_id
|
||||||
|
)
|
||||||
|
|
||||||
|
message_num += 2 # 1 for User message, 1 for AI response
|
||||||
|
parent_edit = 0 # Since no edits, the parent message is always the first edit of that message number
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--contextual",
|
||||||
|
action="store_true",
|
||||||
|
help="If this flag is set, the chat is able to call tools.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
contextual = args.contextual
|
||||||
|
run_chat(contextual)
|
Loading…
x
Reference in New Issue
Block a user