From 12fccfeffd122314f2be5fea108fa1cad95a2e4b Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sun, 18 Aug 2024 15:15:55 -0700 Subject: [PATCH] Add `stop generating` functionality (#2100) * functional types + sidebar * remove commits * remove logs * functional rework of temporary user/assistant ID * robustify switching * remove logs * typing * robustify frontend handling * cleaner loop + data persistence * migrate to streaming response * formatting * add new loading state to prevent collisions * add `ChatState` for more robust handling * remove logs * robustify typing * unnecessary list removed * robustify * remove log * remove false comment * slightly more robust chat state * update utility + copy * improve clarity + new SSE handling utility function * remove comments * clearer * add back stack trace detail * cleaner messages * clean final message handling * tiny formatting (remove newline) * add synchronous wrapper to avoid hampering main event loop * update typing * include logs * slightly more specific logs * add `critical` error just in case --- backend/danswer/chat/chat_utils.py | 1 + backend/danswer/chat/models.py | 5 + backend/danswer/chat/process_message.py | 29 +- backend/danswer/db/chat.py | 84 ++++- backend/danswer/llm/answering/answer.py | 40 ++- .../server/query_and_chat/chat_backend.py | 56 +++- web/src/app/chat/ChatPage.tsx | 305 +++++++++++------- web/src/app/chat/input/ChatInputBar.tsx | 58 ++-- web/src/app/chat/interfaces.ts | 5 + web/src/app/chat/lib.tsx | 114 +++---- web/src/app/chat/message/Messages.tsx | 18 +- .../sessionSidebar/ChatSessionDisplay.tsx | 3 + .../chat/sessionSidebar/HistorySidebar.tsx | 3 + web/src/app/chat/sessionSidebar/PagesTab.tsx | 3 + web/src/app/chat/types.ts | 1 + web/src/components/icons/icons.tsx | 23 ++ web/src/lib/search/streamingUtils.ts | 32 ++ web/tailwind-themes/tailwind.config.js | 1 + 18 files changed, 547 insertions(+), 234 deletions(-) diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index a6c25c2b0..ed9c3c6cb 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -39,6 +39,7 @@ def create_chat_chain( ) -> tuple[ChatMessage, list[ChatMessage]]: """Build the linear chain of messages without including the root message""" mainline_messages: list[ChatMessage] = [] + all_chat_messages = get_chat_messages_by_session( chat_session_id=chat_session_id, user_id=None, diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 2902efe89..d1da783b6 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -76,6 +76,11 @@ class CitationInfo(BaseModel): document_id: str +class MessageResponseIDInfo(BaseModel): + user_message_id: int | None + reserved_assistant_message_id: int + + class StreamingError(BaseModel): error: str stack_trace: str | None = None diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 564d51b76..98f2b29d2 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -11,6 +11,7 @@ from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import ImageGenerationDisplay from danswer.chat.models import LLMRelevanceFilterResponse +from danswer.chat.models import MessageResponseIDInfo from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.configs.chat_configs import BING_API_KEY @@ -27,6 +28,7 @@ from danswer.db.chat import get_chat_session_by_id from danswer.db.chat import get_db_search_doc_by_id from danswer.db.chat import get_doc_query_identifiers_from_model from danswer.db.chat import get_or_create_root_message +from danswer.db.chat import reserve_message_id from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.embedding_model import get_current_db_embedding_model @@ -241,6 +243,7 @@ ChatPacket = ( | CitationInfo | ImageGenerationDisplay | CustomToolResponse + | MessageResponseIDInfo ) ChatPacketStream = Iterator[ChatPacket] @@ -256,9 +259,9 @@ def stream_chat_message_objects( max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, # if specified, uses the last user message and does not create a new user message based # on the `new_msg_req.message`. Currently, requires a state where the last message is a - # user message (e.g. this can only be used for the chat-seeding flow). use_existing_user_message: bool = False, litellm_additional_headers: dict[str, str] | None = None, + is_connected: Callable[[], bool] | None = None, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -449,7 +452,18 @@ def stream_chat_message_objects( ), max_window_percentage=max_document_percentage, ) - + reserved_message_id = reserve_message_id( + db_session=db_session, + chat_session_id=chat_session_id, + parent_message=user_message.id + if user_message is not None + else parent_message.id, + message_type=MessageType.ASSISTANT, + ) + yield MessageResponseIDInfo( + user_message_id=user_message.id if user_message else None, + reserved_assistant_message_id=reserved_message_id, + ) # Cannot determine these without the LLM step or breaking out early partial_response = partial( create_new_chat_message, @@ -582,6 +596,7 @@ def stream_chat_message_objects( # LLM prompt building, response capturing, etc. answer = Answer( + is_connected=is_connected, question=final_msg.message, latest_query_files=latest_query_files, answer_style_config=AnswerStyleConfig( @@ -615,6 +630,7 @@ def stream_chat_message_objects( ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images dropped_indices = None tool_result = None + for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): if packet.id == SEARCH_RESPONSE_SUMMARY_ID: @@ -690,6 +706,7 @@ def stream_chat_message_objects( if isinstance(packet, ToolCallFinalResult): tool_result = packet yield cast(ChatPacket, packet) + logger.debug("Reached end of stream") except Exception as e: error_msg = str(e) logger.exception(f"Failed to process chat message: {error_msg}") @@ -717,6 +734,7 @@ def stream_chat_message_objects( tool_name_to_tool_id[tool.name] = tool_id gen_ai_response_message = partial_response( + reserved_message_id=reserved_message_id, message=answer.llm_answer, rephrased_query=( qa_docs_response.rephrased_query if qa_docs_response else None @@ -737,6 +755,8 @@ def stream_chat_message_objects( if tool_result else [], ) + + logger.debug("Committing messages") db_session.commit() # actually save user / assistant message msg_detail_response = translate_db_message_to_chat_message_detail( @@ -745,7 +765,8 @@ def stream_chat_message_objects( yield msg_detail_response except Exception as e: - logger.exception(e) + error_msg = str(e) + logger.exception(error_msg) # Frontend will erase whatever answer and show this instead yield StreamingError(error="Failed to parse LLM output") @@ -757,6 +778,7 @@ def stream_chat_message( user: User | None, use_existing_user_message: bool = False, litellm_additional_headers: dict[str, str] | None = None, + is_connected: Callable[[], bool] | None = None, ) -> Iterator[str]: with get_session_context_manager() as db_session: objects = stream_chat_message_objects( @@ -765,6 +787,7 @@ def stream_chat_message( db_session=db_session, use_existing_user_message=use_existing_user_message, litellm_additional_headers=litellm_additional_headers, + is_connected=is_connected, ) for obj in objects: yield get_json_line(obj.dict()) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 2ec04b96a..301c48103 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -393,6 +393,34 @@ def get_or_create_root_message( return new_root_message +def reserve_message_id( + db_session: Session, + chat_session_id: int, + parent_message: int, + message_type: MessageType, +) -> int: + # Create an empty chat message + empty_message = ChatMessage( + chat_session_id=chat_session_id, + parent_message=parent_message, + latest_child_message=None, + message="", + token_count=0, + message_type=message_type, + ) + + # Add the empty message to the session + db_session.add(empty_message) + + # Flush the session to get an ID for the new chat message + db_session.flush() + + # Get the ID of the newly created message + new_id = empty_message.id + + return new_id + + def create_new_chat_message( chat_session_id: int, parent_message: ChatMessage, @@ -410,29 +438,51 @@ def create_new_chat_message( citations: dict[int, int] | None = None, tool_calls: list[ToolCall] | None = None, commit: bool = True, + reserved_message_id: int | None = None, ) -> ChatMessage: - new_chat_message = ChatMessage( - chat_session_id=chat_session_id, - parent_message=parent_message.id, - latest_child_message=None, - message=message, - rephrased_query=rephrased_query, - prompt_id=prompt_id, - token_count=token_count, - message_type=message_type, - citations=citations, - files=files, - tool_calls=tool_calls if tool_calls else [], - error=error, - alternate_assistant_id=alternate_assistant_id, - ) + if reserved_message_id is not None: + # Edit existing message + existing_message = db_session.query(ChatMessage).get(reserved_message_id) + if existing_message is None: + raise ValueError(f"No message found with id {reserved_message_id}") + + existing_message.chat_session_id = chat_session_id + existing_message.parent_message = parent_message.id + existing_message.message = message + existing_message.rephrased_query = rephrased_query + existing_message.prompt_id = prompt_id + existing_message.token_count = token_count + existing_message.message_type = message_type + existing_message.citations = citations + existing_message.files = files + existing_message.tool_calls = tool_calls if tool_calls else [] + existing_message.error = error + existing_message.alternate_assistant_id = alternate_assistant_id + + new_chat_message = existing_message + else: + # Create new message + new_chat_message = ChatMessage( + chat_session_id=chat_session_id, + parent_message=parent_message.id, + latest_child_message=None, + message=message, + rephrased_query=rephrased_query, + prompt_id=prompt_id, + token_count=token_count, + message_type=message_type, + citations=citations, + files=files, + tool_calls=tool_calls if tool_calls else [], + error=error, + alternate_assistant_id=alternate_assistant_id, + ) + db_session.add(new_chat_message) # SQL Alchemy will propagate this to update the reference_docs' foreign keys if reference_docs: new_chat_message.search_docs = reference_docs - db_session.add(new_chat_message) - # Flush the session to get an ID for the new chat message db_session.flush() diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index da5ccc4e9..136f18fa6 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from collections.abc import Iterator from typing import cast from uuid import uuid4 @@ -115,6 +116,7 @@ class Answer: # Returns the full document sections text from the search tool return_contexts: bool = False, skip_gen_ai_answer_generation: bool = False, + is_connected: Callable[[], bool] | None = None, ) -> None: if single_message_history and message_history: raise ValueError( @@ -122,6 +124,7 @@ class Answer: ) self.question = question + self.is_connected: Callable[[], bool] | None = is_connected self.latest_query_files = latest_query_files or [] self.file_id_to_file = {file.file_id: file for file in (files or [])} @@ -153,6 +156,7 @@ class Answer: self._return_contexts = return_contexts self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation + self._is_cancelled = False def _update_prompt_builder_for_search_tool( self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc] @@ -235,6 +239,8 @@ class Answer: tool_call_chunk += message # type: ignore else: if message.content: + if self.is_cancelled: + return yield cast(str, message.content) if not tool_call_chunk: @@ -292,12 +298,15 @@ class Answer: yield tool_runner.tool_final_result() prompt = prompt_builder.build(tool_call_summary=tool_call_summary) - yield from message_generator_to_string_generator( + for token in message_generator_to_string_generator( self.llm.stream( prompt=prompt, tools=[tool.tool_definition() for tool in self.tools], ) - ) + ): + if self.is_cancelled: + return + yield token return @@ -378,9 +387,13 @@ class Answer: ) ) prompt = prompt_builder.build() - yield from message_generator_to_string_generator( + for token in message_generator_to_string_generator( self.llm.stream(prompt=prompt) - ) + ): + if self.is_cancelled: + return + yield token + return tool, tool_args = chosen_tool_and_args @@ -434,7 +447,12 @@ class Answer: yield final prompt = prompt_builder.build() - yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt)) + for token in message_generator_to_string_generator( + self.llm.stream(prompt=prompt) + ): + if self.is_cancelled: + return + yield token @property def processed_streamed_output(self) -> AnswerStream: @@ -537,3 +555,15 @@ class Answer: citations.append(packet) return citations + + @property + def is_cancelled(self) -> bool: + if self._is_cancelled: + return True + + if self.is_connected is not None: + if not self.is_connected(): + logger.debug("Answer stream has been cancelled") + self._is_cancelled = not self.is_connected() + + return self._is_cancelled diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 01d5bf072..a37758336 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -1,5 +1,8 @@ +import asyncio import io import uuid +from collections.abc import Callable +from collections.abc import Generator from fastapi import APIRouter from fastapi import Depends @@ -207,8 +210,6 @@ def rename_chat_session( chat_session_id = rename_req.chat_session_id user_id = user.id if user is not None else None - logger.info(f"Received rename request for chat session: {chat_session_id}") - if name: update_chat_session( db_session=db_session, @@ -271,19 +272,39 @@ def delete_chat_session_by_id( delete_chat_session(user_id, session_id, db_session) +async def is_disconnected(request: Request) -> Callable[[], bool]: + main_loop = asyncio.get_event_loop() + + def is_disconnected_sync() -> bool: + future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop) + try: + return not future.result(timeout=0.01) + except asyncio.TimeoutError: + logger.error("Asyncio timed out") + return True + except Exception as e: + error_msg = str(e) + logger.critical( + f"An unexpected error occured with the disconnect check coroutine: {error_msg}" + ) + return True + + return is_disconnected_sync + + @router.post("/send-message") def handle_new_chat_message( chat_message_req: CreateChatMessageRequest, request: Request, user: User | None = Depends(current_user), _: None = Depends(check_token_rate_limits), + is_disconnected_func: Callable[[], bool] = Depends(is_disconnected), ) -> StreamingResponse: """This endpoint is both used for all the following purposes: - Sending a new message in the session - Regenerating a message in the session (just send the same one again) - Editing a message (similar to regenerating but sending a different message) - Kicking off a seeded chat session (set `use_existing_user_message`) - To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path have already been set as latest""" logger.debug(f"Received new chat message: {chat_message_req.message}") @@ -295,15 +316,26 @@ def handle_new_chat_message( ): raise HTTPException(status_code=400, detail="Empty chat message is invalid") - packets = stream_chat_message( - new_msg_req=chat_message_req, - user=user, - use_existing_user_message=chat_message_req.use_existing_user_message, - litellm_additional_headers=get_litellm_additional_request_headers( - request.headers - ), - ) - return StreamingResponse(packets, media_type="application/json") + import json + + def stream_generator() -> Generator[str, None, None]: + try: + for packet in stream_chat_message( + new_msg_req=chat_message_req, + user=user, + use_existing_user_message=chat_message_req.use_existing_user_message, + litellm_additional_headers=get_litellm_additional_request_headers( + request.headers + ), + is_connected=is_disconnected_func, + ): + yield json.dumps(packet) if isinstance(packet, dict) else packet + + except Exception as e: + logger.exception(f"Error in chat message streaming: {e}") + yield json.dumps({"error": str(e)}) + + return StreamingResponse(stream_generator(), media_type="text/event-stream") @router.put("/set-message-as-latest") diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 0d2bb5e32..a28c18912 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -12,6 +12,7 @@ import { FileDescriptor, ImageGenerationDisplay, Message, + MessageResponseIDInfo, RetrievalType, StreamingError, ToolCallMetadata, @@ -50,7 +51,7 @@ import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams"; import { useDocumentSelection } from "./useDocumentSelection"; import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks"; import { computeAvailableFilters } from "@/lib/filters"; -import { FeedbackType } from "./types"; +import { ChatState, FeedbackType } from "./types"; import { DocumentSidebar } from "./documentSidebar/DocumentSidebar"; import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader"; import { FeedbackModal } from "./modal/FeedbackModal"; @@ -211,6 +212,27 @@ export function ChatPage({ } }, [liveAssistant]); + const stopGeneration = () => { + if (abortController) { + abortController.abort(); + } + const lastMessage = messageHistory[messageHistory.length - 1]; + if ( + lastMessage && + lastMessage.type === "assistant" && + lastMessage.toolCalls[0] && + lastMessage.toolCalls[0].tool_result === undefined + ) { + const newCompleteMessageMap = new Map(completeMessageDetail.messageMap); + const updatedMessage = { ...lastMessage, toolCalls: [] }; + newCompleteMessageMap.set(lastMessage.messageId, updatedMessage); + setCompleteMessageDetail({ + sessionId: completeMessageDetail.sessionId, + messageMap: newCompleteMessageMap, + }); + } + }; + // this is for "@"ing assistants // this is used to track which assistant is being used to generate the current message @@ -413,6 +435,7 @@ export function ChatPage({ ); messages[0].parentMessageId = systemMessageId; } + messages.forEach((message) => { const idToReplace = replacementsMap?.get(message.messageId); if (idToReplace) { @@ -428,7 +451,6 @@ export function ChatPage({ } newCompleteMessageMap.set(message.messageId, message); }); - // if specified, make these new message the latest of the current message chain if (makeLatestChildMessage) { const currentMessageChain = buildLatestMessageChain( @@ -452,7 +474,8 @@ export function ChatPage({ const messageHistory = buildLatestMessageChain( completeMessageDetail.messageMap ); - const [isStreaming, setIsStreaming] = useState(false); + const [submittedMessage, setSubmittedMessage] = useState(""); + const [chatState, setChatState] = useState("input"); const [abortController, setAbortController] = useState(null); @@ -663,13 +686,11 @@ export function ChatPage({ params: any ) { try { - for await (const packetBunch of sendMessage(params)) { + for await (const packet of sendMessage(params)) { if (params.signal?.aborted) { throw new Error("AbortError"); } - for (const packet of packetBunch) { - stack.push(packet); - } + stack.push(packet); } } catch (error: unknown) { if (error instanceof Error) { @@ -709,7 +730,7 @@ export function ChatPage({ isSeededChat?: boolean; alternativeAssistantOverride?: Persona | null; } = {}) => { - if (isStreaming) { + if (chatState != "input") { setPopup({ message: "Please wait for the response to complete", type: "error", @@ -718,6 +739,7 @@ export function ChatPage({ return; } + setChatState("loading"); const controller = new AbortController(); setAbortController(controller); @@ -757,13 +779,15 @@ export function ChatPage({ "Failed to re-send message - please refresh the page and try again.", type: "error", }); + setChatState("input"); return; } - let currMessage = messageToResend ? messageToResend.message : message; if (messageOverride) { currMessage = messageOverride; } + + setSubmittedMessage(currMessage); const currMessageHistory = messageToResendIndex !== null ? messageHistory.slice(0, messageToResendIndex) @@ -775,39 +799,6 @@ export function ChatPage({ : null) || (messageMap.size === 1 ? Array.from(messageMap.values())[0] : null); - // if we're resending, set the parent's child to null - // we will use tempMessages until the regenerated message is complete - const messageUpdates: Message[] = [ - { - messageId: TEMP_USER_MESSAGE_ID, - message: currMessage, - type: "user", - files: currentMessageFiles, - toolCalls: [], - parentMessageId: parentMessage?.messageId || null, - }, - ]; - if (parentMessage) { - messageUpdates.push({ - ...parentMessage, - childrenMessageIds: (parentMessage.childrenMessageIds || []).concat([ - TEMP_USER_MESSAGE_ID, - ]), - latestChildMessageId: TEMP_USER_MESSAGE_ID, - }); - } - const { messageMap: frozenMessageMap, sessionId: frozenSessionId } = - upsertToCompleteMessageMap({ - messages: messageUpdates, - chatSessionId: currChatSessionId, - }); - - // on initial message send, we insert a dummy system message - // set this as the parent here if no parent is set - if (!parentMessage && frozenMessageMap.size === 2) { - parentMessage = frozenMessageMap.get(SYSTEM_MESSAGE_ID) || null; - } - const currentAssistantId = alternativeAssistantOverride ? alternativeAssistantOverride.id : alternativeAssistant @@ -815,8 +806,8 @@ export function ChatPage({ : liveAssistant.id; resetInputBar(); + let messageUpdates: Message[] | null = null; - setIsStreaming(true); let answer = ""; let query: string | null = null; let retrievalType: RetrievalType = @@ -831,6 +822,13 @@ export function ChatPage({ let finalMessage: BackendMessage | null = null; let toolCalls: ToolCallMetadata[] = []; + let initialFetchDetails: null | { + user_message_id: number; + assistant_message_id: number; + frozenMessageMap: Map; + frozenSessionId: number | null; + } = null; + try { const lastSuccessfulMessageId = getLastSuccessfulMessageId(currMessageHistory); @@ -838,7 +836,6 @@ export function ChatPage({ const stack = new CurrentMessageFIFO(); updateCurrentMessageFIFO(stack, { signal: controller.signal, // Add this line - message: currMessage, alternateAssistantId: currentAssistantId, fileDescriptors: currentMessageFiles, @@ -875,20 +872,6 @@ export function ChatPage({ useExistingUserMessage: isSeededChat, }); - const updateFn = (messages: Message[]) => { - const replacementsMap = finalMessage - ? new Map([ - [messages[0].messageId, TEMP_USER_MESSAGE_ID], - [messages[1].messageId, TEMP_ASSISTANT_MESSAGE_ID], - ] as [number, number][]) - : null; - upsertToCompleteMessageMap({ - messages: messages, - replacementsMap: replacementsMap, - completeMessageMapOverride: frozenMessageMap, - chatSessionId: frozenSessionId!, - }); - }; const delay = (ms: number) => { return new Promise((resolve) => setTimeout(resolve, ms)); }; @@ -899,8 +882,71 @@ export function ChatPage({ if (!stack.isEmpty()) { const packet = stack.nextPacket(); - console.log(packet); - if (packet) { + if (!packet) { + continue; + } + + if (!initialFetchDetails) { + if (!Object.hasOwn(packet, "user_message_id")) { + console.error( + "First packet should contain message response info " + ); + continue; + } + + const messageResponseIDInfo = packet as MessageResponseIDInfo; + + const user_message_id = messageResponseIDInfo.user_message_id!; + const assistant_message_id = + messageResponseIDInfo.reserved_assistant_message_id; + + // we will use tempMessages until the regenerated message is complete + messageUpdates = [ + { + messageId: user_message_id, + message: currMessage, + type: "user", + files: currentMessageFiles, + toolCalls: [], + parentMessageId: parentMessage?.messageId || null, + }, + ]; + if (parentMessage) { + messageUpdates.push({ + ...parentMessage, + childrenMessageIds: ( + parentMessage.childrenMessageIds || [] + ).concat([user_message_id]), + latestChildMessageId: user_message_id, + }); + } + + const { + messageMap: currentFrozenMessageMap, + sessionId: currentFrozenSessionId, + } = upsertToCompleteMessageMap({ + messages: messageUpdates, + chatSessionId: currChatSessionId, + }); + + const frozenMessageMap = currentFrozenMessageMap; + const frozenSessionId = currentFrozenSessionId; + initialFetchDetails = { + frozenMessageMap, + frozenSessionId, + assistant_message_id, + user_message_id, + }; + } else { + const { user_message_id, frozenMessageMap, frozenSessionId } = + initialFetchDetails; + setChatState((chatState) => { + if (chatState == "loading") { + return "streaming"; + } + return chatState; + }); + if (Object.hasOwn(packet, "answer_piece")) { answer += (packet as AnswerPiecePacket).answer_piece; } else if (Object.hasOwn(packet, "top_documents")) { @@ -910,7 +956,7 @@ export function ChatPage({ if (documents && documents.length > 0) { // point to the latest message (we don't know the messageId yet, which is why // we have to use -1) - setSelectedMessageForDocDisplay(TEMP_USER_MESSAGE_ID); + setSelectedMessageForDocDisplay(user_message_id); } } else if (Object.hasOwn(packet, "tool_name")) { toolCalls = [ @@ -920,6 +966,14 @@ export function ChatPage({ tool_result: (packet as ToolCallMetadata).tool_result, }, ]; + if ( + !toolCalls[0].tool_result || + toolCalls[0].tool_result == undefined + ) { + setChatState("toolBuilding"); + } else { + setChatState("streaming"); + } } else if (Object.hasOwn(packet, "file_ids")) { aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map( (fileId) => { @@ -936,23 +990,34 @@ export function ChatPage({ finalMessage = packet as BackendMessage; } - const newUserMessageId = - finalMessage?.parent_message || TEMP_USER_MESSAGE_ID; - const newAssistantMessageId = - finalMessage?.message_id || TEMP_ASSISTANT_MESSAGE_ID; + // on initial message send, we insert a dummy system message + // set this as the parent here if no parent is set + parentMessage = + parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!; + + const updateFn = (messages: Message[]) => { + const replacementsMap = null; + upsertToCompleteMessageMap({ + messages: messages, + replacementsMap: replacementsMap, + completeMessageMapOverride: frozenMessageMap, + chatSessionId: frozenSessionId!, + }); + }; + updateFn([ { - messageId: newUserMessageId, + messageId: initialFetchDetails.user_message_id!, message: currMessage, type: "user", files: currentMessageFiles, toolCalls: [], - parentMessageId: parentMessage?.messageId || null, - childrenMessageIds: [newAssistantMessageId], - latestChildMessageId: newAssistantMessageId, + parentMessageId: error ? null : lastSuccessfulMessageId, + childrenMessageIds: [initialFetchDetails.assistant_message_id!], + latestChildMessageId: initialFetchDetails.assistant_message_id, }, { - messageId: newAssistantMessageId, + messageId: initialFetchDetails.assistant_message_id!, message: error || answer, type: error ? "error" : "assistant", retrievalType, @@ -962,7 +1027,7 @@ export function ChatPage({ citations: finalMessage?.citations || {}, files: finalMessage?.files || aiMessageImages || [], toolCalls: finalMessage?.tool_calls || toolCalls, - parentMessageId: newUserMessageId, + parentMessageId: initialFetchDetails.user_message_id, alternateAssistantID: alternativeAssistant?.id, stackTrace: stackTrace, }, @@ -975,7 +1040,8 @@ export function ChatPage({ upsertToCompleteMessageMap({ messages: [ { - messageId: TEMP_USER_MESSAGE_ID, + messageId: + initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID, message: currMessage, type: "user", files: currentMessageFiles, @@ -983,24 +1049,28 @@ export function ChatPage({ parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID, }, { - messageId: TEMP_ASSISTANT_MESSAGE_ID, + messageId: + initialFetchDetails?.assistant_message_id || + TEMP_ASSISTANT_MESSAGE_ID, message: errorMsg, type: "error", files: aiMessageImages || [], toolCalls: [], - parentMessageId: TEMP_USER_MESSAGE_ID, + parentMessageId: + initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID, }, ], - completeMessageMapOverride: frozenMessageMap, + completeMessageMapOverride: completeMessageDetail.messageMap, }); } - - setIsStreaming(false); + setChatState("input"); if (isNewSession) { if (finalMessage) { setSelectedMessageForDocDisplay(finalMessage.message_id); } + if (!searchParamBasedChatSessionName) { + await new Promise((resolve) => setTimeout(resolve, 200)); await nameChatSession(currChatSessionId, currMessage); } @@ -1060,8 +1130,8 @@ export function ChatPage({ const onAssistantChange = (assistant: Persona | null) => { if (assistant && assistant.id !== liveAssistant.id) { // Abort the ongoing stream if it exists - if (abortController && isStreaming) { - abortController.abort(); + if (chatState != "input") { + stopGeneration(); resetInputBar(); } @@ -1163,7 +1233,7 @@ export function ChatPage({ }); useScrollonStream({ - isStreaming, + chatState, scrollableDivRef, scrollDist, endDivRef, @@ -1334,6 +1404,7 @@ export function ChatPage({ >
setMessage("")} page="chat" ref={innerSidebarElementRef} @@ -1407,7 +1478,7 @@ export function ChatPage({ {messageHistory.length === 0 && !isFetchingChatMessages && - !isStreaming && ( + chatState == "input" && ( setCurrentFeedback([ @@ -1552,7 +1623,7 @@ export function ChatPage({ } handleSearchQueryEdit={ i === messageHistory.length - 1 && - !isStreaming + chatState == "input" ? (newQuery) => { if (!previousMessage) { setPopup({ @@ -1659,34 +1730,39 @@ export function ChatPage({ ); } })} - {isStreaming && - messageHistory.length > 0 && - messageHistory[messageHistory.length - 1].type === + {chatState == "loading" && + messageHistory[messageHistory.length - 1]?.type != "user" && ( -
- - - Thinking... - -
- } - /> -
+ )} + {chatState == "loading" && ( +
+ + + Thinking... + +
+ } + /> + + )} {currentPersona && currentPersona.starter_messages && @@ -1748,6 +1824,8 @@ export function ChatPage({ )} setSettingsToggled(true)} inputPrompts={userInputPrompts} showDocs={() => setDocumentSelection(true)} @@ -1762,7 +1840,6 @@ export function ChatPage({ message={message} setMessage={setMessage} onSubmit={onSubmit} - isStreaming={isStreaming} filterManager={filterManager} llmOverrideManager={llmOverrideManager} files={currentMessageFiles} diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index 4204b9c5d..b8500eb86 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -21,6 +21,7 @@ import { CpuIconSkeleton, FileIcon, SendIcon, + StopGeneratingIcon, } from "@/components/icons/icons"; import { IconType } from "react-icons"; import Popup from "../../../components/popup/Popup"; @@ -31,6 +32,9 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon"; import { Tooltip } from "@/components/tooltip/Tooltip"; import { Hoverable } from "@/components/Hoverable"; import { SettingsContext } from "@/components/settings/SettingsProvider"; +import { StopCircle } from "@phosphor-icons/react/dist/ssr"; +import { Square } from "@phosphor-icons/react"; +import { ChatState } from "../types"; const MAX_INPUT_HEIGHT = 200; export function ChatInputBar({ @@ -39,10 +43,11 @@ export function ChatInputBar({ selectedDocuments, message, setMessage, + stopGenerating, onSubmit, - isStreaming, filterManager, llmOverrideManager, + chatState, // assistants selectedAssistant, @@ -59,6 +64,8 @@ export function ChatInputBar({ inputPrompts, }: { openModelSettings: () => void; + chatState: ChatState; + stopGenerating: () => void; showDocs: () => void; selectedDocuments: DanswerDocument[]; assistantOptions: Persona[]; @@ -68,7 +75,6 @@ export function ChatInputBar({ message: string; setMessage: (message: string) => void; onSubmit: () => void; - isStreaming: boolean; filterManager: FilterManager; llmOverrideManager: LlmOverrideManager; selectedAssistant: Persona; @@ -597,24 +603,38 @@ export function ChatInputBar({ }} /> +
-
{ - if (message) { - onSubmit(); - } - }} - > - -
+ {chatState == "streaming" || + chatState == "toolBuilding" || + chatState == "loading" ? ( + + ) : ( + + )}
diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index 366bc1ec7..0778ae354 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -118,6 +118,11 @@ export interface BackendMessage { alternate_assistant_id?: number | null; } +export interface MessageResponseIDInfo { + user_message_id: number | null; + reserved_assistant_message_id: number; +} + export interface DocumentsResponse { top_documents: DanswerDocument[]; rephrased_query: string | null; diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 7eb16e5c1..43bba4f89 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -3,8 +3,8 @@ import { DanswerDocument, Filters, } from "@/lib/search/interfaces"; -import { handleStream } from "@/lib/search/streamingUtils"; -import { FeedbackType } from "./types"; +import { handleSSEStream, handleStream } from "@/lib/search/streamingUtils"; +import { ChatState, FeedbackType } from "./types"; import { Dispatch, MutableRefObject, @@ -20,6 +20,7 @@ import { FileDescriptor, ImageGenerationDisplay, Message, + MessageResponseIDInfo, RetrievalType, StreamingError, ToolCallMetadata, @@ -109,7 +110,8 @@ export type PacketType = | AnswerPiecePacket | DocumentsResponse | ImageGenerationDisplay - | StreamingError; + | StreamingError + | MessageResponseIDInfo; export async function* sendMessage({ message, @@ -127,6 +129,7 @@ export async function* sendMessage({ systemPromptOverride, useExistingUserMessage, alternateAssistantId, + signal, }: { message: string; fileDescriptors: FileDescriptor[]; @@ -137,70 +140,69 @@ export async function* sendMessage({ selectedDocumentIds: number[] | null; queryOverride?: string; forceSearch?: boolean; - // LLM overrides modelProvider?: string; modelVersion?: string; temperature?: number; - // prompt overrides systemPromptOverride?: string; - // if specified, will use the existing latest user message - // and will ignore the specified `message` useExistingUserMessage?: boolean; alternateAssistantId?: number; -}) { + signal?: AbortSignal; +}): AsyncGenerator { const documentsAreSelected = selectedDocumentIds && selectedDocumentIds.length > 0; - const sendMessageResponse = await fetch("/api/chat/send-message", { + const body = JSON.stringify({ + alternate_assistant_id: alternateAssistantId, + chat_session_id: chatSessionId, + parent_message_id: parentMessageId, + message: message, + prompt_id: promptId, + search_doc_ids: documentsAreSelected ? selectedDocumentIds : null, + file_descriptors: fileDescriptors, + retrieval_options: !documentsAreSelected + ? { + run_search: + promptId === null || + promptId === undefined || + queryOverride || + forceSearch + ? "always" + : "auto", + real_time: true, + filters: filters, + } + : null, + query_override: queryOverride, + prompt_override: systemPromptOverride + ? { + system_prompt: systemPromptOverride, + } + : null, + llm_override: + temperature || modelVersion + ? { + temperature, + model_provider: modelProvider, + model_version: modelVersion, + } + : null, + use_existing_user_message: useExistingUserMessage, + }); + + const response = await fetch(`/api/chat/send-message`, { method: "POST", headers: { "Content-Type": "application/json", }, - body: JSON.stringify({ - alternate_assistant_id: alternateAssistantId, - chat_session_id: chatSessionId, - parent_message_id: parentMessageId, - message: message, - prompt_id: promptId, - search_doc_ids: documentsAreSelected ? selectedDocumentIds : null, - file_descriptors: fileDescriptors, - retrieval_options: !documentsAreSelected - ? { - run_search: - promptId === null || - promptId === undefined || - queryOverride || - forceSearch - ? "always" - : "auto", - real_time: true, - filters: filters, - } - : null, - query_override: queryOverride, - prompt_override: systemPromptOverride - ? { - system_prompt: systemPromptOverride, - } - : null, - llm_override: - temperature || modelVersion - ? { - temperature, - model_provider: modelProvider, - model_version: modelVersion, - } - : null, - use_existing_user_message: useExistingUserMessage, - }), + body, + signal, }); - if (!sendMessageResponse.ok) { - const errorJson = await sendMessageResponse.json(); - const errorMsg = errorJson.message || errorJson.detail || ""; - throw Error(`Failed to send message - ${errorMsg}`); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); } - yield* handleStream(sendMessageResponse); + yield* handleSSEStream(response); } export async function nameChatSession(chatSessionId: number, message: string) { @@ -635,14 +637,14 @@ export async function uploadFilesForChat( } export async function useScrollonStream({ - isStreaming, + chatState, scrollableDivRef, scrollDist, endDivRef, distance, debounce, }: { - isStreaming: boolean; + chatState: ChatState; scrollableDivRef: RefObject; scrollDist: MutableRefObject; endDivRef: RefObject; @@ -656,7 +658,7 @@ export async function useScrollonStream({ const previousScroll = useRef(0); useEffect(() => { - if (isStreaming && scrollableDivRef && scrollableDivRef.current) { + if (chatState != "input" && scrollableDivRef && scrollableDivRef.current) { let newHeight: number = scrollableDivRef.current?.scrollTop!; const heightDifference = newHeight - previousScroll.current; previousScroll.current = newHeight; @@ -712,7 +714,7 @@ export async function useScrollonStream({ // scroll on end of stream if within distance useEffect(() => { - if (scrollableDivRef?.current && !isStreaming) { + if (scrollableDivRef?.current && chatState == "input") { if (scrollDist.current < distance - 50) { scrollableDivRef?.current?.scrollBy({ left: 0, @@ -721,5 +723,5 @@ export async function useScrollonStream({ }); } } - }, [isStreaming]); + }, [chatState]); } diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 945bfaf0f..2074ea5d4 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -255,7 +255,6 @@ export const AIMessage = ({ size="small" assistant={alternativeAssistant || currentPersona} /> -
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && @@ -623,6 +622,7 @@ export const HumanMessage = ({ onEdit, onMessageSelection, shared, + stopGenerating = () => null, }: { shared?: boolean; content: string; @@ -631,6 +631,7 @@ export const HumanMessage = ({ otherMessagesCanSwitchTo?: number[]; onEdit?: (editedContent: string) => void; onMessageSelection?: (messageId: number) => void; + stopGenerating?: () => void; }) => { const textareaRef = useRef(null); @@ -677,7 +678,6 @@ export const HumanMessage = ({
-
{isEditing ? ( @@ -857,16 +857,18 @@ export const HumanMessage = ({ + handlePrevious={() => { + stopGenerating(); onMessageSelection( otherMessagesCanSwitchTo[currentMessageInd - 1] - ) - } - handleNext={() => + ); + }} + handleNext={() => { + stopGenerating(); onMessageSelection( otherMessagesCanSwitchTo[currentMessageInd + 1] - ) - } + ); + }} />
)} diff --git a/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx b/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx index 8844d9ff4..47848e26b 100644 --- a/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx +++ b/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx @@ -33,6 +33,7 @@ export function ChatSessionDisplay({ isSelected, skipGradient, closeSidebar, + stopGenerating = () => null, showShareModal, showDeleteModal, }: { @@ -43,6 +44,7 @@ export function ChatSessionDisplay({ // if not set, the gradient will still be applied and cause weirdness skipGradient?: boolean; closeSidebar?: () => void; + stopGenerating?: () => void; showShareModal?: (chatSession: ChatSession) => void; showDeleteModal?: (chatSession: ChatSession) => void; }) { @@ -99,6 +101,7 @@ export function ChatSessionDisplay({ className="flex my-1 group relative" key={chatSession.id} onClick={() => { + stopGenerating(); if (settings?.isMobile && closeSidebar) { closeSidebar(); } diff --git a/web/src/app/chat/sessionSidebar/HistorySidebar.tsx b/web/src/app/chat/sessionSidebar/HistorySidebar.tsx index 28bb18531..8e08aaf37 100644 --- a/web/src/app/chat/sessionSidebar/HistorySidebar.tsx +++ b/web/src/app/chat/sessionSidebar/HistorySidebar.tsx @@ -40,6 +40,7 @@ interface HistorySidebarProps { reset?: () => void; showShareModal?: (chatSession: ChatSession) => void; showDeleteModal?: (chatSession: ChatSession) => void; + stopGenerating?: () => void; } export const HistorySidebar = forwardRef( @@ -54,6 +55,7 @@ export const HistorySidebar = forwardRef( openedFolders, toggleSidebar, removeToggle, + stopGenerating = () => null, showShareModal, showDeleteModal, }, @@ -179,6 +181,7 @@ export const HistorySidebar = forwardRef( )}
void; page: pageType; existingChats?: ChatSession[]; currentChatId?: number; @@ -124,6 +126,7 @@ export function PagesTab({ return (
{ + return ( + + + + ); +}; + export const LikeFeedbackIcon = ({ size = 16, className = defaultTailwindCSS, diff --git a/web/src/lib/search/streamingUtils.ts b/web/src/lib/search/streamingUtils.ts index 312d8d28b..44ac7aac1 100644 --- a/web/src/lib/search/streamingUtils.ts +++ b/web/src/lib/search/streamingUtils.ts @@ -1,3 +1,5 @@ +import { PacketType } from "@/app/chat/lib"; + type NonEmptyObject = { [k: string]: any }; const processSingleChunk = ( @@ -75,3 +77,33 @@ export async function* handleStream( yield await Promise.resolve(completedChunks); } } + +export async function* handleSSEStream( + streamingResponse: Response +): AsyncGenerator { + const reader = streamingResponse.body?.getReader(); + const decoder = new TextDecoder(); + + while (true) { + const rawChunk = await reader?.read(); + if (!rawChunk) { + throw new Error("Unable to process chunk"); + } + const { done, value } = rawChunk; + if (done) { + break; + } + + const chunk = decoder.decode(value); + const lines = chunk.split("\n").filter((line) => line.trim() !== ""); + + for (const line of lines) { + try { + const data = JSON.parse(line) as T; + yield data; + } catch (error) { + console.error("Error parsing SSE data:", error); + } + } + } +} diff --git a/web/tailwind-themes/tailwind.config.js b/web/tailwind-themes/tailwind.config.js index 90be59590..208143f35 100644 --- a/web/tailwind-themes/tailwind.config.js +++ b/web/tailwind-themes/tailwind.config.js @@ -90,6 +90,7 @@ module.exports = { "background-200": "#e5e5e5", // neutral-200 "background-300": "#d4d4d4", // neutral-300 "background-400": "#a3a3a3", // neutral-400 + "background-600": "#525252", // neutral-800 "background-500": "#737373", // neutral-400 "background-600": "#525252", // neutral-400 "background-700": "#404040", // neutral-400