mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-23 14:30:57 +02:00
Chat with Context Backend (#441)
This commit is contained in:
parent
a16ce56f6b
commit
e549d2bb4a
@ -0,0 +1,40 @@
|
||||
"""Chat Context Addition
|
||||
|
||||
Revision ID: 8e26726b7683
|
||||
Revises: 5809c0787398
|
||||
Create Date: 2023-09-13 18:34:31.327944
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8e26726b7683"
|
||||
down_revision = "5809c0787398"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"persona",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("system_text", sa.Text(), nullable=True),
|
||||
sa.Column("tools_text", sa.Text(), nullable=True),
|
||||
sa.Column("hint_text", sa.Text(), nullable=True),
|
||||
sa.Column("default_persona", sa.Boolean(), nullable=False),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.add_column("chat_message", sa.Column("persona_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_persona_id", "chat_message", "persona", ["persona_id"], ["id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_chat_message_persona_id", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "persona_id")
|
||||
op.drop_table("persona")
|
@ -1,27 +1,155 @@
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.chat.chat_prompts import form_tool_followup_text
|
||||
from danswer.chat.chat_prompts import form_user_prompt_text
|
||||
from danswer.chat.tools import call_tool
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
from danswer.utils.text_processing import has_unescaped_quote
|
||||
|
||||
|
||||
def llm_chat_answer(previous_messages: list[ChatMessage]) -> Iterator[str]:
|
||||
prompt: list[BaseMessage] = []
|
||||
for msg in previous_messages:
|
||||
content = msg.message
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
prompt.append(SystemMessage(content=content))
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
prompt.append(AIMessage(content=content))
|
||||
def _parse_embedded_json_streamed_response(
|
||||
tokens: Iterator[str],
|
||||
) -> Iterator[DanswerAnswerPiece | DanswerChatModelOut]:
|
||||
final_answer = False
|
||||
just_start_stream = False
|
||||
model_output = ""
|
||||
hold = ""
|
||||
finding_end = 0
|
||||
for token in tokens:
|
||||
model_output += token
|
||||
hold += token
|
||||
|
||||
if (
|
||||
msg.message_type == MessageType.USER
|
||||
or msg.message_type == MessageType.DANSWER # consider using FunctionMessage
|
||||
final_answer is False
|
||||
and '"action":"finalanswer",' in model_output.lower().replace(" ", "")
|
||||
):
|
||||
prompt.append(HumanMessage(content=content))
|
||||
final_answer = True
|
||||
|
||||
if final_answer and '"actioninput":"' in model_output.lower().replace(
|
||||
" ", ""
|
||||
).replace("_", ""):
|
||||
if not just_start_stream:
|
||||
just_start_stream = True
|
||||
hold = ""
|
||||
|
||||
if has_unescaped_quote(hold):
|
||||
finding_end += 1
|
||||
hold = hold[: hold.find('"')]
|
||||
|
||||
if finding_end <= 1:
|
||||
if finding_end == 1:
|
||||
finding_end += 1
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=hold)
|
||||
hold = ""
|
||||
|
||||
model_final = extract_embedded_json(model_output)
|
||||
if "action" not in model_final or "action_input" not in model_final:
|
||||
raise ValueError("Model did not provide all required action values")
|
||||
|
||||
yield DanswerChatModelOut(
|
||||
model_raw=model_output,
|
||||
action=model_final["action"],
|
||||
action_input=model_final["action_input"],
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]:
|
||||
prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||
|
||||
return get_default_llm().stream(prompt)
|
||||
|
||||
|
||||
def llm_contextual_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona,
|
||||
user_id: UUID | None,
|
||||
) -> Iterator[str]:
|
||||
system_text = persona.system_text
|
||||
tool_text = persona.tools_text
|
||||
hint_text = persona.hint_text
|
||||
|
||||
last_message = messages[-1]
|
||||
previous_messages = messages[:-1]
|
||||
|
||||
user_text = form_user_prompt_text(
|
||||
query=last_message.message,
|
||||
tool_text=tool_text,
|
||||
hint_text=hint_text,
|
||||
)
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
if system_text:
|
||||
prompt.append(SystemMessage(content=system_text))
|
||||
|
||||
prompt.extend(
|
||||
[translate_danswer_msg_to_langchain(msg) for msg in previous_messages]
|
||||
)
|
||||
|
||||
prompt.append(HumanMessage(content=user_text))
|
||||
|
||||
tokens = get_default_llm().stream(prompt)
|
||||
|
||||
final_result: DanswerChatModelOut | None = None
|
||||
final_answer_streamed = False
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
final_answer_streamed = True
|
||||
|
||||
if isinstance(result, DanswerChatModelOut):
|
||||
final_result = result
|
||||
break
|
||||
|
||||
if final_answer_streamed:
|
||||
return
|
||||
|
||||
if final_result is None:
|
||||
raise RuntimeError("Model output finished without final output parsing.")
|
||||
|
||||
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||
|
||||
prompt.append(AIMessage(content=final_result.model_raw))
|
||||
prompt.append(
|
||||
HumanMessage(
|
||||
content=form_tool_followup_text(tool_result_str, hint_text=hint_text)
|
||||
)
|
||||
)
|
||||
|
||||
tokens = get_default_llm().stream(prompt)
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
final_answer_streamed = True
|
||||
|
||||
if final_answer_streamed is False:
|
||||
raise RuntimeError("LLM failed to produce a Final Answer")
|
||||
|
||||
|
||||
def llm_chat_answer(
|
||||
messages: list[ChatMessage], persona: Persona | None, user_id: UUID | None
|
||||
) -> Iterator[str]:
|
||||
# TODO how to handle model giving jibberish or fail on a particular message
|
||||
# TODO how to handle model failing to choose the right tool
|
||||
# TODO how to handle model gives wrong format
|
||||
if persona is None:
|
||||
return llm_contextless_chat_answer(messages)
|
||||
|
||||
return llm_contextual_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id
|
||||
)
|
||||
|
125
backend/danswer/chat/chat_prompts.py
Normal file
125
backend/danswer/chat/chat_prompts.py
Normal file
@ -0,0 +1,125 @@
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
|
||||
TOOL_TEMPLATE = """TOOLS
|
||||
------
|
||||
Assistant can ask the user to use tools to look up information that may be helpful in answering the users \
|
||||
original question. The tools the human can use are:
|
||||
|
||||
{}
|
||||
|
||||
RESPONSE FORMAT INSTRUCTIONS
|
||||
----------------------------
|
||||
|
||||
When responding to me, please output a response in one of two formats:
|
||||
|
||||
**Option 1:**
|
||||
Use this if you want the human to use a tool.
|
||||
Markdown code snippet formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": string, \\ The action to take. Must be one of {}
|
||||
"action_input": string \\ The input to the action
|
||||
}}
|
||||
```
|
||||
|
||||
**Option #2:**
|
||||
Use this if you want to respond directly to the human. Markdown code snippet formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": string \\ You should put what you want to return to use here
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
TOOL_LESS_PROMPT = """
|
||||
Respond with a markdown code snippet in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": string \\ You should put what you want to return to use here
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
USER_INPUT = """
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Here is the user's input \
|
||||
(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
TOOL_FOLLOWUP = """
|
||||
TOOL RESPONSE:
|
||||
---------------------
|
||||
{}
|
||||
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
|
||||
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
|
||||
{}
|
||||
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
|
||||
"""
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
hint_text: str | None,
|
||||
user_input_prompt: str = USER_INPUT,
|
||||
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "Hint: " + hint_text
|
||||
|
||||
return user_prompt
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[dict[str, str]], template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
tools_intro = []
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
return template.format(tools_intro_text, tool_names_text)
|
||||
|
||||
|
||||
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
|
||||
return "\n".join(
|
||||
f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}"
|
||||
for ind, chunk in enumerate(chunks, start=1)
|
||||
)
|
||||
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\n{hint_text}\n"
|
||||
return tool_followup_prompt.format(tool_output, hint_text_spaced)
|
||||
|
||||
return tool_followup_prompt.format(tool_output, "")
|
26
backend/danswer/chat/personas.py
Normal file
26
backend/danswer/chat/personas.py
Normal file
@ -0,0 +1,26 @@
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_prompts import form_tool_section_text
|
||||
from danswer.configs.app_configs import PERSONAS_YAML
|
||||
from danswer.db.chat import create_persona
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
|
||||
|
||||
def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
|
||||
with open(personas_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
for persona in all_personas:
|
||||
tools = form_tool_section_text(persona["tools"])
|
||||
create_persona(
|
||||
persona_id=persona["id"],
|
||||
name=persona["name"],
|
||||
system_text=persona["system"],
|
||||
tools_text=tools,
|
||||
hint_text=persona["hint"],
|
||||
default_persona=True,
|
||||
db_session=db_session,
|
||||
)
|
14
backend/danswer/chat/personas.yaml
Normal file
14
backend/danswer/chat/personas.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
personas:
|
||||
- id: 1
|
||||
name: "Danswer"
|
||||
system: |
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
|
||||
You have access to a search tool and should always use it anytime it might be useful.
|
||||
Your responses are as INFORMATIVE and DETAILED as possible.
|
||||
tools:
|
||||
- name: "Current Search"
|
||||
description: |
|
||||
A search for up to date and proprietary information.
|
||||
This tool can handle natural language questions so it is better to ask very precise questions over keywords.
|
||||
hint: "Use the Current Search tool to help find information on any topic!"
|
43
backend/danswer/chat/tools.py
Normal file
43
backend/danswer/chat/tools.py
Normal file
@ -0,0 +1,43 @@
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
|
||||
|
||||
def call_tool(
|
||||
model_actions: DanswerChatModelOut,
|
||||
user_id: UUID | None,
|
||||
) -> str:
|
||||
if model_actions.action.lower() == "current search":
|
||||
query = model_actions.action_input
|
||||
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query,
|
||||
user_id=user_id,
|
||||
filters=None,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return "No results found"
|
||||
|
||||
if unranked_chunks:
|
||||
ranked_chunks.extend(unranked_chunks)
|
||||
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
usable_chunks = get_usable_chunks(
|
||||
chunks=filtered_ranked_chunks,
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
|
||||
)
|
||||
|
||||
return format_danswer_chunks_for_chat(usable_chunks)
|
||||
|
||||
raise ValueError("Invalid tool choice by LLM")
|
@ -144,9 +144,12 @@ NUM_RERANKED_RESULTS = 15
|
||||
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
|
||||
)
|
||||
NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
|
||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (512 * 3)
|
||||
)
|
||||
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
||||
@ -182,6 +185,7 @@ CROSS_ENCODER_PORT = 9000
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
DYNAMIC_CONFIG_STORE = os.environ.get(
|
||||
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
|
||||
)
|
||||
|
@ -25,6 +25,20 @@ BOOST = "boost"
|
||||
SCORE = "score"
|
||||
DEFAULT_BOOST = 0
|
||||
|
||||
# Prompt building constants:
|
||||
GENERAL_SEP_PAT = "\n-----\n"
|
||||
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
|
||||
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
||||
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
||||
QUESTION_PAT = "Query:"
|
||||
THOUGHT_PAT = "Thought:"
|
||||
ANSWER_PAT = "Answer:"
|
||||
FINAL_ANSWER_PAT = "Final Answer:"
|
||||
UNCERTAINTY_PAT = "?"
|
||||
QUOTE_PAT = "Quote:"
|
||||
QUOTES_PAT_PLURAL = "Quotes:"
|
||||
INVALID_PAT = "Invalid:"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
SLACK = "slack"
|
||||
|
@ -12,6 +12,7 @@ from danswer.configs.app_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import Persona
|
||||
|
||||
|
||||
def fetch_chat_sessions_by_user(
|
||||
@ -245,3 +246,47 @@ def set_latest_chat_message(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
||||
stmt = select(Persona).where(Persona.id == persona_id)
|
||||
result = db_session.execute(stmt)
|
||||
persona = result.scalar_one_or_none()
|
||||
|
||||
if persona is None:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist")
|
||||
|
||||
return persona
|
||||
|
||||
|
||||
def create_persona(
|
||||
persona_id: int | None,
|
||||
name: str,
|
||||
system_text: str | None,
|
||||
tools_text: str | None,
|
||||
hint_text: str | None,
|
||||
default_persona: bool,
|
||||
db_session: Session,
|
||||
) -> Persona:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
|
||||
if persona:
|
||||
persona.name = name
|
||||
persona.system_text = system_text
|
||||
persona.tools_text = tools_text
|
||||
persona.hint_text = hint_text
|
||||
persona.default_persona = default_persona
|
||||
else:
|
||||
persona = Persona(
|
||||
id=persona_id,
|
||||
name=name,
|
||||
system_text=system_text,
|
||||
tools_text=tools_text,
|
||||
hint_text=hint_text,
|
||||
default_persona=default_persona,
|
||||
)
|
||||
db_session.add(persona)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return persona
|
||||
|
@ -333,6 +333,7 @@ class ChatSession(Base):
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# The following texts help build up the model's ability to use the context effectively
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@ -348,6 +349,19 @@ class ChatSession(Base):
|
||||
)
|
||||
|
||||
|
||||
class Persona(Base):
|
||||
# TODO introduce user and group ownership for personas
|
||||
__tablename__ = "persona"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tools_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# If it's updated and no longer latest (should no longer be shown), it is also considered deleted
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_message"
|
||||
|
||||
@ -362,8 +376,12 @@ class ChatMessage(Base):
|
||||
latest: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
time_sent: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
|
@ -10,6 +10,13 @@ class DanswerAnswer:
|
||||
answer: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerChatModelOut:
|
||||
model_raw: str
|
||||
action: str
|
||||
action_input: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerAnswerPiece:
|
||||
"""A small piece of a complete answer. Used for streaming back answers."""
|
||||
|
@ -10,19 +10,18 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||
from danswer.configs.constants import QUESTION_PAT
|
||||
from danswer.configs.constants import THOUGHT_PAT
|
||||
from danswer.configs.constants import UNCERTAINTY_PAT
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
||||
from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON
|
||||
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
from danswer.direct_qa.qa_prompts import QUESTION_PAT
|
||||
from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE
|
||||
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
@ -193,7 +192,7 @@ class JsonChatQAUnshackledHandler(QAHandler):
|
||||
"should be in JSON format and contain an answer and (optionally) quotes that help support the answer. "
|
||||
"Your responses should be informative, detailed, and consider all possibilities and edge cases. "
|
||||
f"If you don't know the answer, respond with '{complete_answer_not_found_response}'\n"
|
||||
f"Sample response:\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
||||
f"Sample response:\n\n{json.dumps(EMPTY_SAMPLE_JSON)}"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -2,23 +2,17 @@ import abc
|
||||
import json
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.constants import ANSWER_PAT
|
||||
from danswer.configs.constants import DOC_CONTENT_START_PAT
|
||||
from danswer.configs.constants import DOC_SEP_PAT
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||
from danswer.configs.constants import QUESTION_PAT
|
||||
from danswer.configs.constants import QUOTE_PAT
|
||||
from danswer.configs.constants import UNCERTAINTY_PAT
|
||||
from danswer.connectors.factory import identify_connector_class
|
||||
|
||||
|
||||
GENERAL_SEP_PAT = "\n-----\n"
|
||||
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
|
||||
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
||||
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
||||
QUESTION_PAT = "Query:"
|
||||
THOUGHT_PAT = "Thought:"
|
||||
ANSWER_PAT = "Answer:"
|
||||
FINAL_ANSWER_PAT = "Final Answer:"
|
||||
UNCERTAINTY_PAT = "?"
|
||||
QUOTE_PAT = "Quote:"
|
||||
QUOTES_PAT_PLURAL = "Quotes:"
|
||||
INVALID_PAT = "Invalid:"
|
||||
|
||||
BASE_PROMPT = (
|
||||
"Answer the query based on provided documents and quote relevant sections. "
|
||||
"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
||||
@ -26,16 +20,6 @@ BASE_PROMPT = (
|
||||
"The quotes must be EXACT substrings from the documents."
|
||||
)
|
||||
|
||||
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
|
||||
|
||||
SAMPLE_JSON_RESPONSE = {
|
||||
"answer": "The Eiffel Tower is located in Paris, France.",
|
||||
"quotes": [
|
||||
"The Eiffel Tower is an iconic symbol of Paris",
|
||||
"located on the Champ de Mars in France.",
|
||||
],
|
||||
}
|
||||
|
||||
EMPTY_SAMPLE_JSON = {
|
||||
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
|
||||
"quotes": [
|
||||
@ -44,16 +28,6 @@ EMPTY_SAMPLE_JSON = {
|
||||
],
|
||||
}
|
||||
|
||||
ANSWER_NOT_FOUND_JSON = '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
||||
|
||||
SAMPLE_RESPONSE_COT = (
|
||||
"Let's think step by step. The user is asking for the "
|
||||
"location of the Eiffel Tower. The first document describes the Eiffel Tower "
|
||||
"as being an iconic symbol of Paris and that it is located on the Champ de Mars. "
|
||||
"Since the Champ de Mars is in Paris, we know that the Eiffel Tower is in Paris."
|
||||
f"\n\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
||||
)
|
||||
|
||||
|
||||
def _append_acknowledge_doc_messages(
|
||||
current_messages: list[dict[str, str]], new_chunk_content: str
|
||||
@ -152,7 +126,7 @@ class JsonProcessor(NonChatPromptProcessor):
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str:
|
||||
prompt = (
|
||||
BASE_PROMPT + f" Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
||||
BASE_PROMPT + f" Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}\n\n"
|
||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
|
||||
@ -203,7 +177,7 @@ class JsonChatProcessor(ChatPromptProcessor):
|
||||
f"{complete_answer_not_found_response}\n"
|
||||
"If the query requires aggregating the number of documents, respond with "
|
||||
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
||||
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
||||
f"Sample response:\n{json.dumps(EMPTY_SAMPLE_JSON)}"
|
||||
)
|
||||
messages = [{"role": "system", "content": intro_msg}]
|
||||
for chunk in chunks:
|
||||
@ -221,65 +195,6 @@ class JsonChatProcessor(ChatPromptProcessor):
|
||||
return messages
|
||||
|
||||
|
||||
class JsonCoTChatProcessor(ChatPromptProcessor):
|
||||
"""Pros: improves performance slightly over the regular JsonChatProcessor.
|
||||
Cons: Much slower.
|
||||
"""
|
||||
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str,
|
||||
chunks: list[InferenceChunk],
|
||||
include_metadata: bool = True,
|
||||
) -> list[dict[str, str]]:
|
||||
metadata_prompt_section = (
|
||||
"with metadata and contents " if include_metadata else ""
|
||||
)
|
||||
intro_msg = (
|
||||
f"You are a Question Answering assistant that answers queries "
|
||||
f"based on the provided documents.\n"
|
||||
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
|
||||
)
|
||||
|
||||
complete_answer_not_found_response = (
|
||||
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
||||
)
|
||||
task_msg = (
|
||||
"Now answer the user query based on documents above and quote relevant sections.\n"
|
||||
"When answering, you should think step by step, and verbalize your thought process.\n"
|
||||
"Then respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
||||
"All quotes MUST be EXACT substrings from provided documents.\n"
|
||||
"Your responses should be informative, detailed, and consider all possibilities and edge cases.\n"
|
||||
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
||||
"If the query cannot be answered based on the documents, respond with "
|
||||
f"{complete_answer_not_found_response}\n"
|
||||
"If the query requires aggregating the number of documents, respond with "
|
||||
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
||||
f"Sample response:\n\n{SAMPLE_RESPONSE_COT}"
|
||||
)
|
||||
messages = [{"role": "system", "content": intro_msg}]
|
||||
|
||||
for chunk in chunks:
|
||||
full_context = ""
|
||||
if include_metadata:
|
||||
full_context = _add_metadata_section(
|
||||
full_context, chunk, prepend_tab=False, include_sep=False
|
||||
)
|
||||
full_context += chunk.content
|
||||
messages = _append_acknowledge_doc_messages(messages, full_context)
|
||||
messages.append({"role": "user", "content": task_msg})
|
||||
|
||||
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n\n"})
|
||||
|
||||
messages.append({"role": "user", "content": "Let's think step by step."})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class WeakModelFreeformProcessor(NonChatPromptProcessor):
|
||||
"""Avoid using this one if the model is capable of using another prompt
|
||||
Intended for models that can't follow complex instructions or have short context windows
|
||||
@ -366,117 +281,3 @@ class FreeformProcessor(NonChatPromptProcessor):
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
prompt += f"{ANSWER_PAT}\n"
|
||||
return prompt
|
||||
|
||||
|
||||
class FreeformChatProcessor(ChatPromptProcessor):
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> list[dict[str, str]]:
|
||||
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
|
||||
role_msg = (
|
||||
f"You are a Question Answering assistant that answers queries based on provided documents. "
|
||||
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
|
||||
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
|
||||
f"Answer the question directly and concisely. "
|
||||
f"Each quote should be a single continuous segment from a document. "
|
||||
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
||||
f"An example of quote sections may look like:\n{sample_quote}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": role_msg},
|
||||
]
|
||||
for chunk in chunks:
|
||||
messages = _append_acknowledge_doc_messages(messages, chunk.content)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Please now answer the following query based on the previously provided "
|
||||
f"documents and quote the relevant sections of the documents\n{question}",
|
||||
},
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class JsonCOTProcessor(NonChatPromptProcessor):
|
||||
"""Chain of Thought allows model to explain out its reasoning to handle harder tests.
|
||||
This prompt type works however has higher token cost (more expensive) and is slower.
|
||||
Consider this one if users ask questions that require logical reasoning."""
|
||||
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str:
|
||||
prompt = (
|
||||
f"Answer the query based on provided documents and quote relevant sections. "
|
||||
f'Respond with a freeform reasoning section followed by "Final Answer:" with a '
|
||||
f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n"
|
||||
f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
|
||||
for chunk in chunks:
|
||||
prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}"
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
prompt += "Reasoning:\n"
|
||||
return prompt
|
||||
|
||||
|
||||
class JsonReflexionProcessor(NonChatPromptProcessor):
|
||||
"""Reflexion prompting to attempt to have model evaluate its own answer.
|
||||
This one seems largely useless when only given a single example
|
||||
Model seems to take the one example of answering "Yes" and just does that too."""
|
||||
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str:
|
||||
reflexion_str = "Does this fully answer the user query?"
|
||||
prompt = (
|
||||
BASE_PROMPT
|
||||
+ f'After each generated json, ask "{reflexion_str}" and respond Yes or No. '
|
||||
f"If No, generate a better json response to the query.\n"
|
||||
f"Sample question and response:\n"
|
||||
f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n"
|
||||
f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n"
|
||||
f"{reflexion_str} Yes\n\n"
|
||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
|
||||
for chunk in chunks:
|
||||
prompt += f"\n---NEW CONTEXT DOCUMENT---\n{chunk.content}"
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
return prompt
|
||||
|
||||
|
||||
def get_json_chat_reflexion_msg() -> dict[str, str]:
|
||||
"""With the models tried (curent as of Jul 2023), this has not been very useful.
|
||||
Have not seen any answers improved based on this.
|
||||
For models like gpt-3.5-turbo, it will often answer something like:
|
||||
'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source
|
||||
document. No revision is needed.'"""
|
||||
reflexion_content = (
|
||||
"Is the assistant response a valid json that fully answer the user query? "
|
||||
"If the response needs to be fixed or if an improvement is possible, provide a revised json. "
|
||||
"Otherwise, respond with the same json."
|
||||
)
|
||||
return {"role": "system", "content": reflexion_content}
|
||||
|
@ -13,6 +13,25 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage:
|
||||
if (
|
||||
msg.message_type == MessageType.SYSTEM
|
||||
or msg.message_type == MessageType.DANSWER
|
||||
):
|
||||
# TODO save at least the Danswer responses to postgres
|
||||
raise ValueError(
|
||||
"System and Danswer messages are not currently part of history"
|
||||
)
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=msg.message)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=msg.message)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def dict_based_prompt_to_langchain_prompt(
|
||||
|
@ -12,6 +12,7 @@ from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.auth.users import oauth_client
|
||||
from danswer.chat.personas import load_personas_from_yaml
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
@ -191,6 +192,9 @@ def get_application() -> FastAPI:
|
||||
logger.info("Verifying public credential exists.")
|
||||
create_initial_public_credential()
|
||||
|
||||
logger.info("Loading default Chat Personas")
|
||||
load_personas_from_yaml()
|
||||
|
||||
logger.info("Verifying Document Index(s) is/are available.")
|
||||
get_default_document_index().ensure_indices_exist()
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
from danswer.configs.constants import ANSWER_PAT
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||
from danswer.configs.constants import INVALID_PAT
|
||||
from danswer.configs.constants import QUESTION_PAT
|
||||
from danswer.configs.constants import THOUGHT_PAT
|
||||
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
||||
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
||||
from danswer.direct_qa.qa_prompts import INVALID_PAT
|
||||
from danswer.direct_qa.qa_prompts import QUESTION_PAT
|
||||
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
@ -2,10 +2,10 @@ import re
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import asdict
|
||||
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
||||
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.server.models import QueryValidationResponse
|
||||
from danswer.server.utils import get_json_line
|
||||
|
@ -16,6 +16,7 @@ from danswer.db.chat import fetch_chat_message
|
||||
from danswer.db.chat import fetch_chat_messages_by_session
|
||||
from danswer.db.chat import fetch_chat_session_by_id
|
||||
from danswer.db.chat import fetch_chat_sessions_by_user
|
||||
from danswer.db.chat import fetch_persona_by_id
|
||||
from danswer.db.chat import set_latest_chat_message
|
||||
from danswer.db.chat import update_chat_session
|
||||
from danswer.db.chat import verify_parent_exists
|
||||
@ -29,8 +30,9 @@ from danswer.server.models import ChatMessageIdentifier
|
||||
from danswer.server.models import ChatRenameRequest
|
||||
from danswer.server.models import ChatSessionDetailResponse
|
||||
from danswer.server.models import ChatSessionIdsResponse
|
||||
from danswer.server.models import CreateChatID
|
||||
from danswer.server.models import CreateChatRequest
|
||||
from danswer.server.models import CreateChatMessageRequest
|
||||
from danswer.server.models import CreateChatSessionID
|
||||
from danswer.server.models import RegenerateMessageRequest
|
||||
from danswer.server.models import RenameChatSessionResponse
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -108,14 +110,16 @@ def get_chat_session_messages(
|
||||
def create_new_chat_session(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateChatID:
|
||||
) -> CreateChatSessionID:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
new_chat_session = create_chat_session(
|
||||
"", user_id, db_session # Leave the naming till later to prevent delay
|
||||
"",
|
||||
user_id,
|
||||
db_session, # Leave the naming till later to prevent delay
|
||||
)
|
||||
|
||||
return CreateChatID(chat_session_id=new_chat_session.id)
|
||||
return CreateChatSessionID(chat_session_id=new_chat_session.id)
|
||||
|
||||
|
||||
@router.put("/rename-chat-session")
|
||||
@ -182,9 +186,19 @@ def _create_chat_chain(
|
||||
return mainline_messages
|
||||
|
||||
|
||||
def _return_one_if_any(str_1: str | None, str_2: str | None) -> str | None:
|
||||
if str_1 is not None and str_2 is not None:
|
||||
raise ValueError("Conflicting values, can only set one")
|
||||
if str_1 is not None:
|
||||
return str_1
|
||||
if str_2 is not None:
|
||||
return str_2
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/send-message")
|
||||
def handle_new_chat_message(
|
||||
chat_message: CreateChatRequest,
|
||||
chat_message: CreateChatMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
@ -198,6 +212,11 @@ def handle_new_chat_message(
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
||||
persona = (
|
||||
fetch_persona_by_id(chat_message.persona_id, db_session)
|
||||
if chat_message.persona_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Cannot send messages to a deleted chat session")
|
||||
@ -248,7 +267,9 @@ def handle_new_chat_message(
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_chat_tokens() -> Iterator[str]:
|
||||
tokens = llm_chat_answer(mainline_messages)
|
||||
tokens = llm_chat_answer(
|
||||
messages=mainline_messages, persona=persona, user_id=user_id
|
||||
)
|
||||
llm_output = ""
|
||||
for token in tokens:
|
||||
llm_output += token
|
||||
@ -268,7 +289,7 @@ def handle_new_chat_message(
|
||||
|
||||
@router.post("/regenerate-from-parent")
|
||||
def regenerate_message_given_parent(
|
||||
parent_message: ChatMessageIdentifier,
|
||||
parent_message: RegenerateMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
@ -288,6 +309,11 @@ def regenerate_message_given_parent(
|
||||
)
|
||||
|
||||
chat_session = chat_message.chat_session
|
||||
persona = (
|
||||
fetch_persona_by_id(parent_message.persona_id, db_session)
|
||||
if parent_message.persona_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Chat session has been deleted")
|
||||
@ -317,7 +343,9 @@ def regenerate_message_given_parent(
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_regenerate_tokens() -> Iterator[str]:
|
||||
tokens = llm_chat_answer(mainline_messages)
|
||||
tokens = llm_chat_answer(
|
||||
messages=mainline_messages, persona=persona, user_id=user_id
|
||||
)
|
||||
llm_output = ""
|
||||
for token in tokens:
|
||||
llm_output += token
|
||||
|
@ -134,7 +134,7 @@ class SearchDoc(BaseModel):
|
||||
match_highlights: list[str]
|
||||
|
||||
|
||||
class CreateChatID(BaseModel):
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: int
|
||||
|
||||
|
||||
@ -159,11 +159,12 @@ class SearchFeedbackRequest(BaseModel):
|
||||
search_feedback: SearchFeedbackType
|
||||
|
||||
|
||||
class CreateChatRequest(BaseModel):
|
||||
class CreateChatMessageRequest(BaseModel):
|
||||
chat_session_id: int
|
||||
message_number: int
|
||||
parent_edit_number: int | None
|
||||
message: str
|
||||
persona_id: int | None
|
||||
|
||||
|
||||
class ChatMessageIdentifier(BaseModel):
|
||||
@ -172,6 +173,10 @@ class ChatMessageIdentifier(BaseModel):
|
||||
edit_number: int
|
||||
|
||||
|
||||
class RegenerateMessageRequest(ChatMessageIdentifier):
|
||||
persona_id: int | None
|
||||
|
||||
|
||||
class ChatRenameRequest(BaseModel):
|
||||
chat_session_id: int
|
||||
name: str | None
|
||||
|
@ -1,13 +1,29 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
import bs4
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
def has_unescaped_quote(s: str) -> bool:
|
||||
pattern = r'(?<!\\)"'
|
||||
return bool(re.search(pattern, s))
|
||||
|
||||
|
||||
def escape_newlines(s: str) -> str:
|
||||
return re.sub(r"(?<!\\)\n", "\\\\n", s)
|
||||
|
||||
|
||||
def extract_embedded_json(s: str) -> dict:
|
||||
first_brace_index = s.find("{")
|
||||
last_brace_index = s.rfind("}")
|
||||
|
||||
if first_brace_index == -1 or last_brace_index == -1:
|
||||
raise ValueError("No valid json found")
|
||||
|
||||
return json.loads(s[first_brace_index : last_brace_index + 1])
|
||||
|
||||
|
||||
def clean_up_code_blocks(model_out_raw: str) -> str:
|
||||
return model_out_raw.strip().strip("```").strip()
|
||||
|
||||
|
@ -46,7 +46,7 @@ safetensors==0.3.1
|
||||
sentence-transformers==2.2.2
|
||||
slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.12
|
||||
tensorflow==2.12.0
|
||||
tensorflow==2.13.0
|
||||
tiktoken==0.4.0
|
||||
transformers==4.30.1
|
||||
typesense==0.15.1
|
||||
|
88
backend/scripts/simulate_chat_frontend.py
Normal file
88
backend/scripts/simulate_chat_frontend.py
Normal file
@ -0,0 +1,88 @@
|
||||
# This file is purely for development use, not included in any builds
|
||||
# Use this to test the chat feature with and without context.
|
||||
# With context refers to being able to call out to Danswer and other tools (currently no other tools)
|
||||
# Without context refers to only knowing the chat's own history with no additional information
|
||||
# This script does not allow for branching logic that is supported by the backend APIs
|
||||
# This script also does not allow for editing/regeneration of user/model messages
|
||||
# Have Danswer API server running to use this.
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
|
||||
LOCAL_CHAT_ENDPOINT = f"http://127.0.0.1:{APP_PORT}/chat/"
|
||||
|
||||
|
||||
def create_new_session() -> int:
|
||||
response = requests.post(LOCAL_CHAT_ENDPOINT + "create-chat-session")
|
||||
response.raise_for_status()
|
||||
new_session_id = response.json()["chat_session_id"]
|
||||
return new_session_id
|
||||
|
||||
|
||||
def send_chat_message(
|
||||
message: str,
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
parent_edit_number: int | None,
|
||||
persona_id: int | None,
|
||||
) -> None:
|
||||
data = {
|
||||
"message": message,
|
||||
"chat_session_id": chat_session_id,
|
||||
"message_number": message_number,
|
||||
"parent_edit_number": parent_edit_number,
|
||||
"persona_id": persona_id,
|
||||
}
|
||||
|
||||
with requests.post(
|
||||
LOCAL_CHAT_ENDPOINT + "send-message", json=data, stream=True
|
||||
) as r:
|
||||
for json_response in r.iter_lines():
|
||||
response_text = json.loads(json_response.decode())
|
||||
new_token = response_text.get("answer_piece")
|
||||
print(new_token, end="", flush=True)
|
||||
print()
|
||||
|
||||
|
||||
def run_chat(contextual: bool) -> None:
|
||||
try:
|
||||
new_session_id = create_new_session()
|
||||
except requests.exceptions.ConnectionError:
|
||||
print(
|
||||
"Looks like you haven't started the Danswer Backend server, please run the FastAPI server"
|
||||
)
|
||||
exit()
|
||||
|
||||
persona_id = 1 if contextual else None
|
||||
|
||||
message_num = 0
|
||||
parent_edit = None
|
||||
while True:
|
||||
new_message = input(
|
||||
"\n\n----------------------------------\n"
|
||||
"Please provide a new chat message:\n> "
|
||||
)
|
||||
|
||||
send_chat_message(
|
||||
new_message, new_session_id, message_num, parent_edit, persona_id
|
||||
)
|
||||
|
||||
message_num += 2 # 1 for User message, 1 for AI response
|
||||
parent_edit = 0 # Since no edits, the parent message is always the first edit of that message number
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--contextual",
|
||||
action="store_true",
|
||||
help="If this flag is set, the chat is able to call tools.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
contextual = args.contextual
|
||||
run_chat(contextual)
|
Loading…
x
Reference in New Issue
Block a user