Chat Backend API edge cases handled (#472)

This commit is contained in:
Yuhong Sun
2023-09-21 20:24:47 -07:00
committed by GitHub
parent b416c85f0f
commit 5cc17d39f0
10 changed files with 269 additions and 85 deletions

View File

@@ -0,0 +1,26 @@
"""Count Chat Tokens
Revision ID: 767f1c2a00eb
Revises: dba7f71618f5
Create Date: 2023-09-21 10:03:21.509899
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "767f1c2a00eb"
down_revision = "dba7f71618f5"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message", sa.Column("token_count", sa.Integer(), nullable=False)
)
def downgrade() -> None:
op.drop_column("chat_message", "token_count")

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable
from collections.abc import Iterator
from uuid import UUID
@@ -14,6 +15,7 @@ from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.chat.tools import call_tool
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
from danswer.datastores.document_index import get_default_document_index
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
@@ -22,6 +24,7 @@ from danswer.direct_qa.interfaces import DanswerChatModelOut
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.llm.build import get_default_llm
from danswer.llm.llm import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.utils.logger import setup_logger
@@ -31,6 +34,9 @@ from danswer.utils.text_processing import has_unescaped_quote
logger = setup_logger()
LLM_CHAT_FAILURE_MSG = "The large-language-model failed to generate a valid response."
def _parse_embedded_json_streamed_response(
tokens: Iterator[str],
) -> Iterator[DanswerAnswerPiece | DanswerChatModelOut]:
@@ -81,6 +87,24 @@ def _parse_embedded_json_streamed_response(
return
def _find_last_index(
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
) -> int:
"""From the back, find the index of the last element to include
before the list exceeds the maximum"""
running_sum = 0
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i]
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
if last_ind >= len(lst):
raise ValueError("Last message alone is too large!")
return last_ind
def danswer_chat_retrieval(
query_message: ChatMessage,
history: list[ChatMessage],
@@ -119,16 +143,78 @@ def danswer_chat_retrieval(
return format_danswer_chunks_for_chat(usable_chunks)
def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]:
prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages]
def _drop_messages_history_overflow(
system_msg: BaseMessage | None,
system_token_count: int,
history_msgs: list[BaseMessage],
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
return get_default_llm().stream(prompt)
if len(history_msgs) != len(history_token_counts):
# This should never happen
raise ValueError("Need exactly 1 token count per message for tracking overflow")
prompt: list[BaseMessage] = []
# Start dropping from the history if necessary
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
ind_prev_msg_start = _find_last_index(all_tokens)
if system_msg and ind_prev_msg_start <= len(history_msgs):
prompt.append(system_msg)
prompt.extend(history_msgs[ind_prev_msg_start:])
prompt.append(final_msg)
return prompt
def llm_contextless_chat_answer(
messages: list[ChatMessage],
tokenizer: Callable | None = None,
system_text: str | None = None,
) -> Iterator[str]:
try:
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
if system_text:
tokenizer = tokenizer or get_default_llm_tokenizer()
system_tokens = len(tokenizer(system_text))
system_msg = SystemMessage(content=system_text)
message_tokens = [msg.token_count for msg in messages] + [system_tokens]
else:
message_tokens = [msg.token_count for msg in messages]
last_msg_ind = _find_last_index(message_tokens)
remaining_user_msgs = prompt_msgs[last_msg_ind:]
if not remaining_user_msgs:
raise ValueError("Last user message is too long!")
if system_text:
all_msgs = [system_msg] + remaining_user_msgs
else:
all_msgs = remaining_user_msgs
return get_default_llm().stream(all_msgs)
except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}")
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
def llm_contextual_chat_answer(
messages: list[ChatMessage],
persona: Persona,
user_id: UUID | None,
tokenizer: Callable,
) -> Iterator[str]:
retrieval_enabled = persona.retrieval_enabled
system_text = persona.system_text
@@ -136,93 +222,142 @@ def llm_contextual_chat_answer(
hint_text = persona.hint_text
last_message = messages[-1]
if not last_message.message:
raise ValueError("User chat message is empty.")
previous_messages = messages[:-1]
previous_msgs_as_basemessage = [
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
]
user_text = form_user_prompt_text(
query=last_message.message,
tool_text=tool_text,
hint_text=hint_text,
)
# Failure reasons include:
# - Invalid LLM output, wrong format or wrong/missing keys
# - No "Final Answer" from model after tool calling
# - LLM times out or is otherwise unavailable
# - Calling invalid tool or tool call fails
# - Last message has more tokens than model is set to accept
# - Missing user input
try:
if not last_message.message:
raise ValueError("User chat message is empty.")
prompt: list[BaseMessage] = []
if system_text:
prompt.append(SystemMessage(content=system_text))
prompt.extend(
[translate_danswer_msg_to_langchain(msg) for msg in previous_messages]
)
prompt.append(HumanMessage(content=user_text))
llm = get_default_llm()
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
final_result: DanswerChatModelOut | None = None
final_answer_streamed = False
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result.answer_piece
final_answer_streamed = True
if isinstance(result, DanswerChatModelOut):
final_result = result
break
if final_answer_streamed:
return
if final_result is None:
raise RuntimeError("Model output finished without final output parsing.")
if retrieval_enabled and final_result.action.lower() == DANSWER_TOOL_NAME.lower():
tool_result_str = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user_id=user_id,
# Build the prompt using the last user message
user_text = form_user_prompt_text(
query=last_message.message,
tool_text=tool_text,
hint_text=hint_text,
)
else:
tool_result_str = call_tool(final_result, user_id=user_id)
last_user_msg = HumanMessage(content=user_text)
prompt.append(AIMessage(content=final_result.model_raw))
prompt.append(
HumanMessage(
content=form_tool_followup_text(
tool_output=tool_result_str,
query=last_message.message,
hint_text=hint_text,
# Count tokens once to reuse
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
system_tokens = len(tokenizer(system_text)) if system_text else 0
last_user_msg_tokens = len(tokenizer(user_text))
prompt = _drop_messages_history_overflow(
system_msg=SystemMessage(content=system_text) if system_text else None,
system_token_count=system_tokens,
history_msgs=previous_msgs_as_basemessage,
history_token_counts=previous_msg_token_counts,
final_msg=last_user_msg,
final_msg_token_count=last_user_msg_tokens,
)
llm = get_default_llm()
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
final_result: DanswerChatModelOut | None = None
final_answer_streamed = False
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result.answer_piece
final_answer_streamed = True
if isinstance(result, DanswerChatModelOut):
final_result = result
break
if final_answer_streamed:
return
if final_result is None:
raise RuntimeError("Model output finished without final output parsing.")
if (
retrieval_enabled
and final_result.action.lower() == DANSWER_TOOL_NAME.lower()
):
tool_result_str = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user_id=user_id,
)
else:
tool_result_str = call_tool(final_result, user_id=user_id)
# The AI's tool calling message
tool_call_msg_text = final_result.model_raw
tool_call_msg_token_count = len(tokenizer(tool_call_msg_text))
# Create the new message to use the results of the tool call
tool_followup_text = form_tool_followup_text(
tool_output=tool_result_str,
query=last_message.message,
hint_text=hint_text,
)
)
tool_followup_msg = HumanMessage(content=tool_followup_text)
tool_followup_tokens = len(tokenizer(tool_followup_text))
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
# Drop previous messages, the drop order goes: previous messages in the history,
# the last user prompt and generated intermediate messages from this recent prompt,
# the system message, then finally the tool message that was the last thing generated
follow_up_prompt = _drop_messages_history_overflow(
system_msg=SystemMessage(content=system_text) if system_text else None,
system_token_count=system_tokens,
history_msgs=previous_msgs_as_basemessage
+ [last_user_msg, AIMessage(content=tool_call_msg_text)],
history_token_counts=previous_msg_token_counts
+ [last_user_msg_tokens, tool_call_msg_token_count],
final_msg=tool_followup_msg,
final_msg_token_count=tool_followup_tokens,
)
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result.answer_piece
final_answer_streamed = True
# Good Debug/Breakpoint
tokens = llm.stream(follow_up_prompt)
if final_answer_streamed is False:
raise RuntimeError("LLM failed to produce a Final Answer")
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result.answer_piece
final_answer_streamed = True
if final_answer_streamed is False:
raise RuntimeError("LLM did not to produce a Final Answer after tool call")
except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}")
yield LLM_CHAT_FAILURE_MSG
def llm_chat_answer(
messages: list[ChatMessage], persona: Persona | None, user_id: UUID | None
messages: list[ChatMessage],
persona: Persona | None,
user_id: UUID | None,
tokenizer: Callable,
) -> Iterator[str]:
# TODO how to handle model giving jibberish or fail on a particular message
# TODO how to handle model failing to choose the right tool
# TODO how to handle model gives wrong format
# Common error cases to keep in mind:
# - User asks question about something long ago, due to context limit, the message is dropped
# - Tool use gives wrong/irrelevant results, model gets confused by the noise
# - Model is too weak of an LLM, fails to follow instructions
# - Bad persona design leads to confusing instructions to the model
# - Bad configurations, too small token limit, mismatched tokenizer to LLM, etc.
if persona is None:
return llm_contextless_chat_answer(messages)
elif persona.retrieval_enabled is False and persona.tools_text is None:
return llm_contextless_chat_answer(
messages, tokenizer, system_text=persona.system_text
)
return llm_contextual_chat_answer(
messages=messages, persona=persona, user_id=user_id
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
)

View File

@@ -81,6 +81,7 @@ USER'S INPUT
--------------------
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
If the tool response is not useful, ignore it completely.
{optional_reminder}{hint}
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
"""
@@ -100,7 +101,7 @@ def form_user_prompt_text(
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "Hint: " + hint_text
user_prompt += "\nHint: " + hint_text
return user_prompt.strip()
@@ -145,12 +146,12 @@ def form_tool_followup_text(
) -> str:
# If multi-line query, it likely confuses the model more than helps
if "\n" not in query:
optional_reminder = f"As a reminder, my query was: {query}\n"
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
else:
optional_reminder = ""
if not ignore_hint and hint_text:
hint_text_spaced = f"{hint_text}\n"
hint_text_spaced = f"\nHint: {hint_text}\n"
else:
hint_text_spaced = ""

View File

@@ -10,5 +10,4 @@ personas:
# Each added tool needs to have a "name" and "description"
tools: []
# Short tip to pass near the end of the prompt to emphasize some requirement
# Such as "Remember to be informative!"
hint: ""
hint: "Try to be as informative as possible!"

View File

@@ -76,8 +76,11 @@ GEN_AI_MODEL_VERSION = os.environ.get(
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
# Set this to be enough for an answer + quotes
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "1024"))
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# This next restriction is only used for chat ATM, used to expire old messages as needed
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"

View File

@@ -186,6 +186,7 @@ def create_new_chat_message(
chat_session_id: int,
message_number: int,
message: str,
token_count: int,
parent_edit_number: int | None,
message_type: MessageType,
db_session: Session,
@@ -211,6 +212,7 @@ def create_new_chat_message(
parent_edit_number=parent_edit_number,
edit_number=new_edit_number,
message=message,
token_count=token_count,
message_type=message_type,
)

View File

@@ -379,6 +379,7 @@ class ChatMessage(Base):
) # null if first message
latest: Mapped[bool] = mapped_column(Boolean, default=True)
message: Mapped[str] = mapped_column(Text)
token_count: Mapped[int] = mapped_column(Integer)
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True

View File

@@ -7,6 +7,7 @@ from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.dynamic_configs.interface import ConfigNotFoundError
@@ -45,5 +46,6 @@ def get_default_llm(
endpoint=GEN_AI_ENDPOINT,
model_host_type=GEN_AI_HOST_TYPE,
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
temperature=GEN_AI_TEMPERATURE,
**kwargs,
)

View File

@@ -3,6 +3,7 @@ from typing import Any
from langchain.chat_models.openai import ChatOpenAI
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose
@@ -14,6 +15,7 @@ class OpenAIGPT(LangChainChatLLM):
max_output_tokens: int,
timeout: int,
model_version: str,
temperature: float = GEN_AI_TEMPERATURE,
*args: list[Any],
**kwargs: dict[str, Any]
):
@@ -26,10 +28,9 @@ class OpenAIGPT(LangChainChatLLM):
model=model_version,
openai_api_key=api_key,
max_tokens=max_output_tokens,
temperature=0,
temperature=temperature,
request_timeout=timeout,
model_kwargs={
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},

