From 0ee1bb24003598b9d59f4b0dc2550276d50e4b0d Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 14 May 2024 17:26:59 -0700 Subject: [PATCH] Chat history editing --- backend/danswer/tools/utils.py | 2 +- web/src/app/chat/ChatPage.tsx | 230 +++++++++++++--- web/src/app/chat/interfaces.ts | 8 +- web/src/app/chat/lib.tsx | 180 ++++++++++--- web/src/app/chat/message/Messages.tsx | 255 +++++++++++++++--- web/src/app/chat/message/SearchSummary.tsx | 28 +- .../app/chat/modal/ShareChatSessionModal.tsx | 1 - .../shared/[chatId]/SharedChatDisplay.tsx | 10 +- web/src/components/CopyButton.tsx | 7 +- web/src/components/Hoverable.tsx | 18 ++ web/tailwind.config.js | 3 +- 11 files changed, 612 insertions(+), 130 deletions(-) create mode 100644 web/src/components/Hoverable.tsx diff --git a/backend/danswer/tools/utils.py b/backend/danswer/tools/utils.py index f4e69a921664..831021cdab3a 100644 --- a/backend/danswer/tools/utils.py +++ b/backend/danswer/tools/utils.py @@ -7,7 +7,7 @@ from danswer.llm.utils import get_default_llm_tokenizer from danswer.tools.tool import Tool -OPEN_AI_TOOL_CALLING_MODELS = {"gpt-3.5-turbo", "gpt-4-turbo", "gpt-4", "gpt-4o"} +OPEN_AI_TOOL_CALLING_MODELS = {"gpt-3.5-turbo", "gpt-4-turbo", "gpt-4"} def explicit_tool_calling_supported(model_provider: str, model_name: str) -> bool: diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 9eb61da778c4..b92038bedcc8 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -23,6 +23,7 @@ import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; import { Settings } from "../admin/settings/interfaces"; import { buildChatUrl, + buildLatestMessageChain, createChatSession, getCitedDocumentsFromMessage, getHumanAndAIMessageFromMessageNumber, @@ -32,7 +33,10 @@ import { nameChatSession, personaIncludesRetrieval, processRawChatHistory, + removeMessage, sendMessage, + setMessageAsLatest, + updateParentChildren, uploadFilesForChat, } from "./lib"; import { useContext, useEffect, useRef, useState } from "react"; @@ -67,6 +71,9 @@ import { InputBarPreviewImage } from "./images/InputBarPreviewImage"; import { Folder } from "./folders/interfaces"; const MAX_INPUT_HEIGHT = 200; +const TEMP_USER_MESSAGE_ID = -1; +const TEMP_ASSISTANT_MESSAGE_ID = -2; +const SYSTEM_MESSAGE_ID = -3; export function ChatPage({ user, @@ -160,7 +167,7 @@ export function ChatPage({ } else { setSelectedPersona(undefined); } - setMessageHistory([]); + setCompleteMessageMap(new Map()); setChatSessionSharedStatus(ChatSessionSharedStatus.Private); // if we're supposed to submit on initial load, then do that here @@ -186,10 +193,11 @@ export function ChatPage({ ) ); - const newMessageHistory = processRawChatHistory(chatSession.messages); + const newCompleteMessageMap = processRawChatHistory(chatSession.messages); + const newMessageHistory = buildLatestMessageChain(newCompleteMessageMap); // if the last message is an error, don't overwrite it if (messageHistory[messageHistory.length - 1]?.type !== "error") { - setMessageHistory(newMessageHistory); + setCompleteMessageMap(newCompleteMessageMap); const latestMessageId = newMessageHistory[newMessageHistory.length - 1]?.messageId; @@ -231,7 +239,77 @@ export function ChatPage({ const [message, setMessage] = useState( searchParams.get(SEARCH_PARAM_NAMES.USER_MESSAGE) || "" ); - const [messageHistory, setMessageHistory] = useState([]); + const [completeMessageMap, setCompleteMessageMap] = useState< + Map + >(new Map()); + const upsertToCompleteMessageMap = ({ + messages, + completeMessageMapOverride, + replacementsMap = null, + makeLatestChildMessage = false, + }: { + messages: Message[]; + // if calling this function repeatedly with short delay, stay may not update in time + // and result in weird behavipr + completeMessageMapOverride?: Map | null; + replacementsMap?: Map | null; + makeLatestChildMessage?: boolean; + }) => { + // deep copy + const frozenCompleteMessageMap = + completeMessageMapOverride || completeMessageMap; + const newCompleteMessageMap = structuredClone(frozenCompleteMessageMap); + if (newCompleteMessageMap.size === 0) { + const systemMessageId = messages[0].parentMessageId || SYSTEM_MESSAGE_ID; + const firstMessageId = messages[0].messageId; + const dummySystemMessage: Message = { + messageId: systemMessageId, + message: "", + type: "system", + files: [], + parentMessageId: null, + childrenMessageIds: [firstMessageId], + latestChildMessageId: firstMessageId, + }; + newCompleteMessageMap.set( + dummySystemMessage.messageId, + dummySystemMessage + ); + messages[0].parentMessageId = systemMessageId; + } + + messages.forEach((message) => { + const idToReplace = replacementsMap?.get(message.messageId); + if (idToReplace) { + removeMessage(idToReplace, newCompleteMessageMap); + } + + // update childrenMessageIds for the parent + if ( + !newCompleteMessageMap.has(message.messageId) && + message.parentMessageId !== null + ) { + updateParentChildren(message, newCompleteMessageMap, true); + } + newCompleteMessageMap.set(message.messageId, message); + }); + + // if specified, make these new message the latest of the current message chain + if (makeLatestChildMessage) { + const currentMessageChain = buildLatestMessageChain( + frozenCompleteMessageMap + ); + const latestMessage = currentMessageChain[currentMessageChain.length - 1]; + if (latestMessage) { + newCompleteMessageMap.get( + latestMessage.messageId + )!.latestChildMessageId = messages[0].messageId; + } + } + setCompleteMessageMap(newCompleteMessageMap); + return newCompleteMessageMap; + }; + const messageHistory = buildLatestMessageChain(completeMessageMap); const [currentTool, setCurrentTool] = useState(null); const [isStreaming, setIsStreaming] = useState(false); @@ -415,6 +493,11 @@ export function ChatPage({ const messageToResend = messageHistory.find( (message) => message.messageId === messageIdToResend ); + const messageToResendParent = + messageToResend?.parentMessageId !== null && + messageToResend?.parentMessageId !== undefined + ? completeMessageMap.get(messageToResend.parentMessageId) + : null; const messageToResendIndex = messageToResend ? messageHistory.indexOf(messageToResend) : null; @@ -435,19 +518,39 @@ export function ChatPage({ messageToResendIndex !== null ? messageHistory.slice(0, messageToResendIndex) : messageHistory; + let parentMessage = + messageToResendParent || + (currMessageHistory.length > 0 + ? currMessageHistory[currMessageHistory.length - 1] + : null); const currFiles = currentMessageFileIds.map((id) => ({ id, type: "image", })) as FileDescriptor[]; - setMessageHistory([ - ...currMessageHistory, + + // if we're resending, set the parent's child to null + // we will use tempMessages until the regenerated message is complete + const messageUpdates: Message[] = [ { - messageId: 0, + messageId: TEMP_USER_MESSAGE_ID, message: currMessage, type: "user", files: currFiles, + parentMessageId: parentMessage?.messageId || null, }, - ]); + ]; + if (parentMessage) { + messageUpdates.push({ + ...parentMessage, + childrenMessageIds: (parentMessage.childrenMessageIds || []).concat([ + TEMP_USER_MESSAGE_ID, + ]), + latestChildMessageId: TEMP_USER_MESSAGE_ID, + }); + } + const frozenCompleteMessageMap = upsertToCompleteMessageMap({ + messages: messageUpdates, + }); setMessage(""); setCurrentMessageFileIds([]); @@ -504,7 +607,7 @@ export function ChatPage({ if (documents && documents.length > 0) { // point to the latest message (we don't know the messageId yet, which is why // we have to use -1) - setSelectedMessageForDocDisplay(-1); + setSelectedMessageForDocDisplay(TEMP_USER_MESSAGE_ID); } } else if (Object.hasOwn(packet, "file_ids")) { aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map( @@ -523,16 +626,35 @@ export function ChatPage({ finalMessage = packet as BackendMessage; } } - setMessageHistory([ - ...currMessageHistory, + const updateFn = (messages: Message[]) => { + const replacementsMap = finalMessage + ? new Map([ + [messages[0].messageId, TEMP_USER_MESSAGE_ID], + [messages[1].messageId, TEMP_ASSISTANT_MESSAGE_ID], + ] as [number, number][]) + : null; + upsertToCompleteMessageMap({ + messages: messages, + replacementsMap: replacementsMap, + completeMessageMapOverride: frozenCompleteMessageMap, + }); + }; + const newUserMessageId = + finalMessage?.parent_message || TEMP_USER_MESSAGE_ID; + const newAssistantMessageId = + finalMessage?.message_id || TEMP_ASSISTANT_MESSAGE_ID; + updateFn([ { - messageId: finalMessage?.parent_message || null, + messageId: newUserMessageId, message: currMessage, type: "user", files: currFiles, + parentMessageId: parentMessage?.messageId || null, + childrenMessageIds: [newAssistantMessageId], + latestChildMessageId: newAssistantMessageId, }, { - messageId: finalMessage?.message_id || null, + messageId: newAssistantMessageId, message: error || answer, type: error ? "error" : "assistant", retrievalType, @@ -540,6 +662,7 @@ export function ChatPage({ documents: finalMessage?.context_docs?.top_documents || documents, citations: finalMessage?.citations || {}, files: finalMessage?.files || aiMessageImages || [], + parentMessageId: newUserMessageId, }, ]); if (isCancelledRef.current) { @@ -549,21 +672,25 @@ export function ChatPage({ } } catch (e: any) { const errorMsg = e.message; - setMessageHistory([ - ...currMessageHistory, - { - messageId: null, - message: currMessage, - type: "user", - files: currFiles, - }, - { - messageId: null, - message: errorMsg, - type: "error", - files: aiMessageImages || [], - }, - ]); + upsertToCompleteMessageMap({ + messages: [ + { + messageId: TEMP_USER_MESSAGE_ID, + message: currMessage, + type: "user", + files: currFiles, + parentMessageId: null, + }, + { + messageId: TEMP_ASSISTANT_MESSAGE_ID, + message: errorMsg, + type: "error", + files: aiMessageImages || [], + parentMessageId: null, + }, + ], + completeMessageMapOverride: frozenCompleteMessageMap, + }); } setIsStreaming(false); if (isNewSession) { @@ -787,11 +914,53 @@ export function ChatPage({ > {messageHistory.map((message, i) => { if (message.type === "user") { + const parentMessage = message.parentMessageId + ? completeMessageMap.get(message.parentMessageId) + : null; return (
{ + const parentMessageId = + message.parentMessageId!; + const parentMessage = + completeMessageMap.get(parentMessageId)!; + upsertToCompleteMessageMap({ + messages: [ + { + ...parentMessage, + latestChildMessageId: null, + }, + ], + }); + onSubmit({ + messageIdToResend: + message.messageId || undefined, + messageOverride: editedContent, + }); + }} + onMessageSelection={(messageId) => { + const newCompleteMessageMap = new Map( + completeMessageMap + ); + newCompleteMessageMap.get( + message.parentMessageId! + )!.latestChildMessageId = messageId; + setCompleteMessageMap( + newCompleteMessageMap + ); + setSelectedMessageForDocDisplay(messageId); + + // set message as latest so we can edit this message + // and so it sticks around on page reload + setMessageAsLatest(messageId); + }} />
); @@ -800,7 +969,8 @@ export function ChatPage({ (selectedMessageForDocDisplay !== null && selectedMessageForDocDisplay === message.messageId) || - (selectedMessageForDocDisplay === -1 && + (selectedMessageForDocDisplay === + TEMP_USER_MESSAGE_ID && i === messageHistory.length - 1); const previousMessage = i !== 0 ? messageHistory[i - 1] : null; @@ -921,7 +1091,7 @@ export function ChatPage({ })} {isStreaming && - messageHistory.length && + messageHistory.length > 0 && messageHistory[messageHistory.length - 1].type === "user" && (
diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index 5cc391bc577d..22f685fdd9cb 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -35,14 +35,18 @@ export interface ChatSession { } export interface Message { - messageId: number | null; + messageId: number; message: string; - type: "user" | "assistant" | "error"; + type: "user" | "assistant" | "system" | "error"; retrievalType?: RetrievalType; query?: string | null; documents?: DanswerDocument[] | null; citations?: CitationMap; files: FileDescriptor[]; + // for rebuilding the message tree + parentMessageId: number | null; + childrenMessageIds?: number[]; + latestChildMessageId?: number | null; } export interface BackendChatSession { diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index a63901c1ec76..9428bc45a7b8 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -154,6 +154,19 @@ export async function nameChatSession(chatSessionId: number, message: string) { return response; } +export async function setMessageAsLatest(messageId: number) { + const response = await fetch("/api/chat/set-message-as-latest", { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + message_id: messageId, + }), + }); + return response; +} + export async function handleChatFeedback( messageId: number, feedback: FeedbackType, @@ -332,64 +345,143 @@ export function getLastSuccessfulMessageId(messageHistory: Message[]) { return lastSuccessfulMessage ? lastSuccessfulMessage?.messageId : null; } -export function processRawChatHistory(rawMessages: BackendMessage[]) { - const messageMap: Map = new Map( - rawMessages.map((message) => [message.message_id, message]) +export function processRawChatHistory( + rawMessages: BackendMessage[] +): Map { + const messages: Map = new Map(); + const parentMessageChildrenMap: Map = new Map(); + + rawMessages.forEach((messageInfo) => { + const hasContextDocs = + (messageInfo?.context_docs?.top_documents || []).length > 0; + let retrievalType; + if (hasContextDocs) { + if (messageInfo.rephrased_query) { + retrievalType = RetrievalType.Search; + } else { + retrievalType = RetrievalType.SelectedDocs; + } + } else { + retrievalType = RetrievalType.None; + } + + const message: Message = { + messageId: messageInfo.message_id, + message: messageInfo.message, + type: messageInfo.message_type as "user" | "assistant", + files: messageInfo.files, + // only include these fields if this is an assistant message so that + // this is identical to what is computed at streaming time + ...(messageInfo.message_type === "assistant" + ? { + retrievalType: retrievalType, + query: messageInfo.rephrased_query, + documents: messageInfo?.context_docs?.top_documents || [], + citations: messageInfo?.citations || {}, + } + : {}), + parentMessageId: messageInfo.parent_message, + childrenMessageIds: [], + latestChildMessageId: messageInfo.latest_child_message, + }; + + messages.set(messageInfo.message_id, message); + + if (messageInfo.parent_message !== null) { + if (!parentMessageChildrenMap.has(messageInfo.parent_message)) { + parentMessageChildrenMap.set(messageInfo.parent_message, []); + } + parentMessageChildrenMap + .get(messageInfo.parent_message)! + .push(messageInfo.message_id); + } + }); + + // Populate childrenMessageIds for each message + parentMessageChildrenMap.forEach((childrenIds, parentId) => { + childrenIds.sort((a, b) => a - b); + const parentMesage = messages.get(parentId); + if (parentMesage) { + parentMesage.childrenMessageIds = childrenIds; + } + }); + + return messages; +} + +export function buildLatestMessageChain( + messageMap: Map, + additionalMessagesOnMainline: Message[] = [] +): Message[] { + const rootMessage = Array.from(messageMap.values()).find( + (message) => message.parentMessageId === null ); - const rootMessage = rawMessages.find( - (message) => message.parent_message === null - ); - - const finalMessageList: BackendMessage[] = []; + let finalMessageList: Message[] = []; if (rootMessage) { - let currMessage: BackendMessage | null = rootMessage; + let currMessage: Message | null = rootMessage; while (currMessage) { finalMessageList.push(currMessage); - const childMessageNumber = currMessage.latest_child_message; + const childMessageNumber = currMessage.latestChildMessageId; if (childMessageNumber && messageMap.has(childMessageNumber)) { - currMessage = messageMap.get(childMessageNumber) as BackendMessage; + currMessage = messageMap.get(childMessageNumber) as Message; } else { currMessage = null; } } } - const messages: Message[] = finalMessageList - .filter((messageInfo) => messageInfo.message_type !== "system") - .map((messageInfo) => { - const hasContextDocs = - (messageInfo?.context_docs?.top_documents || []).length > 0; - let retrievalType; - if (hasContextDocs) { - if (messageInfo.rephrased_query) { - retrievalType = RetrievalType.Search; - } else { - retrievalType = RetrievalType.SelectedDocs; - } - } else { - retrievalType = RetrievalType.None; - } + // remove system message + if (finalMessageList.length > 0 && finalMessageList[0].type === "system") { + finalMessageList = finalMessageList.slice(1); + } + return finalMessageList.concat(additionalMessagesOnMainline); +} - return { - messageId: messageInfo.message_id, - message: messageInfo.message, - type: messageInfo.message_type as "user" | "assistant", - files: messageInfo.files, - // only include these fields if this is an assistant message so that - // this is identical to what is computed at streaming time - ...(messageInfo.message_type === "assistant" - ? { - retrievalType: retrievalType, - query: messageInfo.rephrased_query, - documents: messageInfo?.context_docs?.top_documents || [], - citations: messageInfo?.citations || {}, - } - : {}), - }; - }); +export function updateParentChildren( + message: Message, + completeMessageMap: Map, + setAsLatestChild: boolean = false +) { + // NOTE: updates the `completeMessageMap` in place + const parentMessage = message.parentMessageId + ? completeMessageMap.get(message.parentMessageId) + : null; + if (parentMessage) { + if (setAsLatestChild) { + parentMessage.latestChildMessageId = message.messageId; + } - return messages; + const parentChildMessages = parentMessage.childrenMessageIds || []; + if (!parentChildMessages.includes(message.messageId)) { + parentChildMessages.push(message.messageId); + } + parentMessage.childrenMessageIds = parentChildMessages; + } +} + +export function removeMessage( + messageId: number, + completeMessageMap: Map +) { + const messageToRemove = completeMessageMap.get(messageId); + if (!messageToRemove) { + return; + } + + const parentMessage = messageToRemove.parentMessageId + ? completeMessageMap.get(messageToRemove.parentMessageId) + : null; + if (parentMessage) { + if (parentMessage.latestChildMessageId === messageId) { + parentMessage.latestChildMessageId = null; + } + const currChildMessage = parentMessage.childrenMessageIds || []; + const newChildMessage = currChildMessage.filter((id) => id !== messageId); + parentMessage.childrenMessageIds = newChildMessage; + } + + completeMessageMap.delete(messageId); } export function personaIncludesRetrieval(selectedPersona: Persona) { diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 6e8bca90a414..61e5f0ded712 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -1,15 +1,15 @@ import { - FiCheck, - FiCopy, FiCpu, FiImage, FiThumbsDown, FiThumbsUp, - FiTool, FiUser, + FiEdit2, + FiChevronRight, + FiChevronLeft, } from "react-icons/fi"; import { FeedbackType } from "../types"; -import { useState } from "react"; +import { useEffect, useRef, useState } from "react"; import ReactMarkdown from "react-markdown"; import { DanswerDocument } from "@/lib/search/interfaces"; import { SearchSummary, ShowHideDocsButton } from "./SearchSummary"; @@ -20,25 +20,11 @@ import remarkGfm from "remark-gfm"; import { CopyButton } from "@/components/CopyButton"; import { FileDescriptor } from "../interfaces"; import { InMessageImage } from "../images/InMessageImage"; -import { - IMAGE_GENERATION_TOOL_NAME, - SEARCH_TOOL_NAME, -} from "../tools/constants"; +import { IMAGE_GENERATION_TOOL_NAME } from "../tools/constants"; import { ToolRunningAnimation } from "../tools/ToolRunningAnimation"; +import { Hoverable } from "@/components/Hoverable"; -export const Hoverable: React.FC<{ - children: JSX.Element; - onClick?: () => void; -}> = ({ children, onClick }) => { - return ( -
- {children} -
- ); -}; +const ICON_SIZE = 15; export const AIMessage = ({ messageId, @@ -236,14 +222,16 @@ export const AIMessage = ({ )}
{handleFeedback && ( -
+
- handleFeedback("like")}> - - - - handleFeedback("dislike")} /> - + handleFeedback("like")} + /> + handleFeedback("dislike")} + />
)}
@@ -252,15 +240,88 @@ export const AIMessage = ({ ); }; +function MessageSwitcher({ + currentPage, + totalPages, + handlePrevious, + handleNext, +}: { + currentPage: number; + totalPages: number; + handlePrevious: () => void; + handleNext: () => void; +}) { + return ( +
+ + + {currentPage} / {totalPages} + + +
+ ); +} + export const HumanMessage = ({ content, files, + messageId, + otherMessagesCanSwitchTo, + onEdit, + onMessageSelection, }: { - content: string | JSX.Element; + content: string; files?: FileDescriptor[]; + messageId?: number | null; + otherMessagesCanSwitchTo?: number[]; + onEdit?: (editedContent: string) => void; + onMessageSelection?: (messageId: number) => void; }) => { + const textareaRef = useRef(null); + + const [isHovered, setIsHovered] = useState(false); + const [isEditing, setIsEditing] = useState(false); + const [editedContent, setEditedContent] = useState(content); + + useEffect(() => { + if (!isEditing) { + setEditedContent(content); + } + }, [content]); + + useEffect(() => { + if (textareaRef.current) { + // Focus the textarea + textareaRef.current.focus(); + // Move the cursor to the end of the text + textareaRef.current.selectionStart = textareaRef.current.value.length; + textareaRef.current.selectionEnd = textareaRef.current.value.length; + } + }, [isEditing]); + + const handleEditSubmit = () => { + if (editedContent.trim() !== content.trim()) { + onEdit?.(editedContent); + } + setIsEditing(false); + }; + + const currentMessageInd = messageId + ? otherMessagesCanSwitchTo?.indexOf(messageId) + : undefined; + return ( -
+
setIsHovered(true)} + onMouseLeave={() => setIsHovered(false)} + >
@@ -284,7 +345,102 @@ export const HumanMessage = ({
)} - {typeof content === "string" ? ( + {isEditing ? ( +
+
+