diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 207825e92..af9ad5cb6 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -43,13 +43,20 @@ import { uploadFilesForChat, useScrollonStream, } from "./lib"; -import { useContext, useEffect, useRef, useState } from "react"; +import { + Dispatch, + SetStateAction, + useContext, + useEffect, + useRef, + useState, +} from "react"; import { usePopup } from "@/components/admin/connectors/Popup"; import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams"; import { useDocumentSelection } from "./useDocumentSelection"; import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks"; import { computeAvailableFilters } from "@/lib/filters"; -import { ChatState, FeedbackType } from "./types"; +import { ChatState, FeedbackType, RegenerationState } from "./types"; import { DocumentSidebar } from "./documentSidebar/DocumentSidebar"; import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader"; import { FeedbackModal } from "./modal/FeedbackModal"; @@ -84,6 +91,7 @@ import { SetDefaultModelModal } from "./modal/SetDefaultModelModal"; import { DeleteChatModal } from "./modal/DeleteChatModal"; import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown"; import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; + import { SEARCH_TOOL_NAME } from "./tools/constants"; import { useUser } from "@/components/user/UserProvider"; @@ -212,10 +220,18 @@ export function ChatPage({ } }, [liveAssistant]); - const stopGeneration = () => { - if (abortController) { - abortController.abort(); + const stopGenerating = () => { + const currentSession = currentSessionId(); + const controller = abortControllers.get(currentSession); + if (controller) { + controller.abort(); + setAbortControllers((prev) => { + const newControllers = new Map(prev); + newControllers.delete(currentSession); + return newControllers; + }); } + const lastMessage = messageHistory[messageHistory.length - 1]; if ( lastMessage && @@ -223,16 +239,16 @@ export function ChatPage({ lastMessage.toolCalls[0] && lastMessage.toolCalls[0].tool_result === undefined ) { - const newCompleteMessageMap = new Map(completeMessageDetail.messageMap); + const newCompleteMessageMap = new Map( + currentMessageMap(completeMessageDetail) + ); const updatedMessage = { ...lastMessage, toolCalls: [] }; newCompleteMessageMap.set(lastMessage.messageId, updatedMessage); - setCompleteMessageDetail({ - sessionId: completeMessageDetail.sessionId, - messageMap: newCompleteMessageMap, - }); + updateCompleteMessageDetail(currentSession, newCompleteMessageMap); } - }; + updateChatState("input", currentSession); + }; // this is for "@"ing assistants // this is used to track which assistant is being used to generate the current message @@ -308,10 +324,7 @@ export function ChatPage({ } else { setSelectedAssistant(undefined); } - setCompleteMessageDetail({ - sessionId: null, - messageMap: new Map(), - }); + updateCompleteMessageDetail(null, new Map()); setChatSessionSharedStatus(ChatSessionSharedStatus.Private); // if we're supposed to submit on initial load, then do that here @@ -341,13 +354,11 @@ export function ChatPage({ // This corresponds to a "renaming" of chat, which occurs after first message // stream if ( - messageHistory[messageHistory.length - 1]?.type !== "error" || - loadedSessionId != null + (messageHistory[messageHistory.length - 1]?.type !== "error" || + loadedSessionId != null) && + !currentChatAnswering() ) { - setCompleteMessageDetail({ - sessionId: chatSession.chat_session_id, - messageMap: newMessageMap, - }); + updateCompleteMessageDetail(chatSession.chat_session_id, newMessageMap); const latestMessageId = newMessageHistory[newMessageHistory.length - 1]?.messageId; @@ -394,10 +405,31 @@ export function ChatPage({ searchParams.get(SEARCH_PARAM_NAMES.USER_MESSAGE) || "" ); - const [completeMessageDetail, setCompleteMessageDetail] = useState<{ - sessionId: number | null; - messageMap: Map; - }>({ sessionId: null, messageMap: new Map() }); + const [completeMessageDetail, setCompleteMessageDetail] = useState< + Map> + >(new Map()); + + const updateCompleteMessageDetail = ( + sessionId: number | null, + messageMap: Map + ) => { + setCompleteMessageDetail((prevState) => { + const newState = new Map(prevState); + newState.set(sessionId, messageMap); + return newState; + }); + }; + + const currentMessageMap = ( + messageDetail: Map> + ) => { + return ( + messageDetail.get(chatSessionIdRef.current) || new Map() + ); + }; + const currentSessionId = (): number => { + return chatSessionIdRef.current!; + }; const upsertToCompleteMessageMap = ({ messages, @@ -416,7 +448,7 @@ export function ChatPage({ }) => { // deep copy const frozenCompleteMessageMap = - completeMessageMapOverride || completeMessageDetail.messageMap; + completeMessageMapOverride || currentMessageMap(completeMessageDetail); const newCompleteMessageMap = structuredClone(frozenCompleteMessageMap); if (newCompleteMessageMap.size === 0) { @@ -466,30 +498,134 @@ export function ChatPage({ )!.latestChildMessageId = messages[0].messageId; } } + const newCompleteMessageDetail = { - sessionId: chatSessionId || completeMessageDetail.sessionId, + sessionId: chatSessionId || currentSessionId(), messageMap: newCompleteMessageMap, }; - setCompleteMessageDetail(newCompleteMessageDetail); + + updateCompleteMessageDetail( + chatSessionId || currentSessionId(), + newCompleteMessageMap + ); return newCompleteMessageDetail; }; const messageHistory = buildLatestMessageChain( - completeMessageDetail.messageMap + currentMessageMap(completeMessageDetail) ); const [submittedMessage, setSubmittedMessage] = useState(""); - const [chatState, setChatState] = useState("input"); - interface RegenerationState { - regenerating: boolean; - finalMessageIndex: number; - } - const [regenerationState, setRegenerationState] = - useState(null); + const [chatState, setChatState] = useState>( + new Map([[chatSessionIdRef.current, "input"]]) + ); - const [abortController, setAbortController] = - useState(null); + const [scrollHeight, setScrollHeight] = useState>( + new Map([[chatSessionIdRef.current, 0]]) + ); + const currentScrollHeight = () => { + return scrollHeight.get(currentSessionId()); + }; + + const retrieveCurrentScrollHeight = (): number | null => { + return scrollHeight.get(currentSessionId()) || null; + }; + + const [regenerationState, setRegenerationState] = useState< + Map + >(new Map([[null, null]])); + + const [abortControllers, setAbortControllers] = useState< + Map + >(new Map()); + + // Updates "null" session values to new session id for + // regeneration, chat, and abort controller state, messagehistory + const updateStatesWithNewSessionId = (newSessionId: number) => { + const updateState = ( + setState: Dispatch>>, + defaultValue?: any + ) => { + setState((prevState) => { + const newState = new Map(prevState); + const existingState = newState.get(null); + if (existingState !== undefined) { + newState.set(newSessionId, existingState); + newState.delete(null); + } else if (defaultValue !== undefined) { + newState.set(newSessionId, defaultValue); + } + return newState; + }); + }; + + updateState(setRegenerationState); + updateState(setChatState); + updateState(setAbortControllers); + + // Update completeMessageDetail + setCompleteMessageDetail((prevState) => { + const newState = new Map(prevState); + const existingMessages = newState.get(null); + if (existingMessages) { + newState.set(newSessionId, existingMessages); + newState.delete(null); + } + return newState; + }); + + // Update chatSessionIdRef + chatSessionIdRef.current = newSessionId; + }; + + const updateChatState = (newState: ChatState, sessionId?: number | null) => { + setChatState((prevState) => { + const newChatState = new Map(prevState); + newChatState.set( + sessionId !== undefined ? sessionId : currentSessionId(), + newState + ); + return newChatState; + }); + }; + + const currentChatState = (): ChatState => { + return chatState.get(currentSessionId()) || "input"; + }; + + const currentChatAnswering = () => { + return ( + currentChatState() == "toolBuilding" || + currentChatState() == "streaming" || + currentChatState() == "loading" + ); + }; + + const updateRegenerationState = ( + newState: RegenerationState | null, + sessionId?: number | null + ) => { + setRegenerationState((prevState) => { + const newRegenerationState = new Map(prevState); + newRegenerationState.set( + sessionId !== undefined ? sessionId : currentSessionId(), + newState + ); + return newRegenerationState; + }); + }; + + const resetRegenerationState = (sessionId?: number | null) => { + updateRegenerationState(null, sessionId); + }; + + const currentRegenerationState = (): RegenerationState | null => { + return regenerationState.get(currentSessionId()) || null; + }; + + const currentSessionChatState = currentChatState(); + const currentSessionRegenerationState = currentRegenerationState(); // uploaded files const [currentMessageFiles, setCurrentMessageFiles] = useState< @@ -746,7 +882,9 @@ export function ChatPage({ modelOverRide?: LlmOverride; regenerationRequest?: RegenerationRequest | null; } = {}) => { - if (chatState != "input") { + let frozenSessionId = currentSessionId(); + + if (currentChatState() != "input") { setPopup({ message: "Please wait for the response to complete", type: "error", @@ -754,16 +892,13 @@ export function ChatPage({ return; } - setRegenerationState( + updateRegenerationState( regenerationRequest ? { regenerating: true, finalMessageIndex: messageIdToResend || 0 } : null ); - setChatState("loading"); - - const controller = new AbortController(); - setAbortController(controller); + updateChatState("loading"); setAlternativeGeneratingAssistant(alternativeAssistantOverride); clientScrollToBottom(); @@ -780,13 +915,21 @@ export function ChatPage({ } else { currChatSessionId = chatSessionIdRef.current as number; } - chatSessionIdRef.current = currChatSessionId; + frozenSessionId = currChatSessionId; + + updateStatesWithNewSessionId(currChatSessionId); + + const controller = new AbortController(); + + setAbortControllers((prev) => + new Map(prev).set(currChatSessionId, controller) + ); const messageToResend = messageHistory.find( (message) => message.messageId === messageIdToResend ); - const messageMap = completeMessageDetail.messageMap; + const messageMap = currentMessageMap(completeMessageDetail); const messageToResendParent = messageToResend?.parentMessageId !== null && messageToResend?.parentMessageId !== undefined @@ -802,8 +945,8 @@ export function ChatPage({ "Failed to re-send message - please refresh the page and try again.", type: "error", }); - setRegenerationState(null); - setChatState("input"); + resetRegenerationState(currentSessionId()); + updateChatState("input", frozenSessionId); return; } let currMessage = messageToResend ? messageToResend.message : message; @@ -851,11 +994,12 @@ export function ChatPage({ user_message_id: number; assistant_message_id: number; frozenMessageMap: Map; - frozenSessionId: number | null; } = null; try { - const mapKeys = Array.from(completeMessageDetail.messageMap.keys()); + const mapKeys = Array.from( + currentMessageMap(completeMessageDetail).keys() + ); const systemMessage = Math.min(...mapKeys); const lastSuccessfulMessageId = @@ -956,32 +1100,31 @@ export function ChatPage({ }); } - const { - messageMap: currentFrozenMessageMap, - sessionId: currentFrozenSessionId, - } = upsertToCompleteMessageMap({ - messages: messageUpdates, - chatSessionId: currChatSessionId, - }); + const { messageMap: currentFrozenMessageMap } = + upsertToCompleteMessageMap({ + messages: messageUpdates, + chatSessionId: currChatSessionId, + }); const frozenMessageMap = currentFrozenMessageMap; - const frozenSessionId = currentFrozenSessionId; initialFetchDetails = { frozenMessageMap, - frozenSessionId, assistant_message_id, user_message_id, }; - setRegenerationState(null); + resetRegenerationState(); } else { - const { user_message_id, frozenMessageMap, frozenSessionId } = - initialFetchDetails; - setChatState((chatState) => { - if (chatState == "loading") { - return "streaming"; + const { user_message_id, frozenMessageMap } = initialFetchDetails; + + setChatState((prevState) => { + if (prevState.get(chatSessionIdRef.current!) === "loading") { + return new Map(prevState).set( + chatSessionIdRef.current!, + "streaming" + ); } - return chatState; + return prevState; }); if (Object.hasOwn(packet, "answer_piece")) { @@ -1006,9 +1149,9 @@ export function ChatPage({ !toolCalls[0].tool_result || toolCalls[0].tool_result == undefined ) { - setChatState("toolBuilding"); + updateChatState("toolBuilding", frozenSessionId); } else { - setChatState("streaming"); + updateChatState("streaming", frozenSessionId); } // This will be consolidated in upcoming tool calls udpate, @@ -1123,11 +1266,12 @@ export function ChatPage({ initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID, }, ], - completeMessageMapOverride: completeMessageDetail.messageMap, + completeMessageMapOverride: currentMessageMap(completeMessageDetail), }); } - setRegenerationState(null); - setChatState("input"); + resetRegenerationState(currentSessionId()); + + updateChatState("input"); if (isNewSession) { if (finalMessage) { setSelectedMessageForDocDisplay(finalMessage.message_id); @@ -1194,8 +1338,8 @@ export function ChatPage({ const onAssistantChange = (assistant: Persona | null) => { if (assistant && assistant.id !== liveAssistant.id) { // Abort the ongoing stream if it exists - if (chatState != "input") { - stopGeneration(); + if (currentSessionChatState != "input") { + stopGenerating(); resetInputBar(); } @@ -1308,7 +1452,7 @@ export function ChatPage({ }); useScrollonStream({ - chatState, + chatState: currentSessionChatState, scrollableDivRef, scrollDist, endDivRef, @@ -1496,7 +1640,7 @@ export function ChatPage({
setMessage("")} page="chat" ref={innerSidebarElementRef} @@ -1570,7 +1714,7 @@ export function ChatPage({ {messageHistory.length === 0 && !isFetchingChatMessages && - chatState == "input" && ( + currentSessionChatState == "input" && ( {messageHistory.map((message, i) => { - const messageMap = - completeMessageDetail.messageMap; - const messageReactComponentKey = `${i}-${completeMessageDetail.sessionId}`; + const messageMap = currentMessageMap( + completeMessageDetail + ); + const messageReactComponentKey = `${i}-${currentSessionId()}`; const parentMessage = message.parentMessageId ? messageMap.get(message.parentMessageId) : null; if ( - regenerationState && - regenerationState.regenerating && - message.messageId > - regenerationState.finalMessageIndex + currentSessionRegenerationState?.regenerating && + message.messageId >= + currentSessionRegenerationState?.finalMessageIndex! ) { return <>; } @@ -1603,7 +1747,7 @@ export function ChatPage({ return (
- regenerationState.finalMessageIndex - 1 + currentSessionRegenerationState?.regenerating && + currentSessionChatState == "loading" && + message.messageId == messageHistory.length - 1 ) { return <>; } @@ -1703,11 +1844,12 @@ export function ChatPage({ newCompleteMessageMap.get( message.parentMessageId! )!.latestChildMessageId = messageId; - setCompleteMessageDetail({ - sessionId: - completeMessageDetail.sessionId, - messageMap: newCompleteMessageMap, - }); + + updateCompleteMessageDetail( + currentSessionId(), + newCompleteMessageMap + ); + setSelectedMessageForDocDisplay( messageId ); @@ -1742,7 +1884,10 @@ export function ChatPage({ } isComplete={ i !== messageHistory.length - 1 || - chatState == "input" + (currentSessionChatState != + "streaming" && + currentSessionChatState != + "toolBuilding") } hasDocs={ (message.documents && @@ -1750,7 +1895,7 @@ export function ChatPage({ } handleFeedback={ i === messageHistory.length - 1 && - chatState != "input" + currentSessionChatState != "input" ? undefined : (feedbackType) => setCurrentFeedback([ @@ -1760,7 +1905,7 @@ export function ChatPage({ } handleSearchQueryEdit={ i === messageHistory.length - 1 && - chatState == "input" + currentSessionChatState == "input" ? (newQuery) => { if (!previousMessage) { setPopup({ @@ -1770,7 +1915,6 @@ export function ChatPage({ }); return; } - if ( previousMessage.messageId === null @@ -1867,21 +2011,24 @@ export function ChatPage({ ); } })} - {chatState == "loading" && - !regenerationState?.regenerating && + + {currentSessionChatState == "loading" && + !currentSessionRegenerationState?.regenerating && messageHistory[messageHistory.length - 1]?.type != "user" && ( )} - {chatState == "loading" && ( + {currentSessionChatState == "loading" && (
)} - setSettingsToggled(true)} inputPrompts={userInputPrompts} showDocs={() => setDocumentSelection(true)} diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index 58a155284..3f9e37da3 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -32,9 +32,8 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon"; import { Tooltip } from "@/components/tooltip/Tooltip"; import { Hoverable } from "@/components/Hoverable"; import { SettingsContext } from "@/components/settings/SettingsProvider"; -import { StopCircle } from "@phosphor-icons/react/dist/ssr"; -import { Square } from "@phosphor-icons/react"; import { ChatState } from "../types"; + const MAX_INPUT_HEIGHT = 200; export function ChatInputBar({ diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 4307879cf..8a2b4b354 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -231,14 +231,9 @@ export const AIMessage = ({ } return content; }; - content = trimIncompleteCodeSection(content); } - const danswerSearchToolEnabledForPersona = currentPersona.tools.some( - (tool) => tool.in_code_tool_id === SEARCH_TOOL_NAME - ); - let filteredDocs: FilteredDanswerDocument[] = []; if (docs) { @@ -760,24 +755,24 @@ export const HumanMessage = ({