Add stop generating functionality (#2100)

* functional types + sidebar

* remove commits

* remove logs

* functional rework of temporary user/assistant ID

* robustify switching

* remove logs

* typing

* robustify frontend handling

* cleaner loop + data persistence

* migrate to streaming response

* formatting

* add new loading state to prevent collisions

* add `ChatState` for more robust handling

* remove logs

* robustify typing

* unnecessary list removed

* robustify

* remove log

* remove false comment

* slightly more robust chat state

* update utility + copy

* improve clarity + new SSE handling utility function

* remove comments

* clearer

* add back stack trace detail

* cleaner messages

* clean final message handling

* tiny formatting (remove newline)

* add synchronous wrapper to avoid hampering main event loop

* update typing

* include logs

* slightly more specific logs

* add `critical` error just in case
This commit is contained in:
pablodanswer 2024-08-18 15:15:55 -07:00 committed by GitHub
parent 8a7bc4e411
commit 12fccfeffd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 547 additions and 234 deletions

View File

@ -39,6 +39,7 @@ def create_chat_chain(
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
all_chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session_id,
user_id=None,

View File

@ -76,6 +76,11 @@ class CitationInfo(BaseModel):
document_id: str
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None

View File

@ -11,6 +11,7 @@ from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
@ -27,6 +28,7 @@ from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_db_search_doc_by_id
from danswer.db.chat import get_doc_query_identifiers_from_model
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
@ -241,6 +243,7 @@ ChatPacket = (
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageResponseIDInfo
)
ChatPacketStream = Iterator[ChatPacket]
@ -256,9 +259,9 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@ -449,7 +452,18 @@ def stream_chat_message_objects(
),
max_window_percentage=max_document_percentage,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
reserved_assistant_message_id=reserved_message_id,
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
@ -582,6 +596,7 @@ def stream_chat_message_objects(
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=AnswerStyleConfig(
@ -615,6 +630,7 @@ def stream_chat_message_objects(
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
@ -690,6 +706,7 @@ def stream_chat_message_objects(
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
error_msg = str(e)
logger.exception(f"Failed to process chat message: {error_msg}")
@ -717,6 +734,7 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@ -737,6 +755,8 @@ def stream_chat_message_objects(
if tool_result
else [],
)
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
@ -745,7 +765,8 @@ def stream_chat_message_objects(
yield msg_detail_response
except Exception as e:
logger.exception(e)
error_msg = str(e)
logger.exception(error_msg)
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
@ -757,6 +778,7 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
@ -765,6 +787,7 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.dict())

View File

@ -393,6 +393,34 @@ def get_or_create_root_message(
return new_root_message
def reserve_message_id(
db_session: Session,
chat_session_id: int,
parent_message: int,
message_type: MessageType,
) -> int:
# Create an empty chat message
empty_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message,
latest_child_message=None,
message="",
token_count=0,
message_type=message_type,
)
# Add the empty message to the session
db_session.add(empty_message)
# Flush the session to get an ID for the new chat message
db_session.flush()
# Get the ID of the newly created message
new_id = empty_message.id
return new_id
def create_new_chat_message(
chat_session_id: int,
parent_message: ChatMessage,
@ -410,29 +438,51 @@ def create_new_chat_message(
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
) -> ChatMessage:
new_chat_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message.id,
latest_child_message=None,
message=message,
rephrased_query=rephrased_query,
prompt_id=prompt_id,
token_count=token_count,
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
)
if reserved_message_id is not None:
# Edit existing message
existing_message = db_session.query(ChatMessage).get(reserved_message_id)
if existing_message is None:
raise ValueError(f"No message found with id {reserved_message_id}")
existing_message.chat_session_id = chat_session_id
existing_message.parent_message = parent_message.id
existing_message.message = message
existing_message.rephrased_query = rephrased_query
existing_message.prompt_id = prompt_id
existing_message.token_count = token_count
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
new_chat_message = existing_message
else:
# Create new message
new_chat_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message.id,
latest_child_message=None,
message=message,
rephrased_query=rephrased_query,
prompt_id=prompt_id,
token_count=token_count,
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
)
db_session.add(new_chat_message)
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
if reference_docs:
new_chat_message.search_docs = reference_docs
db_session.add(new_chat_message)
# Flush the session to get an ID for the new chat message
db_session.flush()

View File

@ -1,3 +1,4 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from uuid import uuid4
@ -115,6 +116,7 @@ class Answer:
# Returns the full document sections text from the search tool
return_contexts: bool = False,
skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None,
) -> None:
if single_message_history and message_history:
raise ValueError(
@ -122,6 +124,7 @@ class Answer:
)
self.question = question
self.is_connected: Callable[[], bool] | None = is_connected
self.latest_query_files = latest_query_files or []
self.file_id_to_file = {file.file_id: file for file in (files or [])}
@ -153,6 +156,7 @@ class Answer:
self._return_contexts = return_contexts
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
@ -235,6 +239,8 @@ class Answer:
tool_call_chunk += message # type: ignore
else:
if message.content:
if self.is_cancelled:
return
yield cast(str, message.content)
if not tool_call_chunk:
@ -292,12 +298,15 @@ class Answer:
yield tool_runner.tool_final_result()
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield from message_generator_to_string_generator(
for token in message_generator_to_string_generator(
self.llm.stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
)
):
if self.is_cancelled:
return
yield token
return
@ -378,9 +387,13 @@ class Answer:
)
)
prompt = prompt_builder.build()
yield from message_generator_to_string_generator(
for token in message_generator_to_string_generator(
self.llm.stream(prompt=prompt)
)
):
if self.is_cancelled:
return
yield token
return
tool, tool_args = chosen_tool_and_args
@ -434,7 +447,12 @@ class Answer:
yield final
prompt = prompt_builder.build()
yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt))
for token in message_generator_to_string_generator(
self.llm.stream(prompt=prompt)
):
if self.is_cancelled:
return
yield token
@property
def processed_streamed_output(self) -> AnswerStream:
@ -537,3 +555,15 @@ class Answer:
citations.append(packet)
return citations
@property
def is_cancelled(self) -> bool:
if self._is_cancelled:
return True
if self.is_connected is not None:
if not self.is_connected():
logger.debug("Answer stream has been cancelled")
self._is_cancelled = not self.is_connected()
return self._is_cancelled

