From 197b62aed1c3579d91eda33029a3b8a826a360ff Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 22 Aug 2024 12:06:44 -0700 Subject: [PATCH] Regenerate (branch of stop) (#2157) * add regenerate * functional once again post rebase but quite ugly * validated + cleaner UI * more robust implementation for first messages * squash * remove parameter * proper margin * clarify for future programmers * remove some logs * self nit pick - smoother ux * more self-nits * stroke line cap * rebase --- ...5_added_alternate_model_to_chat_message.py | 28 +++ backend/danswer/chat/chat_utils.py | 9 +- backend/danswer/chat/process_message.py | 16 +- backend/danswer/db/chat.py | 4 + backend/danswer/db/models.py | 1 + backend/danswer/llm/utils.py | 8 +- .../danswer/server/query_and_chat/models.py | 5 + web/src/app/chat/ChatPage.tsx | 159 +++++++++++++-- web/src/app/chat/RegenerateOption.tsx | 184 ++++++++++++++++++ web/src/app/chat/interfaces.ts | 2 + web/src/app/chat/lib.tsx | 7 +- web/src/app/chat/message/Messages.tsx | 114 +++++++++-- web/src/components/Dropdown.tsx | 18 +- web/src/components/Hoverable.tsx | 16 +- web/src/components/icons/icons.tsx | 95 ++++++++- 15 files changed, 605 insertions(+), 61 deletions(-) create mode 100644 backend/alembic/versions/ee3f4b47fad5_added_alternate_model_to_chat_message.py create mode 100644 web/src/app/chat/RegenerateOption.tsx diff --git a/backend/alembic/versions/ee3f4b47fad5_added_alternate_model_to_chat_message.py b/backend/alembic/versions/ee3f4b47fad5_added_alternate_model_to_chat_message.py new file mode 100644 index 000000000000..c4f94310da2e --- /dev/null +++ b/backend/alembic/versions/ee3f4b47fad5_added_alternate_model_to_chat_message.py @@ -0,0 +1,28 @@ +"""Added alternate model to chat message + +Revision ID: ee3f4b47fad5 +Revises: 2d2304e27d8c +Create Date: 2024-08-12 00:11:50.915845 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "ee3f4b47fad5" +down_revision = "2d2304e27d8c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_message", + sa.Column("overridden_model", sa.String(length=255), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("chat_message", "overridden_model") diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index ed9c3c6cb78f..b1e4132779bf 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -36,6 +36,8 @@ def create_chat_chain( chat_session_id: int, db_session: Session, prefetch_tool_calls: bool = True, + # Optional id at which we finish processing + stop_at_message_id: int | None = None, ) -> tuple[ChatMessage, list[ChatMessage]]: """Build the linear chain of messages without including the root message""" mainline_messages: list[ChatMessage] = [] @@ -61,7 +63,12 @@ def create_chat_chain( current_message: ChatMessage | None = root_message while current_message is not None: child_msg = current_message.latest_child_message - if not child_msg: + + # Break if at the end of the chain + # or have reached the `final_id` of the submitted message + if not child_msg or ( + stop_at_message_id and current_message.id == stop_at_message_id + ): break current_message = id_to_msg.get(child_msg) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index c8153bd156d5..d7571ad468b7 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -351,7 +351,15 @@ def stream_chat_message_objects( parent_message = root_message user_message = None - if not use_existing_user_message: + + if new_msg_req.regenerate: + final_msg, history_msgs = create_chat_chain( + stop_at_message_id=parent_id, + chat_session_id=chat_session_id, + db_session=db_session, + ) + + elif not use_existing_user_message: # Create new message at the right place in the tree and update the parent's child pointer # Don't commit yet until we verify the chat message chain user_message = create_new_chat_message( @@ -470,12 +478,18 @@ def stream_chat_message_objects( user_message_id=user_message.id if user_message else None, reserved_assistant_message_id=reserved_message_id, ) + + overridden_model = ( + new_msg_req.llm_override.model_version if new_msg_req.llm_override else None + ) + # Cannot determine these without the LLM step or breaking out early partial_response = partial( create_new_chat_message, chat_session_id=chat_session_id, parent_message=final_msg, prompt_id=prompt_id, + overridden_model=overridden_model, # message=, # rephrased_query=, # token_count=, diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 06ece1e922f3..3cb991dd43b0 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -443,6 +443,7 @@ def create_new_chat_message( tool_calls: list[ToolCall] | None = None, commit: bool = True, reserved_message_id: int | None = None, + overridden_model: str | None = None, ) -> ChatMessage: if reserved_message_id is not None: # Edit existing message @@ -462,6 +463,7 @@ def create_new_chat_message( existing_message.tool_calls = tool_calls if tool_calls else [] existing_message.error = error existing_message.alternate_assistant_id = alternate_assistant_id + existing_message.overridden_model = overridden_model new_chat_message = existing_message else: @@ -480,6 +482,7 @@ def create_new_chat_message( tool_calls=tool_calls if tool_calls else [], error=error, alternate_assistant_id=alternate_assistant_id, + overridden_model=overridden_model, ) db_session.add(new_chat_message) @@ -719,6 +722,7 @@ def translate_db_message_to_chat_message_detail( for tool_call in chat_message.tool_calls ], alternate_assistant_id=chat_message.alternate_assistant_id, + overridden_model=chat_message.overridden_model, ) return chat_msg_detail diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 025974bb015e..b3d85418feac 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -897,6 +897,7 @@ class ChatMessage(Base): Integer, ForeignKey("persona.id"), nullable=True ) + overridden_model: Mapped[str | None] = mapped_column(String, nullable=True) parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True) latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True) message: Mapped[str] = mapped_column(Text) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 4172a1f5e611..db777c8e27bd 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -339,11 +339,9 @@ def get_llm_max_tokens( return GEN_AI_MAX_TOKENS try: - model_obj = ( - model_map.get(f"{model_provider}/{model_name}") - or model_map.get(model_name) - or model_map[model_name.split("/")[1]] - ) + model_obj = model_map.get(f"{model_provider}/{model_name}") + if not model_obj: + model_obj = model_map[model_name] if "max_input_tokens" in model_obj: return model_obj["max_input_tokens"] diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 65128e65eab3..8a2c89af54fd 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -98,7 +98,11 @@ class CreateChatMessageRequest(ChunkContext): # will disable Query Rewording if specified query_override: str | None = None + # enables additional handling to ensure that we regenerate with a given user message ID + regenerate: bool | None = None + # allows the caller to override the Persona / Prompt + # these do not persist in the chat thread details llm_override: LLMOverride | None = None prompt_override: PromptOverride | None = None @@ -179,6 +183,7 @@ class ChatMessageDetail(BaseModel): message_type: MessageType time_sent: datetime alternate_assistant_id: str | None + overridden_model: str | None # Dict mapping citation number to db_doc_id chat_session_id: int | None = None citations: dict[int, int] | None diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index dbc361512a5d..2b90baabbf56 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -390,10 +390,12 @@ export function ChatPage({ const [message, setMessage] = useState( searchParams.get(SEARCH_PARAM_NAMES.USER_MESSAGE) || "" ); + const [completeMessageDetail, setCompleteMessageDetail] = useState<{ sessionId: number | null; messageMap: Map; }>({ sessionId: null, messageMap: new Map() }); + const upsertToCompleteMessageMap = ({ messages, completeMessageMapOverride, @@ -413,6 +415,7 @@ export function ChatPage({ const frozenCompleteMessageMap = completeMessageMapOverride || completeMessageDetail.messageMap; const newCompleteMessageMap = structuredClone(frozenCompleteMessageMap); + if (newCompleteMessageMap.size === 0) { const systemMessageId = messages[0].parentMessageId || SYSTEM_MESSAGE_ID; const firstMessageId = messages[0].messageId; @@ -471,8 +474,17 @@ export function ChatPage({ const messageHistory = buildLatestMessageChain( completeMessageDetail.messageMap ); + const [submittedMessage, setSubmittedMessage] = useState(""); const [chatState, setChatState] = useState("input"); + interface RegenerationState { + regenerating: boolean; + finalMessageIndex: number; + } + + const [regenerationState, setRegenerationState] = + useState(null); + const [abortController, setAbortController] = useState(null); @@ -719,6 +731,8 @@ export function ChatPage({ forceSearch, isSeededChat, alternativeAssistantOverride = null, + modelOverRide, + regenerationRequest, }: { messageIdToResend?: number; messageOverride?: string; @@ -726,6 +740,8 @@ export function ChatPage({ forceSearch?: boolean; isSeededChat?: boolean; alternativeAssistantOverride?: Persona | null; + modelOverRide?: LlmOverride; + regenerationRequest?: RegenerationRequest | null; } = {}) => { if (chatState != "input") { setPopup({ @@ -735,8 +751,14 @@ export function ChatPage({ return; } + setRegenerationState( + regenerationRequest + ? { regenerating: true, finalMessageIndex: messageIdToResend || 0 } + : null + ); setChatState("loading"); + const controller = new AbortController(); setAbortController(controller); @@ -770,12 +792,14 @@ export function ChatPage({ const messageToResendIndex = messageToResend ? messageHistory.indexOf(messageToResend) : null; + if (!messageToResend && messageIdToResend !== undefined) { setPopup({ message: "Failed to re-send message - please refresh the page and try again.", type: "error", }); + setRegenerationState(null); setChatState("input"); return; } @@ -789,6 +813,7 @@ export function ChatPage({ messageToResendIndex !== null ? messageHistory.slice(0, messageToResendIndex) : messageHistory; + let parentMessage = messageToResendParent || (currMessageHistory.length > 0 @@ -827,8 +852,11 @@ export function ChatPage({ } = null; try { + const mapKeys = Array.from(completeMessageDetail.messageMap.keys()); + const systemMessage = Math.min(...mapKeys); + const lastSuccessfulMessageId = - getLastSuccessfulMessageId(currMessageHistory); + getLastSuccessfulMessageId(currMessageHistory) || systemMessage; const stack = new CurrentMessageFIFO(); updateCurrentMessageFIFO(stack, { @@ -836,7 +864,9 @@ export function ChatPage({ message: currMessage, alternateAssistantId: currentAssistantId, fileDescriptors: currentMessageFiles, - parentMessageId: lastSuccessfulMessageId, + parentMessageId: + regenerationRequest?.parentMessage.messageId || + lastSuccessfulMessageId, chatSessionId: currChatSessionId, promptId: liveAssistant?.prompts[0]?.id || 0, filters: buildFilters( @@ -853,12 +883,14 @@ export function ChatPage({ .map((document) => document.db_doc_id as number), queryOverride, forceSearch, - + regenerate: regenerationRequest !== undefined, modelProvider: + modelOverRide?.name || llmOverrideManager.llmOverride.name || llmOverrideManager.globalDefault.name || undefined, modelVersion: + modelOverRide?.modelName || llmOverrideManager.llmOverride.modelName || searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || llmOverrideManager.globalDefault.modelName || @@ -900,15 +932,18 @@ export function ChatPage({ // we will use tempMessages until the regenerated message is complete messageUpdates = [ { - messageId: user_message_id, + messageId: regenerationRequest + ? regenerationRequest?.parentMessage?.messageId! + : user_message_id, message: currMessage, type: "user", files: currentMessageFiles, toolCalls: [], - parentMessageId: parentMessage?.messageId || null, + parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID, }, ]; - if (parentMessage) { + + if (parentMessage && !regenerationRequest) { messageUpdates.push({ ...parentMessage, childrenMessageIds: ( @@ -934,6 +969,8 @@ export function ChatPage({ assistant_message_id, user_message_id, }; + + setRegenerationState(null); } else { const { user_message_id, frozenMessageMap, frozenSessionId } = initialFetchDetails; @@ -993,8 +1030,20 @@ export function ChatPage({ parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!; const updateFn = (messages: Message[]) => { - const replacementsMap = null; - upsertToCompleteMessageMap({ + const replacementsMap = regenerationRequest + ? new Map([ + [ + regenerationRequest?.parentMessage?.messageId, + regenerationRequest?.parentMessage?.messageId, + ], + [ + regenerationRequest?.messageId, + initialFetchDetails?.assistant_message_id, + ], + ] as [number, number][]) + : null; + + return upsertToCompleteMessageMap({ messages: messages, replacementsMap: replacementsMap, completeMessageMapOverride: frozenMessageMap, @@ -1004,13 +1053,19 @@ export function ChatPage({ updateFn([ { - messageId: initialFetchDetails.user_message_id!, + messageId: regenerationRequest + ? regenerationRequest?.parentMessage?.messageId! + : initialFetchDetails.user_message_id!, message: currMessage, type: "user", files: currentMessageFiles, toolCalls: [], parentMessageId: error ? null : lastSuccessfulMessageId, - childrenMessageIds: [initialFetchDetails.assistant_message_id!], + childrenMessageIds: [ + ...(regenerationRequest?.parentMessage?.childrenMessageIds || + []), + initialFetchDetails.assistant_message_id!, + ], latestChildMessageId: initialFetchDetails.assistant_message_id, }, { @@ -1024,9 +1079,12 @@ export function ChatPage({ citations: finalMessage?.citations || {}, files: finalMessage?.files || aiMessageImages || [], toolCalls: finalMessage?.tool_calls || toolCalls, - parentMessageId: initialFetchDetails.user_message_id, + parentMessageId: regenerationRequest + ? regenerationRequest?.parentMessage?.messageId! + : initialFetchDetails.user_message_id, alternateAssistantID: alternativeAssistant?.id, stackTrace: stackTrace, + overridden_model: finalMessage?.overridden_model, }, ]); } @@ -1060,6 +1118,7 @@ export function ChatPage({ completeMessageMapOverride: completeMessageDetail.messageMap, }); } + setRegenerationState(null); setChatState("input"); if (isNewSession) { if (finalMessage) { @@ -1309,6 +1368,22 @@ export function ChatPage({ }; const secondsUntilExpiration = getSecondsUntilExpiration(user); + interface RegenerationRequest { + messageId: number; + parentMessage: Message; + } + + function createRegenerator(regenerationRequest: RegenerationRequest) { + // Returns new function that only needs `modelOverRide` to be specified when called + return async function (modelOverRide: LlmOverride) { + return await onSubmit({ + modelOverRide, + messageIdToResend: regenerationRequest.parentMessage.messageId, + regenerationRequest, + }); + }; + } + return ( <> @@ -1494,7 +1569,7 @@ export function ChatPage({ )}
+ regenerationState.finalMessageIndex + ) { + return <>; + } + if (message.type === "user") { - const parentMessage = message.parentMessageId - ? messageMap.get(message.parentMessageId) - : null; return (
{ const parentMessageId = message.parentMessageId!; @@ -1536,6 +1617,9 @@ export function ChatPage({ messageOverride: editedContent, }); }} + otherMessagesCanSwitchTo={ + parentMessage?.childrenMessageIds || [] + } onMessageSelection={(messageId) => { const newCompleteMessageMap = new Map( messageMap @@ -1576,6 +1660,15 @@ export function ChatPage({ ) : null; + if ( + regenerationState && + regenerationState.regenerating && + // chatState == "loading" && + message.messageId > + regenerationState.finalMessageIndex - 1 + ) { + return <>; + } return (
{ + const newCompleteMessageMap = new Map( + messageMap + ); + newCompleteMessageMap.get( + message.parentMessageId! + )!.latestChildMessageId = messageId; + setCompleteMessageDetail({ + sessionId: + completeMessageDetail.sessionId, + messageMap: newCompleteMessageMap, + }); + setSelectedMessageForDocDisplay( + messageId + ); + // set message as latest so we can edit this message + // and so it sticks around on page reload + setMessageAsLatest(messageId); + }} isActive={messageHistory.length - 1 == i} selectedDocuments={selectedDocuments} toggleDocumentSelection={ @@ -1598,6 +1718,7 @@ export function ChatPage({ } messageId={message.messageId} content={message.message} + // content={message.message} files={message.files} query={ messageHistory[i]?.query || undefined @@ -1739,6 +1860,7 @@ export function ChatPage({ } })} {chatState == "loading" && + !regenerationState?.regenerating && messageHistory[messageHistory.length - 1]?.type != "user" && ( )} + {chatState == "loading" && (
void; + includeDefault?: boolean; + side?: "top" | "right" | "bottom" | "left"; + maxHeight?: string; +}) { + const [isOpen, setIsOpen] = useState(false); + + const Dropdown = ( +
+

+ Pick a model +

+ {options.map((option, ind) => { + const isSelected = option.value === selected; + return ( + onSelect(option.value)} + isSelected={isSelected} + /> + ); + })} +
+ ); + + return ( + setIsOpen(open)} + content={ +
setIsOpen(!isOpen)}> + {!alternate ? ( + + ) : ( + + )} +
+ } + popover={Dropdown} + align="start" + side={side} + sideOffset={5} + triggerMaxWidth + /> + ); +} + +export default function RegenerateOption({ + selectedAssistant, + regenerate, + overriddenModel, + onHoverChange, +}: { + selectedAssistant: Persona; + regenerate: (modelOverRide: LlmOverride) => Promise; + overriddenModel?: string; + onHoverChange: (isHovered: boolean) => void; +}) { + const llmOverrideManager = useLlmOverride(); + + const { llmProviders } = useChatContext(); + const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null); + + const llmOptionsByProvider: { + [provider: string]: { name: string; value: string }[]; + } = {}; + const uniqueModelNames = new Set(); + + llmProviders.forEach((llmProvider) => { + if (!llmOptionsByProvider[llmProvider.provider]) { + llmOptionsByProvider[llmProvider.provider] = []; + } + + (llmProvider.display_model_names || llmProvider.model_names).forEach( + (modelName) => { + if (!uniqueModelNames.has(modelName)) { + uniqueModelNames.add(modelName); + llmOptionsByProvider[llmProvider.provider].push({ + name: modelName, + value: structureValue( + llmProvider.name, + llmProvider.provider, + modelName + ), + }); + } + } + ); + }); + + const llmOptions = Object.entries(llmOptionsByProvider).flatMap( + ([provider, options]) => [...options] + ); + + const currentModelName = + llmOverrideManager?.llmOverride.modelName || + (selectedAssistant + ? selectedAssistant.llm_model_version_override || llmName + : llmName); + + return ( +
onHoverChange(true)} + onMouseLeave={() => onHoverChange(false)} + > + { + const { name, provider, modelName } = destructureValue( + value as string + ); + regenerate({ + name: name, + provider: provider, + modelName: modelName, + }); + }} + /> +
+ ); +} diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index 0778ae354e94..b4ba2e97475b 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -88,6 +88,7 @@ export interface Message { latestChildMessageId?: number | null; alternateAssistantID?: number | null; stackTrace?: string | null; + overridden_model?: string; } export interface BackendChatSession { @@ -116,6 +117,7 @@ export interface BackendMessage { files: FileDescriptor[]; tool_calls: ToolCallFinalResult[]; alternate_assistant_id?: number | null; + overridden_model?: string; } export interface MessageResponseIDInfo { diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 43bba4f8952d..b17b94b7ec76 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -114,6 +114,7 @@ export type PacketType = | MessageResponseIDInfo; export async function* sendMessage({ + regenerate, message, fileDescriptors, parentMessageId, @@ -131,6 +132,7 @@ export async function* sendMessage({ alternateAssistantId, signal, }: { + regenerate: boolean; message: string; fileDescriptors: FileDescriptor[]; parentMessageId: number | null; @@ -159,6 +161,7 @@ export async function* sendMessage({ prompt_id: promptId, search_doc_ids: documentsAreSelected ? selectedDocumentIds : null, file_descriptors: fileDescriptors, + regenerate, retrieval_options: !documentsAreSelected ? { run_search: @@ -386,13 +389,12 @@ export function getLastSuccessfulMessageId(messageHistory: Message[]) { .reverse() .find( (message) => - message.type === "assistant" && + (message.type === "assistant" || message.type === "system") && message.messageId !== -1 && message.messageId !== null ); return lastSuccessfulMessage ? lastSuccessfulMessage?.messageId : null; } - export function processRawChatHistory( rawMessages: BackendMessage[] ): Map { @@ -435,6 +437,7 @@ export function processRawChatHistory( parentMessageId: messageInfo.parent_message, childrenMessageIds: [], latestChildMessageId: messageInfo.latest_child_message, + overridden_model: messageInfo.overridden_model, }; messages.set(messageInfo.message_id, message); diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 36eb149e328b..f467b7c69c92 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -45,9 +45,12 @@ import { Persona } from "@/app/admin/assistants/interfaces"; import { AssistantIcon } from "@/components/assistants/AssistantIcon"; import { Citation } from "@/components/search/results/Citation"; import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay"; + import { - DislikeFeedbackIcon, - LikeFeedbackIcon, + ThumbsUpIcon, + ThumbsDownIcon, + LikeFeedback, + DislikeFeedback, } from "@/components/icons/icons"; import { CustomTooltip, @@ -59,6 +62,8 @@ import { useMouseTracking } from "./hooks"; import { InternetSearchIcon } from "@/components/InternetSearchIcon"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import GeneratingImageDisplay from "../tools/GeneratingImageDisplay"; +import RegenerateOption from "../RegenerateOption"; +import { LlmOverride } from "@/lib/hooks"; import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; const TOOLS_WITH_CUSTOM_HANDLING = [ @@ -110,6 +115,8 @@ function FileDisplay({ } export const AIMessage = ({ + regenerate, + overriddenModel, shared, isActive, toggleDocumentSelection, @@ -132,9 +139,13 @@ export const AIMessage = ({ handleForceSearch, retrievalDisabled, currentPersona, + otherMessagesCanSwitchTo, + onMessageSelection, }: { shared?: boolean; isActive?: boolean; + otherMessagesCanSwitchTo?: number[]; + onMessageSelection?: (messageId: number) => void; selectedDocuments?: DanswerDocument[] | null; toggleDocumentSelection?: () => void; docs?: DanswerDocument[] | null; @@ -155,6 +166,8 @@ export const AIMessage = ({ handleSearchQueryEdit?: (query: string) => void; handleForceSearch?: () => void; retrievalDisabled?: boolean; + overriddenModel?: string; + regenerate?: (modelOverRide: LlmOverride) => Promise; }) => { const toolCallGenerating = toolCall && !toolCall.tool_result; const processContent = (content: string | JSX.Element) => { @@ -183,6 +196,7 @@ export const AIMessage = ({ }; const finalContent = processContent(content as string); + const [isRegenerateHovered, setIsRegenerateHovered] = useState(false); const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking(); const settings = useContext(SettingsContext); @@ -240,10 +254,19 @@ export const AIMessage = ({ }); } + const currentMessageInd = messageId + ? otherMessagesCanSwitchTo?.indexOf(messageId) + : undefined; const uniqueSources: ValidSources[] = Array.from( new Set((docs || []).map((doc) => doc.source_type)) ).slice(0, 3); + const includeMessageSwitcher = + currentMessageInd !== undefined && + onMessageSelection && + otherMessagesCanSwitchTo && + otherMessagesCanSwitchTo.length > 1; + return (
+
+ {includeMessageSwitcher && ( +
+ { + onMessageSelection( + otherMessagesCanSwitchTo[ + currentMessageInd - 1 + ] + ); + }} + handleNext={() => { + onMessageSelection( + otherMessagesCanSwitchTo[ + currentMessageInd + 1 + ] + ); + }} + /> +
+ )} +
} + icon={} onClick={() => handleFeedback("like")} /> } + icon={} onClick={() => handleFeedback("dislike")} /> + {regenerate && ( + + )}
) : (
+
+ {includeMessageSwitcher && ( +
+ { + onMessageSelection( + otherMessagesCanSwitchTo[ + currentMessageInd - 1 + ] + ); + }} + handleNext={() => { + onMessageSelection( + otherMessagesCanSwitchTo[ + currentMessageInd + 1 + ] + ); + }} + /> +
+ )} +
} + icon={} onClick={() => handleFeedback("like")} /> } + icon={} onClick={() => handleFeedback("dislike")} /> + {regenerate && ( + + )}
))} diff --git a/web/src/components/Dropdown.tsx b/web/src/components/Dropdown.tsx index 2d4d7325b13f..79fe03083110 100644 --- a/web/src/components/Dropdown.tsx +++ b/web/src/components/Dropdown.tsx @@ -320,15 +320,15 @@ export const DefaultDropdown = forwardRef( const Content = (

{selectedOption?.name || diff --git a/web/src/components/Hoverable.tsx b/web/src/components/Hoverable.tsx index 04b3d22f7047..75162b69324c 100644 --- a/web/src/components/Hoverable.tsx +++ b/web/src/components/Hoverable.tsx @@ -1,4 +1,3 @@ -import { IconProps } from "@tremor/react"; import { IconType } from "react-icons"; const ICON_SIZE = 15; @@ -7,13 +6,22 @@ export const Hoverable: React.FC<{ icon: IconType; onClick?: () => void; size?: number; -}> = ({ icon, onClick, size = ICON_SIZE }) => { + active?: boolean; + hoverText?: string; +}> = ({ icon: Icon, active, hoverText, onClick, size = ICON_SIZE }) => { return (

- {icon({ size: size, className: "my-auto" })} +
+ + {hoverText && ( +
+ {hoverText} +
+ )} +
); }; diff --git a/web/src/components/icons/icons.tsx b/web/src/components/icons/icons.tsx index 2863aadcfcfe..91e67c025773 100644 --- a/web/src/components/icons/icons.tsx +++ b/web/src/components/icons/icons.tsx @@ -755,6 +755,85 @@ export const ChevronIcon = ({ ); }; +export const StarFeedback = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( + + + + ); +}; + +export const DislikeFeedback = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( + + + + + + + ); +}; + +export const LikeFeedback = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( + + + + + + + ); +}; + export const CheckmarkIcon = ({ size = 16, className = defaultTailwindCSS, @@ -2523,8 +2602,8 @@ export const SwapIcon = ({ @@ -2550,8 +2629,8 @@ export const ClosedBookIcon = ({ @@ -2574,8 +2653,8 @@ export const PinIcon = ({ @@ -2599,8 +2678,8 @@ export const TwoRightArrowIcons = ({