mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Chat Backend API edge cases handled (#472)
This commit is contained in:
26
backend/alembic/versions/767f1c2a00eb_count_chat_tokens.py
Normal file
26
backend/alembic/versions/767f1c2a00eb_count_chat_tokens.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Count Chat Tokens
|
||||
|
||||
Revision ID: 767f1c2a00eb
|
||||
Revises: dba7f71618f5
|
||||
Create Date: 2023-09-21 10:03:21.509899
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "767f1c2a00eb"
|
||||
down_revision = "dba7f71618f5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("token_count", sa.Integer(), nullable=False)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "token_count")
|
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
@@ -14,6 +15,7 @@ from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||
from danswer.chat.tools import call_tool
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
@@ -22,6 +24,7 @@ from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -31,6 +34,9 @@ from danswer.utils.text_processing import has_unescaped_quote
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
LLM_CHAT_FAILURE_MSG = "The large-language-model failed to generate a valid response."
|
||||
|
||||
|
||||
def _parse_embedded_json_streamed_response(
|
||||
tokens: Iterator[str],
|
||||
) -> Iterator[DanswerAnswerPiece | DanswerChatModelOut]:
|
||||
@@ -81,6 +87,24 @@ def _parse_embedded_json_streamed_response(
|
||||
return
|
||||
|
||||
|
||||
def _find_last_index(
|
||||
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
|
||||
) -> int:
|
||||
"""From the back, find the index of the last element to include
|
||||
before the list exceeds the maximum"""
|
||||
running_sum = 0
|
||||
|
||||
last_ind = 0
|
||||
for i in range(len(lst) - 1, -1, -1):
|
||||
running_sum += lst[i]
|
||||
if running_sum > max_prompt_tokens:
|
||||
last_ind = i + 1
|
||||
break
|
||||
if last_ind >= len(lst):
|
||||
raise ValueError("Last message alone is too large!")
|
||||
return last_ind
|
||||
|
||||
|
||||
def danswer_chat_retrieval(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
@@ -119,16 +143,78 @@ def danswer_chat_retrieval(
|
||||
return format_danswer_chunks_for_chat(usable_chunks)
|
||||
|
||||
|
||||
def llm_contextless_chat_answer(messages: list[ChatMessage]) -> Iterator[str]:
|
||||
prompt = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||
def _drop_messages_history_overflow(
|
||||
system_msg: BaseMessage | None,
|
||||
system_token_count: int,
|
||||
history_msgs: list[BaseMessage],
|
||||
history_token_counts: list[int],
|
||||
final_msg: BaseMessage,
|
||||
final_msg_token_count: int,
|
||||
) -> list[BaseMessage]:
|
||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||
prompt template must be included"""
|
||||
|
||||
return get_default_llm().stream(prompt)
|
||||
if len(history_msgs) != len(history_token_counts):
|
||||
# This should never happen
|
||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
# Start dropping from the history if necessary
|
||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||
ind_prev_msg_start = _find_last_index(all_tokens)
|
||||
|
||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||
prompt.append(system_msg)
|
||||
|
||||
prompt.extend(history_msgs[ind_prev_msg_start:])
|
||||
|
||||
prompt.append(final_msg)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def llm_contextless_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
tokenizer: Callable | None = None,
|
||||
system_text: str | None = None,
|
||||
) -> Iterator[str]:
|
||||
try:
|
||||
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||
|
||||
if system_text:
|
||||
tokenizer = tokenizer or get_default_llm_tokenizer()
|
||||
system_tokens = len(tokenizer(system_text))
|
||||
system_msg = SystemMessage(content=system_text)
|
||||
|
||||
message_tokens = [msg.token_count for msg in messages] + [system_tokens]
|
||||
else:
|
||||
message_tokens = [msg.token_count for msg in messages]
|
||||
|
||||
last_msg_ind = _find_last_index(message_tokens)
|
||||
|
||||
remaining_user_msgs = prompt_msgs[last_msg_ind:]
|
||||
if not remaining_user_msgs:
|
||||
raise ValueError("Last user message is too long!")
|
||||
|
||||
if system_text:
|
||||
all_msgs = [system_msg] + remaining_user_msgs
|
||||
else:
|
||||
all_msgs = remaining_user_msgs
|
||||
|
||||
return get_default_llm().stream(all_msgs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
|
||||
|
||||
|
||||
def llm_contextual_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona,
|
||||
user_id: UUID | None,
|
||||
tokenizer: Callable,
|
||||
) -> Iterator[str]:
|
||||
retrieval_enabled = persona.retrieval_enabled
|
||||
system_text = persona.system_text
|
||||
@@ -136,93 +222,142 @@ def llm_contextual_chat_answer(
|
||||
hint_text = persona.hint_text
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if not last_message.message:
|
||||
raise ValueError("User chat message is empty.")
|
||||
|
||||
previous_messages = messages[:-1]
|
||||
previous_msgs_as_basemessage = [
|
||||
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
|
||||
]
|
||||
|
||||
user_text = form_user_prompt_text(
|
||||
query=last_message.message,
|
||||
tool_text=tool_text,
|
||||
hint_text=hint_text,
|
||||
)
|
||||
# Failure reasons include:
|
||||
# - Invalid LLM output, wrong format or wrong/missing keys
|
||||
# - No "Final Answer" from model after tool calling
|
||||
# - LLM times out or is otherwise unavailable
|
||||
# - Calling invalid tool or tool call fails
|
||||
# - Last message has more tokens than model is set to accept
|
||||
# - Missing user input
|
||||
try:
|
||||
if not last_message.message:
|
||||
raise ValueError("User chat message is empty.")
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
if system_text:
|
||||
prompt.append(SystemMessage(content=system_text))
|
||||
|
||||
prompt.extend(
|
||||
[translate_danswer_msg_to_langchain(msg) for msg in previous_messages]
|
||||
)
|
||||
|
||||
prompt.append(HumanMessage(content=user_text))
|
||||
|
||||
llm = get_default_llm()
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
|
||||
final_result: DanswerChatModelOut | None = None
|
||||
final_answer_streamed = False
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
final_answer_streamed = True
|
||||
|
||||
if isinstance(result, DanswerChatModelOut):
|
||||
final_result = result
|
||||
break
|
||||
|
||||
if final_answer_streamed:
|
||||
return
|
||||
|
||||
if final_result is None:
|
||||
raise RuntimeError("Model output finished without final output parsing.")
|
||||
|
||||
if retrieval_enabled and final_result.action.lower() == DANSWER_TOOL_NAME.lower():
|
||||
tool_result_str = danswer_chat_retrieval(
|
||||
query_message=last_message,
|
||||
history=previous_messages,
|
||||
llm=llm,
|
||||
user_id=user_id,
|
||||
# Build the prompt using the last user message
|
||||
user_text = form_user_prompt_text(
|
||||
query=last_message.message,
|
||||
tool_text=tool_text,
|
||||
hint_text=hint_text,
|
||||
)
|
||||
else:
|
||||
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||
last_user_msg = HumanMessage(content=user_text)
|
||||
|
||||
prompt.append(AIMessage(content=final_result.model_raw))
|
||||
prompt.append(
|
||||
HumanMessage(
|
||||
content=form_tool_followup_text(
|
||||
tool_output=tool_result_str,
|
||||
query=last_message.message,
|
||||
hint_text=hint_text,
|
||||
# Count tokens once to reuse
|
||||
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
|
||||
system_tokens = len(tokenizer(system_text)) if system_text else 0
|
||||
last_user_msg_tokens = len(tokenizer(user_text))
|
||||
|
||||
prompt = _drop_messages_history_overflow(
|
||||
system_msg=SystemMessage(content=system_text) if system_text else None,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage,
|
||||
history_token_counts=previous_msg_token_counts,
|
||||
final_msg=last_user_msg,
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
llm = get_default_llm()
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
|
||||
final_result: DanswerChatModelOut | None = None
|
||||
final_answer_streamed = False
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
final_answer_streamed = True
|
||||
|
||||
if isinstance(result, DanswerChatModelOut):
|
||||
final_result = result
|
||||
break
|
||||
|
||||
if final_answer_streamed:
|
||||
return
|
||||
|
||||
if final_result is None:
|
||||
raise RuntimeError("Model output finished without final output parsing.")
|
||||
|
||||
if (
|
||||
retrieval_enabled
|
||||
and final_result.action.lower() == DANSWER_TOOL_NAME.lower()
|
||||
):
|
||||
tool_result_str = danswer_chat_retrieval(
|
||||
query_message=last_message,
|
||||
history=previous_messages,
|
||||
llm=llm,
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||
|
||||
# The AI's tool calling message
|
||||
tool_call_msg_text = final_result.model_raw
|
||||
tool_call_msg_token_count = len(tokenizer(tool_call_msg_text))
|
||||
|
||||
# Create the new message to use the results of the tool call
|
||||
tool_followup_text = form_tool_followup_text(
|
||||
tool_output=tool_result_str,
|
||||
query=last_message.message,
|
||||
hint_text=hint_text,
|
||||
)
|
||||
)
|
||||
tool_followup_msg = HumanMessage(content=tool_followup_text)
|
||||
tool_followup_tokens = len(tokenizer(tool_followup_text))
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
# Drop previous messages, the drop order goes: previous messages in the history,
|
||||
# the last user prompt and generated intermediate messages from this recent prompt,
|
||||
# the system message, then finally the tool message that was the last thing generated
|
||||
follow_up_prompt = _drop_messages_history_overflow(
|
||||
system_msg=SystemMessage(content=system_text) if system_text else None,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage
|
||||
+ [last_user_msg, AIMessage(content=tool_call_msg_text)],
|
||||
history_token_counts=previous_msg_token_counts
|
||||
+ [last_user_msg_tokens, tool_call_msg_token_count],
|
||||
final_msg=tool_followup_msg,
|
||||
final_msg_token_count=tool_followup_tokens,
|
||||
)
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
final_answer_streamed = True
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(follow_up_prompt)
|
||||
|
||||
if final_answer_streamed is False:
|
||||
raise RuntimeError("LLM failed to produce a Final Answer")
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
final_answer_streamed = True
|
||||
|
||||
if final_answer_streamed is False:
|
||||
raise RuntimeError("LLM did not to produce a Final Answer after tool call")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield LLM_CHAT_FAILURE_MSG
|
||||
|
||||
|
||||
def llm_chat_answer(
|
||||
messages: list[ChatMessage], persona: Persona | None, user_id: UUID | None
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona | None,
|
||||
user_id: UUID | None,
|
||||
tokenizer: Callable,
|
||||
) -> Iterator[str]:
|
||||
# TODO how to handle model giving jibberish or fail on a particular message
|
||||
# TODO how to handle model failing to choose the right tool
|
||||
# TODO how to handle model gives wrong format
|
||||
# Common error cases to keep in mind:
|
||||
# - User asks question about something long ago, due to context limit, the message is dropped
|
||||
# - Tool use gives wrong/irrelevant results, model gets confused by the noise
|
||||
# - Model is too weak of an LLM, fails to follow instructions
|
||||
# - Bad persona design leads to confusing instructions to the model
|
||||
# - Bad configurations, too small token limit, mismatched tokenizer to LLM, etc.
|
||||
if persona is None:
|
||||
return llm_contextless_chat_answer(messages)
|
||||
|
||||
elif persona.retrieval_enabled is False and persona.tools_text is None:
|
||||
return llm_contextless_chat_answer(
|
||||
messages, tokenizer, system_text=persona.system_text
|
||||
)
|
||||
|
||||
return llm_contextual_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id
|
||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||
)
|
||||
|
@@ -81,6 +81,7 @@ USER'S INPUT
|
||||
--------------------
|
||||
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
|
||||
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
|
||||
If the tool response is not useful, ignore it completely.
|
||||
{optional_reminder}{hint}
|
||||
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
|
||||
"""
|
||||
@@ -100,7 +101,7 @@ def form_user_prompt_text(
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "Hint: " + hint_text
|
||||
user_prompt += "\nHint: " + hint_text
|
||||
|
||||
return user_prompt.strip()
|
||||
|
||||
@@ -145,12 +146,12 @@ def form_tool_followup_text(
|
||||
) -> str:
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"As a reminder, my query was: {query}\n"
|
||||
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"{hint_text}\n"
|
||||
hint_text_spaced = f"\nHint: {hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
|
@@ -10,5 +10,4 @@ personas:
|
||||
# Each added tool needs to have a "name" and "description"
|
||||
tools: []
|
||||
# Short tip to pass near the end of the prompt to emphasize some requirement
|
||||
# Such as "Remember to be informative!"
|
||||
hint: ""
|
||||
hint: "Try to be as informative as possible!"
|
||||
|
@@ -76,8 +76,11 @@ GEN_AI_MODEL_VERSION = os.environ.get(
|
||||
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
|
||||
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
|
||||
|
||||
# Set this to be enough for an answer + quotes
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "1024"))
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||
# This next restriction is only used for chat ATM, used to expire old messages as needed
|
||||
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
|
||||
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
|
@@ -186,6 +186,7 @@ def create_new_chat_message(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
message: str,
|
||||
token_count: int,
|
||||
parent_edit_number: int | None,
|
||||
message_type: MessageType,
|
||||
db_session: Session,
|
||||
@@ -211,6 +212,7 @@ def create_new_chat_message(
|
||||
parent_edit_number=parent_edit_number,
|
||||
edit_number=new_edit_number,
|
||||
message=message,
|
||||
token_count=token_count,
|
||||
message_type=message_type,
|
||||
)
|
||||
|
||||
|
@@ -379,6 +379,7 @@ class ChatMessage(Base):
|
||||
) # null if first message
|
||||
latest: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
token_count: Mapped[int] = mapped_column(Integer)
|
||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
|
@@ -7,6 +7,7 @@ from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
@@ -45,5 +46,6 @@ def get_default_llm(
|
||||
endpoint=GEN_AI_ENDPOINT,
|
||||
model_host_type=GEN_AI_HOST_TYPE,
|
||||
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
**kwargs,
|
||||
)
|
||||
|
@@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.llm import LangChainChatLLM
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
|
||||
@@ -14,6 +15,7 @@ class OpenAIGPT(LangChainChatLLM):
|
||||
max_output_tokens: int,
|
||||
timeout: int,
|
||||
model_version: str,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any]
|
||||
):
|
||||
@@ -26,10 +28,9 @@ class OpenAIGPT(LangChainChatLLM):
|
||||
model=model_version,
|
||||
openai_api_key=api_key,
|
||||
max_tokens=max_output_tokens,
|
||||
temperature=0,
|
||||
temperature=temperature,
|
||||
request_timeout=timeout,
|
||||
model_kwargs={
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
|
@@ -24,6 +24,7 @@ from danswer.db.engine import get_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.secondary_llm_flows.chat_helpers import get_new_chat_name
|
||||
from danswer.server.models import ChatMessageDetail
|
||||
from danswer.server.models import ChatMessageIdentifier
|
||||
@@ -211,6 +212,8 @@ def handle_new_chat_message(
|
||||
parent_edit_number = chat_message.parent_edit_number
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
|
||||
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
||||
persona = (
|
||||
fetch_persona_by_id(chat_message.persona_id, db_session)
|
||||
@@ -250,6 +253,7 @@ def handle_new_chat_message(
|
||||
message_number=message_number,
|
||||
parent_edit_number=parent_edit_number,
|
||||
message=message_content,
|
||||
token_count=len(llm_tokenizer(message_content)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -268,7 +272,10 @@ def handle_new_chat_message(
|
||||
@log_generator_function_time()
|
||||
def stream_chat_tokens() -> Iterator[str]:
|
||||
tokens = llm_chat_answer(
|
||||
messages=mainline_messages, persona=persona, user_id=user_id
|
||||
messages=mainline_messages,
|
||||
persona=persona,
|
||||
user_id=user_id,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
llm_output = ""
|
||||
for token in tokens:
|
||||
@@ -280,6 +287,7 @@ def handle_new_chat_message(
|
||||
message_number=message_number + 1,
|
||||
parent_edit_number=new_message.edit_number,
|
||||
message=llm_output,
|
||||
token_count=len(llm_tokenizer(llm_output)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -301,6 +309,8 @@ def regenerate_message_given_parent(
|
||||
edit_number = parent_message.edit_number
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
|
||||
chat_message = fetch_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
@@ -344,7 +354,10 @@ def regenerate_message_given_parent(
|
||||
@log_generator_function_time()
|
||||
def stream_regenerate_tokens() -> Iterator[str]:
|
||||
tokens = llm_chat_answer(
|
||||
messages=mainline_messages, persona=persona, user_id=user_id
|
||||
messages=mainline_messages,
|
||||
persona=persona,
|
||||
user_id=user_id,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
llm_output = ""
|
||||
for token in tokens:
|
||||
@@ -356,6 +369,7 @@ def regenerate_message_given_parent(
|
||||
message_number=message_number + 1,
|
||||
parent_edit_number=edit_number,
|
||||
message=llm_output,
|
||||
token_count=len(llm_tokenizer(llm_output)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
Reference in New Issue
Block a user