View File

@ -1,5 +1,8 @@
import asyncio
import io
import uuid
from collections.abc import Callable
from collections.abc import Generator
from fastapi import APIRouter
from fastapi import Depends
@ -207,8 +210,6 @@ def rename_chat_session(
chat_session_id = rename_req.chat_session_id
user_id = user.id if user is not None else None
logger.info(f"Received rename request for chat session: {chat_session_id}")
if name:
update_chat_session(
db_session=db_session,
@ -271,19 +272,39 @@ def delete_chat_session_by_id(
delete_chat_session(user_id, session_id, db_session)
async def is_disconnected(request: Request) -> Callable[[], bool]:
main_loop = asyncio.get_event_loop()
def is_disconnected_sync() -> bool:
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
try:
return not future.result(timeout=0.01)
except asyncio.TimeoutError:
logger.error("Asyncio timed out")
return True
except Exception as e:
error_msg = str(e)
logger.critical(
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
)
return True
return is_disconnected_sync
@router.post("/send-message")
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_user),
_: None = Depends(check_token_rate_limits),
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
) -> StreamingResponse:
"""This endpoint is both used for all the following purposes:
- Sending a new message in the session
- Regenerating a message in the session (just send the same one again)
- Editing a message (similar to regenerating but sending a different message)
- Kicking off a seeded chat session (set `use_existing_user_message`)
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
have already been set as latest"""
logger.debug(f"Received new chat message: {chat_message_req.message}")
@ -295,15 +316,26 @@ def handle_new_chat_message(
):
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
packets = stream_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=get_litellm_additional_request_headers(
request.headers
),
)
return StreamingResponse(packets, media_type="application/json")
import json
def stream_generator() -> Generator[str, None, None]:
try:
for packet in stream_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=get_litellm_additional_request_headers(
request.headers
),
is_connected=is_disconnected_func,
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception(f"Error in chat message streaming: {e}")
yield json.dumps({"error": str(e)})
return StreamingResponse(stream_generator(), media_type="text/event-stream")
@router.put("/set-message-as-latest")

View File

