diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index a08425bd9..e6ca00d2a 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -101,6 +101,8 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; import { SEARCH_TOOL_NAME } from "./tools/constants"; import { useUser } from "@/components/user/UserProvider"; import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; +import { Button } from "@tremor/react"; +import dynamic from "next/dynamic"; const TEMP_USER_MESSAGE_ID = -1; const TEMP_ASSISTANT_MESSAGE_ID = -2; @@ -133,7 +135,6 @@ export function ChatPage({ } = useChatContext(); const [showApiKeyModal, setShowApiKeyModal] = useState(true); - const { user, refreshUser, isLoadingUser } = useUser(); // chat session @@ -248,13 +249,13 @@ export function ChatPage({ if ( lastMessage && lastMessage.type === "assistant" && - lastMessage.toolCalls[0] && - lastMessage.toolCalls[0].tool_result === undefined + lastMessage.toolCall && + lastMessage.toolCall.tool_result === undefined ) { const newCompleteMessageMap = new Map( currentMessageMap(completeMessageDetail) ); - const updatedMessage = { ...lastMessage, toolCalls: [] }; + const updatedMessage = { ...lastMessage, toolCall: null }; newCompleteMessageMap.set(lastMessage.messageId, updatedMessage); updateCompleteMessageDetail(currentSession, newCompleteMessageMap); } @@ -483,7 +484,7 @@ export function ChatPage({ message: "", type: "system", files: [], - toolCalls: [], + toolCall: null, parentMessageId: null, childrenMessageIds: [firstMessageId], latestChildMessageId: firstMessageId, @@ -510,6 +511,7 @@ export function ChatPage({ } newCompleteMessageMap.set(message.messageId, message); }); + // if specified, make these new message the latest of the current message chain if (makeLatestChildMessage) { const currentMessageChain = buildLatestMessageChain( @@ -1044,8 +1046,6 @@ export function ChatPage({ resetInputBar(); let messageUpdates: Message[] | null = null; - let answer = ""; - let stopReason: StreamStopReason | null = null; let query: string | null = null; let retrievalType: RetrievalType = @@ -1058,12 +1058,14 @@ export function ChatPage({ let stackTrace: string | null = null; let finalMessage: BackendMessage | null = null; - let toolCalls: ToolCallMetadata[] = []; + let toolCall: ToolCallMetadata | null = null; let initialFetchDetails: null | { user_message_id: number; assistant_message_id: number; frozenMessageMap: Map; + initialDynamicParentMessage: Message; + initialDynamicAssistantMessage: Message; } = null; try { @@ -1122,7 +1124,16 @@ export function ChatPage({ return new Promise((resolve) => setTimeout(resolve, ms)); }; + let updateFn = (messages: Message[]) => { + return upsertToCompleteMessageMap({ + messages: messages, + chatSessionId: currChatSessionId, + }); + }; + await delay(50); + let dynamicParentMessage: Message | null = null; + let dynamicAssistantMessage: Message | null = null; while (!stack.isComplete || !stack.isEmpty()) { await delay(0.5); @@ -1156,12 +1167,12 @@ export function ChatPage({ messageUpdates = [ { messageId: regenerationRequest - ? regenerationRequest?.parentMessage?.messageId! + ? regenerationRequest?.messageId : user_message_id, message: currMessage, type: "user", files: currentMessageFiles, - toolCalls: [], + toolCall: null, parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID, }, ]; @@ -1176,22 +1187,109 @@ export function ChatPage({ }); } - const { messageMap: currentFrozenMessageMap } = + let { messageMap: currentFrozenMessageMap } = upsertToCompleteMessageMap({ messages: messageUpdates, chatSessionId: currChatSessionId, }); - const frozenMessageMap = currentFrozenMessageMap; + let frozenMessageMap = currentFrozenMessageMap; + regenerationRequest?.parentMessage; + let initialDynamicParentMessage: Message = regenerationRequest + ? regenerationRequest?.parentMessage + : { + messageId: user_message_id!, + message: "", + type: "user", + files: currentMessageFiles, + toolCall: null, + parentMessageId: error ? null : lastSuccessfulMessageId, + childrenMessageIds: [assistant_message_id!], + latestChildMessageId: -100, + }; + + let initialDynamicAssistantMessage: Message = { + messageId: assistant_message_id!, + message: "", + type: "assistant", + retrievalType, + query: finalMessage?.rephrased_query || query, + documents: finalMessage?.context_docs?.top_documents || documents, + citations: finalMessage?.citations || {}, + files: finalMessage?.files || aiMessageImages || [], + toolCall: finalMessage?.tool_call || toolCall, + parentMessageId: regenerationRequest + ? regenerationRequest?.parentMessage?.messageId! + : user_message_id, + alternateAssistantID: alternativeAssistant?.id, + stackTrace: stackTrace, + overridden_model: finalMessage?.overridden_model, + stopReason: stopReason, + }; + initialFetchDetails = { frozenMessageMap, assistant_message_id, user_message_id, + initialDynamicParentMessage, + initialDynamicAssistantMessage, }; resetRegenerationState(); } else { - const { user_message_id, frozenMessageMap } = initialFetchDetails; + let { + initialDynamicParentMessage, + initialDynamicAssistantMessage, + user_message_id, + frozenMessageMap, + } = initialFetchDetails; + + if ( + dynamicParentMessage === null && + dynamicAssistantMessage === null + ) { + dynamicParentMessage = initialDynamicParentMessage; + dynamicAssistantMessage = initialDynamicAssistantMessage; + + dynamicParentMessage.message = currMessage; + } + + if (!dynamicAssistantMessage || !dynamicParentMessage) { + return; + } + + if (Object.hasOwn(packet, "user_message_id")) { + let newParentMessageId = dynamicParentMessage.messageId; + const messageResponseIDInfo = packet as MessageResponseIDInfo; + + for (const key in dynamicAssistantMessage) { + (dynamicParentMessage as Record)[key] = ( + dynamicAssistantMessage as Record + )[key]; + } + + dynamicParentMessage.parentMessageId = newParentMessageId; + dynamicParentMessage.latestChildMessageId = + messageResponseIDInfo.reserved_assistant_message_id; + dynamicParentMessage.childrenMessageIds = [ + messageResponseIDInfo.reserved_assistant_message_id, + ]; + + dynamicParentMessage.messageId = + messageResponseIDInfo.user_message_id!; + dynamicAssistantMessage = { + messageId: messageResponseIDInfo.reserved_assistant_message_id, + type: "assistant", + message: "", + documents: [], + retrievalType: undefined, + toolCall: null, + files: [], + parentMessageId: dynamicParentMessage.messageId, + childrenMessageIds: [], + latestChildMessageId: null, + }; + } setChatState((prevState) => { if (prevState.get(chatSessionIdRef.current!) === "loading") { @@ -1204,37 +1302,37 @@ export function ChatPage({ }); if (Object.hasOwn(packet, "answer_piece")) { - answer += (packet as AnswerPiecePacket).answer_piece; + dynamicAssistantMessage.message += ( + packet as AnswerPiecePacket + ).answer_piece; } else if (Object.hasOwn(packet, "top_documents")) { - documents = (packet as DocumentsResponse).top_documents; + dynamicAssistantMessage.documents = ( + packet as DocumentsResponse + ).top_documents; + dynamicAssistantMessage.retrievalType = RetrievalType.Search; retrievalType = RetrievalType.Search; - if (documents && documents.length > 0) { - // point to the latest message (we don't know the messageId yet, which is why - // we have to use -1) - setSelectedMessageForDocDisplay(user_message_id); - } } else if (Object.hasOwn(packet, "tool_name")) { - toolCalls = [ - { - tool_name: (packet as ToolCallMetadata).tool_name, - tool_args: (packet as ToolCallMetadata).tool_args, - tool_result: (packet as ToolCallMetadata).tool_result, - }, - ]; + dynamicAssistantMessage.toolCall = { + tool_name: (packet as ToolCallMetadata).tool_name, + tool_args: (packet as ToolCallMetadata).tool_args, + tool_result: (packet as ToolCallMetadata).tool_result, + }; if ( - !toolCalls[0].tool_result || - toolCalls[0].tool_result == undefined + dynamicAssistantMessage.toolCall.tool_name === SEARCH_TOOL_NAME + ) { + dynamicAssistantMessage.query = + dynamicAssistantMessage.toolCall.tool_args.query; + } + + if ( + !dynamicAssistantMessage.toolCall || + !dynamicAssistantMessage.toolCall.tool_result || + dynamicAssistantMessage.toolCall.tool_result == undefined ) { updateChatState("toolBuilding", frozenSessionId); } else { updateChatState("streaming", frozenSessionId); } - - // This will be consolidated in upcoming tool calls udpate, - // but for now, we need to set query as early as possible - if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) { - query = toolCalls[0].tool_args["query"]; - } } else if (Object.hasOwn(packet, "file_ids")) { aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map( (fileId) => { @@ -1244,82 +1342,54 @@ export function ChatPage({ }; } ); + dynamicAssistantMessage.files = aiMessageImages; } else if (Object.hasOwn(packet, "error")) { error = (packet as StreamingError).error; - stackTrace = (packet as StreamingError).stack_trace; + dynamicAssistantMessage.stackTrace = ( + packet as StreamingError + ).stack_trace; } else if (Object.hasOwn(packet, "message_id")) { finalMessage = packet as BackendMessage; + dynamicAssistantMessage = { + ...dynamicAssistantMessage, + ...finalMessage, + }; } else if (Object.hasOwn(packet, "stop_reason")) { const stop_reason = (packet as StreamStopInfo).stop_reason; + if (stop_reason === StreamStopReason.CONTEXT_LENGTH) { updateCanContinue(true, frozenSessionId); } } + if (!Object.hasOwn(packet, "stop_reason")) { + updateFn = (messages: Message[]) => { + const replacementsMap = regenerationRequest + ? new Map([ + [ + regenerationRequest?.parentMessage?.messageId, + regenerationRequest?.parentMessage?.messageId, + ], + [ + dynamicParentMessage?.messageId, + dynamicAssistantMessage?.messageId, + ], + ] as [number, number][]) + : null; - // on initial message send, we insert a dummy system message - // set this as the parent here if no parent is set - parentMessage = - parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!; + return upsertToCompleteMessageMap({ + messages: messages, + replacementsMap: replacementsMap, + completeMessageMapOverride: frozenMessageMap, + chatSessionId: frozenSessionId!, + }); + }; - const updateFn = (messages: Message[]) => { - const replacementsMap = regenerationRequest - ? new Map([ - [ - regenerationRequest?.parentMessage?.messageId, - regenerationRequest?.parentMessage?.messageId, - ], - [ - regenerationRequest?.messageId, - initialFetchDetails?.assistant_message_id, - ], - ] as [number, number][]) - : null; - - return upsertToCompleteMessageMap({ - messages: messages, - replacementsMap: replacementsMap, - completeMessageMapOverride: frozenMessageMap, - chatSessionId: frozenSessionId!, - }); - }; - - updateFn([ - { - messageId: regenerationRequest - ? regenerationRequest?.parentMessage?.messageId! - : initialFetchDetails.user_message_id!, - message: currMessage, - type: "user", - files: currentMessageFiles, - toolCalls: [], - parentMessageId: error ? null : lastSuccessfulMessageId, - childrenMessageIds: [ - ...(regenerationRequest?.parentMessage?.childrenMessageIds || - []), - initialFetchDetails.assistant_message_id!, - ], - latestChildMessageId: initialFetchDetails.assistant_message_id, - }, - { - messageId: initialFetchDetails.assistant_message_id!, - message: error || answer, - type: error ? "error" : "assistant", - retrievalType, - query: finalMessage?.rephrased_query || query, - documents: - finalMessage?.context_docs?.top_documents || documents, - citations: finalMessage?.citations || {}, - files: finalMessage?.files || aiMessageImages || [], - toolCalls: finalMessage?.tool_calls || toolCalls, - parentMessageId: regenerationRequest - ? regenerationRequest?.parentMessage?.messageId! - : initialFetchDetails.user_message_id, - alternateAssistantID: alternativeAssistant?.id, - stackTrace: stackTrace, - overridden_model: finalMessage?.overridden_model, - stopReason: stopReason, - }, - ]); + let { messageMap } = updateFn([ + dynamicParentMessage, + dynamicAssistantMessage, + ]); + frozenMessageMap = messageMap; + } } } } @@ -1333,7 +1403,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, - toolCalls: [], + toolCall: null, parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID, }, { @@ -1343,7 +1413,7 @@ export function ChatPage({ message: errorMsg, type: "error", files: aiMessageImages || [], - toolCalls: [], + toolCall: null, parentMessageId: initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID, }, @@ -1962,9 +2032,8 @@ export function ChatPage({ completeMessageDetail ); const messageReactComponentKey = `${i}-${currentSessionId()}`; - const parentMessage = message.parentMessageId - ? messageMap.get(message.parentMessageId) - : null; + const parentMessage = + i > 1 ? messageHistory[i - 1] : null; if (message.type === "user") { if ( (currentSessionChatState == "loading" && @@ -2055,6 +2124,25 @@ export function ChatPage({ ) { return <>; } + const mostRecentNonAIParent = messageHistory + .slice(0, i) + .reverse() + .find((msg) => msg.type !== "assistant"); + + const hasChildMessage = + message.latestChildMessageId !== null && + message.latestChildMessageId !== undefined; + const childMessage = hasChildMessage + ? messageMap.get( + message.latestChildMessageId! + ) + : null; + + const hasParentAI = + parentMessage?.type == "assistant"; + const hasChildAI = + childMessage?.type == "assistant"; + return (
{ if (!previousMessage) { @@ -2231,7 +2318,6 @@ export function ChatPage({ {message.message} @@ -2279,7 +2365,6 @@ export function ChatPage({ alternativeAssistant } messageId={null} - personaName={liveAssistant.name} content={
{loadingError} diff --git a/web/src/components/chat_display/CsvDisplay.tsx b/web/src/components/chat_display/CsvDisplay.tsx index dcc0a5211..8e1fd6e88 100644 --- a/web/src/components/chat_display/CsvDisplay.tsx +++ b/web/src/components/chat_display/CsvDisplay.tsx @@ -86,7 +86,7 @@ export const CsvSection = ({ const fileId = csvFileDescriptor.id; useEffect(() => { fetchCSV(fileId); - }, [fileId]); + }, []); const fetchCSV = async (id: string) => { setIsLoading(true); @@ -124,6 +124,7 @@ export const CsvSection = ({ setFadeIn(false); } }, [isLoading]); + console.log("rerendering"); const downloadFile = () => { if (!fileId) return; @@ -150,7 +151,7 @@ export const CsvSection = ({ return (
@@ -161,7 +162,7 @@ export const CsvSection = ({ @@ -189,17 +190,19 @@ export const CsvSection = ({ {isLoading ? ( -
-
-
-
-
-
+
+
+
+
+
+
-
-
+
+
-
+