View File

@@ -24,6 +24,7 @@ from danswer.db.engine import get_session
from danswer.db.models import ChatMessage
from danswer.db.models import User
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.secondary_llm_flows.chat_helpers import get_new_chat_name
from danswer.server.models import ChatMessageDetail
from danswer.server.models import ChatMessageIdentifier
@@ -211,6 +212,8 @@ def handle_new_chat_message(
parent_edit_number = chat_message.parent_edit_number
user_id = user.id if user is not None else None
llm_tokenizer = get_default_llm_tokenizer()
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
persona = (
fetch_persona_by_id(chat_message.persona_id, db_session)
@@ -250,6 +253,7 @@ def handle_new_chat_message(
message_number=message_number,
parent_edit_number=parent_edit_number,
message=message_content,
token_count=len(llm_tokenizer(message_content)),
message_type=MessageType.USER,
db_session=db_session,
)
@@ -268,7 +272,10 @@ def handle_new_chat_message(
@log_generator_function_time()
def stream_chat_tokens() -> Iterator[str]:
tokens = llm_chat_answer(
messages=mainline_messages, persona=persona, user_id=user_id
messages=mainline_messages,
persona=persona,
user_id=user_id,
tokenizer=llm_tokenizer,
)
llm_output = ""
for token in tokens:
@@ -280,6 +287,7 @@ def handle_new_chat_message(
message_number=message_number + 1,
parent_edit_number=new_message.edit_number,
message=llm_output,
token_count=len(llm_tokenizer(llm_output)),
message_type=MessageType.ASSISTANT,
db_session=db_session,
)
@@ -301,6 +309,8 @@ def regenerate_message_given_parent(
edit_number = parent_message.edit_number
user_id = user.id if user is not None else None
llm_tokenizer = get_default_llm_tokenizer()
chat_message = fetch_chat_message(
chat_session_id=chat_session_id,
message_number=message_number,
@@ -344,7 +354,10 @@ def regenerate_message_given_parent(
@log_generator_function_time()
def stream_regenerate_tokens() -> Iterator[str]:
tokens = llm_chat_answer(
messages=mainline_messages, persona=persona, user_id=user_id
messages=mainline_messages,
persona=persona,
user_id=user_id,
tokenizer=llm_tokenizer,
)
llm_output = ""
for token in tokens:
@@ -356,6 +369,7 @@ def regenerate_message_given_parent(
message_number=message_number + 1,
parent_edit_number=edit_number,
message=llm_output,
token_count=len(llm_tokenizer(llm_output)),
message_type=MessageType.ASSISTANT,
db_session=db_session,
)