From 9310a8edc203b0e5446ddb6e355c480852302c18 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 27 Jun 2024 16:40:23 -0700 Subject: [PATCH] Feature/scroll (#1694) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --------- Co-authored-by: “Pablo <“pablo@danswer.ai”> --- web/src/app/chat/ChatPage.tsx | 593 ++++++++++++++++++++++------------ web/src/app/chat/lib.tsx | 133 ++++++-- 2 files changed, 485 insertions(+), 241 deletions(-) diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 7ae62f6fd..f93574437 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -26,9 +26,9 @@ import { getCitedDocumentsFromMessage, getHumanAndAIMessageFromMessageNumber, getLastSuccessfulMessageId, - handleAutoScroll, handleChatFeedback, nameChatSession, + PacketType, personaIncludesRetrieval, processRawChatHistory, removeMessage, @@ -37,6 +37,7 @@ import { updateModelOverrideForChatSession, updateParentChildren, uploadFilesForChat, + useScrollonStream, } from "./lib"; import { useContext, useEffect, useRef, useState } from "react"; import { usePopup } from "@/components/admin/connectors/Popup"; @@ -50,7 +51,7 @@ import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoade import { FeedbackModal } from "./modal/FeedbackModal"; import { ShareChatSessionModal } from "./modal/ShareChatSessionModal"; import { ChatPersonaSelector } from "./ChatPersonaSelector"; -import { FiShare2 } from "react-icons/fi"; +import { FiArrowDown, FiShare2 } from "react-icons/fi"; import { ChatIntro } from "./ChatIntro"; import { AIMessage, HumanMessage } from "./message/Messages"; import { ThreeDots } from "react-loader-spinner"; @@ -77,6 +78,7 @@ import { TbLayoutSidebarRightExpand } from "react-icons/tb"; import { SIDEBAR_WIDTH_CONST } from "@/lib/constants"; import ResizableSection from "@/components/resizable/ResizableSection"; +import { Button } from "@tremor/react"; const MAX_INPUT_HEIGHT = 200; const TEMP_USER_MESSAGE_ID = -1; @@ -426,49 +428,111 @@ export function ChatPage({ availableDocumentSets, }); + const [currentFeedback, setCurrentFeedback] = useState< + [FeedbackType, number] | null + >(null); + + const [sharingModalVisible, setSharingModalVisible] = + useState(false); + // state for cancelling streaming const [isCancelled, setIsCancelled] = useState(false); - const isCancelledRef = useRef(isCancelled); + const [aboveHorizon, setAboveHorizon] = useState(false); + + const scrollableDivRef = useRef(null); + const lastMessageRef = useRef(null); + const inputRef = useRef(null); + const endDivRef = useRef(null); + const endPaddingRef = useRef(null); + + const previousHeight = useRef( + inputRef.current?.getBoundingClientRect().height! + ); + const scrollDist = useRef(0); + + const updateScrollTracking = () => { + const scrollDistance = + endDivRef?.current?.getBoundingClientRect()?.top! - + inputRef?.current?.getBoundingClientRect()?.top!; + scrollDist.current = scrollDistance; + setAboveHorizon(scrollDist.current > 500); + }; + + scrollableDivRef?.current?.addEventListener("scroll", updateScrollTracking); + + const handleInputResize = () => { + setTimeout(() => { + if (inputRef.current && lastMessageRef.current) { + let newHeight: number = + inputRef.current?.getBoundingClientRect().height!; + const heightDifference = newHeight - previousHeight.current; + if ( + previousHeight.current && + heightDifference != 0 && + endPaddingRef.current && + scrollableDivRef && + scrollableDivRef.current + ) { + endPaddingRef.current.style.transition = "height 0.3s ease-out"; + endPaddingRef.current.style.height = `${Math.max(newHeight - 50, 0)}px`; + + scrollableDivRef?.current.scrollBy({ + left: 0, + top: Math.max(heightDifference, 0), + behavior: "smooth", + }); + } + previousHeight.current = newHeight; + } + }, 100); + }; + + const clientScrollToBottom = (fast?: boolean) => { + setTimeout( + () => { + endDivRef.current?.scrollIntoView({ behavior: "smooth" }); + setHasPerformedInitialScroll(true); + }, + fast ? 50 : 500 + ); + }; + + const isCancelledRef = useRef(isCancelled); // scroll is cancelled useEffect(() => { isCancelledRef.current = isCancelled; }, [isCancelled]); - const [currentFeedback, setCurrentFeedback] = useState< - [FeedbackType, number] | null - >(null); - const [sharingModalVisible, setSharingModalVisible] = - useState(false); + const distance = 500; // distance that should "engage" the scroll + const debounce = 100; // time for debouncing - // auto scroll as message comes out - const scrollableDivRef = useRef(null); - const endDivRef = useRef(null); - useEffect(() => { - if (isStreaming || !message) { - handleAutoScroll(endDivRef, scrollableDivRef); - } + useScrollonStream({ + isStreaming, + scrollableDivRef, + scrollDist, + endDivRef, + distance, + debounce, }); - // scroll to bottom initially const [hasPerformedInitialScroll, setHasPerformedInitialScroll] = useState(false); + + // on new page useEffect(() => { - endDivRef.current?.scrollIntoView(); - setHasPerformedInitialScroll(true); - }, [isFetchingChatMessages]); + clientScrollToBottom(); + }, [chatSessionId]); // handle re-sizing of the text area const textAreaRef = useRef(null); useEffect(() => { - const textarea = textAreaRef.current; - if (textarea) { - textarea.style.height = "0px"; - textarea.style.height = `${Math.min( - textarea.scrollHeight, - MAX_INPUT_HEIGHT - )}px`; - } + handleInputResize(); }, [message]); + // tracks scrolling + useEffect(() => { + updateScrollTracking(); + }, [messageHistory]); + // used for resizing of the document sidebar const masterFlexboxRef = useRef(null); const [maxDocumentSidebarWidth, setMaxDocumentSidebarWidth] = useState< @@ -502,6 +566,53 @@ export function ChatPage({ documentSidebarInitialWidth = Math.min(700, maxDocumentSidebarWidth); } + class CurrentMessageFIFO { + private stack: PacketType[] = []; + isComplete: boolean = false; + error: string | null = null; + + push(packetBunch: PacketType) { + this.stack.push(packetBunch); + } + + nextPacket(): PacketType | undefined { + return this.stack.shift(); + } + + isEmpty(): boolean { + return this.stack.length === 0; + } + } + async function updateCurrentMessageFIFO( + stack: CurrentMessageFIFO, + params: any + ) { + try { + for await (const packetBunch of sendMessage(params)) { + for (const packet of packetBunch) { + stack.push(packet); + } + + if (isCancelledRef.current) { + setIsCancelled(false); + break; + } + } + } catch (error) { + stack.error = String(error); + } finally { + stack.isComplete = true; + } + } + + const resetInputBar = () => { + setMessage(""); + setCurrentMessageFiles([]); + if (endPaddingRef.current) { + endPaddingRef.current.style.height = `95px`; + } + }; + const onSubmit = async ({ messageIdToResend, messageOverride, @@ -515,6 +626,7 @@ export function ChatPage({ forceSearch?: boolean; isSeededChat?: boolean; } = {}) => { + clientScrollToBottom(); let currChatSessionId: number; let isNewSession = chatSessionId === null; const searchParamBasedChatSessionName = @@ -596,9 +708,7 @@ export function ChatPage({ if (!parentMessage && frozenCompleteMessageMap.size === 2) { parentMessage = frozenCompleteMessageMap.get(SYSTEM_MESSAGE_ID) || null; } - setMessage(""); - setCurrentMessageFiles([]); - + resetInputBar(); setIsStreaming(true); let answer = ""; let query: string | null = null; @@ -614,7 +724,9 @@ export function ChatPage({ try { const lastSuccessfulMessageId = getLastSuccessfulMessageId(currMessageHistory); - for await (const packetBunch of sendMessage({ + + const stack = new CurrentMessageFIFO(); + updateCurrentMessageFIFO(stack, { message: currMessage, fileDescriptors: currentMessageFiles, parentMessageId: lastSuccessfulMessageId, @@ -647,86 +759,108 @@ export function ChatPage({ systemPromptOverride: searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined, useExistingUserMessage: isSeededChat, - })) { - for (const packet of packetBunch) { - if (Object.hasOwn(packet, "answer_piece")) { - answer += (packet as AnswerPiecePacket).answer_piece; - } else if (Object.hasOwn(packet, "top_documents")) { - documents = (packet as DocumentsResponse).top_documents; - query = (packet as DocumentsResponse).rephrased_query; - retrievalType = RetrievalType.Search; - if (documents && documents.length > 0) { - // point to the latest message (we don't know the messageId yet, which is why - // we have to use -1) - setSelectedMessageForDocDisplay(TEMP_USER_MESSAGE_ID); - } - } else if (Object.hasOwn(packet, "file_ids")) { - aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map( - (fileId) => { - return { - id: fileId, - type: ChatFileType.IMAGE, - }; + }); + const updateFn = (messages: Message[]) => { + const replacementsMap = finalMessage + ? new Map([ + [messages[0].messageId, TEMP_USER_MESSAGE_ID], + [messages[1].messageId, TEMP_ASSISTANT_MESSAGE_ID], + ] as [number, number][]) + : null; + upsertToCompleteMessageMap({ + messages: messages, + replacementsMap: replacementsMap, + completeMessageMapOverride: frozenCompleteMessageMap, + }); + }; + const delay = (ms: number) => { + return new Promise((resolve) => setTimeout(resolve, ms)); + }; + + await delay(50); + while (!stack.isComplete || !stack.isEmpty()) { + await delay(2); + + if (!stack.isEmpty()) { + const packet = stack.nextPacket(); + + if (packet) { + if (Object.hasOwn(packet, "answer_piece")) { + answer += (packet as AnswerPiecePacket).answer_piece; + } else if (Object.hasOwn(packet, "top_documents")) { + documents = (packet as DocumentsResponse).top_documents; + query = (packet as DocumentsResponse).rephrased_query; + retrievalType = RetrievalType.Search; + if (documents && documents.length > 0) { + // point to the latest message (we don't know the messageId yet, which is why + // we have to use -1) + setSelectedMessageForDocDisplay(TEMP_USER_MESSAGE_ID); } - ); - } else if (Object.hasOwn(packet, "tool_name")) { - toolCalls = [ + } else if (Object.hasOwn(packet, "tool_name")) { + toolCalls = [ + { + tool_name: (packet as ToolCallMetadata).tool_name, + tool_args: (packet as ToolCallMetadata).tool_args, + tool_result: (packet as ToolCallMetadata).tool_result, + }, + ]; + } else if (Object.hasOwn(packet, "file_ids")) { + aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map( + (fileId) => { + return { + id: fileId, + type: ChatFileType.IMAGE, + }; + } + ); + } else if (Object.hasOwn(packet, "tool_name")) { + toolCalls = [ + { + tool_name: (packet as ToolCallMetadata).tool_name, + tool_args: (packet as ToolCallMetadata).tool_args, + tool_result: (packet as ToolCallMetadata).tool_result, + }, + ]; + } else if (Object.hasOwn(packet, "error")) { + error = (packet as StreamingError).error; + } else if (Object.hasOwn(packet, "message_id")) { + finalMessage = packet as BackendMessage; + } + + const newUserMessageId = + finalMessage?.parent_message || TEMP_USER_MESSAGE_ID; + const newAssistantMessageId = + finalMessage?.message_id || TEMP_ASSISTANT_MESSAGE_ID; + updateFn([ { - tool_name: (packet as ToolCallMetadata).tool_name, - tool_args: (packet as ToolCallMetadata).tool_args, - tool_result: (packet as ToolCallMetadata).tool_result, + messageId: newUserMessageId, + message: currMessage, + type: "user", + files: currentMessageFiles, + toolCalls: [], + parentMessageId: parentMessage?.messageId || null, + childrenMessageIds: [newAssistantMessageId], + latestChildMessageId: newAssistantMessageId, }, - ]; - } else if (Object.hasOwn(packet, "error")) { - error = (packet as StreamingError).error; - } else if (Object.hasOwn(packet, "message_id")) { - finalMessage = packet as BackendMessage; + { + messageId: newAssistantMessageId, + message: error || answer, + type: error ? "error" : "assistant", + retrievalType, + query: finalMessage?.rephrased_query || query, + documents: + finalMessage?.context_docs?.top_documents || documents, + citations: finalMessage?.citations || {}, + files: finalMessage?.files || aiMessageImages || [], + toolCalls: finalMessage?.tool_calls || toolCalls, + parentMessageId: newUserMessageId, + }, + ]); + } + if (isCancelledRef.current) { + setIsCancelled(false); + break; } - } - const updateFn = (messages: Message[]) => { - const replacementsMap = finalMessage - ? new Map([ - [messages[0].messageId, TEMP_USER_MESSAGE_ID], - [messages[1].messageId, TEMP_ASSISTANT_MESSAGE_ID], - ] as [number, number][]) - : null; - upsertToCompleteMessageMap({ - messages: messages, - replacementsMap: replacementsMap, - completeMessageMapOverride: frozenCompleteMessageMap, - }); - }; - const newUserMessageId = - finalMessage?.parent_message || TEMP_USER_MESSAGE_ID; - const newAssistantMessageId = - finalMessage?.message_id || TEMP_ASSISTANT_MESSAGE_ID; - updateFn([ - { - messageId: newUserMessageId, - message: currMessage, - type: "user", - files: currentMessageFiles, - toolCalls: [], - parentMessageId: parentMessage?.messageId || null, - childrenMessageIds: [newAssistantMessageId], - latestChildMessageId: newAssistantMessageId, - }, - { - messageId: newAssistantMessageId, - message: error || answer, - type: error ? "error" : "assistant", - retrievalType, - query: finalMessage?.rephrased_query || query, - documents: finalMessage?.context_docs?.top_documents || documents, - citations: finalMessage?.citations || {}, - files: finalMessage?.files || aiMessageImages || [], - toolCalls: finalMessage?.tool_calls || toolCalls, - parentMessageId: newUserMessageId, - }, - ]); - if (isCancelledRef.current) { - setIsCancelled(false); - break; } } } catch (e: any) { @@ -817,7 +951,6 @@ export function ChatPage({ if (persona && persona.id !== livePersona.id) { // remove uploaded files setCurrentMessageFiles([]); - setSelectedPersona(persona); textAreaRef.current?.focus(); router.push(buildChatUrl(searchParams, null, persona.id)); @@ -1051,7 +1184,7 @@ export function ChatPage({ ? completeMessageMap.get(message.parentMessageId) : null; return ( -
+
0) === true - } - handleFeedback={ - i === messageHistory.length - 1 && isStreaming - ? undefined - : (feedbackType) => - setCurrentFeedback([ - feedbackType, - message.messageId as number, - ]) - } - handleSearchQueryEdit={ - i === messageHistory.length - 1 && - !isStreaming - ? (newQuery) => { - if (!previousMessage) { - setPopup({ - type: "error", - message: - "Cannot edit query of first message - please refresh the page and try again.", - }); - return; - } + > + 0) === true + } + handleFeedback={ + i === messageHistory.length - 1 && + isStreaming + ? undefined + : (feedbackType) => + setCurrentFeedback([ + feedbackType, + message.messageId as number, + ]) + } + handleSearchQueryEdit={ + i === messageHistory.length - 1 && + !isStreaming + ? (newQuery) => { + if (!previousMessage) { + setPopup({ + type: "error", + message: + "Cannot edit query of first message - please refresh the page and try again.", + }); + return; + } - if ( - previousMessage.messageId === null - ) { - setPopup({ - type: "error", - message: - "Cannot edit query of a pending message - please wait a few seconds and try again.", + if ( + previousMessage.messageId === null + ) { + setPopup({ + type: "error", + message: + "Cannot edit query of a pending message - please wait a few seconds and try again.", + }); + return; + } + onSubmit({ + messageIdToResend: + previousMessage.messageId, + queryOverride: newQuery, }); - return; } - onSubmit({ - messageIdToResend: - previousMessage.messageId, - queryOverride: newQuery, - }); - } - : undefined - } - isCurrentlyShowingRetrieved={isShowingRetrieved} - handleShowRetrieved={(messageNumber) => { - if (isShowingRetrieved) { - setSelectedMessageForDocDisplay(null); - } else { - if (messageNumber !== null) { - setSelectedMessageForDocDisplay( - messageNumber - ); + : undefined + } + isCurrentlyShowingRetrieved={ + isShowingRetrieved + } + handleShowRetrieved={(messageNumber) => { + if (isShowingRetrieved) { + setSelectedMessageForDocDisplay(null); } else { - setSelectedMessageForDocDisplay(-1); + if (messageNumber !== null) { + setSelectedMessageForDocDisplay( + messageNumber + ); + } else { + setSelectedMessageForDocDisplay(-1); + } } - } - }} - handleForceSearch={() => { - if ( - previousMessage && - previousMessage.messageId - ) { - onSubmit({ - messageIdToResend: - previousMessage.messageId, - forceSearch: true, - }); - } else { - setPopup({ - type: "error", - message: - "Failed to force search - please refresh the page and try again.", - }); - } - }} - retrievalDisabled={retrievalDisabled} - /> + }} + handleForceSearch={() => { + if ( + previousMessage && + previousMessage.messageId + ) { + onSubmit({ + messageIdToResend: + previousMessage.messageId, + forceSearch: true, + }); + } else { + setPopup({ + type: "error", + message: + "Failed to force search - please refresh the page and try again.", + }); + } + }} + retrievalDisabled={retrievalDisabled} + /> +
); } else { return ( -
+
0 && messageHistory[messageHistory.length - 1].type === "user" && ( -
+
+
+
{livePersona && livePersona.starter_messages && @@ -1255,22 +1404,25 @@ export function ChatPage({ !isFetchingChatMessages && (
{livePersona.starter_messages.map( (starterMessage, i) => ( -
+
@@ -1289,8 +1441,21 @@ export function ChatPage({
-
-
+
+
+ {aboveHorizon && ( +
+ +
+ )} (sendMessageResponse); + yield* handleStream(sendMessageResponse); } export async function nameChatSession(chatSessionId: number, message: string) { @@ -250,24 +258,6 @@ export async function* simulateLLMResponse(input: string, delay: number = 30) { } } -export function handleAutoScroll( - endRef: RefObject, - scrollableRef: RefObject, - buffer: number = 300 -) { - // Auto-scrolls if the user is within `buffer` of the bottom of the scrollableRef - if (endRef && endRef.current && scrollableRef && scrollableRef.current) { - if ( - scrollableRef.current.scrollHeight - - scrollableRef.current.scrollTop - - buffer <= - scrollableRef.current.clientHeight - ) { - endRef.current.scrollIntoView({ behavior: "smooth" }); - } - } -} - export function getHumanAndAIMessageFromMessageNumber( messageHistory: Message[], messageId: number @@ -565,3 +555,92 @@ export async function uploadFilesForChat( return [responseJson.files as FileDescriptor[], null]; } + +export async function useScrollonStream({ + isStreaming, + scrollableDivRef, + scrollDist, + endDivRef, + distance, + debounce, +}: { + isStreaming: boolean; + scrollableDivRef: RefObject; + scrollDist: MutableRefObject; + endDivRef: RefObject; + distance: number; + debounce: number; +}) { + const preventScrollInterference = useRef(false); + const preventScroll = useRef(false); + const blockActionRef = useRef(false); + const previousScroll = useRef(0); + + useEffect(() => { + if (isStreaming && scrollableDivRef && scrollableDivRef.current) { + let newHeight: number = scrollableDivRef.current?.scrollTop!; + const heightDifference = newHeight - previousScroll.current; + previousScroll.current = newHeight; + + // Prevent streaming scroll + if (heightDifference < 0 && !preventScroll.current) { + scrollableDivRef.current.style.scrollBehavior = "auto"; + scrollableDivRef.current.scrollTop = scrollableDivRef.current.scrollTop; + scrollableDivRef.current.style.scrollBehavior = "smooth"; + preventScrollInterference.current = true; + preventScroll.current = true; + + setTimeout(() => { + preventScrollInterference.current = false; + }, 2000); + setTimeout(() => { + preventScroll.current = false; + }, 10000); + } + + // Ensure can scroll if scroll down + else if (!preventScrollInterference.current) { + preventScroll.current = false; + } + if ( + scrollDist.current < distance && + !blockActionRef.current && + !blockActionRef.current && + !preventScroll.current && + endDivRef && + endDivRef.current + ) { + // catch up if necessary! + const scrollAmount = scrollDist.current + 10000; + if (scrollDist.current > 140) { + endDivRef.current.scrollIntoView(); + } else { + blockActionRef.current = true; + + scrollableDivRef?.current?.scrollBy({ + left: 0, + top: Math.max(0, scrollAmount), + behavior: "smooth", + }); + + setTimeout(() => { + blockActionRef.current = false; + }, debounce); + } + } + } + }); + + // scroll on end of stream if within distance + useEffect(() => { + if (scrollableDivRef?.current && !isStreaming) { + if (scrollDist.current < distance) { + scrollableDivRef?.current?.scrollBy({ + left: 0, + top: Math.max(scrollDist.current + 600, 0), + behavior: "smooth", + }); + } + } + }, [isStreaming]); +}