@ -12,6 +12,7 @@ import {
FileDescriptor,
ImageGenerationDisplay,
Message,
MessageResponseIDInfo,
RetrievalType,
StreamingError,
ToolCallMetadata,
@ -50,7 +51,7 @@ import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
import { useDocumentSelection } from "./useDocumentSelection";
import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks";
import { computeAvailableFilters } from "@/lib/filters";
import { FeedbackType } from "./types";
import { ChatState, FeedbackType } from "./types";
import { DocumentSidebar } from "./documentSidebar/DocumentSidebar";
import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader";
import { FeedbackModal } from "./modal/FeedbackModal";
@ -211,6 +212,27 @@ export function ChatPage({
}
}, [liveAssistant]);
const stopGeneration = () => {
if (abortController) {
abortController.abort();
}
const lastMessage = messageHistory[messageHistory.length - 1];
if (
lastMessage &&
lastMessage.type === "assistant" &&
lastMessage.toolCalls[0] &&
lastMessage.toolCalls[0].tool_result === undefined
) {
const newCompleteMessageMap = new Map(completeMessageDetail.messageMap);
const updatedMessage = { ...lastMessage, toolCalls: [] };
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
setCompleteMessageDetail({
sessionId: completeMessageDetail.sessionId,
messageMap: newCompleteMessageMap,
});
}
};
// this is for "@"ing assistants
// this is used to track which assistant is being used to generate the current message
@ -413,6 +435,7 @@ export function ChatPage({
);
messages[0].parentMessageId = systemMessageId;
}
messages.forEach((message) => {
const idToReplace = replacementsMap?.get(message.messageId);
if (idToReplace) {
@ -428,7 +451,6 @@ export function ChatPage({
}
newCompleteMessageMap.set(message.messageId, message);
});
// if specified, make these new message the latest of the current message chain
if (makeLatestChildMessage) {
const currentMessageChain = buildLatestMessageChain(
@ -452,7 +474,8 @@ export function ChatPage({
const messageHistory = buildLatestMessageChain(
completeMessageDetail.messageMap
);
const [isStreaming, setIsStreaming] = useState(false);
const [submittedMessage, setSubmittedMessage] = useState("");
const [chatState, setChatState] = useState<ChatState>("input");
const [abortController, setAbortController] =
useState<AbortController | null>(null);
@ -663,13 +686,11 @@ export function ChatPage({
params: any
) {
try {
for await (const packetBunch of sendMessage(params)) {
for await (const packet of sendMessage(params)) {
if (params.signal?.aborted) {
throw new Error("AbortError");
}
for (const packet of packetBunch) {
stack.push(packet);
}
stack.push(packet);
}
} catch (error: unknown) {
if (error instanceof Error) {
@ -709,7 +730,7 @@ export function ChatPage({
isSeededChat?: boolean;
alternativeAssistantOverride?: Persona | null;
} = {}) => {
if (isStreaming) {
if (chatState != "input") {
setPopup({
message: "Please wait for the response to complete",
type: "error",
@ -718,6 +739,7 @@ export function ChatPage({
return;
}
setChatState("loading");
const controller = new AbortController();
setAbortController(controller);
@ -757,13 +779,15 @@ export function ChatPage({
"Failed to re-send message - please refresh the page and try again.",
type: "error",
});
setChatState("input");
return;
}
let currMessage = messageToResend ? messageToResend.message : message;
if (messageOverride) {
currMessage = messageOverride;
}
setSubmittedMessage(currMessage);
const currMessageHistory =
messageToResendIndex !== null
? messageHistory.slice(0, messageToResendIndex)
@ -775,39 +799,6 @@ export function ChatPage({
: null) ||
(messageMap.size === 1 ? Array.from(messageMap.values())[0] : null);
// if we're resending, set the parent's child to null
// we will use tempMessages until the regenerated message is complete
const messageUpdates: Message[] = [
{
messageId: TEMP_USER_MESSAGE_ID,
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: parentMessage?.messageId || null,
},
];
if (parentMessage) {
messageUpdates.push({
...parentMessage,
childrenMessageIds: (parentMessage.childrenMessageIds || []).concat([
TEMP_USER_MESSAGE_ID,
]),
latestChildMessageId: TEMP_USER_MESSAGE_ID,
});
}
const { messageMap: frozenMessageMap, sessionId: frozenSessionId } =
upsertToCompleteMessageMap({
messages: messageUpdates,
chatSessionId: currChatSessionId,
});
// on initial message send, we insert a dummy system message
// set this as the parent here if no parent is set
if (!parentMessage && frozenMessageMap.size === 2) {
parentMessage = frozenMessageMap.get(SYSTEM_MESSAGE_ID) || null;
}
const currentAssistantId = alternativeAssistantOverride
? alternativeAssistantOverride.id
: alternativeAssistant
@ -815,8 +806,8 @@ export function ChatPage({
: liveAssistant.id;
resetInputBar();
let messageUpdates: Message[] | null = null;
setIsStreaming(true);
let answer = "";
let query: string | null = null;
let retrievalType: RetrievalType =
@ -831,6 +822,13 @@ export function ChatPage({
let finalMessage: BackendMessage | null = null;
let toolCalls: ToolCallMetadata[] = [];
let initialFetchDetails: null | {
user_message_id: number;
assistant_message_id: number;
frozenMessageMap: Map<number, Message>;
frozenSessionId: number | null;
} = null;
try {
const lastSuccessfulMessageId =
getLastSuccessfulMessageId(currMessageHistory);
@ -838,7 +836,6 @@ export function ChatPage({
const stack = new CurrentMessageFIFO();
updateCurrentMessageFIFO(stack, {
signal: controller.signal, // Add this line
message: currMessage,
alternateAssistantId: currentAssistantId,
fileDescriptors: currentMessageFiles,
@ -875,20 +872,6 @@ export function ChatPage({
useExistingUserMessage: isSeededChat,
});
const updateFn = (messages: Message[]) => {
const replacementsMap = finalMessage
? new Map([
[messages[0].messageId, TEMP_USER_MESSAGE_ID],
[messages[1].messageId, TEMP_ASSISTANT_MESSAGE_ID],
] as [number, number][])
: null;
upsertToCompleteMessageMap({
messages: messages,
replacementsMap: replacementsMap,
completeMessageMapOverride: frozenMessageMap,
chatSessionId: frozenSessionId!,
});
};
const delay = (ms: number) => {
return new Promise((resolve) => setTimeout(resolve, ms));
};
@ -899,8 +882,71 @@ export function ChatPage({
if (!stack.isEmpty()) {
const packet = stack.nextPacket();
console.log(packet);
if (packet) {
if (!packet) {
continue;
}
if (!initialFetchDetails) {
if (!Object.hasOwn(packet, "user_message_id")) {
console.error(
"First packet should contain message response info "
);
continue;
}
const messageResponseIDInfo = packet as MessageResponseIDInfo;
const user_message_id = messageResponseIDInfo.user_message_id!;
const assistant_message_id =
messageResponseIDInfo.reserved_assistant_message_id;
// we will use tempMessages until the regenerated message is complete
messageUpdates = [
{
messageId: user_message_id,
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: parentMessage?.messageId || null,
},
];
if (parentMessage) {
messageUpdates.push({
...parentMessage,
childrenMessageIds: (
parentMessage.childrenMessageIds || []
).concat([user_message_id]),
latestChildMessageId: user_message_id,
});
}
const {
messageMap: currentFrozenMessageMap,
sessionId: currentFrozenSessionId,
} = upsertToCompleteMessageMap({
messages: messageUpdates,
chatSessionId: currChatSessionId,
});
const frozenMessageMap = currentFrozenMessageMap;
const frozenSessionId = currentFrozenSessionId;
initialFetchDetails = {
frozenMessageMap,
frozenSessionId,
assistant_message_id,
user_message_id,
};
} else {
const { user_message_id, frozenMessageMap, frozenSessionId } =
initialFetchDetails;
setChatState((chatState) => {
if (chatState == "loading") {
return "streaming";
}
return chatState;
});
if (Object.hasOwn(packet, "answer_piece")) {
answer += (packet as AnswerPiecePacket).answer_piece;
} else if (Object.hasOwn(packet, "top_documents")) {
@ -910,7 +956,7 @@ export function ChatPage({
if (documents && documents.length > 0) {
// point to the latest message (we don't know the messageId yet, which is why
// we have to use -1)
setSelectedMessageForDocDisplay(TEMP_USER_MESSAGE_ID);
setSelectedMessageForDocDisplay(user_message_id);
}
} else if (Object.hasOwn(packet, "tool_name")) {
toolCalls = [
@ -920,6 +966,14 @@ export function ChatPage({
tool_result: (packet as ToolCallMetadata).tool_result,
},
];
if (
!toolCalls[0].tool_result ||
toolCalls[0].tool_result == undefined
) {
setChatState("toolBuilding");
} else {
setChatState("streaming");
}
} else if (Object.hasOwn(packet, "file_ids")) {
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
(fileId) => {
@ -936,23 +990,34 @@ export function ChatPage({
finalMessage = packet as BackendMessage;
}
const newUserMessageId =
finalMessage?.parent_message || TEMP_USER_MESSAGE_ID;
const newAssistantMessageId =
finalMessage?.message_id || TEMP_ASSISTANT_MESSAGE_ID;
// on initial message send, we insert a dummy system message
// set this as the parent here if no parent is set
parentMessage =
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
const updateFn = (messages: Message[]) => {
const replacementsMap = null;
upsertToCompleteMessageMap({
messages: messages,
replacementsMap: replacementsMap,
completeMessageMapOverride: frozenMessageMap,
chatSessionId: frozenSessionId!,
});
};
updateFn([
{
messageId: newUserMessageId,
messageId: initialFetchDetails.user_message_id!,
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: parentMessage?.messageId || null,
childrenMessageIds: [newAssistantMessageId],
latestChildMessageId: newAssistantMessageId,
parentMessageId: error ? null : lastSuccessfulMessageId,
childrenMessageIds: [initialFetchDetails.assistant_message_id!],
latestChildMessageId: initialFetchDetails.assistant_message_id,
},
{
messageId: newAssistantMessageId,
messageId: initialFetchDetails.assistant_message_id!,
message: error || answer,
type: error ? "error" : "assistant",
retrievalType,
@ -962,7 +1027,7 @@ export function ChatPage({
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
parentMessageId: newUserMessageId,
parentMessageId: initialFetchDetails.user_message_id,
alternateAssistantID: alternativeAssistant?.id,
stackTrace: stackTrace,
},
@ -975,7 +1040,8 @@ export function ChatPage({
upsertToCompleteMessageMap({
messages: [
{
messageId: TEMP_USER_MESSAGE_ID,
messageId:
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
message: currMessage,
type: "user",
files: currentMessageFiles,
@ -983,24 +1049,28 @@ export function ChatPage({
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
{
messageId: TEMP_ASSISTANT_MESSAGE_ID,
messageId:
initialFetchDetails?.assistant_message_id ||
TEMP_ASSISTANT_MESSAGE_ID,
message: errorMsg,
type: "error",
files: aiMessageImages || [],
toolCalls: [],
parentMessageId: TEMP_USER_MESSAGE_ID,
parentMessageId:
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
},
],
completeMessageMapOverride: frozenMessageMap,
completeMessageMapOverride: completeMessageDetail.messageMap,
});
}
setIsStreaming(false);
setChatState("input");
if (isNewSession) {
if (finalMessage) {
setSelectedMessageForDocDisplay(finalMessage.message_id);
}
if (!searchParamBasedChatSessionName) {
await new Promise((resolve) => setTimeout(resolve, 200));
await nameChatSession(currChatSessionId, currMessage);
}
@ -1060,8 +1130,8 @@ export function ChatPage({
const onAssistantChange = (assistant: Persona | null) => {
if (assistant && assistant.id !== liveAssistant.id) {
// Abort the ongoing stream if it exists
if (abortController && isStreaming) {
abortController.abort();
if (chatState != "input") {
stopGeneration();
resetInputBar();
}
@ -1163,7 +1233,7 @@ export function ChatPage({
});
useScrollonStream({
isStreaming,
chatState,
scrollableDivRef,
scrollDist,
endDivRef,
@ -1334,6 +1404,7 @@ export function ChatPage({
>
<div className="w-full relative">
<HistorySidebar
stopGenerating={stopGeneration}
reset={() => setMessage("")}
page="chat"
ref={innerSidebarElementRef}
@ -1407,7 +1478,7 @@ export function ChatPage({
{messageHistory.length === 0 &&
!isFetchingChatMessages &&
!isStreaming && (
chatState == "input" && (
<ChatIntro
availableSources={finalAvailableSources}
selectedPersona={liveAssistant}
@ -1431,6 +1502,7 @@ export function ChatPage({
return (
<div key={messageReactComponentKey}>
<HumanMessage
stopGenerating={stopGeneration}
content={message.message}
files={message.files}
messageId={message.messageId}
@ -1483,9 +1555,7 @@ export function ChatPage({
(selectedMessageForDocDisplay !== null &&
selectedMessageForDocDisplay ===
message.messageId) ||
(selectedMessageForDocDisplay ===
TEMP_USER_MESSAGE_ID &&
i === messageHistory.length - 1);
i === messageHistory.length - 1;
const previousMessage =
i !== 0 ? messageHistory[i - 1] : null;
@ -1534,7 +1604,8 @@ export function ChatPage({
}
isComplete={
i !== messageHistory.length - 1 ||
!isStreaming
(chatState != "streaming" &&
chatState != "toolBuilding")
}
hasDocs={
(message.documents &&
@ -1542,7 +1613,7 @@ export function ChatPage({
}
handleFeedback={
i === messageHistory.length - 1 &&
isStreaming
chatState != "input"
? undefined
: (feedbackType) =>
setCurrentFeedback([
@ -1552,7 +1623,7 @@ export function ChatPage({
}
handleSearchQueryEdit={
i === messageHistory.length - 1 &&
!isStreaming
chatState == "input"
? (newQuery) => {
if (!previousMessage) {
setPopup({
@ -1659,34 +1730,39 @@ export function ChatPage({
);
}
})}
{isStreaming &&
messageHistory.length > 0 &&
messageHistory[messageHistory.length - 1].type ===
{chatState == "loading" &&
messageHistory[messageHistory.length - 1]?.type !=
"user" && (
<div
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
>
<AIMessage
currentPersona={liveAssistant}
alternativeAssistant={
alternativeGeneratingAssistant ??
alternativeAssistant
}
messageId={null}
personaName={liveAssistant.name}
content={
<div
key={"Generating"}
className="mr-auto relative inline-block"
>
<span className="text-sm loading-text">
Thinking...
</span>
</div>
}
/>
</div>
<HumanMessage
messageId={-1}
content={submittedMessage}
/>
)}
{chatState == "loading" && (
<div
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
>
<AIMessage
currentPersona={liveAssistant}
alternativeAssistant={
alternativeGeneratingAssistant ??
alternativeAssistant
}
messageId={null}
personaName={liveAssistant.name}
content={
<div
key={"Generating"}
className="mr-auto relative inline-block"
>
<span className="text-sm loading-text">
Thinking...
</span>
</div>
}
/>
</div>
)}
{currentPersona &&
currentPersona.starter_messages &&
@ -1748,6 +1824,8 @@ export function ChatPage({
)}
<ChatInputBar
chatState={chatState}
stopGenerating={stopGeneration}
openModelSettings={() => setSettingsToggled(true)}
inputPrompts={userInputPrompts}
showDocs={() => setDocumentSelection(true)}
@ -1762,7 +1840,6 @@ export function ChatPage({
message={message}
setMessage={setMessage}
onSubmit={onSubmit}
isStreaming={isStreaming}
filterManager={filterManager}
llmOverrideManager={llmOverrideManager}
files={currentMessageFiles}

View File

@ -21,6 +21,7 @@ import {
CpuIconSkeleton,
FileIcon,
SendIcon,
StopGeneratingIcon,
} from "@/components/icons/icons";
import { IconType } from "react-icons";
import Popup from "../../../components/popup/Popup";
@ -31,6 +32,9 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { Hoverable } from "@/components/Hoverable";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { StopCircle } from "@phosphor-icons/react/dist/ssr";
import { Square } from "@phosphor-icons/react";
import { ChatState } from "../types";
const MAX_INPUT_HEIGHT = 200;
export function ChatInputBar({
@ -39,10 +43,11 @@ export function ChatInputBar({
selectedDocuments,
message,
setMessage,
stopGenerating,
onSubmit,
isStreaming,
filterManager,
llmOverrideManager,
chatState,
// assistants
selectedAssistant,
@ -59,6 +64,8 @@ export function ChatInputBar({
inputPrompts,
}: {
openModelSettings: () => void;
chatState: ChatState;
stopGenerating: () => void;
showDocs: () => void;
selectedDocuments: DanswerDocument[];
assistantOptions: Persona[];
@ -68,7 +75,6 @@ export function ChatInputBar({
message: string;
setMessage: (message: string) => void;
onSubmit: () => void;
isStreaming: boolean;
filterManager: FilterManager;
llmOverrideManager: LlmOverrideManager;
selectedAssistant: Persona;
@ -597,24 +603,38 @@ export function ChatInputBar({
}}
/>
</div>
<div className="absolute bottom-2.5 mobile:right-4 desktop:right-10">
<div
className="cursor-pointer"
onClick={() => {
if (message) {
onSubmit();
}
}}
>
<SendIcon
size={28}
className={`text-emphasis text-white p-1 rounded-full ${
message && !isStreaming
? "bg-background-800"
: "bg-[#D7D7D7]"
}`}
/>
</div>
{chatState == "streaming" ||
chatState == "toolBuilding" ||
chatState == "loading" ? (
<button
className={`cursor-pointer ${chatState != "streaming" ? "bg-background-400" : "bg-background-800"} h-[28px] w-[28px] rounded-full`}
onClick={stopGenerating}
disabled={chatState != "streaming"}
>
<StopGeneratingIcon
size={10}
className={`text-emphasis m-auto text-white flex-none
}`}
/>
</button>
) : (
<button
className="cursor-pointer"
onClick={() => {
if (message) {
onSubmit();
}
}}
disabled={chatState != "input"}
>
<SendIcon
size={28}
className={`text-emphasis text-white p-1 rounded-full ${chatState == "input" && message ? "bg-background-800" : "bg-background-400"} `}
/>
</button>
)}
</div>
</div>
</div>

View File

@ -118,6 +118,11 @@ export interface BackendMessage {
alternate_assistant_id?: number | null;
}
export interface MessageResponseIDInfo {
user_message_id: number | null;
reserved_assistant_message_id: number;
}
export interface DocumentsResponse {
top_documents: DanswerDocument[];
rephrased_query: string | null;

View File

@ -3,8 +3,8 @@ import {
DanswerDocument,
Filters,
} from "@/lib/search/interfaces";
import { handleStream } from "@/lib/search/streamingUtils";
import { FeedbackType } from "./types";
import { handleSSEStream, handleStream } from "@/lib/search/streamingUtils";
import { ChatState, FeedbackType } from "./types";
import {
Dispatch,
MutableRefObject,
@ -20,6 +20,7 @@ import {
FileDescriptor,
ImageGenerationDisplay,
Message,
MessageResponseIDInfo,
RetrievalType,
StreamingError,
ToolCallMetadata,
@ -109,7 +110,8 @@ export type PacketType =
| AnswerPiecePacket
| DocumentsResponse
| ImageGenerationDisplay
| StreamingError;
| StreamingError
| MessageResponseIDInfo;
export async function* sendMessage({
message,
@ -127,6 +129,7 @@ export async function* sendMessage({
systemPromptOverride,
useExistingUserMessage,
alternateAssistantId,
signal,
}: {
message: string;
fileDescriptors: FileDescriptor[];
@ -137,70 +140,69 @@ export async function* sendMessage({
selectedDocumentIds: number[] | null;
queryOverride?: string;
forceSearch?: boolean;
// LLM overrides
modelProvider?: string;
modelVersion?: string;
temperature?: number;
// prompt overrides
systemPromptOverride?: string;
// if specified, will use the existing latest user message
// and will ignore the specified `message`
useExistingUserMessage?: boolean;
alternateAssistantId?: number;
}) {
signal?: AbortSignal;
}): AsyncGenerator<PacketType, void, unknown> {
const documentsAreSelected =
selectedDocumentIds && selectedDocumentIds.length > 0;
const sendMessageResponse = await fetch("/api/chat/send-message", {
const body = JSON.stringify({
alternate_assistant_id: alternateAssistantId,
chat_session_id: chatSessionId,
parent_message_id: parentMessageId,
message: message,
prompt_id: promptId,
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
file_descriptors: fileDescriptors,
retrieval_options: !documentsAreSelected
? {
run_search:
promptId === null ||
promptId === undefined ||
queryOverride ||
forceSearch
? "always"
: "auto",
real_time: true,
filters: filters,
}
: null,
query_override: queryOverride,
prompt_override: systemPromptOverride
? {
system_prompt: systemPromptOverride,
}
: null,
llm_override:
temperature || modelVersion
? {
temperature,
model_provider: modelProvider,
model_version: modelVersion,
}
: null,
use_existing_user_message: useExistingUserMessage,
});
const response = await fetch(`/api/chat/send-message`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
alternate_assistant_id: alternateAssistantId,
chat_session_id: chatSessionId,
parent_message_id: parentMessageId,
message: message,
prompt_id: promptId,
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
file_descriptors: fileDescriptors,
retrieval_options: !documentsAreSelected
? {
run_search:
promptId === null ||
promptId === undefined ||
queryOverride ||
forceSearch
? "always"
: "auto",
real_time: true,
filters: filters,
}
: null,
query_override: queryOverride,
prompt_override: systemPromptOverride
? {
system_prompt: systemPromptOverride,
}
: null,
llm_override:
temperature || modelVersion
? {
temperature,
model_provider: modelProvider,
model_version: modelVersion,
}
: null,
use_existing_user_message: useExistingUserMessage,
}),
body,
signal,
});
if (!sendMessageResponse.ok) {
const errorJson = await sendMessageResponse.json();
const errorMsg = errorJson.message || errorJson.detail || "";
throw Error(`Failed to send message - ${errorMsg}`);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
yield* handleStream<PacketType>(sendMessageResponse);
yield* handleSSEStream<PacketType>(response);
}
export async function nameChatSession(chatSessionId: number, message: string) {
@ -635,14 +637,14 @@ export async function uploadFilesForChat(
}
export async function useScrollonStream({
isStreaming,
chatState,
scrollableDivRef,
scrollDist,
endDivRef,
distance,
debounce,
}: {
isStreaming: boolean;
chatState: ChatState;
scrollableDivRef: RefObject<HTMLDivElement>;
scrollDist: MutableRefObject<number>;
endDivRef: RefObject<HTMLDivElement>;
@ -656,7 +658,7 @@ export async function useScrollonStream({
const previousScroll = useRef<number>(0);
useEffect(() => {
if (isStreaming && scrollableDivRef && scrollableDivRef.current) {
if (chatState != "input" && scrollableDivRef && scrollableDivRef.current) {
let newHeight: number = scrollableDivRef.current?.scrollTop!;
const heightDifference = newHeight - previousScroll.current;
previousScroll.current = newHeight;
@ -712,7 +714,7 @@ export async function useScrollonStream({
// scroll on end of stream if within distance
useEffect(() => {
if (scrollableDivRef?.current && !isStreaming) {
if (scrollableDivRef?.current && chatState == "input") {
if (scrollDist.current < distance - 50) {
scrollableDivRef?.current?.scrollBy({
left: 0,
@ -721,5 +723,5 @@ export async function useScrollonStream({
});
}
}
}, [isStreaming]);
}, [chatState]);
}

View File

@ -255,7 +255,6 @@ export const AIMessage = ({
size="small"
assistant={alternativeAssistant || currentPersona}
/>
<div className="w-full">
<div className="max-w-message-max break-words">
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) &&
@ -623,6 +622,7 @@ export const HumanMessage = ({
onEdit,
onMessageSelection,
shared,
stopGenerating = () => null,
}: {
shared?: boolean;
content: string;
@ -631,6 +631,7 @@ export const HumanMessage = ({
otherMessagesCanSwitchTo?: number[];
onEdit?: (editedContent: string) => void;
onMessageSelection?: (messageId: number) => void;
stopGenerating?: () => void;
}) => {
const textareaRef = useRef<HTMLTextAreaElement>(null);
@ -677,7 +678,6 @@ export const HumanMessage = ({
<div className="xl:ml-8">
<div className="flex flex-col mr-4">
<FileDisplay alignBubble files={files || []} />
<div className="flex justify-end">
<div className="w-full ml-8 flex w-full max-w-message-max break-words">
{isEditing ? (
@ -857,16 +857,18 @@ export const HumanMessage = ({
<MessageSwitcher
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() =>
handlePrevious={() => {
stopGenerating();
onMessageSelection(
otherMessagesCanSwitchTo[currentMessageInd - 1]
)
}
handleNext={() =>
);
}}
handleNext={() => {
stopGenerating();
onMessageSelection(
otherMessagesCanSwitchTo[currentMessageInd + 1]
)
}
);
}}
/>
</div>
)}

View File

@ -33,6 +33,7 @@ export function ChatSessionDisplay({
isSelected,
skipGradient,
closeSidebar,
stopGenerating = () => null,
showShareModal,
showDeleteModal,
}: {
@ -43,6 +44,7 @@ export function ChatSessionDisplay({
// if not set, the gradient will still be applied and cause weirdness
skipGradient?: boolean;
closeSidebar?: () => void;
stopGenerating?: () => void;
showShareModal?: (chatSession: ChatSession) => void;
showDeleteModal?: (chatSession: ChatSession) => void;
}) {
@ -99,6 +101,7 @@ export function ChatSessionDisplay({
className="flex my-1 group relative"
key={chatSession.id}
onClick={() => {
stopGenerating();
if (settings?.isMobile && closeSidebar) {
closeSidebar();
}

View File

@ -40,6 +40,7 @@ interface HistorySidebarProps {
reset?: () => void;
showShareModal?: (chatSession: ChatSession) => void;
showDeleteModal?: (chatSession: ChatSession) => void;
stopGenerating?: () => void;
}
export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
@ -54,6 +55,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
openedFolders,
toggleSidebar,
removeToggle,
stopGenerating = () => null,
showShareModal,
showDeleteModal,
},
@ -179,6 +181,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
)}
<div className="border-b border-border pb-4 mx-3" />
<PagesTab
stopGenerating={stopGenerating}
newFolderId={newFolderId}
showDeleteModal={showDeleteModal}
showShareModal={showShareModal}

View File

@ -17,10 +17,12 @@ export function PagesTab({
folders,
openedFolders,
closeSidebar,
stopGenerating,
newFolderId,
showShareModal,
showDeleteModal,
}: {
stopGenerating: () => void;
page: pageType;
existingChats?: ChatSession[];
currentChatId?: number;
@ -124,6 +126,7 @@ export function PagesTab({
return (
<div key={`${chat.id}-${chat.name}`}>
<ChatSessionDisplay
stopGenerating={stopGenerating}
showDeleteModal={showDeleteModal}
showShareModal={showShareModal}
closeSidebar={closeSidebar}

View File

@ -1 +1,2 @@
export type FeedbackType = "like" | "dislike";
export type ChatState = "input" | "loading" | "streaming" | "toolBuilding";

View File

@ -1763,6 +1763,29 @@ export const FilledLikeIcon = ({
);
};
export const StopGeneratingIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 14 14"
>
<path
fill="currentColor"
fill-rule="evenodd"
d="M1.5 0A1.5 1.5 0 0 0 0 1.5v11A1.5 1.5 0 0 0 1.5 14h11a1.5 1.5 0 0 0 1.5-1.5v-11A1.5 1.5 0 0 0 12.5 0z"
clip-rule="evenodd"
/>
</svg>
);
};
export const LikeFeedbackIcon = ({
size = 16,
className = defaultTailwindCSS,

View File

@ -1,3 +1,5 @@
import { PacketType } from "@/app/chat/lib";
type NonEmptyObject = { [k: string]: any };
const processSingleChunk = <T extends NonEmptyObject>(
@ -75,3 +77,33 @@ export async function* handleStream<T extends NonEmptyObject>(
yield await Promise.resolve(completedChunks);
}
}
export async function* handleSSEStream<T extends PacketType>(
streamingResponse: Response
): AsyncGenerator<T, void, unknown> {
const reader = streamingResponse.body?.getReader();
const decoder = new TextDecoder();
while (true) {
const rawChunk = await reader?.read();
if (!rawChunk) {
throw new Error("Unable to process chunk");
}
const { done, value } = rawChunk;
if (done) {
break;
}
const chunk = decoder.decode(value);
const lines = chunk.split("\n").filter((line) => line.trim() !== "");
for (const line of lines) {
try {
const data = JSON.parse(line) as T;
yield data;
} catch (error) {
console.error("Error parsing SSE data:", error);
}
}
}
}

View File

@ -90,6 +90,7 @@ module.exports = {
"background-200": "#e5e5e5", // neutral-200
"background-300": "#d4d4d4", // neutral-300
"background-400": "#a3a3a3", // neutral-400
"background-600": "#525252", // neutral-800
"background-500": "#737373", // neutral-400
"background-600": "#525252", // neutral-400
"background-700": "#404040", // neutral-400